mirror of
https://github.com/icereed/paperless-gpt.git
synced 2025-03-12 21:08:00 -05:00
feat: add TOKEN_LIMIT environment variable for controlling maximum tokens in prompts
This commit is contained in:
parent
8b6041a93f
commit
b0737aab50
8 changed files with 758 additions and 20 deletions
26
README.md
26
README.md
|
@ -175,6 +175,7 @@ services:
|
||||||
| `AUTO_GENERATE_TAGS` | Generate tags automatically if `paperless-gpt-auto` is used. Default: `true`. | No |
|
| `AUTO_GENERATE_TAGS` | Generate tags automatically if `paperless-gpt-auto` is used. Default: `true`. | No |
|
||||||
| `AUTO_GENERATE_CORRESPONDENTS` | Generate correspondents automatically if `paperless-gpt-auto` is used. Default: `true`. | No |
|
| `AUTO_GENERATE_CORRESPONDENTS` | Generate correspondents automatically if `paperless-gpt-auto` is used. Default: `true`. | No |
|
||||||
| `OCR_LIMIT_PAGES` | Limit the number of pages for OCR. Set to `0` for no limit. Default: `5`. | No |
|
| `OCR_LIMIT_PAGES` | Limit the number of pages for OCR. Set to `0` for no limit. Default: `5`. | No |
|
||||||
|
| `TOKEN_LIMIT` | Maximum tokens allowed for prompts/content. Set to `0` to disable limit. Useful for smaller LLMs. | No |
|
||||||
| `CORRESPONDENT_BLACK_LIST` | A comma-separated list of names to exclude from the correspondents suggestions. Example: `John Doe, Jane Smith`.
|
| `CORRESPONDENT_BLACK_LIST` | A comma-separated list of names to exclude from the correspondents suggestions. Example: `John Doe, Jane Smith`.
|
||||||
|
|
||||||
### Custom Prompt Templates
|
### Custom Prompt Templates
|
||||||
|
@ -417,6 +418,31 @@ P.O. Box 94515
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Working with Local LLMs
|
||||||
|
|
||||||
|
When using local LLMs (like those through Ollama), you might need to adjust certain settings to optimize performance:
|
||||||
|
|
||||||
|
#### Token Management
|
||||||
|
- Use `TOKEN_LIMIT` environment variable to control the maximum number of tokens sent to the LLM
|
||||||
|
- Smaller models might truncate content unexpectedly if given too much text
|
||||||
|
- Start with a conservative limit (e.g., 2000 tokens) and adjust based on your model's capabilities
|
||||||
|
- Set to `0` to disable the limit (use with caution)
|
||||||
|
|
||||||
|
Example configuration for smaller models:
|
||||||
|
```yaml
|
||||||
|
environment:
|
||||||
|
TOKEN_LIMIT: '2000' # Adjust based on your model's context window
|
||||||
|
LLM_PROVIDER: 'ollama'
|
||||||
|
LLM_MODEL: 'llama2' # Or other local model
|
||||||
|
```
|
||||||
|
|
||||||
|
Common issues and solutions:
|
||||||
|
- If you see truncated or incomplete responses, try lowering the `TOKEN_LIMIT`
|
||||||
|
- If processing is too limited, gradually increase the limit while monitoring performance
|
||||||
|
- For models with larger context windows, you can increase the limit or disable it entirely
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
**Pull requests** and **issues** are welcome!
|
**Pull requests** and **issues** are welcome!
|
||||||
|
|
83
app_llm.go
83
app_llm.go
|
@ -23,14 +23,29 @@ func (app *App) getSuggestedCorrespondent(ctx context.Context, content string, s
|
||||||
templateMutex.RLock()
|
templateMutex.RLock()
|
||||||
defer templateMutex.RUnlock()
|
defer templateMutex.RUnlock()
|
||||||
|
|
||||||
var promptBuffer bytes.Buffer
|
// Get available tokens for content
|
||||||
err := correspondentTemplate.Execute(&promptBuffer, map[string]interface{}{
|
templateData := map[string]interface{}{
|
||||||
"Language": likelyLanguage,
|
"Language": likelyLanguage,
|
||||||
"AvailableCorrespondents": availableCorrespondents,
|
"AvailableCorrespondents": availableCorrespondents,
|
||||||
"BlackList": correspondentBlackList,
|
"BlackList": correspondentBlackList,
|
||||||
"Title": suggestedTitle,
|
"Title": suggestedTitle,
|
||||||
"Content": content,
|
}
|
||||||
})
|
|
||||||
|
availableTokens, err := getAvailableTokensForContent(correspondentTemplate, templateData)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error calculating available tokens: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Truncate content if needed
|
||||||
|
truncatedContent, err := truncateContentByTokens(content, availableTokens)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error truncating content: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute template with truncated content
|
||||||
|
var promptBuffer bytes.Buffer
|
||||||
|
templateData["Content"] = truncatedContent
|
||||||
|
err = correspondentTemplate.Execute(&promptBuffer, templateData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error executing correspondent template: %v", err)
|
return "", fmt.Errorf("error executing correspondent template: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -74,14 +89,31 @@ func (app *App) getSuggestedTags(
|
||||||
availableTags = removeTagFromList(availableTags, autoTag)
|
availableTags = removeTagFromList(availableTags, autoTag)
|
||||||
availableTags = removeTagFromList(availableTags, autoOcrTag)
|
availableTags = removeTagFromList(availableTags, autoOcrTag)
|
||||||
|
|
||||||
var promptBuffer bytes.Buffer
|
// Get available tokens for content
|
||||||
err := tagTemplate.Execute(&promptBuffer, map[string]interface{}{
|
templateData := map[string]interface{}{
|
||||||
"Language": likelyLanguage,
|
"Language": likelyLanguage,
|
||||||
"AvailableTags": availableTags,
|
"AvailableTags": availableTags,
|
||||||
"OriginalTags": originalTags,
|
"OriginalTags": originalTags,
|
||||||
"Title": suggestedTitle,
|
"Title": suggestedTitle,
|
||||||
"Content": content,
|
}
|
||||||
})
|
|
||||||
|
availableTokens, err := getAvailableTokensForContent(tagTemplate, templateData)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Error calculating available tokens: %v", err)
|
||||||
|
return nil, fmt.Errorf("error calculating available tokens: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Truncate content if needed
|
||||||
|
truncatedContent, err := truncateContentByTokens(content, availableTokens)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Error truncating content: %v", err)
|
||||||
|
return nil, fmt.Errorf("error truncating content: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute template with truncated content
|
||||||
|
var promptBuffer bytes.Buffer
|
||||||
|
templateData["Content"] = truncatedContent
|
||||||
|
err = tagTemplate.Execute(&promptBuffer, templateData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error executing tag template: %v", err)
|
logger.Errorf("Error executing tag template: %v", err)
|
||||||
return nil, fmt.Errorf("error executing tag template: %v", err)
|
return nil, fmt.Errorf("error executing tag template: %v", err)
|
||||||
|
@ -132,7 +164,6 @@ func (app *App) getSuggestedTags(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *App) doOCRViaLLM(ctx context.Context, jpegBytes []byte, logger *logrus.Entry) (string, error) {
|
func (app *App) doOCRViaLLM(ctx context.Context, jpegBytes []byte, logger *logrus.Entry) (string, error) {
|
||||||
|
|
||||||
templateMutex.RLock()
|
templateMutex.RLock()
|
||||||
defer templateMutex.RUnlock()
|
defer templateMutex.RUnlock()
|
||||||
likelyLanguage := getLikelyLanguage()
|
likelyLanguage := getLikelyLanguage()
|
||||||
|
@ -191,24 +222,41 @@ func (app *App) doOCRViaLLM(ctx context.Context, jpegBytes []byte, logger *logru
|
||||||
}
|
}
|
||||||
|
|
||||||
// getSuggestedTitle generates a suggested title for a document using the LLM
|
// getSuggestedTitle generates a suggested title for a document using the LLM
|
||||||
func (app *App) getSuggestedTitle(ctx context.Context, content string, suggestedTitle string, logger *logrus.Entry) (string, error) {
|
func (app *App) getSuggestedTitle(ctx context.Context, content string, originalTitle string, logger *logrus.Entry) (string, error) {
|
||||||
likelyLanguage := getLikelyLanguage()
|
likelyLanguage := getLikelyLanguage()
|
||||||
|
|
||||||
templateMutex.RLock()
|
templateMutex.RLock()
|
||||||
defer templateMutex.RUnlock()
|
defer templateMutex.RUnlock()
|
||||||
|
|
||||||
var promptBuffer bytes.Buffer
|
// Get available tokens for content
|
||||||
err := titleTemplate.Execute(&promptBuffer, map[string]interface{}{
|
templateData := map[string]interface{}{
|
||||||
"Language": likelyLanguage,
|
"Language": likelyLanguage,
|
||||||
"Content": content,
|
"Content": content,
|
||||||
"Title": suggestedTitle,
|
"Title": originalTitle,
|
||||||
})
|
}
|
||||||
|
|
||||||
|
availableTokens, err := getAvailableTokensForContent(titleTemplate, templateData)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Error calculating available tokens: %v", err)
|
||||||
|
return "", fmt.Errorf("error calculating available tokens: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Truncate content if needed
|
||||||
|
truncatedContent, err := truncateContentByTokens(content, availableTokens)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Error truncating content: %v", err)
|
||||||
|
return "", fmt.Errorf("error truncating content: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute template with truncated content
|
||||||
|
var promptBuffer bytes.Buffer
|
||||||
|
templateData["Content"] = truncatedContent
|
||||||
|
err = titleTemplate.Execute(&promptBuffer, templateData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error executing title template: %v", err)
|
return "", fmt.Errorf("error executing title template: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt := promptBuffer.String()
|
prompt := promptBuffer.String()
|
||||||
|
|
||||||
logger.Debugf("Title suggestion prompt: %s", prompt)
|
logger.Debugf("Title suggestion prompt: %s", prompt)
|
||||||
|
|
||||||
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
|
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
|
||||||
|
@ -273,10 +321,6 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque
|
||||||
docLogger.Printf("Processing Document ID %d...", documentID)
|
docLogger.Printf("Processing Document ID %d...", documentID)
|
||||||
|
|
||||||
content := doc.Content
|
content := doc.Content
|
||||||
if len(content) > 5000 {
|
|
||||||
content = content[:5000]
|
|
||||||
}
|
|
||||||
|
|
||||||
suggestedTitle := doc.Title
|
suggestedTitle := doc.Title
|
||||||
var suggestedTags []string
|
var suggestedTags []string
|
||||||
var suggestedCorrespondent string
|
var suggestedCorrespondent string
|
||||||
|
@ -312,7 +356,6 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque
|
||||||
log.Errorf("Error generating correspondents for document %d: %v", documentID, err)
|
log.Errorf("Error generating correspondents for document %d: %v", documentID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
|
|
268
app_llm_test.go
Normal file
268
app_llm_test.go
Normal file
|
@ -0,0 +1,268 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tmc/langchaingo/llms"
|
||||||
|
"github.com/tmc/langchaingo/textsplitter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Mock LLM for testing
|
||||||
|
type mockLLM struct {
|
||||||
|
lastPrompt string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockLLM) CreateEmbedding(_ context.Context, texts []string) ([][]float32, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockLLM) Call(_ context.Context, prompt string, _ ...llms.CallOption) (string, error) {
|
||||||
|
m.lastPrompt = prompt
|
||||||
|
return "test response", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockLLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, opts ...llms.CallOption) (*llms.ContentResponse, error) {
|
||||||
|
m.lastPrompt = messages[0].Parts[0].(llms.TextContent).Text
|
||||||
|
return &llms.ContentResponse{
|
||||||
|
Choices: []*llms.ContentChoice{
|
||||||
|
{
|
||||||
|
Content: "test response",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock templates for testing
|
||||||
|
const (
|
||||||
|
testTitleTemplate = `
|
||||||
|
Language: {{.Language}}
|
||||||
|
Title: {{.Title}}
|
||||||
|
Content: {{.Content}}
|
||||||
|
`
|
||||||
|
testTagTemplate = `
|
||||||
|
Language: {{.Language}}
|
||||||
|
Tags: {{.AvailableTags}}
|
||||||
|
Content: {{.Content}}
|
||||||
|
`
|
||||||
|
testCorrespondentTemplate = `
|
||||||
|
Language: {{.Language}}
|
||||||
|
Content: {{.Content}}
|
||||||
|
`
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPromptTokenLimits(t *testing.T) {
|
||||||
|
testLogger := logrus.WithField("test", "test")
|
||||||
|
|
||||||
|
// Initialize test templates
|
||||||
|
var err error
|
||||||
|
titleTemplate, err = template.New("title").Parse(testTitleTemplate)
|
||||||
|
require.NoError(t, err)
|
||||||
|
tagTemplate, err = template.New("tag").Parse(testTagTemplate)
|
||||||
|
require.NoError(t, err)
|
||||||
|
correspondentTemplate, err = template.New("correspondent").Parse(testCorrespondentTemplate)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Save current env and restore after test
|
||||||
|
originalLimit := os.Getenv("TOKEN_LIMIT")
|
||||||
|
defer os.Setenv("TOKEN_LIMIT", originalLimit)
|
||||||
|
|
||||||
|
// Create a test app with mock LLM
|
||||||
|
mockLLM := &mockLLM{}
|
||||||
|
app := &App{
|
||||||
|
LLM: mockLLM,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up test template
|
||||||
|
testTemplate := template.Must(template.New("test").Parse(`
|
||||||
|
Language: {{.Language}}
|
||||||
|
Content: {{.Content}}
|
||||||
|
`))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tokenLimit int
|
||||||
|
content string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no limit",
|
||||||
|
tokenLimit: 0,
|
||||||
|
content: "This is the original content that should not be truncated.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "content within limit",
|
||||||
|
tokenLimit: 100,
|
||||||
|
content: "Short content",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "content exceeds limit",
|
||||||
|
tokenLimit: 50,
|
||||||
|
content: "This is a much longer content that should definitely be truncated to fit within token limits",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Set token limit for this test
|
||||||
|
os.Setenv("TOKEN_LIMIT", fmt.Sprintf("%d", tc.tokenLimit))
|
||||||
|
resetTokenLimit()
|
||||||
|
|
||||||
|
// Prepare test data
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"Language": "English",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate available tokens
|
||||||
|
availableTokens, err := getAvailableTokensForContent(testTemplate, data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Truncate content if needed
|
||||||
|
truncatedContent, err := truncateContentByTokens(tc.content, availableTokens)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test with the app's LLM
|
||||||
|
ctx := context.Background()
|
||||||
|
_, err = app.getSuggestedTitle(ctx, truncatedContent, "Test Title", testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify truncation
|
||||||
|
if tc.tokenLimit > 0 {
|
||||||
|
// Count tokens in final prompt received by LLM
|
||||||
|
splitter := textsplitter.NewTokenSplitter()
|
||||||
|
tokens, err := splitter.SplitText(mockLLM.lastPrompt)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify prompt is within limits
|
||||||
|
assert.LessOrEqual(t, len(tokens), tc.tokenLimit,
|
||||||
|
"Final prompt should be within token limit")
|
||||||
|
|
||||||
|
if len(tc.content) > len(truncatedContent) {
|
||||||
|
// Content was truncated
|
||||||
|
t.Logf("Content truncated from %d to %d characters",
|
||||||
|
len(tc.content), len(truncatedContent))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No limit set, content should be unchanged
|
||||||
|
assert.Contains(t, mockLLM.lastPrompt, tc.content,
|
||||||
|
"Original content should be in prompt when no limit is set")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenLimitInCorrespondentGeneration(t *testing.T) {
|
||||||
|
// Save current env and restore after test
|
||||||
|
originalLimit := os.Getenv("TOKEN_LIMIT")
|
||||||
|
defer os.Setenv("TOKEN_LIMIT", originalLimit)
|
||||||
|
|
||||||
|
// Create a test app with mock LLM
|
||||||
|
mockLLM := &mockLLM{}
|
||||||
|
app := &App{
|
||||||
|
LLM: mockLLM,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test content that would exceed reasonable token limits
|
||||||
|
longContent := "This is a very long content that would normally exceed token limits. " +
|
||||||
|
"It contains multiple sentences and should be truncated appropriately " +
|
||||||
|
"based on the token limit that we set."
|
||||||
|
|
||||||
|
// Set a small token limit
|
||||||
|
os.Setenv("TOKEN_LIMIT", "50")
|
||||||
|
resetTokenLimit()
|
||||||
|
|
||||||
|
// Call getSuggestedCorrespondent
|
||||||
|
ctx := context.Background()
|
||||||
|
availableCorrespondents := []string{"Test Corp", "Example Inc"}
|
||||||
|
correspondentBlackList := []string{"Blocked Corp"}
|
||||||
|
|
||||||
|
_, err := app.getSuggestedCorrespondent(ctx, longContent, "Test Title", availableCorrespondents, correspondentBlackList)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the final prompt size
|
||||||
|
splitter := textsplitter.NewTokenSplitter()
|
||||||
|
tokens, err := splitter.SplitText(mockLLM.lastPrompt)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Final prompt should be within token limit
|
||||||
|
assert.LessOrEqual(t, len(tokens), 50, "Final prompt should be within token limit")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenLimitInTagGeneration(t *testing.T) {
|
||||||
|
testLogger := logrus.WithField("test", "test")
|
||||||
|
|
||||||
|
// Save current env and restore after test
|
||||||
|
originalLimit := os.Getenv("TOKEN_LIMIT")
|
||||||
|
defer os.Setenv("TOKEN_LIMIT", originalLimit)
|
||||||
|
|
||||||
|
// Create a test app with mock LLM
|
||||||
|
mockLLM := &mockLLM{}
|
||||||
|
app := &App{
|
||||||
|
LLM: mockLLM,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test content that would exceed reasonable token limits
|
||||||
|
longContent := "This is a very long content that would normally exceed token limits. " +
|
||||||
|
"It contains multiple sentences and should be truncated appropriately."
|
||||||
|
|
||||||
|
// Set a small token limit
|
||||||
|
os.Setenv("TOKEN_LIMIT", "50")
|
||||||
|
resetTokenLimit()
|
||||||
|
|
||||||
|
// Call getSuggestedTags
|
||||||
|
ctx := context.Background()
|
||||||
|
availableTags := []string{"test", "example"}
|
||||||
|
originalTags := []string{"original"}
|
||||||
|
|
||||||
|
_, err := app.getSuggestedTags(ctx, longContent, "Test Title", availableTags, originalTags, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the final prompt size
|
||||||
|
splitter := textsplitter.NewTokenSplitter()
|
||||||
|
tokens, err := splitter.SplitText(mockLLM.lastPrompt)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Final prompt should be within token limit
|
||||||
|
assert.LessOrEqual(t, len(tokens), 50, "Final prompt should be within token limit")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenLimitInTitleGeneration(t *testing.T) {
|
||||||
|
testLogger := logrus.WithField("test", "test")
|
||||||
|
|
||||||
|
// Save current env and restore after test
|
||||||
|
originalLimit := os.Getenv("TOKEN_LIMIT")
|
||||||
|
defer os.Setenv("TOKEN_LIMIT", originalLimit)
|
||||||
|
|
||||||
|
// Create a test app with mock LLM
|
||||||
|
mockLLM := &mockLLM{}
|
||||||
|
app := &App{
|
||||||
|
LLM: mockLLM,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test content that would exceed reasonable token limits
|
||||||
|
longContent := "This is a very long content that would normally exceed token limits. " +
|
||||||
|
"It contains multiple sentences and should be truncated appropriately."
|
||||||
|
|
||||||
|
// Set a small token limit
|
||||||
|
os.Setenv("TOKEN_LIMIT", "50")
|
||||||
|
resetTokenLimit()
|
||||||
|
|
||||||
|
// Call getSuggestedTitle
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err := app.getSuggestedTitle(ctx, longContent, "Original Title", testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the final prompt size
|
||||||
|
splitter := textsplitter.NewTokenSplitter()
|
||||||
|
tokens, err := splitter.SplitText(mockLLM.lastPrompt)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Final prompt should be within token limit
|
||||||
|
assert.LessOrEqual(t, len(tokens), 50, "Final prompt should be within token limit")
|
||||||
|
}
|
5
go.mod
5
go.mod
|
@ -57,6 +57,11 @@ require (
|
||||||
github.com/spf13/cast v1.7.0 // indirect
|
github.com/spf13/cast v1.7.0 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
|
gitlab.com/golang-commonmark/html v0.0.0-20191124015941-a22733972181 // indirect
|
||||||
|
gitlab.com/golang-commonmark/linkify v0.0.0-20191026162114-a0c2df6c8f82 // indirect
|
||||||
|
gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a // indirect
|
||||||
|
gitlab.com/golang-commonmark/mdurl v0.0.0-20191124015652-932350d1cb84 // indirect
|
||||||
|
gitlab.com/golang-commonmark/puny v0.0.0-20191124015043-9f83538fa04f // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.26.0 // indirect
|
golang.org/x/crypto v0.26.0 // indirect
|
||||||
golang.org/x/net v0.25.0 // indirect
|
golang.org/x/net v0.25.0 // indirect
|
||||||
|
|
13
go.sum
13
go.sum
|
@ -98,6 +98,7 @@ github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAc
|
||||||
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ=
|
github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ=
|
||||||
github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
|
github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
|
||||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||||
|
@ -131,6 +132,17 @@ github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2
|
||||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
|
gitlab.com/golang-commonmark/html v0.0.0-20191124015941-a22733972181 h1:K+bMSIx9A7mLES1rtG+qKduLIXq40DAzYHtb0XuCukA=
|
||||||
|
gitlab.com/golang-commonmark/html v0.0.0-20191124015941-a22733972181/go.mod h1:dzYhVIwWCtzPAa4QP98wfB9+mzt33MSmM8wsKiMi2ow=
|
||||||
|
gitlab.com/golang-commonmark/linkify v0.0.0-20191026162114-a0c2df6c8f82 h1:oYrL81N608MLZhma3ruL8qTM4xcpYECGut8KSxRY59g=
|
||||||
|
gitlab.com/golang-commonmark/linkify v0.0.0-20191026162114-a0c2df6c8f82/go.mod h1:Gn+LZmCrhPECMD3SOKlE+BOHwhOYD9j7WT9NUtkCrC8=
|
||||||
|
gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a h1:O85GKETcmnCNAfv4Aym9tepU8OE0NmcZNqPlXcsBKBs=
|
||||||
|
gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a/go.mod h1:LaSIs30YPGs1H5jwGgPhLzc8vkNc/k0rDX/fEZqiU/M=
|
||||||
|
gitlab.com/golang-commonmark/mdurl v0.0.0-20191124015652-932350d1cb84 h1:qqjvoVXdWIcZCLPMlzgA7P9FZWdPGPvP/l3ef8GzV6o=
|
||||||
|
gitlab.com/golang-commonmark/mdurl v0.0.0-20191124015652-932350d1cb84/go.mod h1:IJZ+fdMvbW2qW6htJx7sLJ04FEs4Ldl/MDsJtMKywfw=
|
||||||
|
gitlab.com/golang-commonmark/puny v0.0.0-20191124015043-9f83538fa04f h1:Wku8eEdeJqIOFHtrfkYUByc4bCaTeA6fL0UJgfEiFMI=
|
||||||
|
gitlab.com/golang-commonmark/puny v0.0.0-20191124015043-9f83538fa04f/go.mod h1:Tiuhl+njh/JIg0uS/sOJVYi0x2HEa5rc1OAaVsb5tAs=
|
||||||
|
gitlab.com/opennota/wd v0.0.0-20180912061657-c5d65f63c638/go.mod h1:EGRJaqe2eO9XGmFtQCvV3Lm9NLico3UhFwUpCG/+mVU=
|
||||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||||
|
@ -174,6 +186,7 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
|
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||||
|
|
8
main.go
8
main.go
|
@ -50,6 +50,7 @@ var (
|
||||||
autoGenerateTags = os.Getenv("AUTO_GENERATE_TAGS")
|
autoGenerateTags = os.Getenv("AUTO_GENERATE_TAGS")
|
||||||
autoGenerateCorrespondents = os.Getenv("AUTO_GENERATE_CORRESPONDENTS")
|
autoGenerateCorrespondents = os.Getenv("AUTO_GENERATE_CORRESPONDENTS")
|
||||||
limitOcrPages int // Will be read from OCR_LIMIT_PAGES
|
limitOcrPages int // Will be read from OCR_LIMIT_PAGES
|
||||||
|
tokenLimit = 0 // Will be read from TOKEN_LIMIT
|
||||||
|
|
||||||
// Templates
|
// Templates
|
||||||
titleTemplate *template.Template
|
titleTemplate *template.Template
|
||||||
|
@ -382,6 +383,13 @@ func validateOrDefaultEnvVars() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize token limit from environment variable
|
||||||
|
if limit := os.Getenv("TOKEN_LIMIT"); limit != "" {
|
||||||
|
if parsed, err := strconv.Atoi(limit); err == nil {
|
||||||
|
tokenLimit = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// documentLogger creates a logger with document context
|
// documentLogger creates a logger with document context
|
||||||
|
|
73
tokens.go
Normal file
73
tokens.go
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/tmc/langchaingo/llms"
|
||||||
|
"github.com/tmc/langchaingo/textsplitter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// getAvailableTokensForContent calculates how many tokens are available for content
|
||||||
|
// by rendering the template with empty content and counting tokens
|
||||||
|
func getAvailableTokensForContent(template *template.Template, data map[string]interface{}) (int, error) {
|
||||||
|
if tokenLimit <= 0 {
|
||||||
|
return 0, nil // No limit when disabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a copy of data and set Content to empty
|
||||||
|
templateData := make(map[string]interface{})
|
||||||
|
for k, v := range data {
|
||||||
|
templateData[k] = v
|
||||||
|
}
|
||||||
|
templateData["Content"] = ""
|
||||||
|
|
||||||
|
// Execute template with empty content
|
||||||
|
var promptBuffer bytes.Buffer
|
||||||
|
if err := template.Execute(&promptBuffer, templateData); err != nil {
|
||||||
|
return 0, fmt.Errorf("error executing template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count tokens in prompt template
|
||||||
|
promptTokens := getTokenCount(promptBuffer.String())
|
||||||
|
log.Debugf("Prompt template uses %d tokens", promptTokens)
|
||||||
|
|
||||||
|
// Add safety margin for prompt tokens
|
||||||
|
promptTokens += 10
|
||||||
|
|
||||||
|
// Calculate available tokens for content
|
||||||
|
availableTokens := tokenLimit - promptTokens
|
||||||
|
if availableTokens < 0 {
|
||||||
|
return 0, fmt.Errorf("prompt template exceeds token limit")
|
||||||
|
}
|
||||||
|
return availableTokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTokenCount(content string) int {
|
||||||
|
return llms.CountTokens(llmModel, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncateContentByTokens truncates the content to fit within the specified token limit
|
||||||
|
func truncateContentByTokens(content string, availableTokens int) (string, error) {
|
||||||
|
if availableTokens <= 0 || tokenLimit <= 0 {
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
tokenCount := getTokenCount(content)
|
||||||
|
if tokenCount <= availableTokens {
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
splitter := textsplitter.NewTokenSplitter(
|
||||||
|
textsplitter.WithChunkSize(availableTokens),
|
||||||
|
textsplitter.WithChunkOverlap(0),
|
||||||
|
// textsplitter.WithModelName(llmModel),
|
||||||
|
)
|
||||||
|
chunks, err := splitter.SplitText(content)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error splitting content: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// return the first chunk
|
||||||
|
return chunks[0], nil
|
||||||
|
}
|
302
tokens_test.go
Normal file
302
tokens_test.go
Normal file
|
@ -0,0 +1,302 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tmc/langchaingo/textsplitter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// resetTokenLimit parses TOKEN_LIMIT from environment and sets the tokenLimit variable
|
||||||
|
func resetTokenLimit() {
|
||||||
|
// Reset tokenLimit
|
||||||
|
tokenLimit = 0
|
||||||
|
// Parse from environment
|
||||||
|
if limit := os.Getenv("TOKEN_LIMIT"); limit != "" {
|
||||||
|
if parsed, err := strconv.Atoi(limit); err == nil {
|
||||||
|
tokenLimit = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenLimit(t *testing.T) {
|
||||||
|
// Save current env and restore after test
|
||||||
|
originalLimit := os.Getenv("TOKEN_LIMIT")
|
||||||
|
defer os.Setenv("TOKEN_LIMIT", originalLimit)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
envValue string
|
||||||
|
wantLimit int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty value",
|
||||||
|
envValue: "",
|
||||||
|
wantLimit: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero value",
|
||||||
|
envValue: "0",
|
||||||
|
wantLimit: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "positive value",
|
||||||
|
envValue: "1000",
|
||||||
|
wantLimit: 1000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid value",
|
||||||
|
envValue: "not-a-number",
|
||||||
|
wantLimit: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Set environment variable
|
||||||
|
os.Setenv("TOKEN_LIMIT", tc.envValue)
|
||||||
|
|
||||||
|
// Set tokenLimit based on environment
|
||||||
|
resetTokenLimit()
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantLimit, tokenLimit)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAvailableTokensForContent(t *testing.T) {
|
||||||
|
// Save current env and restore after test
|
||||||
|
originalLimit := os.Getenv("TOKEN_LIMIT")
|
||||||
|
defer os.Setenv("TOKEN_LIMIT", originalLimit)
|
||||||
|
|
||||||
|
// Test template
|
||||||
|
tmpl := template.Must(template.New("test").Parse("Template with {{.Var1}} and {{.Content}}"))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
limit int
|
||||||
|
data map[string]interface{}
|
||||||
|
wantCount int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "disabled token limit",
|
||||||
|
limit: 0,
|
||||||
|
data: map[string]interface{}{"Var1": "test"},
|
||||||
|
wantCount: 0,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "template exceeds limit",
|
||||||
|
limit: 2,
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"Var1": "test",
|
||||||
|
},
|
||||||
|
wantCount: 0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "available tokens calculation",
|
||||||
|
limit: 100,
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"Var1": "test",
|
||||||
|
},
|
||||||
|
wantCount: 85,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Set token limit
|
||||||
|
os.Setenv("TOKEN_LIMIT", fmt.Sprintf("%d", tc.limit))
|
||||||
|
// Set tokenLimit based on environment
|
||||||
|
resetTokenLimit()
|
||||||
|
|
||||||
|
count, err := getAvailableTokensForContent(tmpl, tc.data)
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.wantCount, count)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncateContentByTokens(t *testing.T) {
|
||||||
|
// Save current env and restore after test
|
||||||
|
originalLimit := os.Getenv("TOKEN_LIMIT")
|
||||||
|
defer os.Setenv("TOKEN_LIMIT", originalLimit)
|
||||||
|
|
||||||
|
// Set a token limit for testing
|
||||||
|
os.Setenv("TOKEN_LIMIT", "100")
|
||||||
|
// Set tokenLimit based on environment
|
||||||
|
resetTokenLimit()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
availableTokens int
|
||||||
|
wantTruncated bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no truncation needed",
|
||||||
|
content: "short content",
|
||||||
|
availableTokens: 20,
|
||||||
|
wantTruncated: false,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disabled by token limit",
|
||||||
|
content: "any content",
|
||||||
|
availableTokens: 0,
|
||||||
|
wantTruncated: false,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "truncation needed",
|
||||||
|
content: "This is a much longer content that will definitely need to be truncated because it exceeds the available tokens",
|
||||||
|
availableTokens: 10,
|
||||||
|
wantTruncated: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty content",
|
||||||
|
content: "",
|
||||||
|
availableTokens: 10,
|
||||||
|
wantTruncated: false,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exact token count",
|
||||||
|
content: "one two three four five",
|
||||||
|
availableTokens: 5,
|
||||||
|
wantTruncated: false,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result, err := truncateContentByTokens(tc.content, tc.availableTokens)
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if tc.wantTruncated {
|
||||||
|
assert.True(t, len(result) < len(tc.content), "Content should be truncated")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tc.content, result, "Content should not be truncated")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenLimitIntegration(t *testing.T) {
|
||||||
|
// Save current env and restore after test
|
||||||
|
originalLimit := os.Getenv("TOKEN_LIMIT")
|
||||||
|
defer os.Setenv("TOKEN_LIMIT", originalLimit)
|
||||||
|
|
||||||
|
// Create a test template
|
||||||
|
tmpl := template.Must(template.New("test").Parse(`
|
||||||
|
Template with variables:
|
||||||
|
Language: {{.Language}}
|
||||||
|
Title: {{.Title}}
|
||||||
|
Content: {{.Content}}
|
||||||
|
`))
|
||||||
|
|
||||||
|
// Test data
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"Language": "English",
|
||||||
|
"Title": "Test Document",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with different token limits
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
limit int
|
||||||
|
content string
|
||||||
|
wantSize int
|
||||||
|
wantError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no limit",
|
||||||
|
limit: 0,
|
||||||
|
content: "original content",
|
||||||
|
wantSize: len("original content"),
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sufficient limit",
|
||||||
|
limit: 1000,
|
||||||
|
content: "original content",
|
||||||
|
wantSize: len("original content"),
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tight limit",
|
||||||
|
limit: 50,
|
||||||
|
content: "This is a long content that should be truncated to fit within the token limit",
|
||||||
|
wantSize: 50,
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "very small limit",
|
||||||
|
limit: 3,
|
||||||
|
content: "Content too large for small limit",
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Set token limit
|
||||||
|
os.Setenv("TOKEN_LIMIT", fmt.Sprintf("%d", tc.limit))
|
||||||
|
// Set tokenLimit based on environment
|
||||||
|
resetTokenLimit()
|
||||||
|
|
||||||
|
// First get available tokens
|
||||||
|
availableTokens, err := getAvailableTokensForContent(tmpl, data)
|
||||||
|
if tc.wantError {
|
||||||
|
require.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Then truncate content
|
||||||
|
truncated, err := truncateContentByTokens(tc.content, availableTokens)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Finally execute template with truncated content
|
||||||
|
data["Content"] = truncated
|
||||||
|
var result string
|
||||||
|
{
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err = tmpl.Execute(&buf, data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
result = buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify final size is within limit if limit is enabled
|
||||||
|
if tc.limit > 0 {
|
||||||
|
splitter := textsplitter.NewTokenSplitter()
|
||||||
|
tokens, err := splitter.SplitText(result)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.LessOrEqual(t, len(tokens), tc.limit)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue