paperless-gpt/main_test.go

200 lines
5.2 KiB
Go
Raw Normal View History

package main
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"slices"
"testing"
"text/template"
"github.com/Masterminds/sprig/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestProcessAutoTagDocuments(t *testing.T) {
// Initialize required global variables
autoTag = "paperless-gpt-auto"
autoOcrTag = "paperless-gpt-ocr-auto"
// Initialize templates
var err error
titleTemplate, err = template.New("title").Funcs(sprig.FuncMap()).Parse("")
require.NoError(t, err)
tagTemplate, err = template.New("tag").Funcs(sprig.FuncMap()).Parse("")
require.NoError(t, err)
correspondentTemplate, err = template.New("correspondent").Funcs(sprig.FuncMap()).Parse("")
require.NoError(t, err)
// Create test environment
env := newTestEnv(t)
defer env.teardown()
// Set up test cases
testCases := []struct {
name string
documents []Document
expectedCount int
expectedError string
updateResponse int // HTTP status code for update response
}{
{
name: "Skip document with autoOcrTag",
documents: []Document{
{
ID: 1,
Title: "Doc with OCR tag",
Tags: []string{autoTag, autoOcrTag},
},
{
ID: 2,
Title: "Doc without OCR tag",
Tags: []string{autoTag},
},
{
ID: 3,
Title: "Doc with OCR tag",
Tags: []string{autoTag, autoOcrTag},
},
},
expectedCount: 1,
updateResponse: http.StatusOK,
},
{
name: "No documents to process",
documents: []Document{},
expectedCount: 0,
updateResponse: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Mock the GetAllTags response
env.setMockResponse("/api/tags/", func(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"results": []map[string]interface{}{
{"id": 1, "name": autoTag},
{"id": 2, "name": autoOcrTag},
{"id": 3, "name": "other-tag"},
},
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
})
// Mock the GetDocumentsByTags response
env.setMockResponse("/api/documents/", func(w http.ResponseWriter, r *http.Request) {
response := GetDocumentsApiResponse{
Results: make([]GetDocumentApiResponseResult, len(tc.documents)),
}
for i, doc := range tc.documents {
tagIds := make([]int, len(doc.Tags))
for j, tagName := range doc.Tags {
switch tagName {
case autoTag:
tagIds[j] = 1
case autoOcrTag:
tagIds[j] = 2
default:
tagIds[j] = 3
}
}
response.Results[i] = GetDocumentApiResponseResult{
ID: doc.ID,
Title: doc.Title,
Tags: tagIds,
Content: "Test content",
}
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
})
// Mock the correspondent creation endpoint
env.setMockResponse("/api/correspondents/", func(w http.ResponseWriter, r *http.Request) {
if r.Method == "POST" {
// Mock successful correspondent creation
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{
"id": 3,
"name": "test response",
})
} else {
// Mock GET response for existing correspondents
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"results": []map[string]interface{}{
{"id": 1, "name": "Alpha"},
{"id": 2, "name": "Beta"},
},
})
}
})
// Create test app
app := &App{
Client: env.client,
Database: env.db,
LLM: &mockLLM{}, // Use mock LLM from app_llm_test.go
}
// Set auto-generate flags
autoGenerateTitle = "true"
autoGenerateTags = "true"
autoGenerateCorrespondents = "true"
// Mock the document update responses
for _, doc := range tc.documents {
if !slices.Contains(doc.Tags, autoOcrTag) {
updatePath := fmt.Sprintf("/api/documents/%d/", doc.ID)
env.setMockResponse(updatePath, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tc.updateResponse)
json.NewEncoder(w).Encode(map[string]interface{}{
"id": doc.ID,
"title": "Updated " + doc.Title,
"tags": []int{1, 3}, // Mock updated tag IDs
})
})
}
}
// Run the test
count, err := app.processAutoTagDocuments()
// Verify results
if tc.expectedError != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedError)
} else {
require.NoError(t, err)
assert.Equal(t, tc.expectedCount, count)
}
})
}
}
func TestCreateCustomHTTPClient(t *testing.T) {
// Create a test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify custom header
assert.Equal(t, "paperless-gpt", r.Header.Get("X-Title"), "Expected X-Title header")
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Get custom client
client := createCustomHTTPClient()
require.NotNil(t, client, "HTTP client should not be nil")
// Make a request
resp, err := client.Get(server.URL)
require.NoError(t, err, "Request should not fail")
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode, "Expected 200 OK response")
}