From b0737aab50ad7f84937a13b7584d47d89df628cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dominik=20Schr=C3=B6ter?= Date: Mon, 27 Jan 2025 13:08:37 +0100 Subject: [PATCH] feat: add TOKEN_LIMIT environment variable for controlling maximum tokens in prompts --- README.md | 26 +++++ app_llm.go | 83 +++++++++---- app_llm_test.go | 268 ++++++++++++++++++++++++++++++++++++++++++ go.mod | 5 + go.sum | 13 +++ main.go | 8 ++ tokens.go | 73 ++++++++++++ tokens_test.go | 302 ++++++++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 758 insertions(+), 20 deletions(-) create mode 100644 app_llm_test.go create mode 100644 tokens.go create mode 100644 tokens_test.go diff --git a/README.md b/README.md index 1e8f1fb..a919a81 100644 --- a/README.md +++ b/README.md @@ -175,6 +175,7 @@ services: | `AUTO_GENERATE_TAGS` | Generate tags automatically if `paperless-gpt-auto` is used. Default: `true`. | No | | `AUTO_GENERATE_CORRESPONDENTS` | Generate correspondents automatically if `paperless-gpt-auto` is used. Default: `true`. | No | | `OCR_LIMIT_PAGES` | Limit the number of pages for OCR. Set to `0` for no limit. Default: `5`. | No | +| `TOKEN_LIMIT` | Maximum tokens allowed for prompts/content. Set to `0` to disable limit. Useful for smaller LLMs. | No | | `CORRESPONDENT_BLACK_LIST` | A comma-separated list of names to exclude from the correspondents suggestions. Example: `John Doe, Jane Smith`. ### Custom Prompt Templates @@ -417,6 +418,31 @@ P.O. Box 94515 --- +## Troubleshooting + +### Working with Local LLMs + +When using local LLMs (like those through Ollama), you might need to adjust certain settings to optimize performance: + +#### Token Management +- Use `TOKEN_LIMIT` environment variable to control the maximum number of tokens sent to the LLM +- Smaller models might truncate content unexpectedly if given too much text +- Start with a conservative limit (e.g., 2000 tokens) and adjust based on your model's capabilities +- Set to `0` to disable the limit (use with caution) + +Example configuration for smaller models: +```yaml +environment: + TOKEN_LIMIT: '2000' # Adjust based on your model's context window + LLM_PROVIDER: 'ollama' + LLM_MODEL: 'llama2' # Or other local model +``` + +Common issues and solutions: +- If you see truncated or incomplete responses, try lowering the `TOKEN_LIMIT` +- If processing is too limited, gradually increase the limit while monitoring performance +- For models with larger context windows, you can increase the limit or disable it entirely + ## Contributing **Pull requests** and **issues** are welcome! diff --git a/app_llm.go b/app_llm.go index d1c6310..aeaf0e4 100644 --- a/app_llm.go +++ b/app_llm.go @@ -23,14 +23,29 @@ func (app *App) getSuggestedCorrespondent(ctx context.Context, content string, s templateMutex.RLock() defer templateMutex.RUnlock() - var promptBuffer bytes.Buffer - err := correspondentTemplate.Execute(&promptBuffer, map[string]interface{}{ + // Get available tokens for content + templateData := map[string]interface{}{ "Language": likelyLanguage, "AvailableCorrespondents": availableCorrespondents, "BlackList": correspondentBlackList, "Title": suggestedTitle, - "Content": content, - }) + } + + availableTokens, err := getAvailableTokensForContent(correspondentTemplate, templateData) + if err != nil { + return "", fmt.Errorf("error calculating available tokens: %v", err) + } + + // Truncate content if needed + truncatedContent, err := truncateContentByTokens(content, availableTokens) + if err != nil { + return "", fmt.Errorf("error truncating content: %v", err) + } + + // Execute template with truncated content + var promptBuffer bytes.Buffer + templateData["Content"] = truncatedContent + err = correspondentTemplate.Execute(&promptBuffer, templateData) if err != nil { return "", fmt.Errorf("error executing correspondent template: %v", err) } @@ -74,14 +89,31 @@ func (app *App) getSuggestedTags( availableTags = removeTagFromList(availableTags, autoTag) availableTags = removeTagFromList(availableTags, autoOcrTag) - var promptBuffer bytes.Buffer - err := tagTemplate.Execute(&promptBuffer, map[string]interface{}{ + // Get available tokens for content + templateData := map[string]interface{}{ "Language": likelyLanguage, "AvailableTags": availableTags, "OriginalTags": originalTags, "Title": suggestedTitle, - "Content": content, - }) + } + + availableTokens, err := getAvailableTokensForContent(tagTemplate, templateData) + if err != nil { + logger.Errorf("Error calculating available tokens: %v", err) + return nil, fmt.Errorf("error calculating available tokens: %v", err) + } + + // Truncate content if needed + truncatedContent, err := truncateContentByTokens(content, availableTokens) + if err != nil { + logger.Errorf("Error truncating content: %v", err) + return nil, fmt.Errorf("error truncating content: %v", err) + } + + // Execute template with truncated content + var promptBuffer bytes.Buffer + templateData["Content"] = truncatedContent + err = tagTemplate.Execute(&promptBuffer, templateData) if err != nil { logger.Errorf("Error executing tag template: %v", err) return nil, fmt.Errorf("error executing tag template: %v", err) @@ -132,7 +164,6 @@ func (app *App) getSuggestedTags( } func (app *App) doOCRViaLLM(ctx context.Context, jpegBytes []byte, logger *logrus.Entry) (string, error) { - templateMutex.RLock() defer templateMutex.RUnlock() likelyLanguage := getLikelyLanguage() @@ -191,24 +222,41 @@ func (app *App) doOCRViaLLM(ctx context.Context, jpegBytes []byte, logger *logru } // getSuggestedTitle generates a suggested title for a document using the LLM -func (app *App) getSuggestedTitle(ctx context.Context, content string, suggestedTitle string, logger *logrus.Entry) (string, error) { +func (app *App) getSuggestedTitle(ctx context.Context, content string, originalTitle string, logger *logrus.Entry) (string, error) { likelyLanguage := getLikelyLanguage() templateMutex.RLock() defer templateMutex.RUnlock() - var promptBuffer bytes.Buffer - err := titleTemplate.Execute(&promptBuffer, map[string]interface{}{ + // Get available tokens for content + templateData := map[string]interface{}{ "Language": likelyLanguage, "Content": content, - "Title": suggestedTitle, - }) + "Title": originalTitle, + } + + availableTokens, err := getAvailableTokensForContent(titleTemplate, templateData) + if err != nil { + logger.Errorf("Error calculating available tokens: %v", err) + return "", fmt.Errorf("error calculating available tokens: %v", err) + } + + // Truncate content if needed + truncatedContent, err := truncateContentByTokens(content, availableTokens) + if err != nil { + logger.Errorf("Error truncating content: %v", err) + return "", fmt.Errorf("error truncating content: %v", err) + } + + // Execute template with truncated content + var promptBuffer bytes.Buffer + templateData["Content"] = truncatedContent + err = titleTemplate.Execute(&promptBuffer, templateData) if err != nil { return "", fmt.Errorf("error executing title template: %v", err) } prompt := promptBuffer.String() - logger.Debugf("Title suggestion prompt: %s", prompt) completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{ @@ -273,10 +321,6 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque docLogger.Printf("Processing Document ID %d...", documentID) content := doc.Content - if len(content) > 5000 { - content = content[:5000] - } - suggestedTitle := doc.Title var suggestedTags []string var suggestedCorrespondent string @@ -312,7 +356,6 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque log.Errorf("Error generating correspondents for document %d: %v", documentID, err) return } - } mu.Lock() diff --git a/app_llm_test.go b/app_llm_test.go new file mode 100644 index 0000000..953204b --- /dev/null +++ b/app_llm_test.go @@ -0,0 +1,268 @@ +package main + +import ( + "context" + "fmt" + "os" + "testing" + "text/template" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/textsplitter" +) + +// Mock LLM for testing +type mockLLM struct { + lastPrompt string +} + +func (m *mockLLM) CreateEmbedding(_ context.Context, texts []string) ([][]float32, error) { + return nil, nil +} + +func (m *mockLLM) Call(_ context.Context, prompt string, _ ...llms.CallOption) (string, error) { + m.lastPrompt = prompt + return "test response", nil +} + +func (m *mockLLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, opts ...llms.CallOption) (*llms.ContentResponse, error) { + m.lastPrompt = messages[0].Parts[0].(llms.TextContent).Text + return &llms.ContentResponse{ + Choices: []*llms.ContentChoice{ + { + Content: "test response", + }, + }, + }, nil +} + +// Mock templates for testing +const ( + testTitleTemplate = ` +Language: {{.Language}} +Title: {{.Title}} +Content: {{.Content}} +` + testTagTemplate = ` +Language: {{.Language}} +Tags: {{.AvailableTags}} +Content: {{.Content}} +` + testCorrespondentTemplate = ` +Language: {{.Language}} +Content: {{.Content}} +` +) + +func TestPromptTokenLimits(t *testing.T) { + testLogger := logrus.WithField("test", "test") + + // Initialize test templates + var err error + titleTemplate, err = template.New("title").Parse(testTitleTemplate) + require.NoError(t, err) + tagTemplate, err = template.New("tag").Parse(testTagTemplate) + require.NoError(t, err) + correspondentTemplate, err = template.New("correspondent").Parse(testCorrespondentTemplate) + require.NoError(t, err) + + // Save current env and restore after test + originalLimit := os.Getenv("TOKEN_LIMIT") + defer os.Setenv("TOKEN_LIMIT", originalLimit) + + // Create a test app with mock LLM + mockLLM := &mockLLM{} + app := &App{ + LLM: mockLLM, + } + + // Set up test template + testTemplate := template.Must(template.New("test").Parse(` +Language: {{.Language}} +Content: {{.Content}} +`)) + + tests := []struct { + name string + tokenLimit int + content string + }{ + { + name: "no limit", + tokenLimit: 0, + content: "This is the original content that should not be truncated.", + }, + { + name: "content within limit", + tokenLimit: 100, + content: "Short content", + }, + { + name: "content exceeds limit", + tokenLimit: 50, + content: "This is a much longer content that should definitely be truncated to fit within token limits", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Set token limit for this test + os.Setenv("TOKEN_LIMIT", fmt.Sprintf("%d", tc.tokenLimit)) + resetTokenLimit() + + // Prepare test data + data := map[string]interface{}{ + "Language": "English", + } + + // Calculate available tokens + availableTokens, err := getAvailableTokensForContent(testTemplate, data) + require.NoError(t, err) + + // Truncate content if needed + truncatedContent, err := truncateContentByTokens(tc.content, availableTokens) + require.NoError(t, err) + + // Test with the app's LLM + ctx := context.Background() + _, err = app.getSuggestedTitle(ctx, truncatedContent, "Test Title", testLogger) + require.NoError(t, err) + + // Verify truncation + if tc.tokenLimit > 0 { + // Count tokens in final prompt received by LLM + splitter := textsplitter.NewTokenSplitter() + tokens, err := splitter.SplitText(mockLLM.lastPrompt) + require.NoError(t, err) + + // Verify prompt is within limits + assert.LessOrEqual(t, len(tokens), tc.tokenLimit, + "Final prompt should be within token limit") + + if len(tc.content) > len(truncatedContent) { + // Content was truncated + t.Logf("Content truncated from %d to %d characters", + len(tc.content), len(truncatedContent)) + } + } else { + // No limit set, content should be unchanged + assert.Contains(t, mockLLM.lastPrompt, tc.content, + "Original content should be in prompt when no limit is set") + } + }) + } +} + +func TestTokenLimitInCorrespondentGeneration(t *testing.T) { + // Save current env and restore after test + originalLimit := os.Getenv("TOKEN_LIMIT") + defer os.Setenv("TOKEN_LIMIT", originalLimit) + + // Create a test app with mock LLM + mockLLM := &mockLLM{} + app := &App{ + LLM: mockLLM, + } + + // Test content that would exceed reasonable token limits + longContent := "This is a very long content that would normally exceed token limits. " + + "It contains multiple sentences and should be truncated appropriately " + + "based on the token limit that we set." + + // Set a small token limit + os.Setenv("TOKEN_LIMIT", "50") + resetTokenLimit() + + // Call getSuggestedCorrespondent + ctx := context.Background() + availableCorrespondents := []string{"Test Corp", "Example Inc"} + correspondentBlackList := []string{"Blocked Corp"} + + _, err := app.getSuggestedCorrespondent(ctx, longContent, "Test Title", availableCorrespondents, correspondentBlackList) + require.NoError(t, err) + + // Verify the final prompt size + splitter := textsplitter.NewTokenSplitter() + tokens, err := splitter.SplitText(mockLLM.lastPrompt) + require.NoError(t, err) + + // Final prompt should be within token limit + assert.LessOrEqual(t, len(tokens), 50, "Final prompt should be within token limit") +} + +func TestTokenLimitInTagGeneration(t *testing.T) { + testLogger := logrus.WithField("test", "test") + + // Save current env and restore after test + originalLimit := os.Getenv("TOKEN_LIMIT") + defer os.Setenv("TOKEN_LIMIT", originalLimit) + + // Create a test app with mock LLM + mockLLM := &mockLLM{} + app := &App{ + LLM: mockLLM, + } + + // Test content that would exceed reasonable token limits + longContent := "This is a very long content that would normally exceed token limits. " + + "It contains multiple sentences and should be truncated appropriately." + + // Set a small token limit + os.Setenv("TOKEN_LIMIT", "50") + resetTokenLimit() + + // Call getSuggestedTags + ctx := context.Background() + availableTags := []string{"test", "example"} + originalTags := []string{"original"} + + _, err := app.getSuggestedTags(ctx, longContent, "Test Title", availableTags, originalTags, testLogger) + require.NoError(t, err) + + // Verify the final prompt size + splitter := textsplitter.NewTokenSplitter() + tokens, err := splitter.SplitText(mockLLM.lastPrompt) + require.NoError(t, err) + + // Final prompt should be within token limit + assert.LessOrEqual(t, len(tokens), 50, "Final prompt should be within token limit") +} + +func TestTokenLimitInTitleGeneration(t *testing.T) { + testLogger := logrus.WithField("test", "test") + + // Save current env and restore after test + originalLimit := os.Getenv("TOKEN_LIMIT") + defer os.Setenv("TOKEN_LIMIT", originalLimit) + + // Create a test app with mock LLM + mockLLM := &mockLLM{} + app := &App{ + LLM: mockLLM, + } + + // Test content that would exceed reasonable token limits + longContent := "This is a very long content that would normally exceed token limits. " + + "It contains multiple sentences and should be truncated appropriately." + + // Set a small token limit + os.Setenv("TOKEN_LIMIT", "50") + resetTokenLimit() + + // Call getSuggestedTitle + ctx := context.Background() + + _, err := app.getSuggestedTitle(ctx, longContent, "Original Title", testLogger) + require.NoError(t, err) + + // Verify the final prompt size + splitter := textsplitter.NewTokenSplitter() + tokens, err := splitter.SplitText(mockLLM.lastPrompt) + require.NoError(t, err) + + // Final prompt should be within token limit + assert.LessOrEqual(t, len(tokens), 50, "Final prompt should be within token limit") +} diff --git a/go.mod b/go.mod index b4b0981..4162124 100644 --- a/go.mod +++ b/go.mod @@ -57,6 +57,11 @@ require ( github.com/spf13/cast v1.7.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect + gitlab.com/golang-commonmark/html v0.0.0-20191124015941-a22733972181 // indirect + gitlab.com/golang-commonmark/linkify v0.0.0-20191026162114-a0c2df6c8f82 // indirect + gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a // indirect + gitlab.com/golang-commonmark/mdurl v0.0.0-20191124015652-932350d1cb84 // indirect + gitlab.com/golang-commonmark/puny v0.0.0-20191124015043-9f83538fa04f // indirect golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.26.0 // indirect golang.org/x/net v0.25.0 // indirect diff --git a/go.sum b/go.sum index 0e1f847..6d3c822 100644 --- a/go.sum +++ b/go.sum @@ -98,6 +98,7 @@ github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAc github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= @@ -131,6 +132,17 @@ github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2 github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +gitlab.com/golang-commonmark/html v0.0.0-20191124015941-a22733972181 h1:K+bMSIx9A7mLES1rtG+qKduLIXq40DAzYHtb0XuCukA= +gitlab.com/golang-commonmark/html v0.0.0-20191124015941-a22733972181/go.mod h1:dzYhVIwWCtzPAa4QP98wfB9+mzt33MSmM8wsKiMi2ow= +gitlab.com/golang-commonmark/linkify v0.0.0-20191026162114-a0c2df6c8f82 h1:oYrL81N608MLZhma3ruL8qTM4xcpYECGut8KSxRY59g= +gitlab.com/golang-commonmark/linkify v0.0.0-20191026162114-a0c2df6c8f82/go.mod h1:Gn+LZmCrhPECMD3SOKlE+BOHwhOYD9j7WT9NUtkCrC8= +gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a h1:O85GKETcmnCNAfv4Aym9tepU8OE0NmcZNqPlXcsBKBs= +gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a/go.mod h1:LaSIs30YPGs1H5jwGgPhLzc8vkNc/k0rDX/fEZqiU/M= +gitlab.com/golang-commonmark/mdurl v0.0.0-20191124015652-932350d1cb84 h1:qqjvoVXdWIcZCLPMlzgA7P9FZWdPGPvP/l3ef8GzV6o= +gitlab.com/golang-commonmark/mdurl v0.0.0-20191124015652-932350d1cb84/go.mod h1:IJZ+fdMvbW2qW6htJx7sLJ04FEs4Ldl/MDsJtMKywfw= +gitlab.com/golang-commonmark/puny v0.0.0-20191124015043-9f83538fa04f h1:Wku8eEdeJqIOFHtrfkYUByc4bCaTeA6fL0UJgfEiFMI= +gitlab.com/golang-commonmark/puny v0.0.0-20191124015043-9f83538fa04f/go.mod h1:Tiuhl+njh/JIg0uS/sOJVYi0x2HEa5rc1OAaVsb5tAs= +gitlab.com/opennota/wd v0.0.0-20180912061657-c5d65f63c638/go.mod h1:EGRJaqe2eO9XGmFtQCvV3Lm9NLico3UhFwUpCG/+mVU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= @@ -174,6 +186,7 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= diff --git a/main.go b/main.go index ebfacbc..448757e 100644 --- a/main.go +++ b/main.go @@ -50,6 +50,7 @@ var ( autoGenerateTags = os.Getenv("AUTO_GENERATE_TAGS") autoGenerateCorrespondents = os.Getenv("AUTO_GENERATE_CORRESPONDENTS") limitOcrPages int // Will be read from OCR_LIMIT_PAGES + tokenLimit = 0 // Will be read from TOKEN_LIMIT // Templates titleTemplate *template.Template @@ -382,6 +383,13 @@ func validateOrDefaultEnvVars() { } } } + + // Initialize token limit from environment variable + if limit := os.Getenv("TOKEN_LIMIT"); limit != "" { + if parsed, err := strconv.Atoi(limit); err == nil { + tokenLimit = parsed + } + } } // documentLogger creates a logger with document context diff --git a/tokens.go b/tokens.go new file mode 100644 index 0000000..4b6eed2 --- /dev/null +++ b/tokens.go @@ -0,0 +1,73 @@ +package main + +import ( + "bytes" + "fmt" + "text/template" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/textsplitter" +) + +// getAvailableTokensForContent calculates how many tokens are available for content +// by rendering the template with empty content and counting tokens +func getAvailableTokensForContent(template *template.Template, data map[string]interface{}) (int, error) { + if tokenLimit <= 0 { + return 0, nil // No limit when disabled + } + + // Create a copy of data and set Content to empty + templateData := make(map[string]interface{}) + for k, v := range data { + templateData[k] = v + } + templateData["Content"] = "" + + // Execute template with empty content + var promptBuffer bytes.Buffer + if err := template.Execute(&promptBuffer, templateData); err != nil { + return 0, fmt.Errorf("error executing template: %v", err) + } + + // Count tokens in prompt template + promptTokens := getTokenCount(promptBuffer.String()) + log.Debugf("Prompt template uses %d tokens", promptTokens) + + // Add safety margin for prompt tokens + promptTokens += 10 + + // Calculate available tokens for content + availableTokens := tokenLimit - promptTokens + if availableTokens < 0 { + return 0, fmt.Errorf("prompt template exceeds token limit") + } + return availableTokens, nil +} + +func getTokenCount(content string) int { + return llms.CountTokens(llmModel, content) +} + +// truncateContentByTokens truncates the content to fit within the specified token limit +func truncateContentByTokens(content string, availableTokens int) (string, error) { + if availableTokens <= 0 || tokenLimit <= 0 { + return content, nil + } + tokenCount := getTokenCount(content) + if tokenCount <= availableTokens { + return content, nil + } + + splitter := textsplitter.NewTokenSplitter( + textsplitter.WithChunkSize(availableTokens), + textsplitter.WithChunkOverlap(0), + // textsplitter.WithModelName(llmModel), + ) + chunks, err := splitter.SplitText(content) + if err != nil { + return "", fmt.Errorf("error splitting content: %v", err) + } + + // return the first chunk + return chunks[0], nil +} diff --git a/tokens_test.go b/tokens_test.go new file mode 100644 index 0000000..03e7d62 --- /dev/null +++ b/tokens_test.go @@ -0,0 +1,302 @@ +package main + +import ( + "bytes" + "fmt" + "os" + "strconv" + "testing" + "text/template" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tmc/langchaingo/textsplitter" +) + +// resetTokenLimit parses TOKEN_LIMIT from environment and sets the tokenLimit variable +func resetTokenLimit() { + // Reset tokenLimit + tokenLimit = 0 + // Parse from environment + if limit := os.Getenv("TOKEN_LIMIT"); limit != "" { + if parsed, err := strconv.Atoi(limit); err == nil { + tokenLimit = parsed + } + } +} + +func TestTokenLimit(t *testing.T) { + // Save current env and restore after test + originalLimit := os.Getenv("TOKEN_LIMIT") + defer os.Setenv("TOKEN_LIMIT", originalLimit) + + tests := []struct { + name string + envValue string + wantLimit int + }{ + { + name: "empty value", + envValue: "", + wantLimit: 0, + }, + { + name: "zero value", + envValue: "0", + wantLimit: 0, + }, + { + name: "positive value", + envValue: "1000", + wantLimit: 1000, + }, + { + name: "invalid value", + envValue: "not-a-number", + wantLimit: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Set environment variable + os.Setenv("TOKEN_LIMIT", tc.envValue) + + // Set tokenLimit based on environment + resetTokenLimit() + + assert.Equal(t, tc.wantLimit, tokenLimit) + }) + } +} + +func TestGetAvailableTokensForContent(t *testing.T) { + // Save current env and restore after test + originalLimit := os.Getenv("TOKEN_LIMIT") + defer os.Setenv("TOKEN_LIMIT", originalLimit) + + // Test template + tmpl := template.Must(template.New("test").Parse("Template with {{.Var1}} and {{.Content}}")) + + tests := []struct { + name string + limit int + data map[string]interface{} + wantCount int + wantErr bool + }{ + { + name: "disabled token limit", + limit: 0, + data: map[string]interface{}{"Var1": "test"}, + wantCount: 0, + wantErr: false, + }, + { + name: "template exceeds limit", + limit: 2, + data: map[string]interface{}{ + "Var1": "test", + }, + wantCount: 0, + wantErr: true, + }, + { + name: "available tokens calculation", + limit: 100, + data: map[string]interface{}{ + "Var1": "test", + }, + wantCount: 85, + wantErr: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Set token limit + os.Setenv("TOKEN_LIMIT", fmt.Sprintf("%d", tc.limit)) + // Set tokenLimit based on environment + resetTokenLimit() + + count, err := getAvailableTokensForContent(tmpl, tc.data) + + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.wantCount, count) + } + }) + } +} + +func TestTruncateContentByTokens(t *testing.T) { + // Save current env and restore after test + originalLimit := os.Getenv("TOKEN_LIMIT") + defer os.Setenv("TOKEN_LIMIT", originalLimit) + + // Set a token limit for testing + os.Setenv("TOKEN_LIMIT", "100") + // Set tokenLimit based on environment + resetTokenLimit() + + tests := []struct { + name string + content string + availableTokens int + wantTruncated bool + wantErr bool + }{ + { + name: "no truncation needed", + content: "short content", + availableTokens: 20, + wantTruncated: false, + wantErr: false, + }, + { + name: "disabled by token limit", + content: "any content", + availableTokens: 0, + wantTruncated: false, + wantErr: false, + }, + { + name: "truncation needed", + content: "This is a much longer content that will definitely need to be truncated because it exceeds the available tokens", + availableTokens: 10, + wantTruncated: true, + wantErr: false, + }, + { + name: "empty content", + content: "", + availableTokens: 10, + wantTruncated: false, + wantErr: false, + }, + { + name: "exact token count", + content: "one two three four five", + availableTokens: 5, + wantTruncated: false, + wantErr: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := truncateContentByTokens(tc.content, tc.availableTokens) + + if tc.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + + if tc.wantTruncated { + assert.True(t, len(result) < len(tc.content), "Content should be truncated") + } else { + assert.Equal(t, tc.content, result, "Content should not be truncated") + } + }) + } +} + +func TestTokenLimitIntegration(t *testing.T) { + // Save current env and restore after test + originalLimit := os.Getenv("TOKEN_LIMIT") + defer os.Setenv("TOKEN_LIMIT", originalLimit) + + // Create a test template + tmpl := template.Must(template.New("test").Parse(` +Template with variables: +Language: {{.Language}} +Title: {{.Title}} +Content: {{.Content}} +`)) + + // Test data + data := map[string]interface{}{ + "Language": "English", + "Title": "Test Document", + } + + // Test with different token limits + tests := []struct { + name string + limit int + content string + wantSize int + wantError bool + }{ + { + name: "no limit", + limit: 0, + content: "original content", + wantSize: len("original content"), + wantError: false, + }, + { + name: "sufficient limit", + limit: 1000, + content: "original content", + wantSize: len("original content"), + wantError: false, + }, + { + name: "tight limit", + limit: 50, + content: "This is a long content that should be truncated to fit within the token limit", + wantSize: 50, + wantError: false, + }, + { + name: "very small limit", + limit: 3, + content: "Content too large for small limit", + wantError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Set token limit + os.Setenv("TOKEN_LIMIT", fmt.Sprintf("%d", tc.limit)) + // Set tokenLimit based on environment + resetTokenLimit() + + // First get available tokens + availableTokens, err := getAvailableTokensForContent(tmpl, data) + if tc.wantError { + require.Error(t, err) + return + } + require.NoError(t, err) + + // Then truncate content + truncated, err := truncateContentByTokens(tc.content, availableTokens) + require.NoError(t, err) + + // Finally execute template with truncated content + data["Content"] = truncated + var result string + { + var buf bytes.Buffer + err = tmpl.Execute(&buf, data) + require.NoError(t, err) + result = buf.String() + } + + // Verify final size is within limit if limit is enabled + if tc.limit > 0 { + splitter := textsplitter.NewTokenSplitter() + tokens, err := splitter.SplitText(result) + require.NoError(t, err) + assert.LessOrEqual(t, len(tokens), tc.limit) + } + }) + } +}