diff --git a/README.md b/README.md index 0112328..effeabb 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ https://github.com/user-attachments/assets/bd5d38b9-9309-40b9-93ca-918dfa4f3fd4 - **LLM OCR**: Use OpenAI or Ollama to extract text from images. - **Google Document AI**: Leverage Google's powerful Document AI for OCR tasks. - - **More to come**: Stay tuned for more OCR providers! + - **Azure Document Intelligence**: Use Microsoft's enterprise OCR solution. 3. **Automatic Title & Tag Generation** No more guesswork. Let the AI do the naming and categorizing. You can easily review suggestions and refine them if needed. @@ -39,11 +39,11 @@ https://github.com/user-attachments/assets/bd5d38b9-9309-40b9-93ca-918dfa4f3fd4 - **Tagging**: Decide how documents get tagged—manually, automatically, or via OCR-based flows. 7. **Simple Docker Deployment** - A few environment variables, and you’re off! Compose it alongside paperless-ngx with minimal fuss. + A few environment variables, and you're off! Compose it alongside paperless-ngx with minimal fuss. 8. **Unified Web UI** - - **Manual Review**: Approve or tweak AI’s suggestions. + - **Manual Review**: Approve or tweak AI's suggestions. - **Auto Processing**: Focus only on edge cases while the rest is sorted for you. --- @@ -56,6 +56,12 @@ https://github.com/user-attachments/assets/bd5d38b9-9309-40b9-93ca-918dfa4f3fd4 - [Installation](#installation) - [Docker Compose](#docker-compose) - [Manual Setup](#manual-setup) +- [OCR Providers](#ocr-providers) + - [LLM-based OCR](#1-llm-based-ocr-default) + - [Azure Document Intelligence](#2-azure-document-intelligence) + - [Google Document AI](#3-google-document-ai) + - [Comparing OCR Providers](#comparing-ocr-providers) + - [Choosing the Right Provider](#choosing-the-right-provider) - [Configuration](#configuration) - [Environment Variables](#environment-variables) - [Custom Prompt Templates](#custom-prompt-templates) @@ -86,7 +92,7 @@ https://github.com/user-attachments/assets/bd5d38b9-9309-40b9-93ca-918dfa4f3fd4 #### Docker Compose -Here’s an example `docker-compose.yml` to spin up **paperless-gpt** alongside paperless-ngx: +Here's an example `docker-compose.yml` to spin up **paperless-gpt** alongside paperless-ngx: ```yaml services: @@ -124,6 +130,13 @@ services: # GOOGLE_PROCESSOR_ID: 'processor-id' # Your processor ID # GOOGLE_APPLICATION_CREDENTIALS: '/app/credentials.json' # Path to service account key + # Option 3: Azure Document Intelligence + # OCR_PROVIDER: 'azure' # Use Azure Document Intelligence + # AZURE_DOCAI_ENDPOINT: 'your-endpoint' # Your Azure endpoint URL + # AZURE_DOCAI_KEY: 'your-key' # Your Azure API key + # AZURE_DOCAI_MODEL_ID: 'prebuilt-read' # Optional, defaults to prebuilt-read + # AZURE_DOCAI_TIMEOUT_SECONDS: '120' # Optional, defaults to 120 seconds + AUTO_OCR_TAG: "paperless-gpt-ocr-auto" # Optional, default: paperless-gpt-ocr-auto OCR_LIMIT_PAGES: "5" # Optional, default: 5. Set to 0 for no limit. LOG_LEVEL: "info" # Optional: debug, warn, error @@ -172,6 +185,63 @@ services: ``` --- +## OCR Providers + +paperless-gpt supports three different OCR providers, each with unique strengths and capabilities: + +### 1. LLM-based OCR (Default) +- **Key Features**: + - Uses vision-capable LLMs like GPT-4V or MiniCPM-V + - High accuracy with complex layouts and difficult scans + - Context-aware text recognition + - Self-correcting capabilities for OCR errors +- **Best For**: + - Complex or unusual document layouts + - Poor quality scans + - Documents with mixed languages +- **Configuration**: + ```yaml + OCR_PROVIDER: "llm" + VISION_LLM_PROVIDER: "openai" # or "ollama" + VISION_LLM_MODEL: "gpt-4v" # or "minicpm-v" + ``` + +### 2. Azure Document Intelligence +- **Key Features**: + - Enterprise-grade OCR solution + - Prebuilt models for common document types + - Layout preservation and table detection + - Fast processing speeds +- **Best For**: + - Business documents and forms + - High-volume processing + - Documents requiring layout analysis +- **Configuration**: + ```yaml + OCR_PROVIDER: "azure" + AZURE_DOCAI_ENDPOINT: "https://your-endpoint.cognitiveservices.azure.com/" + AZURE_DOCAI_KEY: "your-key" + AZURE_DOCAI_MODEL_ID: "prebuilt-read" # optional + AZURE_DOCAI_TIMEOUT_SECONDS: "120" # optional + ``` + +### 3. Google Document AI +- **Key Features**: + - Specialized document processors + - Strong form field detection + - Multi-language support + - High accuracy on structured documents +- **Best For**: + - Forms and structured documents + - Documents with tables + - Multi-language documents +- **Configuration**: + ```yaml + OCR_PROVIDER: "google_docai" + GOOGLE_PROJECT_ID: "your-project" + GOOGLE_LOCATION: "us" + GOOGLE_PROCESSOR_ID: "processor-id" + ``` ## Configuration @@ -192,9 +262,13 @@ services: | `OPENAI_BASE_URL` | OpenAI base URL (optional, if using a custom OpenAI compatible service like LiteLLM). | No | | | `LLM_LANGUAGE` | Likely language for documents (e.g. `English`). | No | English | | `OLLAMA_HOST` | Ollama server URL (e.g. `http://host.docker.internal:11434`). | No | | -| `OCR_PROVIDER` | OCR provider to use (`llm` or `google_docai`). | No | llm | +| `OCR_PROVIDER` | OCR provider to use (`llm`, `azure`, or `google_docai`). | No | llm | | `VISION_LLM_PROVIDER` | AI backend for LLM OCR (`openai` or `ollama`). Required if OCR_PROVIDER is `llm`. | Cond. | | | `VISION_LLM_MODEL` | Model name for LLM OCR (e.g. `minicpm-v`). Required if OCR_PROVIDER is `llm`. | Cond. | | +| `AZURE_DOCAI_ENDPOINT` | Azure Document Intelligence endpoint. Required if OCR_PROVIDER is `azure`. | Cond. | | +| `AZURE_DOCAI_KEY` | Azure Document Intelligence API key. Required if OCR_PROVIDER is `azure`. | Cond. | | +| `AZURE_DOCAI_MODEL_ID` | Azure Document Intelligence model ID. Optional if using `azure` provider. | No | prebuilt-read | +| `AZURE_DOCAI_TIMEOUT_SECONDS` | Azure Document Intelligence timeout in seconds. | No | 120 | | `GOOGLE_PROJECT_ID` | Google Cloud project ID. Required if OCR_PROVIDER is `google_docai`. | Cond. | | | `GOOGLE_LOCATION` | Google Cloud region (e.g. `us`, `eu`). Required if OCR_PROVIDER is `google_docai`. | Cond. | | | `GOOGLE_PROCESSOR_ID` | Document AI processor ID. Required if OCR_PROVIDER is `google_docai`. | Cond. | | @@ -211,7 +285,7 @@ services: ### Custom Prompt Templates -paperless-gpt’s flexible **prompt templates** let you shape how AI responds: +paperless-gpt's flexible **prompt templates** let you shape how AI responds: 1. **`title_prompt.tmpl`**: For document titles. 2. **`tag_prompt.tmpl`**: For tagging logic. @@ -232,13 +306,11 @@ Then tweak at will—**paperless-gpt** reloads them automatically on startup! Each template has access to specific variables: **title_prompt.tmpl**: - - `{{.Language}}` - Target language (e.g., "English") - `{{.Content}}` - Document content text - `{{.Title}}` - Original document title **tag_prompt.tmpl**: - - `{{.Language}}` - Target language - `{{.AvailableTags}}` - List of existing tags in paperless-ngx - `{{.OriginalTags}}` - Document's current tags @@ -246,11 +318,9 @@ Each template has access to specific variables: - `{{.Content}}` - Document content text **ocr_prompt.tmpl**: - - `{{.Language}}` - Target language **correspondent_prompt.tmpl**: - - `{{.Language}}` - Target language - `{{.AvailableCorrespondents}}` - List of existing correspondents - `{{.BlackList}}` - List of blacklisted correspondent names @@ -265,23 +335,25 @@ The templates use Go's text/template syntax. paperless-gpt automatically reloads 1. **Tag Documents** - - Add `paperless-gpt` or your custom tag to the docs you want to AI-ify. + - Add `paperless-gpt` tag to documents for manual processing + - Add `paperless-gpt-auto` for automatic processing + - Add `paperless-gpt-ocr-auto` for automatic OCR processing 2. **Visit Web UI** - - Go to `http://localhost:8080` (or your host) in your browser. + - Go to `http://localhost:8080` (or your host) in your browser + - Review documents tagged for processing 3. **Generate & Apply Suggestions** - - Click “Generate Suggestions” to see AI-proposed titles/tags/correspondents. - - Approve, edit, or discard. Hit “Apply” to finalize in paperless-ngx. - -4. **Try LLM-Based OCR (Experimental)** - - If you enabled `VISION_LLM_PROVIDER` and `VISION_LLM_MODEL`, let AI-based OCR read your scanned PDFs. - - Tag those documents with `paperless-gpt-ocr-auto` (or your custom `AUTO_OCR_TAG`). - -**Tip**: The entire pipeline can be **fully automated** if you prefer minimal manual intervention. + - Click "Generate Suggestions" to see AI-proposed titles/tags/correspondents + - Review and approve or edit suggestions + - Click "Apply" to save changes to paperless-ngx +4. **OCR Processing** + - Tag documents with appropriate OCR tag to process them + - Monitor progress in the Web UI + - Review results and apply changes --- ## LLM-Based OCR: Compare for Yourself diff --git a/go.mod b/go.mod index f5b996d..43f5011 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/gen2brain/go-fitz v1.24.14 github.com/gin-gonic/gin v1.10.0 github.com/google/uuid v1.6.0 + github.com/hashicorp/go-retryablehttp v0.7.7 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.10.0 github.com/tmc/langchaingo v0.1.13 @@ -48,6 +49,7 @@ require ( github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/huandu/xstrings v1.5.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect diff --git a/go.sum b/go.sum index 3588b06..6a594f0 100644 --- a/go.sum +++ b/go.sum @@ -75,6 +75,12 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gT github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q= github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= +github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= diff --git a/main.go b/main.go index d382901..8ce68b6 100644 --- a/main.go +++ b/main.go @@ -36,6 +36,10 @@ var ( correspondentBlackList = strings.Split(os.Getenv("CORRESPONDENT_BLACK_LIST"), ",") paperlessBaseURL = os.Getenv("PAPERLESS_BASE_URL") paperlessAPIToken = os.Getenv("PAPERLESS_API_TOKEN") + azureDocAIEndpoint = os.Getenv("AZURE_DOCAI_ENDPOINT") + azureDocAIKey = os.Getenv("AZURE_DOCAI_KEY") + azureDocAIModelID = os.Getenv("AZURE_DOCAI_MODEL_ID") + azureDocAITimeout = os.Getenv("AZURE_DOCAI_TIMEOUT_SECONDS") openaiAPIKey = os.Getenv("OPENAI_API_KEY") manualTag = os.Getenv("MANUAL_TAG") autoTag = os.Getenv("AUTO_TAG") @@ -167,6 +171,18 @@ func main() { GoogleProcessorID: os.Getenv("GOOGLE_PROCESSOR_ID"), VisionLLMProvider: visionLlmProvider, VisionLLMModel: visionLlmModel, + AzureEndpoint: azureDocAIEndpoint, + AzureAPIKey: azureDocAIKey, + AzureModelID: azureDocAIModelID, + } + + // Parse Azure timeout if set + if azureDocAITimeout != "" { + if timeout, err := strconv.Atoi(azureDocAITimeout); err == nil { + ocrConfig.AzureTimeout = timeout + } else { + log.Warnf("Invalid AZURE_DOCAI_TIMEOUT_SECONDS value: %v, using default", err) + } } // If provider is LLM, but no VISION_LLM_PROVIDER is set, don't initialize OCR provider @@ -422,6 +438,17 @@ func validateOrDefaultEnvVars() { log.Fatal("Please set the LLM_PROVIDER environment variable to 'openai' or 'ollama'.") } + // Validate OCR provider if set + ocrProvider := os.Getenv("OCR_PROVIDER") + if ocrProvider == "azure" { + if azureDocAIEndpoint == "" { + log.Fatal("Please set the AZURE_DOCAI_ENDPOINT environment variable for Azure provider") + } + if azureDocAIKey == "" { + log.Fatal("Please set the AZURE_DOCAI_KEY environment variable for Azure provider") + } + } + if llmModel == "" { log.Fatal("Please set the LLM_MODEL environment variable.") } diff --git a/ocr/azure_provider.go b/ocr/azure_provider.go new file mode 100644 index 0000000..426937f --- /dev/null +++ b/ocr/azure_provider.go @@ -0,0 +1,224 @@ +package ocr + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/gabriel-vasile/mimetype" + "github.com/hashicorp/go-retryablehttp" + "github.com/sirupsen/logrus" +) + +const ( + apiVersion = "2024-11-30" + defaultModelID = "prebuilt-read" + defaultTimeout = 120 + pollingInterval = 2 * time.Second +) + +// AzureProvider implements OCR using Azure Document Intelligence +type AzureProvider struct { + endpoint string + apiKey string + modelID string + timeout time.Duration + httpClient *retryablehttp.Client +} + +// Request body for Azure Document Intelligence +type analyzeRequest struct { + Base64Source string `json:"base64Source"` +} + +func newAzureProvider(config Config) (*AzureProvider, error) { + logger := log.WithFields(logrus.Fields{ + "endpoint": config.AzureEndpoint, + "model_id": config.AzureModelID, + }) + logger.Info("Creating new Azure Document Intelligence provider") + + // Validate required configuration + if config.AzureEndpoint == "" || config.AzureAPIKey == "" { + logger.Error("Missing required configuration") + return nil, fmt.Errorf("missing required Azure Document Intelligence configuration") + } + + // Set defaults and create provider + modelID := defaultModelID + if config.AzureModelID != "" { + modelID = config.AzureModelID + } + + timeout := defaultTimeout + if config.AzureTimeout > 0 { + timeout = config.AzureTimeout + } + + // Configure retryablehttp client + client := retryablehttp.NewClient() + client.RetryMax = 3 + client.RetryWaitMin = 1 * time.Second + client.RetryWaitMax = 5 * time.Second + client.Logger = logger + + provider := &AzureProvider{ + endpoint: config.AzureEndpoint, + apiKey: config.AzureAPIKey, + modelID: modelID, + timeout: time.Duration(timeout) * time.Second, + httpClient: client, + } + + logger.Info("Successfully initialized Azure Document Intelligence provider") + return provider, nil +} + +func (p *AzureProvider) ProcessImage(ctx context.Context, imageContent []byte) (*OCRResult, error) { + logger := log.WithFields(logrus.Fields{ + "model_id": p.modelID, + }) + logger.Debug("Starting Azure Document Intelligence processing") + + // Detect MIME type + mtype := mimetype.Detect(imageContent) + logger.WithField("mime_type", mtype.String()).Debug("Detected file type") + + if !isImageMIMEType(mtype.String()) { + logger.WithField("mime_type", mtype.String()).Error("Unsupported file type") + return nil, fmt.Errorf("unsupported file type: %s", mtype.String()) + } + + // Create context with timeout + ctx, cancel := context.WithTimeout(ctx, p.timeout) + defer cancel() + + // Submit document for analysis + operationLocation, err := p.submitDocument(ctx, imageContent) + if err != nil { + return nil, fmt.Errorf("error submitting document: %w", err) + } + + // Poll for results + result, err := p.pollForResults(ctx, operationLocation) + if err != nil { + return nil, fmt.Errorf("error polling for results: %w", err) + } + + // Convert to OCR result + ocrResult := &OCRResult{ + Text: result.AnalyzeResult.Content, + Metadata: map[string]string{ + "provider": "azure_docai", + "page_count": fmt.Sprintf("%d", len(result.AnalyzeResult.Pages)), + "api_version": result.AnalyzeResult.APIVersion, + }, + } + + logger.WithFields(logrus.Fields{ + "content_length": len(ocrResult.Text), + "page_count": len(result.AnalyzeResult.Pages), + }).Info("Successfully processed document") + return ocrResult, nil +} + +func (p *AzureProvider) submitDocument(ctx context.Context, imageContent []byte) (string, error) { + requestURL := fmt.Sprintf("%s/documentintelligence/documentModels/%s:analyze?api-version=%s", + p.endpoint, p.modelID, apiVersion) + + // Prepare request body + requestBody := analyzeRequest{ + Base64Source: base64.StdEncoding.EncodeToString(imageContent), + } + requestBodyBytes, err := json.Marshal(requestBody) + if err != nil { + return "", fmt.Errorf("error marshaling request body: %w", err) + } + + req, err := retryablehttp.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewBuffer(requestBodyBytes)) + if err != nil { + return "", fmt.Errorf("error creating HTTP request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Ocp-Apim-Subscription-Key", p.apiKey) + + resp, err := p.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("error sending HTTP request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, string(body)) + } + + operationLocation := resp.Header.Get("Operation-Location") + if operationLocation == "" { + return "", fmt.Errorf("no Operation-Location header in response") + } + + return operationLocation, nil +} + +func (p *AzureProvider) pollForResults(ctx context.Context, operationLocation string) (*AzureDocumentResult, error) { + logger := log.WithField("operation_location", operationLocation) + logger.Debug("Starting to poll for results") + + ticker := time.NewTicker(pollingInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("operation timed out after %v: %w", p.timeout, ctx.Err()) + case <-ticker.C: + req, err := retryablehttp.NewRequestWithContext(ctx, "GET", operationLocation, nil) + if err != nil { + return nil, fmt.Errorf("error creating poll request: %w", err) + } + req.Header.Set("Ocp-Apim-Subscription-Key", p.apiKey) + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error polling for results: %w", err) + } + + var result AzureDocumentResult + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + resp.Body.Close() + logger.WithError(err).Error("Failed to decode response") + return nil, fmt.Errorf("error decoding response: %w", err) + } + defer resp.Body.Close() + + logger.WithFields(logrus.Fields{ + "status_code": resp.StatusCode, + "content_length": len(result.AnalyzeResult.Content), + "page_count": len(result.AnalyzeResult.Pages), + "status": result.Status, + }).Debug("Poll response received") + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code %d while polling", resp.StatusCode) + } + + switch result.Status { + case "succeeded": + return &result, nil + case "failed": + return nil, fmt.Errorf("document processing failed") + case "running": + // Continue polling + default: + return nil, fmt.Errorf("unexpected status: %s", result.Status) + } + } + } +} diff --git a/ocr/azure_provider_test.go b/ocr/azure_provider_test.go new file mode 100644 index 0000000..56d6743 --- /dev/null +++ b/ocr/azure_provider_test.go @@ -0,0 +1,222 @@ +package ocr + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/hashicorp/go-retryablehttp" + "github.com/stretchr/testify/assert" +) + +func TestNewAzureProvider(t *testing.T) { + tests := []struct { + name string + config Config + wantErr bool + errContains string + }{ + { + name: "valid config", + config: Config{ + AzureEndpoint: "https://test.cognitiveservices.azure.com/", + AzureAPIKey: "test-key", + }, + wantErr: false, + }, + { + name: "valid config with custom model and timeout", + config: Config{ + AzureEndpoint: "https://test.cognitiveservices.azure.com/", + AzureAPIKey: "test-key", + AzureModelID: "custom-model", + AzureTimeout: 60, + }, + wantErr: false, + }, + { + name: "missing endpoint", + config: Config{ + AzureAPIKey: "test-key", + }, + wantErr: true, + errContains: "missing required Azure Document Intelligence configuration", + }, + { + name: "missing api key", + config: Config{ + AzureEndpoint: "https://test.cognitiveservices.azure.com/", + }, + wantErr: true, + errContains: "missing required Azure Document Intelligence configuration", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := newAzureProvider(tt.config) + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + assert.NoError(t, err) + assert.NotNil(t, provider) + + // Verify default values + if tt.config.AzureModelID == "" { + assert.Equal(t, defaultModelID, provider.modelID) + } else { + assert.Equal(t, tt.config.AzureModelID, provider.modelID) + } + + if tt.config.AzureTimeout == 0 { + assert.Equal(t, time.Duration(defaultTimeout)*time.Second, provider.timeout) + } else { + assert.Equal(t, time.Duration(tt.config.AzureTimeout)*time.Second, provider.timeout) + } + }) + } +} + +func TestAzureProvider_ProcessImage(t *testing.T) { + // Sample success response + now := time.Now() + successResult := AzureDocumentResult{ + Status: "succeeded", + CreatedDateTime: now, + LastUpdatedDateTime: now, + AnalyzeResult: AzureAnalyzeResult{ + APIVersion: apiVersion, + ModelID: defaultModelID, + StringIndexType: "utf-16", + Content: "Test document content", + Pages: []AzurePage{ + { + PageNumber: 1, + Angle: 0.0, + Width: 800, + Height: 600, + Unit: "pixel", + Lines: []AzureLine{ + { + Content: "Test line", + Polygon: []int{0, 0, 100, 0, 100, 20, 0, 20}, + Spans: []AzureSpan{{Offset: 0, Length: 9}}, + }, + }, + Spans: []AzureSpan{{Offset: 0, Length: 9}}, + }, + }, + Paragraphs: []AzureParagraph{ + { + Content: "Test document content", + Spans: []AzureSpan{{Offset: 0, Length: 19}}, + BoundingRegions: []AzureBoundingBox{ + { + PageNumber: 1, + Polygon: []int{0, 0, 100, 0, 100, 20, 0, 20}, + }, + }, + }, + }, + ContentFormat: "text", + }, + } + + tests := []struct { + name string + setupServer func() *httptest.Server + imageContent []byte + wantErr bool + errContains string + expectedText string + }{ + { + name: "successful processing", + setupServer: func() *httptest.Server { + mux := http.NewServeMux() + server := httptest.NewServer(mux) + + mux.HandleFunc("/documentintelligence/documentModels/prebuilt-read:analyze", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Operation-Location", fmt.Sprintf("%s/operations/123", server.URL)) + w.WriteHeader(http.StatusAccepted) + }) + + mux.HandleFunc("/operations/123", func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(successResult) + }) + + return server + }, + // Create minimal JPEG content with magic numbers + imageContent: append([]byte{0xFF, 0xD8, 0xFF, 0xE0}, []byte("JFIF test content")...), + expectedText: "Test document content", + }, + { + name: "invalid mime type", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Log("Server should not be called with invalid mime type") + w.WriteHeader(http.StatusBadRequest) + })) + }, + imageContent: []byte("invalid content"), + wantErr: true, + errContains: "unsupported file type", + }, + { + name: "submission error", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintln(w, "Invalid request") + })) + }, + imageContent: []byte{0xFF, 0xD8, 0xFF, 0xE0}, // JPEG magic numbers + wantErr: true, + errContains: "unexpected status code 400", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := tt.setupServer() + defer server.Close() + + client := retryablehttp.NewClient() + client.HTTPClient = server.Client() + client.Logger = log + + provider := &AzureProvider{ + endpoint: server.URL, + apiKey: "test-key", + modelID: defaultModelID, + timeout: 5 * time.Second, + httpClient: client, + } + + result, err := provider.ProcessImage(context.Background(), tt.imageContent) + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, tt.expectedText, result.Text) + assert.Equal(t, "azure_docai", result.Metadata["provider"]) + assert.Equal(t, apiVersion, result.Metadata["api_version"]) + assert.Equal(t, "1", result.Metadata["page_count"]) + }) + } +} diff --git a/ocr/azure_types.go b/ocr/azure_types.go new file mode 100644 index 0000000..71a83ad --- /dev/null +++ b/ocr/azure_types.go @@ -0,0 +1,72 @@ +package ocr + +import "time" + +// AzureDocumentResult represents the root response from Azure Document Intelligence +type AzureDocumentResult struct { + Status string `json:"status"` + CreatedDateTime time.Time `json:"createdDateTime"` + LastUpdatedDateTime time.Time `json:"lastUpdatedDateTime"` + AnalyzeResult AzureAnalyzeResult `json:"analyzeResult"` +} + +// AzureAnalyzeResult represents the analyze result part of the Azure Document Intelligence response +type AzureAnalyzeResult struct { + APIVersion string `json:"apiVersion"` + ModelID string `json:"modelId"` + StringIndexType string `json:"stringIndexType"` + Content string `json:"content"` + Pages []AzurePage `json:"pages"` + Paragraphs []AzureParagraph `json:"paragraphs"` + Styles []interface{} `json:"styles"` + ContentFormat string `json:"contentFormat"` +} + +// AzurePage represents a single page in the document +type AzurePage struct { + PageNumber int `json:"pageNumber"` + Angle float64 `json:"angle"` + Width int `json:"width"` + Height int `json:"height"` + Unit string `json:"unit"` + Words []AzureWord `json:"words"` + Lines []AzureLine `json:"lines"` + Spans []AzureSpan `json:"spans"` +} + +// AzureWord represents a single word with its properties +type AzureWord struct { + Content string `json:"content"` + Polygon []int `json:"polygon"` + Confidence float64 `json:"confidence"` + Span AzureSpan `json:"span"` +} + +// AzureLine represents a line of text +type AzureLine struct { + Content string `json:"content"` + Polygon []int `json:"polygon"` + Spans []AzureSpan `json:"spans"` +} + +// AzureSpan represents a span of text with offset and length +type AzureSpan struct { + Offset int `json:"offset"` + Length int `json:"length"` +} + +// AzureParagraph represents a paragraph of text +type AzureParagraph struct { + Content string `json:"content"` + Spans []AzureSpan `json:"spans"` + BoundingRegions []AzureBoundingBox `json:"boundingRegions"` +} + +// AzureBoundingBox represents the location of content on a page +type AzureBoundingBox struct { + PageNumber int `json:"pageNumber"` + Polygon []int `json:"polygon"` +} + +// AzureStyle represents style information for text segments - changed to interface{} as per input +type AzureStyle interface{} diff --git a/ocr/provider.go b/ocr/provider.go index cfd3b62..1fc9873 100644 --- a/ocr/provider.go +++ b/ocr/provider.go @@ -28,7 +28,7 @@ type Provider interface { // Config holds the OCR provider configuration type Config struct { - // Provider type (e.g., "llm", "google_docai") + // Provider type (e.g., "llm", "google_docai", "azure") Provider string // Google Document AI settings @@ -40,6 +40,12 @@ type Config struct { VisionLLMProvider string VisionLLMModel string + // Azure Document Intelligence settings + AzureEndpoint string + AzureAPIKey string + AzureModelID string // Optional, defaults to "prebuilt-read" + AzureTimeout int // Optional, defaults to 120 seconds + // OCR output options EnableHOCR bool // Whether to request hOCR output if supported by the provider } @@ -69,6 +75,12 @@ func NewProvider(config Config) (Provider, error) { }).Info("Using LLM OCR provider") return newLLMProvider(config) + case "azure": + if config.AzureEndpoint == "" || config.AzureAPIKey == "" { + return nil, fmt.Errorf("missing required Azure Document Intelligence configuration") + } + return newAzureProvider(config) + default: return nil, fmt.Errorf("unsupported OCR provider: %s", config.Provider) }