diff --git a/tokens.go b/tokens.go index 206cd60..7f1d0a4 100644 --- a/tokens.go +++ b/tokens.go @@ -6,17 +6,16 @@ import ( "text/template" "github.com/tmc/langchaingo/llms" - "github.com/tmc/langchaingo/textsplitter" ) // getAvailableTokensForContent calculates how many tokens are available for content // by rendering the template with empty content and counting tokens -func getAvailableTokensForContent(template *template.Template, data map[string]interface{}) (int, error) { +func getAvailableTokensForContent(tmpl *template.Template, data map[string]interface{}) (int, error) { if tokenLimit <= 0 { return 0, nil // No limit when disabled } - // Create a copy of data and set Content to empty + // Create a copy of data and set "Content" to empty templateData := make(map[string]interface{}) for k, v := range data { templateData[k] = v @@ -25,7 +24,7 @@ func getAvailableTokensForContent(template *template.Template, data map[string]i // Execute template with empty content var promptBuffer bytes.Buffer - if err := template.Execute(&promptBuffer, templateData); err != nil { + if err := tmpl.Execute(&promptBuffer, templateData); err != nil { return 0, fmt.Errorf("error executing template: %v", err) } @@ -51,39 +50,50 @@ func getTokenCount(content string) (int, error) { return llms.CountTokens(llmModel, content), nil } -// truncateContentByTokens truncates the content to fit within the specified token limit +// truncateContentByTokens truncates the content so that its token count does not exceed availableTokens. +// This implementation uses a binary search on runes to find the longest prefix whose token count is within the limit. +// If availableTokens is 0 or negative, the original content is returned. func truncateContentByTokens(content string, availableTokens int) (string, error) { if availableTokens <= 0 || tokenLimit <= 0 { return content, nil } - tokenCount, err := getTokenCount(content) + totalTokens, err := getTokenCount(content) if err != nil { return "", fmt.Errorf("error counting tokens: %v", err) } - if tokenCount <= availableTokens { + if totalTokens <= availableTokens { return content, nil } - splitter := textsplitter.NewTokenSplitter( - textsplitter.WithChunkSize(availableTokens), - textsplitter.WithChunkOverlap(0), - textsplitter.WithModelName(llmModel), - ) - chunks, err := splitter.SplitText(content) - if err != nil { - return "", fmt.Errorf("error splitting content: %v", err) + // Convert content to runes for safe slicing. + runes := []rune(content) + low := 0 + high := len(runes) + validCut := 0 + + for low <= high { + mid := (low + high) / 2 + substr := string(runes[:mid]) + count, err := getTokenCount(substr) + if err != nil { + return "", fmt.Errorf("error counting tokens in substring: %v", err) + } + if count <= availableTokens { + validCut = mid + low = mid + 1 + } else { + high = mid - 1 + } } - // Validate first chunk's token count - firstChunk := chunks[0] - chunkTokens, err := getTokenCount(firstChunk) + truncated := string(runes[:validCut]) + // Final verification + finalTokens, err := getTokenCount(truncated) if err != nil { - return "", fmt.Errorf("error counting tokens in chunk: %v", err) + return "", fmt.Errorf("error counting tokens in final truncated content: %v", err) } - if chunkTokens > availableTokens { - return "", fmt.Errorf("first chunk uses %d tokens which exceeds the limit of %d tokens", chunkTokens, availableTokens) + if finalTokens > availableTokens { + return "", fmt.Errorf("truncated content still exceeds the available token limit") } - - // return the first chunk - return firstChunk, nil + return truncated, nil }