mirror of
https://github.com/icereed/paperless-gpt.git
synced 2025-03-12 12:58:02 -05:00
222 lines
5.6 KiB
Go
222 lines
5.6 KiB
Go
package ocr
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-retryablehttp"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestNewAzureProvider(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
config Config
|
|
wantErr bool
|
|
errContains string
|
|
}{
|
|
{
|
|
name: "valid config",
|
|
config: Config{
|
|
AzureEndpoint: "https://test.cognitiveservices.azure.com/",
|
|
AzureAPIKey: "test-key",
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "valid config with custom model and timeout",
|
|
config: Config{
|
|
AzureEndpoint: "https://test.cognitiveservices.azure.com/",
|
|
AzureAPIKey: "test-key",
|
|
AzureModelID: "custom-model",
|
|
AzureTimeout: 60,
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "missing endpoint",
|
|
config: Config{
|
|
AzureAPIKey: "test-key",
|
|
},
|
|
wantErr: true,
|
|
errContains: "missing required Azure Document Intelligence configuration",
|
|
},
|
|
{
|
|
name: "missing api key",
|
|
config: Config{
|
|
AzureEndpoint: "https://test.cognitiveservices.azure.com/",
|
|
},
|
|
wantErr: true,
|
|
errContains: "missing required Azure Document Intelligence configuration",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
provider, err := newAzureProvider(tt.config)
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
if tt.errContains != "" {
|
|
assert.Contains(t, err.Error(), tt.errContains)
|
|
}
|
|
return
|
|
}
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, provider)
|
|
|
|
// Verify default values
|
|
if tt.config.AzureModelID == "" {
|
|
assert.Equal(t, defaultModelID, provider.modelID)
|
|
} else {
|
|
assert.Equal(t, tt.config.AzureModelID, provider.modelID)
|
|
}
|
|
|
|
if tt.config.AzureTimeout == 0 {
|
|
assert.Equal(t, time.Duration(defaultTimeout)*time.Second, provider.timeout)
|
|
} else {
|
|
assert.Equal(t, time.Duration(tt.config.AzureTimeout)*time.Second, provider.timeout)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAzureProvider_ProcessImage(t *testing.T) {
|
|
// Sample success response
|
|
now := time.Now()
|
|
successResult := AzureDocumentResult{
|
|
Status: "succeeded",
|
|
CreatedDateTime: now,
|
|
LastUpdatedDateTime: now,
|
|
AnalyzeResult: AzureAnalyzeResult{
|
|
APIVersion: apiVersion,
|
|
ModelID: defaultModelID,
|
|
StringIndexType: "utf-16",
|
|
Content: "Test document content",
|
|
Pages: []AzurePage{
|
|
{
|
|
PageNumber: 1,
|
|
Angle: 0.0,
|
|
Width: 800,
|
|
Height: 600,
|
|
Unit: "pixel",
|
|
Lines: []AzureLine{
|
|
{
|
|
Content: "Test line",
|
|
Polygon: []int{0, 0, 100, 0, 100, 20, 0, 20},
|
|
Spans: []AzureSpan{{Offset: 0, Length: 9}},
|
|
},
|
|
},
|
|
Spans: []AzureSpan{{Offset: 0, Length: 9}},
|
|
},
|
|
},
|
|
Paragraphs: []AzureParagraph{
|
|
{
|
|
Content: "Test document content",
|
|
Spans: []AzureSpan{{Offset: 0, Length: 19}},
|
|
BoundingRegions: []AzureBoundingBox{
|
|
{
|
|
PageNumber: 1,
|
|
Polygon: []int{0, 0, 100, 0, 100, 20, 0, 20},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
ContentFormat: "text",
|
|
},
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
setupServer func() *httptest.Server
|
|
imageContent []byte
|
|
wantErr bool
|
|
errContains string
|
|
expectedText string
|
|
}{
|
|
{
|
|
name: "successful processing",
|
|
setupServer: func() *httptest.Server {
|
|
mux := http.NewServeMux()
|
|
server := httptest.NewServer(mux)
|
|
|
|
mux.HandleFunc("/documentintelligence/documentModels/prebuilt-read:analyze", func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Operation-Location", fmt.Sprintf("%s/operations/123", server.URL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
})
|
|
|
|
mux.HandleFunc("/operations/123", func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewEncoder(w).Encode(successResult)
|
|
})
|
|
|
|
return server
|
|
},
|
|
// Create minimal JPEG content with magic numbers
|
|
imageContent: append([]byte{0xFF, 0xD8, 0xFF, 0xE0}, []byte("JFIF test content")...),
|
|
expectedText: "Test document content",
|
|
},
|
|
{
|
|
name: "invalid mime type",
|
|
setupServer: func() *httptest.Server {
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Log("Server should not be called with invalid mime type")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
}))
|
|
},
|
|
imageContent: []byte("invalid content"),
|
|
wantErr: true,
|
|
errContains: "unsupported file type",
|
|
},
|
|
{
|
|
name: "submission error",
|
|
setupServer: func() *httptest.Server {
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
fmt.Fprintln(w, "Invalid request")
|
|
}))
|
|
},
|
|
imageContent: []byte{0xFF, 0xD8, 0xFF, 0xE0}, // JPEG magic numbers
|
|
wantErr: true,
|
|
errContains: "unexpected status code 400",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
server := tt.setupServer()
|
|
defer server.Close()
|
|
|
|
client := retryablehttp.NewClient()
|
|
client.HTTPClient = server.Client()
|
|
client.Logger = log
|
|
|
|
provider := &AzureProvider{
|
|
endpoint: server.URL,
|
|
apiKey: "test-key",
|
|
modelID: defaultModelID,
|
|
timeout: 5 * time.Second,
|
|
httpClient: client,
|
|
}
|
|
|
|
result, err := provider.ProcessImage(context.Background(), tt.imageContent)
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
if tt.errContains != "" {
|
|
assert.Contains(t, err.Error(), tt.errContains)
|
|
}
|
|
return
|
|
}
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
assert.Equal(t, tt.expectedText, result.Text)
|
|
assert.Equal(t, "azure_docai", result.Metadata["provider"])
|
|
assert.Equal(t, apiVersion, result.Metadata["api_version"])
|
|
assert.Equal(t, "1", result.Metadata["page_count"])
|
|
})
|
|
}
|
|
}
|