paperless-gpt/app_llm.go

272 lines
7.4 KiB
Go
Raw Permalink Normal View History

2024-10-28 11:34:41 -05:00
package main
import (
"bytes"
"context"
"encoding/base64"
2024-10-28 11:34:41 -05:00
"fmt"
"image"
2024-10-28 11:34:41 -05:00
"strings"
"sync"
_ "image/jpeg"
2025-01-13 00:52:40 -06:00
"github.com/sirupsen/logrus"
2024-10-28 11:34:41 -05:00
"github.com/tmc/langchaingo/llms"
)
// getSuggestedTags generates suggested tags for a document using the LLM
func (app *App) getSuggestedTags(
ctx context.Context,
content string,
suggestedTitle string,
availableTags []string,
logger *logrus.Entry) ([]string, error) {
2024-10-28 11:34:41 -05:00
likelyLanguage := getLikelyLanguage()
templateMutex.RLock()
defer templateMutex.RUnlock()
var promptBuffer bytes.Buffer
err := tagTemplate.Execute(&promptBuffer, map[string]interface{}{
"Language": likelyLanguage,
"AvailableTags": availableTags,
"Title": suggestedTitle,
"Content": content,
})
if err != nil {
logger.Errorf("Error executing tag template: %v", err)
2024-10-28 11:34:41 -05:00
return nil, fmt.Errorf("error executing tag template: %v", err)
}
prompt := promptBuffer.String()
logger.Debugf("Tag suggestion prompt: %s", prompt)
2024-10-28 11:34:41 -05:00
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
{
Parts: []llms.ContentPart{
llms.TextContent{
Text: prompt,
},
},
Role: llms.ChatMessageTypeHuman,
},
})
if err != nil {
logger.Errorf("Error getting response from LLM: %v", err)
2024-10-28 11:34:41 -05:00
return nil, fmt.Errorf("error getting response from LLM: %v", err)
}
response := strings.TrimSpace(completion.Choices[0].Content)
suggestedTags := strings.Split(response, ",")
for i, tag := range suggestedTags {
suggestedTags[i] = strings.TrimSpace(tag)
}
// Filter out tags that are not in the available tags list
filteredTags := []string{}
for _, tag := range suggestedTags {
for _, availableTag := range availableTags {
if strings.EqualFold(tag, availableTag) {
filteredTags = append(filteredTags, availableTag)
break
}
}
}
return filteredTags, nil
}
2025-01-13 00:52:40 -06:00
func (app *App) doOCRViaLLM(ctx context.Context, jpegBytes []byte, logger *logrus.Entry) (string, error) {
2024-10-28 11:34:41 -05:00
templateMutex.RLock()
defer templateMutex.RUnlock()
likelyLanguage := getLikelyLanguage()
var promptBuffer bytes.Buffer
err := ocrTemplate.Execute(&promptBuffer, map[string]interface{}{
"Language": likelyLanguage,
})
if err != nil {
return "", fmt.Errorf("error executing tag template: %v", err)
}
prompt := promptBuffer.String()
// Log the image dimensions
img, _, err := image.Decode(bytes.NewReader(jpegBytes))
if err != nil {
return "", fmt.Errorf("error decoding image: %v", err)
}
bounds := img.Bounds()
2025-01-13 00:52:40 -06:00
logger.Debugf("Image dimensions: %dx%d", bounds.Dx(), bounds.Dy())
// If not OpenAI then use binary part for image, otherwise, use the ImageURL part with encoding from https://platform.openai.com/docs/guides/vision
var parts []llms.ContentPart
if strings.ToLower(visionLlmProvider) != "openai" {
// Log image size in kilobytes
2025-01-13 00:52:40 -06:00
logger.Debugf("Image size: %d KB", len(jpegBytes)/1024)
parts = []llms.ContentPart{
llms.BinaryPart("image/jpeg", jpegBytes),
llms.TextPart(prompt),
}
} else {
base64Image := base64.StdEncoding.EncodeToString(jpegBytes)
// Log image size in kilobytes
2025-01-13 00:52:40 -06:00
logger.Debugf("Image size: %d KB", len(base64Image)/1024)
parts = []llms.ContentPart{
llms.ImageURLPart(fmt.Sprintf("data:image/jpeg;base64,%s", base64Image)),
llms.TextPart(prompt),
}
}
2024-10-28 11:34:41 -05:00
// Convert the image to text
completion, err := app.VisionLLM.GenerateContent(ctx, []llms.MessageContent{
{
Parts: parts,
Role: llms.ChatMessageTypeHuman,
2024-10-28 11:34:41 -05:00
},
})
if err != nil {
return "", fmt.Errorf("error getting response from LLM: %v", err)
}
result := completion.Choices[0].Content
fmt.Println(result)
return result, nil
}
// getSuggestedTitle generates a suggested title for a document using the LLM
func (app *App) getSuggestedTitle(ctx context.Context, content string, logger *logrus.Entry) (string, error) {
2024-10-28 11:34:41 -05:00
likelyLanguage := getLikelyLanguage()
templateMutex.RLock()
defer templateMutex.RUnlock()
var promptBuffer bytes.Buffer
err := titleTemplate.Execute(&promptBuffer, map[string]interface{}{
"Language": likelyLanguage,
"Content": content,
})
if err != nil {
return "", fmt.Errorf("error executing title template: %v", err)
}
prompt := promptBuffer.String()
logger.Debugf("Title suggestion prompt: %s", prompt)
2024-10-28 11:34:41 -05:00
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
{
Parts: []llms.ContentPart{
llms.TextContent{
Text: prompt,
},
},
Role: llms.ChatMessageTypeHuman,
},
})
if err != nil {
return "", fmt.Errorf("error getting response from LLM: %v", err)
}
return strings.TrimSpace(strings.Trim(completion.Choices[0].Content, "\"")), nil
}
// generateDocumentSuggestions generates suggestions for a set of documents
2025-01-13 00:52:40 -06:00
func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionRequest GenerateSuggestionsRequest, logger *logrus.Entry) ([]DocumentSuggestion, error) {
2024-10-28 11:34:41 -05:00
// Fetch all available tags from paperless-ngx
availableTagsMap, err := app.Client.GetAllTags(ctx)
if err != nil {
return nil, fmt.Errorf("failed to fetch available tags: %v", err)
}
// Prepare a list of tag names
availableTagNames := make([]string, 0, len(availableTagsMap))
for tagName := range availableTagsMap {
if tagName == manualTag {
continue
}
availableTagNames = append(availableTagNames, tagName)
}
documents := suggestionRequest.Documents
documentSuggestions := []DocumentSuggestion{}
var wg sync.WaitGroup
var mu sync.Mutex
errorsList := make([]error, 0)
for i := range documents {
wg.Add(1)
go func(doc Document) {
defer wg.Done()
documentID := doc.ID
docLogger := documentLogger(documentID)
docLogger.Printf("Processing Document ID %d...", documentID)
2024-10-28 11:34:41 -05:00
content := doc.Content
if len(content) > 5000 {
content = content[:5000]
}
var suggestedTitle string
var suggestedTags []string
if suggestionRequest.GenerateTitles {
suggestedTitle, err = app.getSuggestedTitle(ctx, content, docLogger)
2024-10-28 11:34:41 -05:00
if err != nil {
mu.Lock()
errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err))
mu.Unlock()
docLogger.Errorf("Error processing document %d: %v", documentID, err)
2024-10-28 11:34:41 -05:00
return
}
}
if suggestionRequest.GenerateTags {
suggestedTags, err = app.getSuggestedTags(ctx, content, suggestedTitle, availableTagNames, docLogger)
2024-10-28 11:34:41 -05:00
if err != nil {
mu.Lock()
errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err))
mu.Unlock()
2025-01-13 00:52:40 -06:00
logger.Errorf("Error generating tags for document %d: %v", documentID, err)
2024-10-28 11:34:41 -05:00
return
}
}
mu.Lock()
suggestion := DocumentSuggestion{
ID: documentID,
OriginalDocument: doc,
}
// Titles
if suggestionRequest.GenerateTitles {
docLogger.Printf("Suggested title for document %d: %s", documentID, suggestedTitle)
2024-10-28 11:34:41 -05:00
suggestion.SuggestedTitle = suggestedTitle
} else {
suggestion.SuggestedTitle = doc.Title
}
// Tags
if suggestionRequest.GenerateTags {
docLogger.Printf("Suggested tags for document %d: %v", documentID, suggestedTags)
2024-10-28 11:34:41 -05:00
suggestion.SuggestedTags = suggestedTags
} else {
suggestion.SuggestedTags = removeTagFromList(doc.Tags, manualTag)
}
documentSuggestions = append(documentSuggestions, suggestion)
mu.Unlock()
docLogger.Printf("Document %d processed successfully.", documentID)
2024-10-28 11:34:41 -05:00
}(documents[i])
}
wg.Wait()
if len(errorsList) > 0 {
return nil, errorsList[0] // Return the first error encountered
}
return documentSuggestions, nil
}