feat: improve auto throughput & logging

This commit is contained in:
Jonas Hess 2024-10-31 20:00:43 +01:00
parent def89b5aea
commit 5dadbcb53d
8 changed files with 113 additions and 52 deletions

View file

@ -77,6 +77,7 @@ services:
OLLAMA_HOST: 'http://host.docker.internal:11434' # If using Ollama OLLAMA_HOST: 'http://host.docker.internal:11434' # If using Ollama
VISION_LLM_PROVIDER: 'ollama' # Optional, for OCR VISION_LLM_PROVIDER: 'ollama' # Optional, for OCR
VISION_LLM_MODEL: 'minicpm-v' # Optional, for OCR VISION_LLM_MODEL: 'minicpm-v' # Optional, for OCR
LOG_LEVEL: 'info' # Optional or 'debug', 'warn', 'error'
volumes: volumes:
- ./prompts:/app/prompts # Mount the prompts directory - ./prompts:/app/prompts # Mount the prompts directory
ports: ports:
@ -122,6 +123,7 @@ If you prefer to run the application manually:
-e LLM_LANGUAGE='English' \ -e LLM_LANGUAGE='English' \
-e VISION_LLM_PROVIDER='ollama' \ -e VISION_LLM_PROVIDER='ollama' \
-e VISION_LLM_MODEL='minicpm-v' \ -e VISION_LLM_MODEL='minicpm-v' \
-e LOG_LEVEL='info' \
-v $(pwd)/prompts:/app/prompts \ # Mount the prompts directory -v $(pwd)/prompts:/app/prompts \ # Mount the prompts directory
-p 8080:8080 \ -p 8080:8080 \
paperless-gpt paperless-gpt
@ -142,6 +144,7 @@ If you prefer to run the application manually:
| `OLLAMA_HOST` | The URL of the Ollama server (e.g., `http://host.docker.internal:11434`). Useful if using Ollama. Default is `http://127.0.0.1:11434`. | No | | `OLLAMA_HOST` | The URL of the Ollama server (e.g., `http://host.docker.internal:11434`). Useful if using Ollama. Default is `http://127.0.0.1:11434`. | No |
| `VISION_LLM_PROVIDER` | The vision LLM provider to use for OCR (`openai` or `ollama`). | No | | `VISION_LLM_PROVIDER` | The vision LLM provider to use for OCR (`openai` or `ollama`). | No |
| `VISION_LLM_MODEL` | The model name to use for OCR (e.g., `minicpm-v`). | No | | `VISION_LLM_MODEL` | The model name to use for OCR (e.g., `minicpm-v`). | No |
| `LOG_LEVEL` | The log level for the application (`info`, `debug`, `warn`, `error`). Default is `info`. | No |
**Note:** When using Ollama, ensure that the Ollama server is running and accessible from the paperless-gpt container. **Note:** When using Ollama, ensure that the Ollama server is running and accessible from the paperless-gpt container.

View file

@ -2,7 +2,6 @@ package main
import ( import (
"fmt" "fmt"
"log"
"net/http" "net/http"
"os" "os"
"strconv" "strconv"
@ -59,7 +58,7 @@ func updatePromptsHandler(c *gin.Context) {
titleTemplate = t titleTemplate = t
err = os.WriteFile("prompts/title_prompt.tmpl", []byte(req.TitleTemplate), 0644) err = os.WriteFile("prompts/title_prompt.tmpl", []byte(req.TitleTemplate), 0644)
if err != nil { if err != nil {
log.Printf("Failed to write title_prompt.tmpl: %v", err) log.Errorf("Failed to write title_prompt.tmpl: %v", err)
} }
} }
@ -73,7 +72,7 @@ func updatePromptsHandler(c *gin.Context) {
tagTemplate = t tagTemplate = t
err = os.WriteFile("prompts/tag_prompt.tmpl", []byte(req.TagTemplate), 0644) err = os.WriteFile("prompts/tag_prompt.tmpl", []byte(req.TagTemplate), 0644)
if err != nil { if err != nil {
log.Printf("Failed to write tag_prompt.tmpl: %v", err) log.Errorf("Failed to write tag_prompt.tmpl: %v", err)
} }
} }
@ -87,7 +86,7 @@ func (app *App) getAllTagsHandler(c *gin.Context) {
tags, err := app.Client.GetAllTags(ctx) tags, err := app.Client.GetAllTags(ctx)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error fetching tags: %v", err)}) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error fetching tags: %v", err)})
log.Printf("Error fetching tags: %v", err) log.Errorf("Error fetching tags: %v", err)
return return
} }
@ -101,7 +100,7 @@ func (app *App) documentsHandler(c *gin.Context) {
documents, err := app.Client.GetDocumentsByTags(ctx, []string{manualTag}) documents, err := app.Client.GetDocumentsByTags(ctx, []string{manualTag})
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error fetching documents: %v", err)}) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error fetching documents: %v", err)})
log.Printf("Error fetching documents: %v", err) log.Errorf("Error fetching documents: %v", err)
return return
} }
@ -115,14 +114,14 @@ func (app *App) generateSuggestionsHandler(c *gin.Context) {
var suggestionRequest GenerateSuggestionsRequest var suggestionRequest GenerateSuggestionsRequest
if err := c.ShouldBindJSON(&suggestionRequest); err != nil { if err := c.ShouldBindJSON(&suggestionRequest); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request payload: %v", err)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request payload: %v", err)})
log.Printf("Invalid request payload: %v", err) log.Errorf("Invalid request payload: %v", err)
return return
} }
results, err := app.generateDocumentSuggestions(ctx, suggestionRequest) results, err := app.generateDocumentSuggestions(ctx, suggestionRequest)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error processing documents: %v", err)}) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error processing documents: %v", err)})
log.Printf("Error processing documents: %v", err) log.Errorf("Error processing documents: %v", err)
return return
} }
@ -135,14 +134,14 @@ func (app *App) updateDocumentsHandler(c *gin.Context) {
var documents []DocumentSuggestion var documents []DocumentSuggestion
if err := c.ShouldBindJSON(&documents); err != nil { if err := c.ShouldBindJSON(&documents); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request payload: %v", err)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request payload: %v", err)})
log.Printf("Invalid request payload: %v", err) log.Errorf("Invalid request payload: %v", err)
return return
} }
err := app.Client.UpdateDocuments(ctx, documents) err := app.Client.UpdateDocuments(ctx, documents)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error updating documents: %v", err)}) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error updating documents: %v", err)})
log.Printf("Error updating documents: %v", err) log.Errorf("Error updating documents: %v", err)
return return
} }

View file

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"log"
"strings" "strings"
"sync" "sync"
@ -30,7 +29,7 @@ func (app *App) getSuggestedTags(ctx context.Context, content string, suggestedT
} }
prompt := promptBuffer.String() prompt := promptBuffer.String()
log.Printf("Tag suggestion prompt: %s", prompt) log.Debugf("Tag suggestion prompt: %s", prompt)
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{ completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
{ {
@ -119,7 +118,7 @@ func (app *App) getSuggestedTitle(ctx context.Context, content string) (string,
prompt := promptBuffer.String() prompt := promptBuffer.String()
log.Printf("Title suggestion prompt: %s", prompt) log.Debugf("Title suggestion prompt: %s", prompt)
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{ completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
{ {
@ -183,7 +182,7 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque
mu.Lock() mu.Lock()
errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err)) errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err))
mu.Unlock() mu.Unlock()
log.Printf("Error processing document %d: %v", documentID, err) log.Errorf("Error processing document %d: %v", documentID, err)
return return
} }
} }
@ -194,7 +193,7 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque
mu.Lock() mu.Lock()
errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err)) errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err))
mu.Unlock() mu.Unlock()
log.Printf("Error generating tags for document %d: %v", documentID, err) log.Errorf("Error generating tags for document %d: %v", documentID, err)
return return
} }
} }
@ -206,6 +205,7 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque
} }
// Titles // Titles
if suggestionRequest.GenerateTitles { if suggestionRequest.GenerateTitles {
log.Printf("Suggested title for document %d: %s", documentID, suggestedTitle)
suggestion.SuggestedTitle = suggestedTitle suggestion.SuggestedTitle = suggestedTitle
} else { } else {
suggestion.SuggestedTitle = doc.Title suggestion.SuggestedTitle = doc.Title
@ -213,6 +213,7 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque
// Tags // Tags
if suggestionRequest.GenerateTags { if suggestionRequest.GenerateTags {
log.Printf("Suggested tags for document %d: %v", documentID, suggestedTags)
suggestion.SuggestedTags = suggestedTags suggestion.SuggestedTags = suggestedTags
} else { } else {
suggestion.SuggestedTags = removeTagFromList(doc.Tags, manualTag) suggestion.SuggestedTags = removeTagFromList(doc.Tags, manualTag)

3
go.mod
View file

@ -8,6 +8,8 @@ require (
github.com/Masterminds/sprig/v3 v3.2.3 github.com/Masterminds/sprig/v3 v3.2.3
github.com/gen2brain/go-fitz v1.24.14 github.com/gen2brain/go-fitz v1.24.14
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/google/uuid v1.6.0
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/tmc/langchaingo v0.1.12 github.com/tmc/langchaingo v0.1.12
golang.org/x/sync v0.7.0 golang.org/x/sync v0.7.0
@ -29,7 +31,6 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/go-playground/validator/v10 v10.20.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/huandu/xstrings v1.3.3 // indirect github.com/huandu/xstrings v1.3.3 // indirect
github.com/imdario/mergo v0.3.13 // indirect github.com/imdario/mergo v0.3.13 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect

3
go.sum
View file

@ -77,6 +77,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ=
github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng=
github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -123,6 +125,7 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

34
jobs.go
View file

@ -3,7 +3,6 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"os" "os"
"sort" "sort"
"strings" "strings"
@ -11,6 +10,7 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/sirupsen/logrus"
) )
// Job represents an OCR job // Job represents an OCR job
@ -31,13 +31,25 @@ type JobStore struct {
} }
var ( var (
logger = logrus.New()
jobStore = &JobStore{ jobStore = &JobStore{
jobs: make(map[string]*Job), jobs: make(map[string]*Job),
} }
jobQueue = make(chan *Job, 100) // Buffered channel with capacity of 100 jobs jobQueue = make(chan *Job, 100) // Buffered channel with capacity of 100 jobs
logger = log.New(os.Stdout, "OCR_JOB: ", log.LstdFlags)
) )
func init() {
// Initialize logger
logger.SetOutput(os.Stdout)
logger.SetFormatter(&logrus.TextFormatter{
FullTimestamp: true,
})
logger.SetLevel(logrus.InfoLevel)
logger.WithField("prefix", "OCR_JOB")
}
func generateJobID() string { func generateJobID() string {
return uuid.New().String() return uuid.New().String()
} }
@ -47,7 +59,7 @@ func (store *JobStore) addJob(job *Job) {
defer store.Unlock() defer store.Unlock()
job.PagesDone = 0 // Initialize PagesDone to 0 job.PagesDone = 0 // Initialize PagesDone to 0
store.jobs[job.ID] = job store.jobs[job.ID] = job
logger.Printf("Job added: %v", job) logger.Infof("Job added: %v", job)
} }
func (store *JobStore) getJob(jobID string) (*Job, bool) { func (store *JobStore) getJob(jobID string) (*Job, bool) {
@ -82,7 +94,7 @@ func (store *JobStore) updateJobStatus(jobID, status, result string) {
job.Result = result job.Result = result
} }
job.UpdatedAt = time.Now() job.UpdatedAt = time.Now()
logger.Printf("Job status updated: %v", job) logger.Infof("Job status updated: %v", job)
} }
} }
@ -92,16 +104,16 @@ func (store *JobStore) updatePagesDone(jobID string, pagesDone int) {
if job, exists := store.jobs[jobID]; exists { if job, exists := store.jobs[jobID]; exists {
job.PagesDone = pagesDone job.PagesDone = pagesDone
job.UpdatedAt = time.Now() job.UpdatedAt = time.Now()
logger.Printf("Job pages done updated: %v", job) logger.Infof("Job pages done updated: %v", job)
} }
} }
func startWorkerPool(app *App, numWorkers int) { func startWorkerPool(app *App, numWorkers int) {
for i := 0; i < numWorkers; i++ { for i := 0; i < numWorkers; i++ {
go func(workerID int) { go func(workerID int) {
logger.Printf("Worker %d started", workerID) logger.Infof("Worker %d started", workerID)
for job := range jobQueue { for job := range jobQueue {
logger.Printf("Worker %d processing job: %s", workerID, job.ID) logger.Infof("Worker %d processing job: %s", workerID, job.ID)
processJob(app, job) processJob(app, job)
} }
}(i) }(i)
@ -116,7 +128,7 @@ func processJob(app *App, job *Job) {
// Download images of the document // Download images of the document
imagePaths, err := app.Client.DownloadDocumentAsImages(ctx, job.DocumentID) imagePaths, err := app.Client.DownloadDocumentAsImages(ctx, job.DocumentID)
if err != nil { if err != nil {
logger.Printf("Error downloading document images for job %s: %v", job.ID, err) logger.Infof("Error downloading document images for job %s: %v", job.ID, err)
jobStore.updateJobStatus(job.ID, "failed", fmt.Sprintf("Error downloading document images: %v", err)) jobStore.updateJobStatus(job.ID, "failed", fmt.Sprintf("Error downloading document images: %v", err))
return return
} }
@ -125,14 +137,14 @@ func processJob(app *App, job *Job) {
for i, imagePath := range imagePaths { for i, imagePath := range imagePaths {
imageContent, err := os.ReadFile(imagePath) imageContent, err := os.ReadFile(imagePath)
if err != nil { if err != nil {
logger.Printf("Error reading image file for job %s: %v", job.ID, err) logger.Errorf("Error reading image file for job %s: %v", job.ID, err)
jobStore.updateJobStatus(job.ID, "failed", fmt.Sprintf("Error reading image file: %v", err)) jobStore.updateJobStatus(job.ID, "failed", fmt.Sprintf("Error reading image file: %v", err))
return return
} }
ocrText, err := app.doOCRViaLLM(ctx, imageContent) ocrText, err := app.doOCRViaLLM(ctx, imageContent)
if err != nil { if err != nil {
logger.Printf("Error performing OCR for job %s: %v", job.ID, err) logger.Errorf("Error performing OCR for job %s: %v", job.ID, err)
jobStore.updateJobStatus(job.ID, "failed", fmt.Sprintf("Error performing OCR: %v", err)) jobStore.updateJobStatus(job.ID, "failed", fmt.Sprintf("Error performing OCR: %v", err))
return return
} }
@ -146,5 +158,5 @@ func processJob(app *App, job *Job) {
// Update job status and result // Update job status and result
jobStore.updateJobStatus(job.ID, "completed", fullOcrText) jobStore.updateJobStatus(job.ID, "completed", fullOcrText)
logger.Printf("Job completed: %s", job.ID) logger.Infof("Job completed: %s", job.ID)
} }

81
main.go
View file

@ -3,7 +3,6 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@ -14,6 +13,7 @@ import (
"github.com/Masterminds/sprig/v3" "github.com/Masterminds/sprig/v3"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/ollama" "github.com/tmc/langchaingo/llms/ollama"
"github.com/tmc/langchaingo/llms/openai" "github.com/tmc/langchaingo/llms/openai"
@ -21,6 +21,11 @@ import (
// Global Variables and Constants // Global Variables and Constants
var ( var (
// Logger
log = logrus.New()
// Environment Variables
paperlessBaseURL = os.Getenv("PAPERLESS_BASE_URL") paperlessBaseURL = os.Getenv("PAPERLESS_BASE_URL")
paperlessAPIToken = os.Getenv("PAPERLESS_API_TOKEN") paperlessAPIToken = os.Getenv("PAPERLESS_API_TOKEN")
openaiAPIKey = os.Getenv("OPENAI_API_KEY") openaiAPIKey = os.Getenv("OPENAI_API_KEY")
@ -30,6 +35,7 @@ var (
llmModel = os.Getenv("LLM_MODEL") llmModel = os.Getenv("LLM_MODEL")
visionLlmProvider = os.Getenv("VISION_LLM_PROVIDER") visionLlmProvider = os.Getenv("VISION_LLM_PROVIDER")
visionLlmModel = os.Getenv("VISION_LLM_MODEL") visionLlmModel = os.Getenv("VISION_LLM_MODEL")
logLevel = strings.ToLower(os.Getenv("LOG_LEVEL"))
// Templates // Templates
titleTemplate *template.Template titleTemplate *template.Template
@ -75,6 +81,9 @@ func main() {
// Validate Environment Variables // Validate Environment Variables
validateEnvVars() validateEnvVars()
// Initialize logrus logger
initLogger()
// Initialize PaperlessClient // Initialize PaperlessClient
client := NewPaperlessClient(paperlessBaseURL, paperlessAPIToken) client := NewPaperlessClient(paperlessBaseURL, paperlessAPIToken)
@ -102,25 +111,28 @@ func main() {
// Start background process for auto-tagging // Start background process for auto-tagging
go func() { go func() {
minBackoffDuration := 10 * time.Second
minBackoffDuration := time.Second
maxBackoffDuration := time.Hour maxBackoffDuration := time.Hour
pollingInterval := 10 * time.Second pollingInterval := 10 * time.Second
backoffDuration := minBackoffDuration backoffDuration := minBackoffDuration
for { for {
if err := app.processAutoTagDocuments(); err != nil { processedCount, err := app.processAutoTagDocuments()
log.Printf("Error in processAutoTagDocuments: %v", err) if err != nil {
log.Errorf("Error in processAutoTagDocuments: %v", err)
time.Sleep(backoffDuration) time.Sleep(backoffDuration)
backoffDuration *= 2 // Exponential backoff backoffDuration *= 2 // Exponential backoff
if backoffDuration > maxBackoffDuration { if backoffDuration > maxBackoffDuration {
log.Printf("Repeated errors in processAutoTagDocuments detected. Setting backoff to %v", maxBackoffDuration) log.Warnf("Repeated errors in processAutoTagDocuments detected. Setting backoff to %v", maxBackoffDuration)
backoffDuration = maxBackoffDuration backoffDuration = maxBackoffDuration
} }
} else { } else {
backoffDuration = minBackoffDuration backoffDuration = minBackoffDuration
} }
time.Sleep(pollingInterval)
if processedCount == 0 {
time.Sleep(pollingInterval)
}
} }
}() }()
@ -168,12 +180,34 @@ func main() {
numWorkers := 1 // Number of workers to start numWorkers := 1 // Number of workers to start
startWorkerPool(app, numWorkers) startWorkerPool(app, numWorkers)
log.Println("Server started on port :8080") log.Infoln("Server started on port :8080")
if err := router.Run(":8080"); err != nil { if err := router.Run(":8080"); err != nil {
log.Fatalf("Failed to run server: %v", err) log.Fatalf("Failed to run server: %v", err)
} }
} }
func initLogger() {
switch logLevel {
case "debug":
log.SetLevel(logrus.DebugLevel)
case "info":
log.SetLevel(logrus.InfoLevel)
case "warn":
log.SetLevel(logrus.WarnLevel)
case "error":
log.SetLevel(logrus.ErrorLevel)
default:
log.SetLevel(logrus.InfoLevel)
if logLevel != "" {
log.Fatalf("Invalid log level: '%s'.", logLevel)
}
}
log.SetFormatter(&logrus.TextFormatter{
FullTimestamp: true,
})
}
func isOcrEnabled() bool { func isOcrEnabled() bool {
return visionLlmModel != "" && visionLlmProvider != "" return visionLlmModel != "" && visionLlmProvider != ""
} }
@ -192,28 +226,37 @@ func validateEnvVars() {
log.Fatal("Please set the LLM_PROVIDER environment variable.") log.Fatal("Please set the LLM_PROVIDER environment variable.")
} }
if visionLlmProvider != "" && visionLlmProvider != "openai" && visionLlmProvider != "ollama" {
log.Fatal("Please set the LLM_PROVIDER environment variable to 'openai' or 'ollama'.")
}
if llmModel == "" { if llmModel == "" {
log.Fatal("Please set the LLM_MODEL environment variable.") log.Fatal("Please set the LLM_MODEL environment variable.")
} }
if llmProvider == "openai" && openaiAPIKey == "" { if (llmProvider == "openai" || visionLlmProvider == "openai") && openaiAPIKey == "" {
log.Fatal("Please set the OPENAI_API_KEY environment variable for OpenAI provider.") log.Fatal("Please set the OPENAI_API_KEY environment variable for OpenAI provider.")
} }
} }
// processAutoTagDocuments handles the background auto-tagging of documents // processAutoTagDocuments handles the background auto-tagging of documents
func (app *App) processAutoTagDocuments() error { func (app *App) processAutoTagDocuments() (int, error) {
ctx := context.Background() ctx := context.Background()
documents, err := app.Client.GetDocumentsByTags(ctx, []string{autoTag}) documents, err := app.Client.GetDocumentsByTags(ctx, []string{autoTag})
if err != nil { if err != nil {
return fmt.Errorf("error fetching documents with autoTag: %w", err) return 0, fmt.Errorf("error fetching documents with autoTag: %w", err)
} }
if len(documents) == 0 { if len(documents) == 0 {
return nil // No documents to process log.Debugf("No documents with tag %s found", autoTag)
return 0, nil // No documents to process
} }
log.Debugf("Found at least %d remaining documents with tag %s", len(documents), autoTag)
documents = documents[:1] // Process only one document at a time
suggestionRequest := GenerateSuggestionsRequest{ suggestionRequest := GenerateSuggestionsRequest{
Documents: documents, Documents: documents,
GenerateTitles: true, GenerateTitles: true,
@ -222,15 +265,15 @@ func (app *App) processAutoTagDocuments() error {
suggestions, err := app.generateDocumentSuggestions(ctx, suggestionRequest) suggestions, err := app.generateDocumentSuggestions(ctx, suggestionRequest)
if err != nil { if err != nil {
return fmt.Errorf("error generating suggestions: %w", err) return 0, fmt.Errorf("error generating suggestions: %w", err)
} }
err = app.Client.UpdateDocuments(ctx, suggestions) err = app.Client.UpdateDocuments(ctx, suggestions)
if err != nil { if err != nil {
return fmt.Errorf("error updating documents: %w", err) return 0, fmt.Errorf("error updating documents: %w", err)
} }
return nil return len(documents), nil
} }
// removeTagFromList removes a specific tag from a list of tags // removeTagFromList removes a specific tag from a list of tags
@ -268,7 +311,7 @@ func loadTemplates() {
titleTemplatePath := filepath.Join(promptsDir, "title_prompt.tmpl") titleTemplatePath := filepath.Join(promptsDir, "title_prompt.tmpl")
titleTemplateContent, err := os.ReadFile(titleTemplatePath) titleTemplateContent, err := os.ReadFile(titleTemplatePath)
if err != nil { if err != nil {
log.Printf("Could not read %s, using default template: %v", titleTemplatePath, err) log.Errorf("Could not read %s, using default template: %v", titleTemplatePath, err)
titleTemplateContent = []byte(defaultTitleTemplate) titleTemplateContent = []byte(defaultTitleTemplate)
if err := os.WriteFile(titleTemplatePath, titleTemplateContent, os.ModePerm); err != nil { if err := os.WriteFile(titleTemplatePath, titleTemplateContent, os.ModePerm); err != nil {
log.Fatalf("Failed to write default title template to disk: %v", err) log.Fatalf("Failed to write default title template to disk: %v", err)
@ -283,7 +326,7 @@ func loadTemplates() {
tagTemplatePath := filepath.Join(promptsDir, "tag_prompt.tmpl") tagTemplatePath := filepath.Join(promptsDir, "tag_prompt.tmpl")
tagTemplateContent, err := os.ReadFile(tagTemplatePath) tagTemplateContent, err := os.ReadFile(tagTemplatePath)
if err != nil { if err != nil {
log.Printf("Could not read %s, using default template: %v", tagTemplatePath, err) log.Errorf("Could not read %s, using default template: %v", tagTemplatePath, err)
tagTemplateContent = []byte(defaultTagTemplate) tagTemplateContent = []byte(defaultTagTemplate)
if err := os.WriteFile(tagTemplatePath, tagTemplateContent, os.ModePerm); err != nil { if err := os.WriteFile(tagTemplatePath, tagTemplateContent, os.ModePerm); err != nil {
log.Fatalf("Failed to write default tag template to disk: %v", err) log.Fatalf("Failed to write default tag template to disk: %v", err)
@ -298,7 +341,7 @@ func loadTemplates() {
ocrTemplatePath := filepath.Join(promptsDir, "ocr_prompt.tmpl") ocrTemplatePath := filepath.Join(promptsDir, "ocr_prompt.tmpl")
ocrTemplateContent, err := os.ReadFile(ocrTemplatePath) ocrTemplateContent, err := os.ReadFile(ocrTemplatePath)
if err != nil { if err != nil {
log.Printf("Could not read %s, using default template: %v", ocrTemplatePath, err) log.Errorf("Could not read %s, using default template: %v", ocrTemplatePath, err)
ocrTemplateContent = []byte(defaultOcrPrompt) ocrTemplateContent = []byte(defaultOcrPrompt)
if err := os.WriteFile(ocrTemplatePath, ocrTemplateContent, os.ModePerm); err != nil { if err := os.WriteFile(ocrTemplatePath, ocrTemplateContent, os.ModePerm); err != nil {
log.Fatalf("Failed to write default OCR template to disk: %v", err) log.Fatalf("Failed to write default OCR template to disk: %v", err)
@ -355,7 +398,7 @@ func createVisionLLM() (llms.Model, error) {
ollama.WithServerURL(host), ollama.WithServerURL(host),
) )
default: default:
log.Printf("No Vision LLM provider created: %s", visionLlmProvider) log.Infoln("Vision LLM not enabled")
return nil, nil return nil, nil
} }
} }

View file

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"image/jpeg" "image/jpeg"
"io" "io"
"log"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@ -223,7 +222,7 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum
// Fetch all available tags // Fetch all available tags
availableTags, err := c.GetAllTags(ctx) availableTags, err := c.GetAllTags(ctx)
if err != nil { if err != nil {
log.Printf("Error fetching available tags: %v", err) log.Errorf("Error fetching available tags: %v", err)
return err return err
} }
@ -249,7 +248,7 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum
} }
newTags = append(newTags, tagID) newTags = append(newTags, tagID)
} else { } else {
log.Printf("Tag '%s' does not exist in paperless-ngx, skipping.", tagName) log.Warnf("Tag '%s' does not exist in paperless-ngx, skipping.", tagName)
} }
} }
@ -262,7 +261,7 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum
if suggestedTitle != "" { if suggestedTitle != "" {
updatedFields["title"] = suggestedTitle updatedFields["title"] = suggestedTitle
} else { } else {
log.Printf("No valid title found for document %d, skipping.", documentID) log.Warnf("No valid title found for document %d, skipping.", documentID)
} }
// Suggested Content // Suggested Content
@ -274,7 +273,7 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum
// Marshal updated fields to JSON // Marshal updated fields to JSON
jsonData, err := json.Marshal(updatedFields) jsonData, err := json.Marshal(updatedFields)
if err != nil { if err != nil {
log.Printf("Error marshalling JSON for document %d: %v", documentID, err) log.Errorf("Error marshalling JSON for document %d: %v", documentID, err)
return err return err
} }
@ -282,14 +281,14 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum
path := fmt.Sprintf("api/documents/%d/", documentID) path := fmt.Sprintf("api/documents/%d/", documentID)
resp, err := c.Do(ctx, "PATCH", path, bytes.NewBuffer(jsonData)) resp, err := c.Do(ctx, "PATCH", path, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
log.Printf("Error updating document %d: %v", documentID, err) log.Errorf("Error updating document %d: %v", documentID, err)
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body) bodyBytes, _ := io.ReadAll(resp.Body)
log.Printf("Error updating document %d: %d, %s", documentID, resp.StatusCode, string(bodyBytes)) log.Errorf("Error updating document %d: %d, %s", documentID, resp.StatusCode, string(bodyBytes))
return fmt.Errorf("error updating document %d: %d, %s", documentID, resp.StatusCode, string(bodyBytes)) return fmt.Errorf("error updating document %d: %d, %s", documentID, resp.StatusCode, string(bodyBytes))
} }