refactor: improve token handling in getAvailableTokensForContent and truncateContentByTokens functions

This commit is contained in:
Dominik Schröter 2025-02-02 15:54:16 +01:00
parent e76eaff4cd
commit 62b58d6e3f

View file

@ -6,17 +6,16 @@ import (
"text/template" "text/template"
"github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/textsplitter"
) )
// getAvailableTokensForContent calculates how many tokens are available for content // getAvailableTokensForContent calculates how many tokens are available for content
// by rendering the template with empty content and counting tokens // by rendering the template with empty content and counting tokens
func getAvailableTokensForContent(template *template.Template, data map[string]interface{}) (int, error) { func getAvailableTokensForContent(tmpl *template.Template, data map[string]interface{}) (int, error) {
if tokenLimit <= 0 { if tokenLimit <= 0 {
return 0, nil // No limit when disabled return 0, nil // No limit when disabled
} }
// Create a copy of data and set Content to empty // Create a copy of data and set "Content" to empty
templateData := make(map[string]interface{}) templateData := make(map[string]interface{})
for k, v := range data { for k, v := range data {
templateData[k] = v templateData[k] = v
@ -25,7 +24,7 @@ func getAvailableTokensForContent(template *template.Template, data map[string]i
// Execute template with empty content // Execute template with empty content
var promptBuffer bytes.Buffer var promptBuffer bytes.Buffer
if err := template.Execute(&promptBuffer, templateData); err != nil { if err := tmpl.Execute(&promptBuffer, templateData); err != nil {
return 0, fmt.Errorf("error executing template: %v", err) return 0, fmt.Errorf("error executing template: %v", err)
} }
@ -51,39 +50,50 @@ func getTokenCount(content string) (int, error) {
return llms.CountTokens(llmModel, content), nil return llms.CountTokens(llmModel, content), nil
} }
// truncateContentByTokens truncates the content to fit within the specified token limit // truncateContentByTokens truncates the content so that its token count does not exceed availableTokens.
// This implementation uses a binary search on runes to find the longest prefix whose token count is within the limit.
// If availableTokens is 0 or negative, the original content is returned.
func truncateContentByTokens(content string, availableTokens int) (string, error) { func truncateContentByTokens(content string, availableTokens int) (string, error) {
if availableTokens <= 0 || tokenLimit <= 0 { if availableTokens <= 0 || tokenLimit <= 0 {
return content, nil return content, nil
} }
tokenCount, err := getTokenCount(content) totalTokens, err := getTokenCount(content)
if err != nil { if err != nil {
return "", fmt.Errorf("error counting tokens: %v", err) return "", fmt.Errorf("error counting tokens: %v", err)
} }
if tokenCount <= availableTokens { if totalTokens <= availableTokens {
return content, nil return content, nil
} }
splitter := textsplitter.NewTokenSplitter( // Convert content to runes for safe slicing.
textsplitter.WithChunkSize(availableTokens), runes := []rune(content)
textsplitter.WithChunkOverlap(0), low := 0
textsplitter.WithModelName(llmModel), high := len(runes)
) validCut := 0
chunks, err := splitter.SplitText(content)
if err != nil { for low <= high {
return "", fmt.Errorf("error splitting content: %v", err) mid := (low + high) / 2
substr := string(runes[:mid])
count, err := getTokenCount(substr)
if err != nil {
return "", fmt.Errorf("error counting tokens in substring: %v", err)
}
if count <= availableTokens {
validCut = mid
low = mid + 1
} else {
high = mid - 1
}
} }
// Validate first chunk's token count truncated := string(runes[:validCut])
firstChunk := chunks[0] // Final verification
chunkTokens, err := getTokenCount(firstChunk) finalTokens, err := getTokenCount(truncated)
if err != nil { if err != nil {
return "", fmt.Errorf("error counting tokens in chunk: %v", err) return "", fmt.Errorf("error counting tokens in final truncated content: %v", err)
} }
if chunkTokens > availableTokens { if finalTokens > availableTokens {
return "", fmt.Errorf("first chunk uses %d tokens which exceeds the limit of %d tokens", chunkTokens, availableTokens) return "", fmt.Errorf("truncated content still exceeds the available token limit")
} }
return truncated, nil
// return the first chunk
return firstChunk, nil
} }