package main import ( "bytes" "fmt" "os" "strconv" "testing" "text/template" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tmc/langchaingo/textsplitter" ) // resetTokenLimit parses TOKEN_LIMIT from environment and sets the tokenLimit variable func resetTokenLimit() { // Reset tokenLimit tokenLimit = 0 // Parse from environment if limit := os.Getenv("TOKEN_LIMIT"); limit != "" { if parsed, err := strconv.Atoi(limit); err == nil { tokenLimit = parsed } } } func TestTokenLimit(t *testing.T) { // Save current env and restore after test originalLimit := os.Getenv("TOKEN_LIMIT") defer os.Setenv("TOKEN_LIMIT", originalLimit) tests := []struct { name string envValue string wantLimit int }{ { name: "empty value", envValue: "", wantLimit: 0, }, { name: "zero value", envValue: "0", wantLimit: 0, }, { name: "positive value", envValue: "1000", wantLimit: 1000, }, { name: "invalid value", envValue: "not-a-number", wantLimit: 0, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Set environment variable os.Setenv("TOKEN_LIMIT", tc.envValue) // Set tokenLimit based on environment resetTokenLimit() assert.Equal(t, tc.wantLimit, tokenLimit) }) } } func TestGetAvailableTokensForContent(t *testing.T) { // Save current env and restore after test originalLimit := os.Getenv("TOKEN_LIMIT") defer os.Setenv("TOKEN_LIMIT", originalLimit) // Test template tmpl := template.Must(template.New("test").Parse("Template with {{.Var1}} and {{.Content}}")) tests := []struct { name string limit int data map[string]interface{} wantCount int wantErr bool }{ { name: "disabled token limit", limit: 0, data: map[string]interface{}{"Var1": "test"}, wantCount: -1, wantErr: false, }, { name: "template exceeds limit", limit: 2, data: map[string]interface{}{ "Var1": "test", }, wantCount: 0, wantErr: true, }, { name: "available tokens calculation", limit: 100, data: map[string]interface{}{ "Var1": "test", }, wantCount: 85, wantErr: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Set token limit os.Setenv("TOKEN_LIMIT", fmt.Sprintf("%d", tc.limit)) // Set tokenLimit based on environment resetTokenLimit() count, err := getAvailableTokensForContent(tmpl, tc.data) if tc.wantErr { assert.Error(t, err) } else { assert.NoError(t, err) assert.Equal(t, tc.wantCount, count) } }) } } func TestTruncateContentByTokens(t *testing.T) { // Save current env and restore after test originalLimit := os.Getenv("TOKEN_LIMIT") defer os.Setenv("TOKEN_LIMIT", originalLimit) // Set a token limit for testing os.Setenv("TOKEN_LIMIT", "100") // Set tokenLimit based on environment resetTokenLimit() tests := []struct { name string content string availableTokens int wantTruncated bool wantErr bool }{ { name: "no truncation needed", content: "short content", availableTokens: 20, wantTruncated: false, wantErr: false, }, { name: "disabled by token limit", content: "any content", availableTokens: -1, wantTruncated: false, wantErr: false, }, { name: "truncation needed", content: "This is a much longer content that will definitely need to be truncated because it exceeds the available tokens", availableTokens: 10, wantTruncated: true, wantErr: false, }, { name: "empty content", content: "", availableTokens: 10, wantTruncated: false, wantErr: false, }, { name: "exact token count", content: "one two three four five", availableTokens: 5, wantTruncated: false, wantErr: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { result, err := truncateContentByTokens(tc.content, tc.availableTokens) if tc.wantErr { require.Error(t, err) return } require.NoError(t, err) if tc.wantTruncated { assert.True(t, len(result) < len(tc.content), "Content should be truncated") } else { assert.Equal(t, tc.content, result, "Content should not be truncated") } }) } } func TestTokenLimitIntegration(t *testing.T) { // Save current env and restore after test originalLimit := os.Getenv("TOKEN_LIMIT") defer os.Setenv("TOKEN_LIMIT", originalLimit) // Create a test template tmpl := template.Must(template.New("test").Parse(` Template with variables: Language: {{.Language}} Title: {{.Title}} Content: {{.Content}} `)) // Test data data := map[string]interface{}{ "Language": "English", "Title": "Test Document", } // Test with different token limits tests := []struct { name string limit int content string wantSize int wantError bool }{ { name: "no limit", limit: 0, content: "original content", wantSize: len("original content"), wantError: false, }, { name: "sufficient limit", limit: 1000, content: "original content", wantSize: len("original content"), wantError: false, }, { name: "tight limit", limit: 50, content: "This is a long content that should be truncated to fit within the token limit", wantSize: 50, wantError: false, }, { name: "very small limit", limit: 3, content: "Content too large for small limit", wantError: true, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Set token limit os.Setenv("TOKEN_LIMIT", fmt.Sprintf("%d", tc.limit)) // Set tokenLimit based on environment resetTokenLimit() // First get available tokens availableTokens, err := getAvailableTokensForContent(tmpl, data) if tc.wantError { require.Error(t, err) return } require.NoError(t, err) // Then truncate content truncated, err := truncateContentByTokens(tc.content, availableTokens) require.NoError(t, err) // Finally execute template with truncated content data["Content"] = truncated var result string { var buf bytes.Buffer err = tmpl.Execute(&buf, data) require.NoError(t, err) result = buf.String() } // Verify final size is within limit if limit is enabled if tc.limit > 0 { splitter := textsplitter.NewTokenSplitter() tokens, err := splitter.SplitText(result) require.NoError(t, err) assert.LessOrEqual(t, len(tokens), tc.limit) } }) } }