feat(ocr): add support for Azure Document Intelligence provider

This commit is contained in:
Dominik Schröter 2025-03-10 10:41:37 +01:00
parent 360663b05b
commit 2a16d9d95c
8 changed files with 658 additions and 21 deletions

112
README.md
View file

@ -22,7 +22,7 @@ https://github.com/user-attachments/assets/bd5d38b9-9309-40b9-93ca-918dfa4f3fd4
- **LLM OCR**: Use OpenAI or Ollama to extract text from images. - **LLM OCR**: Use OpenAI or Ollama to extract text from images.
- **Google Document AI**: Leverage Google's powerful Document AI for OCR tasks. - **Google Document AI**: Leverage Google's powerful Document AI for OCR tasks.
- **More to come**: Stay tuned for more OCR providers! - **Azure Document Intelligence**: Use Microsoft's enterprise OCR solution.
3. **Automatic Title & Tag Generation** 3. **Automatic Title & Tag Generation**
No more guesswork. Let the AI do the naming and categorizing. You can easily review suggestions and refine them if needed. No more guesswork. Let the AI do the naming and categorizing. You can easily review suggestions and refine them if needed.
@ -39,11 +39,11 @@ https://github.com/user-attachments/assets/bd5d38b9-9309-40b9-93ca-918dfa4f3fd4
- **Tagging**: Decide how documents get tagged—manually, automatically, or via OCR-based flows. - **Tagging**: Decide how documents get tagged—manually, automatically, or via OCR-based flows.
7. **Simple Docker Deployment** 7. **Simple Docker Deployment**
A few environment variables, and youre off! Compose it alongside paperless-ngx with minimal fuss. A few environment variables, and you're off! Compose it alongside paperless-ngx with minimal fuss.
8. **Unified Web UI** 8. **Unified Web UI**
- **Manual Review**: Approve or tweak AIs suggestions. - **Manual Review**: Approve or tweak AI's suggestions.
- **Auto Processing**: Focus only on edge cases while the rest is sorted for you. - **Auto Processing**: Focus only on edge cases while the rest is sorted for you.
--- ---
@ -56,6 +56,12 @@ https://github.com/user-attachments/assets/bd5d38b9-9309-40b9-93ca-918dfa4f3fd4
- [Installation](#installation) - [Installation](#installation)
- [Docker Compose](#docker-compose) - [Docker Compose](#docker-compose)
- [Manual Setup](#manual-setup) - [Manual Setup](#manual-setup)
- [OCR Providers](#ocr-providers)
- [LLM-based OCR](#1-llm-based-ocr-default)
- [Azure Document Intelligence](#2-azure-document-intelligence)
- [Google Document AI](#3-google-document-ai)
- [Comparing OCR Providers](#comparing-ocr-providers)
- [Choosing the Right Provider](#choosing-the-right-provider)
- [Configuration](#configuration) - [Configuration](#configuration)
- [Environment Variables](#environment-variables) - [Environment Variables](#environment-variables)
- [Custom Prompt Templates](#custom-prompt-templates) - [Custom Prompt Templates](#custom-prompt-templates)
@ -86,7 +92,7 @@ https://github.com/user-attachments/assets/bd5d38b9-9309-40b9-93ca-918dfa4f3fd4
#### Docker Compose #### Docker Compose
Heres an example `docker-compose.yml` to spin up **paperless-gpt** alongside paperless-ngx: Here's an example `docker-compose.yml` to spin up **paperless-gpt** alongside paperless-ngx:
```yaml ```yaml
services: services:
@ -124,6 +130,13 @@ services:
# GOOGLE_PROCESSOR_ID: 'processor-id' # Your processor ID # GOOGLE_PROCESSOR_ID: 'processor-id' # Your processor ID
# GOOGLE_APPLICATION_CREDENTIALS: '/app/credentials.json' # Path to service account key # GOOGLE_APPLICATION_CREDENTIALS: '/app/credentials.json' # Path to service account key
# Option 3: Azure Document Intelligence
# OCR_PROVIDER: 'azure' # Use Azure Document Intelligence
# AZURE_DOCAI_ENDPOINT: 'your-endpoint' # Your Azure endpoint URL
# AZURE_DOCAI_KEY: 'your-key' # Your Azure API key
# AZURE_DOCAI_MODEL_ID: 'prebuilt-read' # Optional, defaults to prebuilt-read
# AZURE_DOCAI_TIMEOUT_SECONDS: '120' # Optional, defaults to 120 seconds
AUTO_OCR_TAG: "paperless-gpt-ocr-auto" # Optional, default: paperless-gpt-ocr-auto AUTO_OCR_TAG: "paperless-gpt-ocr-auto" # Optional, default: paperless-gpt-ocr-auto
OCR_LIMIT_PAGES: "5" # Optional, default: 5. Set to 0 for no limit. OCR_LIMIT_PAGES: "5" # Optional, default: 5. Set to 0 for no limit.
LOG_LEVEL: "info" # Optional: debug, warn, error LOG_LEVEL: "info" # Optional: debug, warn, error
@ -172,6 +185,63 @@ services:
``` ```
--- ---
## OCR Providers
paperless-gpt supports three different OCR providers, each with unique strengths and capabilities:
### 1. LLM-based OCR (Default)
- **Key Features**:
- Uses vision-capable LLMs like GPT-4V or MiniCPM-V
- High accuracy with complex layouts and difficult scans
- Context-aware text recognition
- Self-correcting capabilities for OCR errors
- **Best For**:
- Complex or unusual document layouts
- Poor quality scans
- Documents with mixed languages
- **Configuration**:
```yaml
OCR_PROVIDER: "llm"
VISION_LLM_PROVIDER: "openai" # or "ollama"
VISION_LLM_MODEL: "gpt-4v" # or "minicpm-v"
```
### 2. Azure Document Intelligence
- **Key Features**:
- Enterprise-grade OCR solution
- Prebuilt models for common document types
- Layout preservation and table detection
- Fast processing speeds
- **Best For**:
- Business documents and forms
- High-volume processing
- Documents requiring layout analysis
- **Configuration**:
```yaml
OCR_PROVIDER: "azure"
AZURE_DOCAI_ENDPOINT: "https://your-endpoint.cognitiveservices.azure.com/"
AZURE_DOCAI_KEY: "your-key"
AZURE_DOCAI_MODEL_ID: "prebuilt-read" # optional
AZURE_DOCAI_TIMEOUT_SECONDS: "120" # optional
```
### 3. Google Document AI
- **Key Features**:
- Specialized document processors
- Strong form field detection
- Multi-language support
- High accuracy on structured documents
- **Best For**:
- Forms and structured documents
- Documents with tables
- Multi-language documents
- **Configuration**:
```yaml
OCR_PROVIDER: "google_docai"
GOOGLE_PROJECT_ID: "your-project"
GOOGLE_LOCATION: "us"
GOOGLE_PROCESSOR_ID: "processor-id"
```
## Configuration ## Configuration
@ -192,9 +262,13 @@ services:
| `OPENAI_BASE_URL` | OpenAI base URL (optional, if using a custom OpenAI compatible service like LiteLLM). | No | | | `OPENAI_BASE_URL` | OpenAI base URL (optional, if using a custom OpenAI compatible service like LiteLLM). | No | |
| `LLM_LANGUAGE` | Likely language for documents (e.g. `English`). | No | English | | `LLM_LANGUAGE` | Likely language for documents (e.g. `English`). | No | English |
| `OLLAMA_HOST` | Ollama server URL (e.g. `http://host.docker.internal:11434`). | No | | | `OLLAMA_HOST` | Ollama server URL (e.g. `http://host.docker.internal:11434`). | No | |
| `OCR_PROVIDER` | OCR provider to use (`llm` or `google_docai`). | No | llm | | `OCR_PROVIDER` | OCR provider to use (`llm`, `azure`, or `google_docai`). | No | llm |
| `VISION_LLM_PROVIDER` | AI backend for LLM OCR (`openai` or `ollama`). Required if OCR_PROVIDER is `llm`. | Cond. | | | `VISION_LLM_PROVIDER` | AI backend for LLM OCR (`openai` or `ollama`). Required if OCR_PROVIDER is `llm`. | Cond. | |
| `VISION_LLM_MODEL` | Model name for LLM OCR (e.g. `minicpm-v`). Required if OCR_PROVIDER is `llm`. | Cond. | | | `VISION_LLM_MODEL` | Model name for LLM OCR (e.g. `minicpm-v`). Required if OCR_PROVIDER is `llm`. | Cond. | |
| `AZURE_DOCAI_ENDPOINT` | Azure Document Intelligence endpoint. Required if OCR_PROVIDER is `azure`. | Cond. | |
| `AZURE_DOCAI_KEY` | Azure Document Intelligence API key. Required if OCR_PROVIDER is `azure`. | Cond. | |
| `AZURE_DOCAI_MODEL_ID` | Azure Document Intelligence model ID. Optional if using `azure` provider. | No | prebuilt-read |
| `AZURE_DOCAI_TIMEOUT_SECONDS` | Azure Document Intelligence timeout in seconds. | No | 120 |
| `GOOGLE_PROJECT_ID` | Google Cloud project ID. Required if OCR_PROVIDER is `google_docai`. | Cond. | | | `GOOGLE_PROJECT_ID` | Google Cloud project ID. Required if OCR_PROVIDER is `google_docai`. | Cond. | |
| `GOOGLE_LOCATION` | Google Cloud region (e.g. `us`, `eu`). Required if OCR_PROVIDER is `google_docai`. | Cond. | | | `GOOGLE_LOCATION` | Google Cloud region (e.g. `us`, `eu`). Required if OCR_PROVIDER is `google_docai`. | Cond. | |
| `GOOGLE_PROCESSOR_ID` | Document AI processor ID. Required if OCR_PROVIDER is `google_docai`. | Cond. | | | `GOOGLE_PROCESSOR_ID` | Document AI processor ID. Required if OCR_PROVIDER is `google_docai`. | Cond. | |
@ -211,7 +285,7 @@ services:
### Custom Prompt Templates ### Custom Prompt Templates
paperless-gpts flexible **prompt templates** let you shape how AI responds: paperless-gpt's flexible **prompt templates** let you shape how AI responds:
1. **`title_prompt.tmpl`**: For document titles. 1. **`title_prompt.tmpl`**: For document titles.
2. **`tag_prompt.tmpl`**: For tagging logic. 2. **`tag_prompt.tmpl`**: For tagging logic.
@ -232,13 +306,11 @@ Then tweak at will—**paperless-gpt** reloads them automatically on startup!
Each template has access to specific variables: Each template has access to specific variables:
**title_prompt.tmpl**: **title_prompt.tmpl**:
- `{{.Language}}` - Target language (e.g., "English") - `{{.Language}}` - Target language (e.g., "English")
- `{{.Content}}` - Document content text - `{{.Content}}` - Document content text
- `{{.Title}}` - Original document title - `{{.Title}}` - Original document title
**tag_prompt.tmpl**: **tag_prompt.tmpl**:
- `{{.Language}}` - Target language - `{{.Language}}` - Target language
- `{{.AvailableTags}}` - List of existing tags in paperless-ngx - `{{.AvailableTags}}` - List of existing tags in paperless-ngx
- `{{.OriginalTags}}` - Document's current tags - `{{.OriginalTags}}` - Document's current tags
@ -246,11 +318,9 @@ Each template has access to specific variables:
- `{{.Content}}` - Document content text - `{{.Content}}` - Document content text
**ocr_prompt.tmpl**: **ocr_prompt.tmpl**:
- `{{.Language}}` - Target language - `{{.Language}}` - Target language
**correspondent_prompt.tmpl**: **correspondent_prompt.tmpl**:
- `{{.Language}}` - Target language - `{{.Language}}` - Target language
- `{{.AvailableCorrespondents}}` - List of existing correspondents - `{{.AvailableCorrespondents}}` - List of existing correspondents
- `{{.BlackList}}` - List of blacklisted correspondent names - `{{.BlackList}}` - List of blacklisted correspondent names
@ -265,23 +335,25 @@ The templates use Go's text/template syntax. paperless-gpt automatically reloads
1. **Tag Documents** 1. **Tag Documents**
- Add `paperless-gpt` or your custom tag to the docs you want to AI-ify. - Add `paperless-gpt` tag to documents for manual processing
- Add `paperless-gpt-auto` for automatic processing
- Add `paperless-gpt-ocr-auto` for automatic OCR processing
2. **Visit Web UI** 2. **Visit Web UI**
- Go to `http://localhost:8080` (or your host) in your browser. - Go to `http://localhost:8080` (or your host) in your browser
- Review documents tagged for processing
3. **Generate & Apply Suggestions** 3. **Generate & Apply Suggestions**
- Click “Generate Suggestions” to see AI-proposed titles/tags/correspondents. - Click "Generate Suggestions" to see AI-proposed titles/tags/correspondents
- Approve, edit, or discard. Hit “Apply” to finalize in paperless-ngx. - Review and approve or edit suggestions
- Click "Apply" to save changes to paperless-ngx
4. **Try LLM-Based OCR (Experimental)**
- If you enabled `VISION_LLM_PROVIDER` and `VISION_LLM_MODEL`, let AI-based OCR read your scanned PDFs.
- Tag those documents with `paperless-gpt-ocr-auto` (or your custom `AUTO_OCR_TAG`).
**Tip**: The entire pipeline can be **fully automated** if you prefer minimal manual intervention.
4. **OCR Processing**
- Tag documents with appropriate OCR tag to process them
- Monitor progress in the Web UI
- Review results and apply changes
--- ---
## LLM-Based OCR: Compare for Yourself ## LLM-Based OCR: Compare for Yourself

2
go.mod
View file

@ -12,6 +12,7 @@ require (
github.com/gen2brain/go-fitz v1.24.14 github.com/gen2brain/go-fitz v1.24.14
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/hashicorp/go-retryablehttp v0.7.7
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0
github.com/tmc/langchaingo v0.1.13 github.com/tmc/langchaingo v0.1.13
@ -48,6 +49,7 @@ require (
github.com/google/s2a-go v0.1.9 // indirect github.com/google/s2a-go v0.1.9 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
github.com/googleapis/gax-go/v2 v2.14.1 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/huandu/xstrings v1.5.0 // indirect github.com/huandu/xstrings v1.5.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect

6
go.sum
View file

@ -75,6 +75,12 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gT
github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q= github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q=
github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA= github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA=
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k=
github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M=
github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU=
github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=

27
main.go
View file

@ -36,6 +36,10 @@ var (
correspondentBlackList = strings.Split(os.Getenv("CORRESPONDENT_BLACK_LIST"), ",") correspondentBlackList = strings.Split(os.Getenv("CORRESPONDENT_BLACK_LIST"), ",")
paperlessBaseURL = os.Getenv("PAPERLESS_BASE_URL") paperlessBaseURL = os.Getenv("PAPERLESS_BASE_URL")
paperlessAPIToken = os.Getenv("PAPERLESS_API_TOKEN") paperlessAPIToken = os.Getenv("PAPERLESS_API_TOKEN")
azureDocAIEndpoint = os.Getenv("AZURE_DOCAI_ENDPOINT")
azureDocAIKey = os.Getenv("AZURE_DOCAI_KEY")
azureDocAIModelID = os.Getenv("AZURE_DOCAI_MODEL_ID")
azureDocAITimeout = os.Getenv("AZURE_DOCAI_TIMEOUT_SECONDS")
openaiAPIKey = os.Getenv("OPENAI_API_KEY") openaiAPIKey = os.Getenv("OPENAI_API_KEY")
manualTag = os.Getenv("MANUAL_TAG") manualTag = os.Getenv("MANUAL_TAG")
autoTag = os.Getenv("AUTO_TAG") autoTag = os.Getenv("AUTO_TAG")
@ -167,6 +171,18 @@ func main() {
GoogleProcessorID: os.Getenv("GOOGLE_PROCESSOR_ID"), GoogleProcessorID: os.Getenv("GOOGLE_PROCESSOR_ID"),
VisionLLMProvider: visionLlmProvider, VisionLLMProvider: visionLlmProvider,
VisionLLMModel: visionLlmModel, VisionLLMModel: visionLlmModel,
AzureEndpoint: azureDocAIEndpoint,
AzureAPIKey: azureDocAIKey,
AzureModelID: azureDocAIModelID,
}
// Parse Azure timeout if set
if azureDocAITimeout != "" {
if timeout, err := strconv.Atoi(azureDocAITimeout); err == nil {
ocrConfig.AzureTimeout = timeout
} else {
log.Warnf("Invalid AZURE_DOCAI_TIMEOUT_SECONDS value: %v, using default", err)
}
} }
// If provider is LLM, but no VISION_LLM_PROVIDER is set, don't initialize OCR provider // If provider is LLM, but no VISION_LLM_PROVIDER is set, don't initialize OCR provider
@ -422,6 +438,17 @@ func validateOrDefaultEnvVars() {
log.Fatal("Please set the LLM_PROVIDER environment variable to 'openai' or 'ollama'.") log.Fatal("Please set the LLM_PROVIDER environment variable to 'openai' or 'ollama'.")
} }
// Validate OCR provider if set
ocrProvider := os.Getenv("OCR_PROVIDER")
if ocrProvider == "azure" {
if azureDocAIEndpoint == "" {
log.Fatal("Please set the AZURE_DOCAI_ENDPOINT environment variable for Azure provider")
}
if azureDocAIKey == "" {
log.Fatal("Please set the AZURE_DOCAI_KEY environment variable for Azure provider")
}
}
if llmModel == "" { if llmModel == "" {
log.Fatal("Please set the LLM_MODEL environment variable.") log.Fatal("Please set the LLM_MODEL environment variable.")
} }

224
ocr/azure_provider.go Normal file
View file

@ -0,0 +1,224 @@
package ocr
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/gabriel-vasile/mimetype"
"github.com/hashicorp/go-retryablehttp"
"github.com/sirupsen/logrus"
)
const (
apiVersion = "2024-11-30"
defaultModelID = "prebuilt-read"
defaultTimeout = 120
pollingInterval = 2 * time.Second
)
// AzureProvider implements OCR using Azure Document Intelligence
type AzureProvider struct {
endpoint string
apiKey string
modelID string
timeout time.Duration
httpClient *retryablehttp.Client
}
// Request body for Azure Document Intelligence
type analyzeRequest struct {
Base64Source string `json:"base64Source"`
}
func newAzureProvider(config Config) (*AzureProvider, error) {
logger := log.WithFields(logrus.Fields{
"endpoint": config.AzureEndpoint,
"model_id": config.AzureModelID,
})
logger.Info("Creating new Azure Document Intelligence provider")
// Validate required configuration
if config.AzureEndpoint == "" || config.AzureAPIKey == "" {
logger.Error("Missing required configuration")
return nil, fmt.Errorf("missing required Azure Document Intelligence configuration")
}
// Set defaults and create provider
modelID := defaultModelID
if config.AzureModelID != "" {
modelID = config.AzureModelID
}
timeout := defaultTimeout
if config.AzureTimeout > 0 {
timeout = config.AzureTimeout
}
// Configure retryablehttp client
client := retryablehttp.NewClient()
client.RetryMax = 3
client.RetryWaitMin = 1 * time.Second
client.RetryWaitMax = 5 * time.Second
client.Logger = logger
provider := &AzureProvider{
endpoint: config.AzureEndpoint,
apiKey: config.AzureAPIKey,
modelID: modelID,
timeout: time.Duration(timeout) * time.Second,
httpClient: client,
}
logger.Info("Successfully initialized Azure Document Intelligence provider")
return provider, nil
}
func (p *AzureProvider) ProcessImage(ctx context.Context, imageContent []byte) (*OCRResult, error) {
logger := log.WithFields(logrus.Fields{
"model_id": p.modelID,
})
logger.Debug("Starting Azure Document Intelligence processing")
// Detect MIME type
mtype := mimetype.Detect(imageContent)
logger.WithField("mime_type", mtype.String()).Debug("Detected file type")
if !isImageMIMEType(mtype.String()) {
logger.WithField("mime_type", mtype.String()).Error("Unsupported file type")
return nil, fmt.Errorf("unsupported file type: %s", mtype.String())
}
// Create context with timeout
ctx, cancel := context.WithTimeout(ctx, p.timeout)
defer cancel()
// Submit document for analysis
operationLocation, err := p.submitDocument(ctx, imageContent)
if err != nil {
return nil, fmt.Errorf("error submitting document: %w", err)
}
// Poll for results
result, err := p.pollForResults(ctx, operationLocation)
if err != nil {
return nil, fmt.Errorf("error polling for results: %w", err)
}
// Convert to OCR result
ocrResult := &OCRResult{
Text: result.AnalyzeResult.Content,
Metadata: map[string]string{
"provider": "azure_docai",
"page_count": fmt.Sprintf("%d", len(result.AnalyzeResult.Pages)),
"api_version": result.AnalyzeResult.APIVersion,
},
}
logger.WithFields(logrus.Fields{
"content_length": len(ocrResult.Text),
"page_count": len(result.AnalyzeResult.Pages),
}).Info("Successfully processed document")
return ocrResult, nil
}
func (p *AzureProvider) submitDocument(ctx context.Context, imageContent []byte) (string, error) {
requestURL := fmt.Sprintf("%s/documentintelligence/documentModels/%s:analyze?api-version=%s",
p.endpoint, p.modelID, apiVersion)
// Prepare request body
requestBody := analyzeRequest{
Base64Source: base64.StdEncoding.EncodeToString(imageContent),
}
requestBodyBytes, err := json.Marshal(requestBody)
if err != nil {
return "", fmt.Errorf("error marshaling request body: %w", err)
}
req, err := retryablehttp.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewBuffer(requestBodyBytes))
if err != nil {
return "", fmt.Errorf("error creating HTTP request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Ocp-Apim-Subscription-Key", p.apiKey)
resp, err := p.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("error sending HTTP request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusAccepted {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, string(body))
}
operationLocation := resp.Header.Get("Operation-Location")
if operationLocation == "" {
return "", fmt.Errorf("no Operation-Location header in response")
}
return operationLocation, nil
}
func (p *AzureProvider) pollForResults(ctx context.Context, operationLocation string) (*AzureDocumentResult, error) {
logger := log.WithField("operation_location", operationLocation)
logger.Debug("Starting to poll for results")
ticker := time.NewTicker(pollingInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, fmt.Errorf("operation timed out after %v: %w", p.timeout, ctx.Err())
case <-ticker.C:
req, err := retryablehttp.NewRequestWithContext(ctx, "GET", operationLocation, nil)
if err != nil {
return nil, fmt.Errorf("error creating poll request: %w", err)
}
req.Header.Set("Ocp-Apim-Subscription-Key", p.apiKey)
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("error polling for results: %w", err)
}
var result AzureDocumentResult
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
resp.Body.Close()
logger.WithError(err).Error("Failed to decode response")
return nil, fmt.Errorf("error decoding response: %w", err)
}
defer resp.Body.Close()
logger.WithFields(logrus.Fields{
"status_code": resp.StatusCode,
"content_length": len(result.AnalyzeResult.Content),
"page_count": len(result.AnalyzeResult.Pages),
"status": result.Status,
}).Debug("Poll response received")
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code %d while polling", resp.StatusCode)
}
switch result.Status {
case "succeeded":
return &result, nil
case "failed":
return nil, fmt.Errorf("document processing failed")
case "running":
// Continue polling
default:
return nil, fmt.Errorf("unexpected status: %s", result.Status)
}
}
}
}

222
ocr/azure_provider_test.go Normal file
View file

@ -0,0 +1,222 @@
package ocr
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/hashicorp/go-retryablehttp"
"github.com/stretchr/testify/assert"
)
func TestNewAzureProvider(t *testing.T) {
tests := []struct {
name string
config Config
wantErr bool
errContains string
}{
{
name: "valid config",
config: Config{
AzureEndpoint: "https://test.cognitiveservices.azure.com/",
AzureAPIKey: "test-key",
},
wantErr: false,
},
{
name: "valid config with custom model and timeout",
config: Config{
AzureEndpoint: "https://test.cognitiveservices.azure.com/",
AzureAPIKey: "test-key",
AzureModelID: "custom-model",
AzureTimeout: 60,
},
wantErr: false,
},
{
name: "missing endpoint",
config: Config{
AzureAPIKey: "test-key",
},
wantErr: true,
errContains: "missing required Azure Document Intelligence configuration",
},
{
name: "missing api key",
config: Config{
AzureEndpoint: "https://test.cognitiveservices.azure.com/",
},
wantErr: true,
errContains: "missing required Azure Document Intelligence configuration",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider, err := newAzureProvider(tt.config)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
assert.NoError(t, err)
assert.NotNil(t, provider)
// Verify default values
if tt.config.AzureModelID == "" {
assert.Equal(t, defaultModelID, provider.modelID)
} else {
assert.Equal(t, tt.config.AzureModelID, provider.modelID)
}
if tt.config.AzureTimeout == 0 {
assert.Equal(t, time.Duration(defaultTimeout)*time.Second, provider.timeout)
} else {
assert.Equal(t, time.Duration(tt.config.AzureTimeout)*time.Second, provider.timeout)
}
})
}
}
func TestAzureProvider_ProcessImage(t *testing.T) {
// Sample success response
now := time.Now()
successResult := AzureDocumentResult{
Status: "succeeded",
CreatedDateTime: now,
LastUpdatedDateTime: now,
AnalyzeResult: AzureAnalyzeResult{
APIVersion: apiVersion,
ModelID: defaultModelID,
StringIndexType: "utf-16",
Content: "Test document content",
Pages: []AzurePage{
{
PageNumber: 1,
Angle: 0.0,
Width: 800,
Height: 600,
Unit: "pixel",
Lines: []AzureLine{
{
Content: "Test line",
Polygon: []int{0, 0, 100, 0, 100, 20, 0, 20},
Spans: []AzureSpan{{Offset: 0, Length: 9}},
},
},
Spans: []AzureSpan{{Offset: 0, Length: 9}},
},
},
Paragraphs: []AzureParagraph{
{
Content: "Test document content",
Spans: []AzureSpan{{Offset: 0, Length: 19}},
BoundingRegions: []AzureBoundingBox{
{
PageNumber: 1,
Polygon: []int{0, 0, 100, 0, 100, 20, 0, 20},
},
},
},
},
ContentFormat: "text",
},
}
tests := []struct {
name string
setupServer func() *httptest.Server
imageContent []byte
wantErr bool
errContains string
expectedText string
}{
{
name: "successful processing",
setupServer: func() *httptest.Server {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
mux.HandleFunc("/documentintelligence/documentModels/prebuilt-read:analyze", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Operation-Location", fmt.Sprintf("%s/operations/123", server.URL))
w.WriteHeader(http.StatusAccepted)
})
mux.HandleFunc("/operations/123", func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(successResult)
})
return server
},
// Create minimal JPEG content with magic numbers
imageContent: append([]byte{0xFF, 0xD8, 0xFF, 0xE0}, []byte("JFIF test content")...),
expectedText: "Test document content",
},
{
name: "invalid mime type",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Log("Server should not be called with invalid mime type")
w.WriteHeader(http.StatusBadRequest)
}))
},
imageContent: []byte("invalid content"),
wantErr: true,
errContains: "unsupported file type",
},
{
name: "submission error",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintln(w, "Invalid request")
}))
},
imageContent: []byte{0xFF, 0xD8, 0xFF, 0xE0}, // JPEG magic numbers
wantErr: true,
errContains: "unexpected status code 400",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := tt.setupServer()
defer server.Close()
client := retryablehttp.NewClient()
client.HTTPClient = server.Client()
client.Logger = log
provider := &AzureProvider{
endpoint: server.URL,
apiKey: "test-key",
modelID: defaultModelID,
timeout: 5 * time.Second,
httpClient: client,
}
result, err := provider.ProcessImage(context.Background(), tt.imageContent)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
assert.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, tt.expectedText, result.Text)
assert.Equal(t, "azure_docai", result.Metadata["provider"])
assert.Equal(t, apiVersion, result.Metadata["api_version"])
assert.Equal(t, "1", result.Metadata["page_count"])
})
}
}

72
ocr/azure_types.go Normal file
View file

@ -0,0 +1,72 @@
package ocr
import "time"
// AzureDocumentResult represents the root response from Azure Document Intelligence
type AzureDocumentResult struct {
Status string `json:"status"`
CreatedDateTime time.Time `json:"createdDateTime"`
LastUpdatedDateTime time.Time `json:"lastUpdatedDateTime"`
AnalyzeResult AzureAnalyzeResult `json:"analyzeResult"`
}
// AzureAnalyzeResult represents the analyze result part of the Azure Document Intelligence response
type AzureAnalyzeResult struct {
APIVersion string `json:"apiVersion"`
ModelID string `json:"modelId"`
StringIndexType string `json:"stringIndexType"`
Content string `json:"content"`
Pages []AzurePage `json:"pages"`
Paragraphs []AzureParagraph `json:"paragraphs"`
Styles []interface{} `json:"styles"`
ContentFormat string `json:"contentFormat"`
}
// AzurePage represents a single page in the document
type AzurePage struct {
PageNumber int `json:"pageNumber"`
Angle float64 `json:"angle"`
Width int `json:"width"`
Height int `json:"height"`
Unit string `json:"unit"`
Words []AzureWord `json:"words"`
Lines []AzureLine `json:"lines"`
Spans []AzureSpan `json:"spans"`
}
// AzureWord represents a single word with its properties
type AzureWord struct {
Content string `json:"content"`
Polygon []int `json:"polygon"`
Confidence float64 `json:"confidence"`
Span AzureSpan `json:"span"`
}
// AzureLine represents a line of text
type AzureLine struct {
Content string `json:"content"`
Polygon []int `json:"polygon"`
Spans []AzureSpan `json:"spans"`
}
// AzureSpan represents a span of text with offset and length
type AzureSpan struct {
Offset int `json:"offset"`
Length int `json:"length"`
}
// AzureParagraph represents a paragraph of text
type AzureParagraph struct {
Content string `json:"content"`
Spans []AzureSpan `json:"spans"`
BoundingRegions []AzureBoundingBox `json:"boundingRegions"`
}
// AzureBoundingBox represents the location of content on a page
type AzureBoundingBox struct {
PageNumber int `json:"pageNumber"`
Polygon []int `json:"polygon"`
}
// AzureStyle represents style information for text segments - changed to interface{} as per input
type AzureStyle interface{}

View file

@ -28,7 +28,7 @@ type Provider interface {
// Config holds the OCR provider configuration // Config holds the OCR provider configuration
type Config struct { type Config struct {
// Provider type (e.g., "llm", "google_docai") // Provider type (e.g., "llm", "google_docai", "azure")
Provider string Provider string
// Google Document AI settings // Google Document AI settings
@ -40,6 +40,12 @@ type Config struct {
VisionLLMProvider string VisionLLMProvider string
VisionLLMModel string VisionLLMModel string
// Azure Document Intelligence settings
AzureEndpoint string
AzureAPIKey string
AzureModelID string // Optional, defaults to "prebuilt-read"
AzureTimeout int // Optional, defaults to 120 seconds
// OCR output options // OCR output options
EnableHOCR bool // Whether to request hOCR output if supported by the provider EnableHOCR bool // Whether to request hOCR output if supported by the provider
} }
@ -69,6 +75,12 @@ func NewProvider(config Config) (Provider, error) {
}).Info("Using LLM OCR provider") }).Info("Using LLM OCR provider")
return newLLMProvider(config) return newLLMProvider(config)
case "azure":
if config.AzureEndpoint == "" || config.AzureAPIKey == "" {
return nil, fmt.Errorf("missing required Azure Document Intelligence configuration")
}
return newAzureProvider(config)
default: default:
return nil, fmt.Errorf("unsupported OCR provider: %s", config.Provider) return nil, fmt.Errorf("unsupported OCR provider: %s", config.Provider)
} }