mirror of
https://github.com/icereed/paperless-gpt.git
synced 2025-03-12 12:58:02 -05:00
refactor: Improve auto-tagging process to skip OCR-tagged documents (#227)
* refactor: Improve auto-tagging process to skip OCR-tagged documents * Use empty template in test case
This commit is contained in:
parent
8936e5ea89
commit
29c2bb6d04
2 changed files with 190 additions and 4 deletions
17
main.go
17
main.go
|
@ -9,7 +9,8 @@ import (
|
|||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"text/template"
|
||||
"time"
|
||||
|
@ -462,7 +463,14 @@ func (app *App) processAutoTagDocuments() (int, error) {
|
|||
|
||||
log.Debugf("Found at least %d remaining documents with tag %s", len(documents), autoTag)
|
||||
|
||||
processedCount := 0
|
||||
for _, document := range documents {
|
||||
// Skip documents that have the autoOcrTag
|
||||
if slices.Contains(document.Tags, autoOcrTag) {
|
||||
log.Debugf("Skipping document %d as it has the OCR tag %s", document.ID, autoOcrTag)
|
||||
continue
|
||||
}
|
||||
|
||||
docLogger := documentLogger(document.ID)
|
||||
docLogger.Info("Processing document for auto-tagging")
|
||||
|
||||
|
@ -475,17 +483,18 @@ func (app *App) processAutoTagDocuments() (int, error) {
|
|||
|
||||
suggestions, err := app.generateDocumentSuggestions(ctx, suggestionRequest, docLogger)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error generating suggestions for document %d: %w", document.ID, err)
|
||||
return processedCount, fmt.Errorf("error generating suggestions for document %d: %w", document.ID, err)
|
||||
}
|
||||
|
||||
err = app.Client.UpdateDocuments(ctx, suggestions, app.Database, false)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error updating document %d: %w", document.ID, err)
|
||||
return processedCount, fmt.Errorf("error updating document %d: %w", document.ID, err)
|
||||
}
|
||||
|
||||
docLogger.Info("Successfully processed document")
|
||||
processedCount++
|
||||
}
|
||||
return len(documents), nil
|
||||
return processedCount, nil
|
||||
}
|
||||
|
||||
// processAutoOcrTagDocuments handles the background auto-tagging of OCR documents
|
||||
|
|
177
main_test.go
Normal file
177
main_test.go
Normal file
|
@ -0,0 +1,177 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"testing"
|
||||
"text/template"
|
||||
|
||||
"github.com/Masterminds/sprig/v3"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProcessAutoTagDocuments(t *testing.T) {
|
||||
// Initialize required global variables
|
||||
autoTag = "paperless-gpt-auto"
|
||||
autoOcrTag = "paperless-gpt-ocr-auto"
|
||||
|
||||
// Initialize templates
|
||||
var err error
|
||||
titleTemplate, err = template.New("title").Funcs(sprig.FuncMap()).Parse("")
|
||||
require.NoError(t, err)
|
||||
tagTemplate, err = template.New("tag").Funcs(sprig.FuncMap()).Parse("")
|
||||
require.NoError(t, err)
|
||||
correspondentTemplate, err = template.New("correspondent").Funcs(sprig.FuncMap()).Parse("")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test environment
|
||||
env := newTestEnv(t)
|
||||
defer env.teardown()
|
||||
|
||||
// Set up test cases
|
||||
testCases := []struct {
|
||||
name string
|
||||
documents []Document
|
||||
expectedCount int
|
||||
expectedError string
|
||||
updateResponse int // HTTP status code for update response
|
||||
}{
|
||||
{
|
||||
name: "Skip document with autoOcrTag",
|
||||
documents: []Document{
|
||||
{
|
||||
ID: 1,
|
||||
Title: "Doc with OCR tag",
|
||||
Tags: []string{autoTag, autoOcrTag},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Title: "Doc without OCR tag",
|
||||
Tags: []string{autoTag},
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Title: "Doc with OCR tag",
|
||||
Tags: []string{autoTag, autoOcrTag},
|
||||
},
|
||||
},
|
||||
expectedCount: 1,
|
||||
updateResponse: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "No documents to process",
|
||||
documents: []Document{},
|
||||
expectedCount: 0,
|
||||
updateResponse: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Mock the GetAllTags response
|
||||
env.setMockResponse("/api/tags/", func(w http.ResponseWriter, r *http.Request) {
|
||||
response := map[string]interface{}{
|
||||
"results": []map[string]interface{}{
|
||||
{"id": 1, "name": autoTag},
|
||||
{"id": 2, "name": autoOcrTag},
|
||||
{"id": 3, "name": "other-tag"},
|
||||
},
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
})
|
||||
|
||||
// Mock the GetDocumentsByTags response
|
||||
env.setMockResponse("/api/documents/", func(w http.ResponseWriter, r *http.Request) {
|
||||
response := GetDocumentsApiResponse{
|
||||
Results: make([]GetDocumentApiResponseResult, len(tc.documents)),
|
||||
}
|
||||
for i, doc := range tc.documents {
|
||||
tagIds := make([]int, len(doc.Tags))
|
||||
for j, tagName := range doc.Tags {
|
||||
switch tagName {
|
||||
case autoTag:
|
||||
tagIds[j] = 1
|
||||
case autoOcrTag:
|
||||
tagIds[j] = 2
|
||||
default:
|
||||
tagIds[j] = 3
|
||||
}
|
||||
}
|
||||
response.Results[i] = GetDocumentApiResponseResult{
|
||||
ID: doc.ID,
|
||||
Title: doc.Title,
|
||||
Tags: tagIds,
|
||||
Content: "Test content",
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
})
|
||||
|
||||
// Mock the correspondent creation endpoint
|
||||
env.setMockResponse("/api/correspondents/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "POST" {
|
||||
// Mock successful correspondent creation
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"id": 3,
|
||||
"name": "test response",
|
||||
})
|
||||
} else {
|
||||
// Mock GET response for existing correspondents
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"results": []map[string]interface{}{
|
||||
{"id": 1, "name": "Alpha"},
|
||||
{"id": 2, "name": "Beta"},
|
||||
},
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Create test app
|
||||
app := &App{
|
||||
Client: env.client,
|
||||
Database: env.db,
|
||||
LLM: &mockLLM{}, // Use mock LLM from app_llm_test.go
|
||||
}
|
||||
|
||||
// Set auto-generate flags
|
||||
autoGenerateTitle = "true"
|
||||
autoGenerateTags = "true"
|
||||
autoGenerateCorrespondents = "true"
|
||||
|
||||
// Mock the document update responses
|
||||
for _, doc := range tc.documents {
|
||||
if !slices.Contains(doc.Tags, autoOcrTag) {
|
||||
updatePath := fmt.Sprintf("/api/documents/%d/", doc.ID)
|
||||
env.setMockResponse(updatePath, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(tc.updateResponse)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"id": doc.ID,
|
||||
"title": "Updated " + doc.Title,
|
||||
"tags": []int{1, 3}, // Mock updated tag IDs
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Run the test
|
||||
count, err := app.processAutoTagDocuments()
|
||||
|
||||
// Verify results
|
||||
if tc.expectedError != "" {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.expectedError)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedCount, count)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue