mirror of
https://github.com/icereed/paperless-gpt.git
synced 2025-03-13 13:18:02 -05:00
refactor: improve token handling in getAvailableTokensForContent and truncateContentByTokens functions
This commit is contained in:
parent
e76eaff4cd
commit
62b58d6e3f
1 changed files with 34 additions and 24 deletions
56
tokens.go
56
tokens.go
|
@ -6,17 +6,16 @@ import (
|
|||
"text/template"
|
||||
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
"github.com/tmc/langchaingo/textsplitter"
|
||||
)
|
||||
|
||||
// getAvailableTokensForContent calculates how many tokens are available for content
|
||||
// by rendering the template with empty content and counting tokens
|
||||
func getAvailableTokensForContent(template *template.Template, data map[string]interface{}) (int, error) {
|
||||
func getAvailableTokensForContent(tmpl *template.Template, data map[string]interface{}) (int, error) {
|
||||
if tokenLimit <= 0 {
|
||||
return 0, nil // No limit when disabled
|
||||
}
|
||||
|
||||
// Create a copy of data and set Content to empty
|
||||
// Create a copy of data and set "Content" to empty
|
||||
templateData := make(map[string]interface{})
|
||||
for k, v := range data {
|
||||
templateData[k] = v
|
||||
|
@ -25,7 +24,7 @@ func getAvailableTokensForContent(template *template.Template, data map[string]i
|
|||
|
||||
// Execute template with empty content
|
||||
var promptBuffer bytes.Buffer
|
||||
if err := template.Execute(&promptBuffer, templateData); err != nil {
|
||||
if err := tmpl.Execute(&promptBuffer, templateData); err != nil {
|
||||
return 0, fmt.Errorf("error executing template: %v", err)
|
||||
}
|
||||
|
||||
|
@ -51,39 +50,50 @@ func getTokenCount(content string) (int, error) {
|
|||
return llms.CountTokens(llmModel, content), nil
|
||||
}
|
||||
|
||||
// truncateContentByTokens truncates the content to fit within the specified token limit
|
||||
// truncateContentByTokens truncates the content so that its token count does not exceed availableTokens.
|
||||
// This implementation uses a binary search on runes to find the longest prefix whose token count is within the limit.
|
||||
// If availableTokens is 0 or negative, the original content is returned.
|
||||
func truncateContentByTokens(content string, availableTokens int) (string, error) {
|
||||
if availableTokens <= 0 || tokenLimit <= 0 {
|
||||
return content, nil
|
||||
}
|
||||
tokenCount, err := getTokenCount(content)
|
||||
totalTokens, err := getTokenCount(content)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error counting tokens: %v", err)
|
||||
}
|
||||
if tokenCount <= availableTokens {
|
||||
if totalTokens <= availableTokens {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
splitter := textsplitter.NewTokenSplitter(
|
||||
textsplitter.WithChunkSize(availableTokens),
|
||||
textsplitter.WithChunkOverlap(0),
|
||||
textsplitter.WithModelName(llmModel),
|
||||
)
|
||||
chunks, err := splitter.SplitText(content)
|
||||
// Convert content to runes for safe slicing.
|
||||
runes := []rune(content)
|
||||
low := 0
|
||||
high := len(runes)
|
||||
validCut := 0
|
||||
|
||||
for low <= high {
|
||||
mid := (low + high) / 2
|
||||
substr := string(runes[:mid])
|
||||
count, err := getTokenCount(substr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error splitting content: %v", err)
|
||||
return "", fmt.Errorf("error counting tokens in substring: %v", err)
|
||||
}
|
||||
if count <= availableTokens {
|
||||
validCut = mid
|
||||
low = mid + 1
|
||||
} else {
|
||||
high = mid - 1
|
||||
}
|
||||
}
|
||||
|
||||
// Validate first chunk's token count
|
||||
firstChunk := chunks[0]
|
||||
chunkTokens, err := getTokenCount(firstChunk)
|
||||
truncated := string(runes[:validCut])
|
||||
// Final verification
|
||||
finalTokens, err := getTokenCount(truncated)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error counting tokens in chunk: %v", err)
|
||||
return "", fmt.Errorf("error counting tokens in final truncated content: %v", err)
|
||||
}
|
||||
if chunkTokens > availableTokens {
|
||||
return "", fmt.Errorf("first chunk uses %d tokens which exceeds the limit of %d tokens", chunkTokens, availableTokens)
|
||||
if finalTokens > availableTokens {
|
||||
return "", fmt.Errorf("truncated content still exceeds the available token limit")
|
||||
}
|
||||
|
||||
// return the first chunk
|
||||
return firstChunk, nil
|
||||
return truncated, nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue