paperless-gpt/app_llm.go
Icereed 90ca87a207
feat: add TOKEN_LIMIT environment variable for controlling maximum to… (#161)
* feat: add TOKEN_LIMIT environment variable for controlling maximum tokens in prompts

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
2025-02-02 15:14:18 +00:00

406 lines
12 KiB
Go

package main
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"image"
"slices"
"strings"
"sync"
_ "image/jpeg"
"github.com/sirupsen/logrus"
"github.com/tmc/langchaingo/llms"
)
// getSuggestedCorrespondent generates a suggested correspondent for a document using the LLM
func (app *App) getSuggestedCorrespondent(ctx context.Context, content string, suggestedTitle string, availableCorrespondents []string, correspondentBlackList []string) (string, error) {
likelyLanguage := getLikelyLanguage()
templateMutex.RLock()
defer templateMutex.RUnlock()
// Get available tokens for content
templateData := map[string]interface{}{
"Language": likelyLanguage,
"AvailableCorrespondents": availableCorrespondents,
"BlackList": correspondentBlackList,
"Title": suggestedTitle,
}
availableTokens, err := getAvailableTokensForContent(correspondentTemplate, templateData)
if err != nil {
return "", fmt.Errorf("error calculating available tokens: %v", err)
}
// Truncate content if needed
truncatedContent, err := truncateContentByTokens(content, availableTokens)
if err != nil {
return "", fmt.Errorf("error truncating content: %v", err)
}
// Execute template with truncated content
var promptBuffer bytes.Buffer
templateData["Content"] = truncatedContent
err = correspondentTemplate.Execute(&promptBuffer, templateData)
if err != nil {
return "", fmt.Errorf("error executing correspondent template: %v", err)
}
prompt := promptBuffer.String()
log.Debugf("Correspondent suggestion prompt: %s", prompt)
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
{
Parts: []llms.ContentPart{
llms.TextContent{
Text: prompt,
},
},
Role: llms.ChatMessageTypeHuman,
},
})
if err != nil {
return "", fmt.Errorf("error getting response from LLM: %v", err)
}
response := strings.TrimSpace(completion.Choices[0].Content)
return response, nil
}
// getSuggestedTags generates suggested tags for a document using the LLM
func (app *App) getSuggestedTags(
ctx context.Context,
content string,
suggestedTitle string,
availableTags []string,
originalTags []string,
logger *logrus.Entry) ([]string, error) {
likelyLanguage := getLikelyLanguage()
templateMutex.RLock()
defer templateMutex.RUnlock()
// Remove all paperless-gpt related tags from available tags
availableTags = removeTagFromList(availableTags, manualTag)
availableTags = removeTagFromList(availableTags, autoTag)
availableTags = removeTagFromList(availableTags, autoOcrTag)
// Get available tokens for content
templateData := map[string]interface{}{
"Language": likelyLanguage,
"AvailableTags": availableTags,
"OriginalTags": originalTags,
"Title": suggestedTitle,
}
availableTokens, err := getAvailableTokensForContent(tagTemplate, templateData)
if err != nil {
logger.Errorf("Error calculating available tokens: %v", err)
return nil, fmt.Errorf("error calculating available tokens: %v", err)
}
// Truncate content if needed
truncatedContent, err := truncateContentByTokens(content, availableTokens)
if err != nil {
logger.Errorf("Error truncating content: %v", err)
return nil, fmt.Errorf("error truncating content: %v", err)
}
// Execute template with truncated content
var promptBuffer bytes.Buffer
templateData["Content"] = truncatedContent
err = tagTemplate.Execute(&promptBuffer, templateData)
if err != nil {
logger.Errorf("Error executing tag template: %v", err)
return nil, fmt.Errorf("error executing tag template: %v", err)
}
prompt := promptBuffer.String()
logger.Debugf("Tag suggestion prompt: %s", prompt)
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
{
Parts: []llms.ContentPart{
llms.TextContent{
Text: prompt,
},
},
Role: llms.ChatMessageTypeHuman,
},
})
if err != nil {
logger.Errorf("Error getting response from LLM: %v", err)
return nil, fmt.Errorf("error getting response from LLM: %v", err)
}
response := strings.TrimSpace(completion.Choices[0].Content)
suggestedTags := strings.Split(response, ",")
for i, tag := range suggestedTags {
suggestedTags[i] = strings.TrimSpace(tag)
}
// append the original tags to the suggested tags
suggestedTags = append(suggestedTags, originalTags...)
// Remove duplicates
slices.Sort(suggestedTags)
suggestedTags = slices.Compact(suggestedTags)
// Filter out tags that are not in the available tags list
filteredTags := []string{}
for _, tag := range suggestedTags {
for _, availableTag := range availableTags {
if strings.EqualFold(tag, availableTag) {
filteredTags = append(filteredTags, availableTag)
break
}
}
}
return filteredTags, nil
}
func (app *App) doOCRViaLLM(ctx context.Context, jpegBytes []byte, logger *logrus.Entry) (string, error) {
templateMutex.RLock()
defer templateMutex.RUnlock()
likelyLanguage := getLikelyLanguage()
var promptBuffer bytes.Buffer
err := ocrTemplate.Execute(&promptBuffer, map[string]interface{}{
"Language": likelyLanguage,
})
if err != nil {
return "", fmt.Errorf("error executing tag template: %v", err)
}
prompt := promptBuffer.String()
// Log the image dimensions
img, _, err := image.Decode(bytes.NewReader(jpegBytes))
if err != nil {
return "", fmt.Errorf("error decoding image: %v", err)
}
bounds := img.Bounds()
logger.Debugf("Image dimensions: %dx%d", bounds.Dx(), bounds.Dy())
// If not OpenAI then use binary part for image, otherwise, use the ImageURL part with encoding from https://platform.openai.com/docs/guides/vision
var parts []llms.ContentPart
if strings.ToLower(visionLlmProvider) != "openai" {
// Log image size in kilobytes
logger.Debugf("Image size: %d KB", len(jpegBytes)/1024)
parts = []llms.ContentPart{
llms.BinaryPart("image/jpeg", jpegBytes),
llms.TextPart(prompt),
}
} else {
base64Image := base64.StdEncoding.EncodeToString(jpegBytes)
// Log image size in kilobytes
logger.Debugf("Image size: %d KB", len(base64Image)/1024)
parts = []llms.ContentPart{
llms.ImageURLPart(fmt.Sprintf("data:image/jpeg;base64,%s", base64Image)),
llms.TextPart(prompt),
}
}
// Convert the image to text
completion, err := app.VisionLLM.GenerateContent(ctx, []llms.MessageContent{
{
Parts: parts,
Role: llms.ChatMessageTypeHuman,
},
})
if err != nil {
return "", fmt.Errorf("error getting response from LLM: %v", err)
}
result := completion.Choices[0].Content
fmt.Println(result)
return result, nil
}
// getSuggestedTitle generates a suggested title for a document using the LLM
func (app *App) getSuggestedTitle(ctx context.Context, content string, originalTitle string, logger *logrus.Entry) (string, error) {
likelyLanguage := getLikelyLanguage()
templateMutex.RLock()
defer templateMutex.RUnlock()
// Get available tokens for content
templateData := map[string]interface{}{
"Language": likelyLanguage,
"Content": content,
"Title": originalTitle,
}
availableTokens, err := getAvailableTokensForContent(titleTemplate, templateData)
if err != nil {
logger.Errorf("Error calculating available tokens: %v", err)
return "", fmt.Errorf("error calculating available tokens: %v", err)
}
// Truncate content if needed
truncatedContent, err := truncateContentByTokens(content, availableTokens)
if err != nil {
logger.Errorf("Error truncating content: %v", err)
return "", fmt.Errorf("error truncating content: %v", err)
}
// Execute template with truncated content
var promptBuffer bytes.Buffer
templateData["Content"] = truncatedContent
err = titleTemplate.Execute(&promptBuffer, templateData)
if err != nil {
return "", fmt.Errorf("error executing title template: %v", err)
}
prompt := promptBuffer.String()
logger.Debugf("Title suggestion prompt: %s", prompt)
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
{
Parts: []llms.ContentPart{
llms.TextContent{
Text: prompt,
},
},
Role: llms.ChatMessageTypeHuman,
},
})
if err != nil {
return "", fmt.Errorf("error getting response from LLM: %v", err)
}
return strings.TrimSpace(strings.Trim(completion.Choices[0].Content, "\"")), nil
}
// generateDocumentSuggestions generates suggestions for a set of documents
func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionRequest GenerateSuggestionsRequest, logger *logrus.Entry) ([]DocumentSuggestion, error) {
// Fetch all available tags from paperless-ngx
availableTagsMap, err := app.Client.GetAllTags(ctx)
if err != nil {
return nil, fmt.Errorf("failed to fetch available tags: %v", err)
}
// Prepare a list of tag names
availableTagNames := make([]string, 0, len(availableTagsMap))
for tagName := range availableTagsMap {
if tagName == manualTag {
continue
}
availableTagNames = append(availableTagNames, tagName)
}
// Prepare a list of document correspodents
availableCorrespondentsMap, err := app.Client.GetAllCorrespondents(ctx)
if err != nil {
return nil, fmt.Errorf("failed to fetch available correspondents: %v", err)
}
// Prepare a list of correspondent names
availableCorrespondentNames := make([]string, 0, len(availableCorrespondentsMap))
for correspondentName := range availableCorrespondentsMap {
availableCorrespondentNames = append(availableCorrespondentNames, correspondentName)
}
documents := suggestionRequest.Documents
documentSuggestions := []DocumentSuggestion{}
var wg sync.WaitGroup
var mu sync.Mutex
errorsList := make([]error, 0)
for i := range documents {
wg.Add(1)
go func(doc Document) {
defer wg.Done()
documentID := doc.ID
docLogger := documentLogger(documentID)
docLogger.Printf("Processing Document ID %d...", documentID)
content := doc.Content
suggestedTitle := doc.Title
var suggestedTags []string
var suggestedCorrespondent string
if suggestionRequest.GenerateTitles {
suggestedTitle, err = app.getSuggestedTitle(ctx, content, suggestedTitle, docLogger)
if err != nil {
mu.Lock()
errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err))
mu.Unlock()
docLogger.Errorf("Error processing document %d: %v", documentID, err)
return
}
}
if suggestionRequest.GenerateTags {
suggestedTags, err = app.getSuggestedTags(ctx, content, suggestedTitle, availableTagNames, doc.Tags, docLogger)
if err != nil {
mu.Lock()
errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err))
mu.Unlock()
logger.Errorf("Error generating tags for document %d: %v", documentID, err)
return
}
}
if suggestionRequest.GenerateCorrespondents {
suggestedCorrespondent, err = app.getSuggestedCorrespondent(ctx, content, suggestedTitle, availableCorrespondentNames, correspondentBlackList)
if err != nil {
mu.Lock()
errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err))
mu.Unlock()
log.Errorf("Error generating correspondents for document %d: %v", documentID, err)
return
}
}
mu.Lock()
suggestion := DocumentSuggestion{
ID: documentID,
OriginalDocument: doc,
}
// Titles
if suggestionRequest.GenerateTitles {
docLogger.Printf("Suggested title for document %d: %s", documentID, suggestedTitle)
suggestion.SuggestedTitle = suggestedTitle
} else {
suggestion.SuggestedTitle = doc.Title
}
// Tags
if suggestionRequest.GenerateTags {
docLogger.Printf("Suggested tags for document %d: %v", documentID, suggestedTags)
suggestion.SuggestedTags = suggestedTags
} else {
suggestion.SuggestedTags = doc.Tags
}
// Correspondents
if suggestionRequest.GenerateCorrespondents {
log.Printf("Suggested correspondent for document %d: %s", documentID, suggestedCorrespondent)
suggestion.SuggestedCorrespondent = suggestedCorrespondent
} else {
suggestion.SuggestedCorrespondent = ""
}
// Remove manual tag from the list of suggested tags
suggestion.RemoveTags = []string{manualTag, autoTag}
documentSuggestions = append(documentSuggestions, suggestion)
mu.Unlock()
docLogger.Printf("Document %d processed successfully.", documentID)
}(documents[i])
}
wg.Wait()
if len(errorsList) > 0 {
return nil, errorsList[0] // Return the first error encountered
}
return documentSuggestions, nil
}