mirror of
https://github.com/icereed/paperless-gpt.git
synced 2025-03-12 12:58:02 -05:00
feat(ocr): add support for Azure Document Intelligence provider (#279)
This commit is contained in:
parent
360663b05b
commit
cbd9c5438c
8 changed files with 658 additions and 21 deletions
112
README.md
112
README.md
|
@ -22,7 +22,7 @@ https://github.com/user-attachments/assets/bd5d38b9-9309-40b9-93ca-918dfa4f3fd4
|
|||
|
||||
- **LLM OCR**: Use OpenAI or Ollama to extract text from images.
|
||||
- **Google Document AI**: Leverage Google's powerful Document AI for OCR tasks.
|
||||
- **More to come**: Stay tuned for more OCR providers!
|
||||
- **Azure Document Intelligence**: Use Microsoft's enterprise OCR solution.
|
||||
|
||||
3. **Automatic Title & Tag Generation**
|
||||
No more guesswork. Let the AI do the naming and categorizing. You can easily review suggestions and refine them if needed.
|
||||
|
@ -39,11 +39,11 @@ https://github.com/user-attachments/assets/bd5d38b9-9309-40b9-93ca-918dfa4f3fd4
|
|||
- **Tagging**: Decide how documents get tagged—manually, automatically, or via OCR-based flows.
|
||||
|
||||
7. **Simple Docker Deployment**
|
||||
A few environment variables, and you’re off! Compose it alongside paperless-ngx with minimal fuss.
|
||||
A few environment variables, and you're off! Compose it alongside paperless-ngx with minimal fuss.
|
||||
|
||||
8. **Unified Web UI**
|
||||
|
||||
- **Manual Review**: Approve or tweak AI’s suggestions.
|
||||
- **Manual Review**: Approve or tweak AI's suggestions.
|
||||
- **Auto Processing**: Focus only on edge cases while the rest is sorted for you.
|
||||
|
||||
---
|
||||
|
@ -56,6 +56,12 @@ https://github.com/user-attachments/assets/bd5d38b9-9309-40b9-93ca-918dfa4f3fd4
|
|||
- [Installation](#installation)
|
||||
- [Docker Compose](#docker-compose)
|
||||
- [Manual Setup](#manual-setup)
|
||||
- [OCR Providers](#ocr-providers)
|
||||
- [LLM-based OCR](#1-llm-based-ocr-default)
|
||||
- [Azure Document Intelligence](#2-azure-document-intelligence)
|
||||
- [Google Document AI](#3-google-document-ai)
|
||||
- [Comparing OCR Providers](#comparing-ocr-providers)
|
||||
- [Choosing the Right Provider](#choosing-the-right-provider)
|
||||
- [Configuration](#configuration)
|
||||
- [Environment Variables](#environment-variables)
|
||||
- [Custom Prompt Templates](#custom-prompt-templates)
|
||||
|
@ -86,7 +92,7 @@ https://github.com/user-attachments/assets/bd5d38b9-9309-40b9-93ca-918dfa4f3fd4
|
|||
|
||||
#### Docker Compose
|
||||
|
||||
Here’s an example `docker-compose.yml` to spin up **paperless-gpt** alongside paperless-ngx:
|
||||
Here's an example `docker-compose.yml` to spin up **paperless-gpt** alongside paperless-ngx:
|
||||
|
||||
```yaml
|
||||
services:
|
||||
|
@ -124,6 +130,13 @@ services:
|
|||
# GOOGLE_PROCESSOR_ID: 'processor-id' # Your processor ID
|
||||
# GOOGLE_APPLICATION_CREDENTIALS: '/app/credentials.json' # Path to service account key
|
||||
|
||||
# Option 3: Azure Document Intelligence
|
||||
# OCR_PROVIDER: 'azure' # Use Azure Document Intelligence
|
||||
# AZURE_DOCAI_ENDPOINT: 'your-endpoint' # Your Azure endpoint URL
|
||||
# AZURE_DOCAI_KEY: 'your-key' # Your Azure API key
|
||||
# AZURE_DOCAI_MODEL_ID: 'prebuilt-read' # Optional, defaults to prebuilt-read
|
||||
# AZURE_DOCAI_TIMEOUT_SECONDS: '120' # Optional, defaults to 120 seconds
|
||||
|
||||
AUTO_OCR_TAG: "paperless-gpt-ocr-auto" # Optional, default: paperless-gpt-ocr-auto
|
||||
OCR_LIMIT_PAGES: "5" # Optional, default: 5. Set to 0 for no limit.
|
||||
LOG_LEVEL: "info" # Optional: debug, warn, error
|
||||
|
@ -172,6 +185,63 @@ services:
|
|||
```
|
||||
|
||||
---
|
||||
## OCR Providers
|
||||
|
||||
paperless-gpt supports three different OCR providers, each with unique strengths and capabilities:
|
||||
|
||||
### 1. LLM-based OCR (Default)
|
||||
- **Key Features**:
|
||||
- Uses vision-capable LLMs like GPT-4V or MiniCPM-V
|
||||
- High accuracy with complex layouts and difficult scans
|
||||
- Context-aware text recognition
|
||||
- Self-correcting capabilities for OCR errors
|
||||
- **Best For**:
|
||||
- Complex or unusual document layouts
|
||||
- Poor quality scans
|
||||
- Documents with mixed languages
|
||||
- **Configuration**:
|
||||
```yaml
|
||||
OCR_PROVIDER: "llm"
|
||||
VISION_LLM_PROVIDER: "openai" # or "ollama"
|
||||
VISION_LLM_MODEL: "gpt-4v" # or "minicpm-v"
|
||||
```
|
||||
|
||||
### 2. Azure Document Intelligence
|
||||
- **Key Features**:
|
||||
- Enterprise-grade OCR solution
|
||||
- Prebuilt models for common document types
|
||||
- Layout preservation and table detection
|
||||
- Fast processing speeds
|
||||
- **Best For**:
|
||||
- Business documents and forms
|
||||
- High-volume processing
|
||||
- Documents requiring layout analysis
|
||||
- **Configuration**:
|
||||
```yaml
|
||||
OCR_PROVIDER: "azure"
|
||||
AZURE_DOCAI_ENDPOINT: "https://your-endpoint.cognitiveservices.azure.com/"
|
||||
AZURE_DOCAI_KEY: "your-key"
|
||||
AZURE_DOCAI_MODEL_ID: "prebuilt-read" # optional
|
||||
AZURE_DOCAI_TIMEOUT_SECONDS: "120" # optional
|
||||
```
|
||||
|
||||
### 3. Google Document AI
|
||||
- **Key Features**:
|
||||
- Specialized document processors
|
||||
- Strong form field detection
|
||||
- Multi-language support
|
||||
- High accuracy on structured documents
|
||||
- **Best For**:
|
||||
- Forms and structured documents
|
||||
- Documents with tables
|
||||
- Multi-language documents
|
||||
- **Configuration**:
|
||||
```yaml
|
||||
OCR_PROVIDER: "google_docai"
|
||||
GOOGLE_PROJECT_ID: "your-project"
|
||||
GOOGLE_LOCATION: "us"
|
||||
GOOGLE_PROCESSOR_ID: "processor-id"
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
|
@ -192,9 +262,13 @@ services:
|
|||
| `OPENAI_BASE_URL` | OpenAI base URL (optional, if using a custom OpenAI compatible service like LiteLLM). | No | |
|
||||
| `LLM_LANGUAGE` | Likely language for documents (e.g. `English`). | No | English |
|
||||
| `OLLAMA_HOST` | Ollama server URL (e.g. `http://host.docker.internal:11434`). | No | |
|
||||
| `OCR_PROVIDER` | OCR provider to use (`llm` or `google_docai`). | No | llm |
|
||||
| `OCR_PROVIDER` | OCR provider to use (`llm`, `azure`, or `google_docai`). | No | llm |
|
||||
| `VISION_LLM_PROVIDER` | AI backend for LLM OCR (`openai` or `ollama`). Required if OCR_PROVIDER is `llm`. | Cond. | |
|
||||
| `VISION_LLM_MODEL` | Model name for LLM OCR (e.g. `minicpm-v`). Required if OCR_PROVIDER is `llm`. | Cond. | |
|
||||
| `AZURE_DOCAI_ENDPOINT` | Azure Document Intelligence endpoint. Required if OCR_PROVIDER is `azure`. | Cond. | |
|
||||
| `AZURE_DOCAI_KEY` | Azure Document Intelligence API key. Required if OCR_PROVIDER is `azure`. | Cond. | |
|
||||
| `AZURE_DOCAI_MODEL_ID` | Azure Document Intelligence model ID. Optional if using `azure` provider. | No | prebuilt-read |
|
||||
| `AZURE_DOCAI_TIMEOUT_SECONDS` | Azure Document Intelligence timeout in seconds. | No | 120 |
|
||||
| `GOOGLE_PROJECT_ID` | Google Cloud project ID. Required if OCR_PROVIDER is `google_docai`. | Cond. | |
|
||||
| `GOOGLE_LOCATION` | Google Cloud region (e.g. `us`, `eu`). Required if OCR_PROVIDER is `google_docai`. | Cond. | |
|
||||
| `GOOGLE_PROCESSOR_ID` | Document AI processor ID. Required if OCR_PROVIDER is `google_docai`. | Cond. | |
|
||||
|
@ -211,7 +285,7 @@ services:
|
|||
|
||||
### Custom Prompt Templates
|
||||
|
||||
paperless-gpt’s flexible **prompt templates** let you shape how AI responds:
|
||||
paperless-gpt's flexible **prompt templates** let you shape how AI responds:
|
||||
|
||||
1. **`title_prompt.tmpl`**: For document titles.
|
||||
2. **`tag_prompt.tmpl`**: For tagging logic.
|
||||
|
@ -232,13 +306,11 @@ Then tweak at will—**paperless-gpt** reloads them automatically on startup!
|
|||
Each template has access to specific variables:
|
||||
|
||||
**title_prompt.tmpl**:
|
||||
|
||||
- `{{.Language}}` - Target language (e.g., "English")
|
||||
- `{{.Content}}` - Document content text
|
||||
- `{{.Title}}` - Original document title
|
||||
|
||||
**tag_prompt.tmpl**:
|
||||
|
||||
- `{{.Language}}` - Target language
|
||||
- `{{.AvailableTags}}` - List of existing tags in paperless-ngx
|
||||
- `{{.OriginalTags}}` - Document's current tags
|
||||
|
@ -246,11 +318,9 @@ Each template has access to specific variables:
|
|||
- `{{.Content}}` - Document content text
|
||||
|
||||
**ocr_prompt.tmpl**:
|
||||
|
||||
- `{{.Language}}` - Target language
|
||||
|
||||
**correspondent_prompt.tmpl**:
|
||||
|
||||
- `{{.Language}}` - Target language
|
||||
- `{{.AvailableCorrespondents}}` - List of existing correspondents
|
||||
- `{{.BlackList}}` - List of blacklisted correspondent names
|
||||
|
@ -265,23 +335,25 @@ The templates use Go's text/template syntax. paperless-gpt automatically reloads
|
|||
|
||||
1. **Tag Documents**
|
||||
|
||||
- Add `paperless-gpt` or your custom tag to the docs you want to AI-ify.
|
||||
- Add `paperless-gpt` tag to documents for manual processing
|
||||
- Add `paperless-gpt-auto` for automatic processing
|
||||
- Add `paperless-gpt-ocr-auto` for automatic OCR processing
|
||||
|
||||
2. **Visit Web UI**
|
||||
|
||||
- Go to `http://localhost:8080` (or your host) in your browser.
|
||||
- Go to `http://localhost:8080` (or your host) in your browser
|
||||
- Review documents tagged for processing
|
||||
|
||||
3. **Generate & Apply Suggestions**
|
||||
|
||||
- Click “Generate Suggestions” to see AI-proposed titles/tags/correspondents.
|
||||
- Approve, edit, or discard. Hit “Apply” to finalize in paperless-ngx.
|
||||
|
||||
4. **Try LLM-Based OCR (Experimental)**
|
||||
- If you enabled `VISION_LLM_PROVIDER` and `VISION_LLM_MODEL`, let AI-based OCR read your scanned PDFs.
|
||||
- Tag those documents with `paperless-gpt-ocr-auto` (or your custom `AUTO_OCR_TAG`).
|
||||
|
||||
**Tip**: The entire pipeline can be **fully automated** if you prefer minimal manual intervention.
|
||||
- Click "Generate Suggestions" to see AI-proposed titles/tags/correspondents
|
||||
- Review and approve or edit suggestions
|
||||
- Click "Apply" to save changes to paperless-ngx
|
||||
|
||||
4. **OCR Processing**
|
||||
- Tag documents with appropriate OCR tag to process them
|
||||
- Monitor progress in the Web UI
|
||||
- Review results and apply changes
|
||||
---
|
||||
|
||||
## LLM-Based OCR: Compare for Yourself
|
||||
|
|
2
go.mod
2
go.mod
|
@ -12,6 +12,7 @@ require (
|
|||
github.com/gen2brain/go-fitz v1.24.14
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/hashicorp/go-retryablehttp v0.7.7
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/tmc/langchaingo v0.1.13
|
||||
|
@ -48,6 +49,7 @@ require (
|
|||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
|
||||
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
|
||||
github.com/huandu/xstrings v1.5.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
|
|
6
go.sum
6
go.sum
|
@ -75,6 +75,12 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gT
|
|||
github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
|
||||
github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q=
|
||||
github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA=
|
||||
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
|
||||
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
|
||||
github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k=
|
||||
github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M=
|
||||
github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU=
|
||||
github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk=
|
||||
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
|
||||
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
|
|
27
main.go
27
main.go
|
@ -36,6 +36,10 @@ var (
|
|||
correspondentBlackList = strings.Split(os.Getenv("CORRESPONDENT_BLACK_LIST"), ",")
|
||||
paperlessBaseURL = os.Getenv("PAPERLESS_BASE_URL")
|
||||
paperlessAPIToken = os.Getenv("PAPERLESS_API_TOKEN")
|
||||
azureDocAIEndpoint = os.Getenv("AZURE_DOCAI_ENDPOINT")
|
||||
azureDocAIKey = os.Getenv("AZURE_DOCAI_KEY")
|
||||
azureDocAIModelID = os.Getenv("AZURE_DOCAI_MODEL_ID")
|
||||
azureDocAITimeout = os.Getenv("AZURE_DOCAI_TIMEOUT_SECONDS")
|
||||
openaiAPIKey = os.Getenv("OPENAI_API_KEY")
|
||||
manualTag = os.Getenv("MANUAL_TAG")
|
||||
autoTag = os.Getenv("AUTO_TAG")
|
||||
|
@ -167,6 +171,18 @@ func main() {
|
|||
GoogleProcessorID: os.Getenv("GOOGLE_PROCESSOR_ID"),
|
||||
VisionLLMProvider: visionLlmProvider,
|
||||
VisionLLMModel: visionLlmModel,
|
||||
AzureEndpoint: azureDocAIEndpoint,
|
||||
AzureAPIKey: azureDocAIKey,
|
||||
AzureModelID: azureDocAIModelID,
|
||||
}
|
||||
|
||||
// Parse Azure timeout if set
|
||||
if azureDocAITimeout != "" {
|
||||
if timeout, err := strconv.Atoi(azureDocAITimeout); err == nil {
|
||||
ocrConfig.AzureTimeout = timeout
|
||||
} else {
|
||||
log.Warnf("Invalid AZURE_DOCAI_TIMEOUT_SECONDS value: %v, using default", err)
|
||||
}
|
||||
}
|
||||
|
||||
// If provider is LLM, but no VISION_LLM_PROVIDER is set, don't initialize OCR provider
|
||||
|
@ -422,6 +438,17 @@ func validateOrDefaultEnvVars() {
|
|||
log.Fatal("Please set the LLM_PROVIDER environment variable to 'openai' or 'ollama'.")
|
||||
}
|
||||
|
||||
// Validate OCR provider if set
|
||||
ocrProvider := os.Getenv("OCR_PROVIDER")
|
||||
if ocrProvider == "azure" {
|
||||
if azureDocAIEndpoint == "" {
|
||||
log.Fatal("Please set the AZURE_DOCAI_ENDPOINT environment variable for Azure provider")
|
||||
}
|
||||
if azureDocAIKey == "" {
|
||||
log.Fatal("Please set the AZURE_DOCAI_KEY environment variable for Azure provider")
|
||||
}
|
||||
}
|
||||
|
||||
if llmModel == "" {
|
||||
log.Fatal("Please set the LLM_MODEL environment variable.")
|
||||
}
|
||||
|
|
224
ocr/azure_provider.go
Normal file
224
ocr/azure_provider.go
Normal file
|
@ -0,0 +1,224 @@
|
|||
package ocr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"github.com/hashicorp/go-retryablehttp"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
apiVersion = "2024-11-30"
|
||||
defaultModelID = "prebuilt-read"
|
||||
defaultTimeout = 120
|
||||
pollingInterval = 2 * time.Second
|
||||
)
|
||||
|
||||
// AzureProvider implements OCR using Azure Document Intelligence
|
||||
type AzureProvider struct {
|
||||
endpoint string
|
||||
apiKey string
|
||||
modelID string
|
||||
timeout time.Duration
|
||||
httpClient *retryablehttp.Client
|
||||
}
|
||||
|
||||
// Request body for Azure Document Intelligence
|
||||
type analyzeRequest struct {
|
||||
Base64Source string `json:"base64Source"`
|
||||
}
|
||||
|
||||
func newAzureProvider(config Config) (*AzureProvider, error) {
|
||||
logger := log.WithFields(logrus.Fields{
|
||||
"endpoint": config.AzureEndpoint,
|
||||
"model_id": config.AzureModelID,
|
||||
})
|
||||
logger.Info("Creating new Azure Document Intelligence provider")
|
||||
|
||||
// Validate required configuration
|
||||
if config.AzureEndpoint == "" || config.AzureAPIKey == "" {
|
||||
logger.Error("Missing required configuration")
|
||||
return nil, fmt.Errorf("missing required Azure Document Intelligence configuration")
|
||||
}
|
||||
|
||||
// Set defaults and create provider
|
||||
modelID := defaultModelID
|
||||
if config.AzureModelID != "" {
|
||||
modelID = config.AzureModelID
|
||||
}
|
||||
|
||||
timeout := defaultTimeout
|
||||
if config.AzureTimeout > 0 {
|
||||
timeout = config.AzureTimeout
|
||||
}
|
||||
|
||||
// Configure retryablehttp client
|
||||
client := retryablehttp.NewClient()
|
||||
client.RetryMax = 3
|
||||
client.RetryWaitMin = 1 * time.Second
|
||||
client.RetryWaitMax = 5 * time.Second
|
||||
client.Logger = logger
|
||||
|
||||
provider := &AzureProvider{
|
||||
endpoint: config.AzureEndpoint,
|
||||
apiKey: config.AzureAPIKey,
|
||||
modelID: modelID,
|
||||
timeout: time.Duration(timeout) * time.Second,
|
||||
httpClient: client,
|
||||
}
|
||||
|
||||
logger.Info("Successfully initialized Azure Document Intelligence provider")
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (p *AzureProvider) ProcessImage(ctx context.Context, imageContent []byte) (*OCRResult, error) {
|
||||
logger := log.WithFields(logrus.Fields{
|
||||
"model_id": p.modelID,
|
||||
})
|
||||
logger.Debug("Starting Azure Document Intelligence processing")
|
||||
|
||||
// Detect MIME type
|
||||
mtype := mimetype.Detect(imageContent)
|
||||
logger.WithField("mime_type", mtype.String()).Debug("Detected file type")
|
||||
|
||||
if !isImageMIMEType(mtype.String()) {
|
||||
logger.WithField("mime_type", mtype.String()).Error("Unsupported file type")
|
||||
return nil, fmt.Errorf("unsupported file type: %s", mtype.String())
|
||||
}
|
||||
|
||||
// Create context with timeout
|
||||
ctx, cancel := context.WithTimeout(ctx, p.timeout)
|
||||
defer cancel()
|
||||
|
||||
// Submit document for analysis
|
||||
operationLocation, err := p.submitDocument(ctx, imageContent)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error submitting document: %w", err)
|
||||
}
|
||||
|
||||
// Poll for results
|
||||
result, err := p.pollForResults(ctx, operationLocation)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error polling for results: %w", err)
|
||||
}
|
||||
|
||||
// Convert to OCR result
|
||||
ocrResult := &OCRResult{
|
||||
Text: result.AnalyzeResult.Content,
|
||||
Metadata: map[string]string{
|
||||
"provider": "azure_docai",
|
||||
"page_count": fmt.Sprintf("%d", len(result.AnalyzeResult.Pages)),
|
||||
"api_version": result.AnalyzeResult.APIVersion,
|
||||
},
|
||||
}
|
||||
|
||||
logger.WithFields(logrus.Fields{
|
||||
"content_length": len(ocrResult.Text),
|
||||
"page_count": len(result.AnalyzeResult.Pages),
|
||||
}).Info("Successfully processed document")
|
||||
return ocrResult, nil
|
||||
}
|
||||
|
||||
func (p *AzureProvider) submitDocument(ctx context.Context, imageContent []byte) (string, error) {
|
||||
requestURL := fmt.Sprintf("%s/documentintelligence/documentModels/%s:analyze?api-version=%s",
|
||||
p.endpoint, p.modelID, apiVersion)
|
||||
|
||||
// Prepare request body
|
||||
requestBody := analyzeRequest{
|
||||
Base64Source: base64.StdEncoding.EncodeToString(imageContent),
|
||||
}
|
||||
requestBodyBytes, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error marshaling request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := retryablehttp.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewBuffer(requestBodyBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating HTTP request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Ocp-Apim-Subscription-Key", p.apiKey)
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error sending HTTP request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
operationLocation := resp.Header.Get("Operation-Location")
|
||||
if operationLocation == "" {
|
||||
return "", fmt.Errorf("no Operation-Location header in response")
|
||||
}
|
||||
|
||||
return operationLocation, nil
|
||||
}
|
||||
|
||||
func (p *AzureProvider) pollForResults(ctx context.Context, operationLocation string) (*AzureDocumentResult, error) {
|
||||
logger := log.WithField("operation_location", operationLocation)
|
||||
logger.Debug("Starting to poll for results")
|
||||
|
||||
ticker := time.NewTicker(pollingInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("operation timed out after %v: %w", p.timeout, ctx.Err())
|
||||
case <-ticker.C:
|
||||
req, err := retryablehttp.NewRequestWithContext(ctx, "GET", operationLocation, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating poll request: %w", err)
|
||||
}
|
||||
req.Header.Set("Ocp-Apim-Subscription-Key", p.apiKey)
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error polling for results: %w", err)
|
||||
}
|
||||
|
||||
var result AzureDocumentResult
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
resp.Body.Close()
|
||||
logger.WithError(err).Error("Failed to decode response")
|
||||
return nil, fmt.Errorf("error decoding response: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
logger.WithFields(logrus.Fields{
|
||||
"status_code": resp.StatusCode,
|
||||
"content_length": len(result.AnalyzeResult.Content),
|
||||
"page_count": len(result.AnalyzeResult.Pages),
|
||||
"status": result.Status,
|
||||
}).Debug("Poll response received")
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code %d while polling", resp.StatusCode)
|
||||
}
|
||||
|
||||
switch result.Status {
|
||||
case "succeeded":
|
||||
return &result, nil
|
||||
case "failed":
|
||||
return nil, fmt.Errorf("document processing failed")
|
||||
case "running":
|
||||
// Continue polling
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected status: %s", result.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
222
ocr/azure_provider_test.go
Normal file
222
ocr/azure_provider_test.go
Normal file
|
@ -0,0 +1,222 @@
|
|||
package ocr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-retryablehttp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewAzureProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config Config
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: Config{
|
||||
AzureEndpoint: "https://test.cognitiveservices.azure.com/",
|
||||
AzureAPIKey: "test-key",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid config with custom model and timeout",
|
||||
config: Config{
|
||||
AzureEndpoint: "https://test.cognitiveservices.azure.com/",
|
||||
AzureAPIKey: "test-key",
|
||||
AzureModelID: "custom-model",
|
||||
AzureTimeout: 60,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing endpoint",
|
||||
config: Config{
|
||||
AzureAPIKey: "test-key",
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "missing required Azure Document Intelligence configuration",
|
||||
},
|
||||
{
|
||||
name: "missing api key",
|
||||
config: Config{
|
||||
AzureEndpoint: "https://test.cognitiveservices.azure.com/",
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "missing required Azure Document Intelligence configuration",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider, err := newAzureProvider(tt.config)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
|
||||
// Verify default values
|
||||
if tt.config.AzureModelID == "" {
|
||||
assert.Equal(t, defaultModelID, provider.modelID)
|
||||
} else {
|
||||
assert.Equal(t, tt.config.AzureModelID, provider.modelID)
|
||||
}
|
||||
|
||||
if tt.config.AzureTimeout == 0 {
|
||||
assert.Equal(t, time.Duration(defaultTimeout)*time.Second, provider.timeout)
|
||||
} else {
|
||||
assert.Equal(t, time.Duration(tt.config.AzureTimeout)*time.Second, provider.timeout)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureProvider_ProcessImage(t *testing.T) {
|
||||
// Sample success response
|
||||
now := time.Now()
|
||||
successResult := AzureDocumentResult{
|
||||
Status: "succeeded",
|
||||
CreatedDateTime: now,
|
||||
LastUpdatedDateTime: now,
|
||||
AnalyzeResult: AzureAnalyzeResult{
|
||||
APIVersion: apiVersion,
|
||||
ModelID: defaultModelID,
|
||||
StringIndexType: "utf-16",
|
||||
Content: "Test document content",
|
||||
Pages: []AzurePage{
|
||||
{
|
||||
PageNumber: 1,
|
||||
Angle: 0.0,
|
||||
Width: 800,
|
||||
Height: 600,
|
||||
Unit: "pixel",
|
||||
Lines: []AzureLine{
|
||||
{
|
||||
Content: "Test line",
|
||||
Polygon: []int{0, 0, 100, 0, 100, 20, 0, 20},
|
||||
Spans: []AzureSpan{{Offset: 0, Length: 9}},
|
||||
},
|
||||
},
|
||||
Spans: []AzureSpan{{Offset: 0, Length: 9}},
|
||||
},
|
||||
},
|
||||
Paragraphs: []AzureParagraph{
|
||||
{
|
||||
Content: "Test document content",
|
||||
Spans: []AzureSpan{{Offset: 0, Length: 19}},
|
||||
BoundingRegions: []AzureBoundingBox{
|
||||
{
|
||||
PageNumber: 1,
|
||||
Polygon: []int{0, 0, 100, 0, 100, 20, 0, 20},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
ContentFormat: "text",
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupServer func() *httptest.Server
|
||||
imageContent []byte
|
||||
wantErr bool
|
||||
errContains string
|
||||
expectedText string
|
||||
}{
|
||||
{
|
||||
name: "successful processing",
|
||||
setupServer: func() *httptest.Server {
|
||||
mux := http.NewServeMux()
|
||||
server := httptest.NewServer(mux)
|
||||
|
||||
mux.HandleFunc("/documentintelligence/documentModels/prebuilt-read:analyze", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Operation-Location", fmt.Sprintf("%s/operations/123", server.URL))
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
|
||||
mux.HandleFunc("/operations/123", func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(successResult)
|
||||
})
|
||||
|
||||
return server
|
||||
},
|
||||
// Create minimal JPEG content with magic numbers
|
||||
imageContent: append([]byte{0xFF, 0xD8, 0xFF, 0xE0}, []byte("JFIF test content")...),
|
||||
expectedText: "Test document content",
|
||||
},
|
||||
{
|
||||
name: "invalid mime type",
|
||||
setupServer: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Log("Server should not be called with invalid mime type")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
},
|
||||
imageContent: []byte("invalid content"),
|
||||
wantErr: true,
|
||||
errContains: "unsupported file type",
|
||||
},
|
||||
{
|
||||
name: "submission error",
|
||||
setupServer: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintln(w, "Invalid request")
|
||||
}))
|
||||
},
|
||||
imageContent: []byte{0xFF, 0xD8, 0xFF, 0xE0}, // JPEG magic numbers
|
||||
wantErr: true,
|
||||
errContains: "unexpected status code 400",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := tt.setupServer()
|
||||
defer server.Close()
|
||||
|
||||
client := retryablehttp.NewClient()
|
||||
client.HTTPClient = server.Client()
|
||||
client.Logger = log
|
||||
|
||||
provider := &AzureProvider{
|
||||
endpoint: server.URL,
|
||||
apiKey: "test-key",
|
||||
modelID: defaultModelID,
|
||||
timeout: 5 * time.Second,
|
||||
httpClient: client,
|
||||
}
|
||||
|
||||
result, err := provider.ProcessImage(context.Background(), tt.imageContent)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, tt.expectedText, result.Text)
|
||||
assert.Equal(t, "azure_docai", result.Metadata["provider"])
|
||||
assert.Equal(t, apiVersion, result.Metadata["api_version"])
|
||||
assert.Equal(t, "1", result.Metadata["page_count"])
|
||||
})
|
||||
}
|
||||
}
|
72
ocr/azure_types.go
Normal file
72
ocr/azure_types.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package ocr
|
||||
|
||||
import "time"
|
||||
|
||||
// AzureDocumentResult represents the root response from Azure Document Intelligence
|
||||
type AzureDocumentResult struct {
|
||||
Status string `json:"status"`
|
||||
CreatedDateTime time.Time `json:"createdDateTime"`
|
||||
LastUpdatedDateTime time.Time `json:"lastUpdatedDateTime"`
|
||||
AnalyzeResult AzureAnalyzeResult `json:"analyzeResult"`
|
||||
}
|
||||
|
||||
// AzureAnalyzeResult represents the analyze result part of the Azure Document Intelligence response
|
||||
type AzureAnalyzeResult struct {
|
||||
APIVersion string `json:"apiVersion"`
|
||||
ModelID string `json:"modelId"`
|
||||
StringIndexType string `json:"stringIndexType"`
|
||||
Content string `json:"content"`
|
||||
Pages []AzurePage `json:"pages"`
|
||||
Paragraphs []AzureParagraph `json:"paragraphs"`
|
||||
Styles []interface{} `json:"styles"`
|
||||
ContentFormat string `json:"contentFormat"`
|
||||
}
|
||||
|
||||
// AzurePage represents a single page in the document
|
||||
type AzurePage struct {
|
||||
PageNumber int `json:"pageNumber"`
|
||||
Angle float64 `json:"angle"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
Unit string `json:"unit"`
|
||||
Words []AzureWord `json:"words"`
|
||||
Lines []AzureLine `json:"lines"`
|
||||
Spans []AzureSpan `json:"spans"`
|
||||
}
|
||||
|
||||
// AzureWord represents a single word with its properties
|
||||
type AzureWord struct {
|
||||
Content string `json:"content"`
|
||||
Polygon []int `json:"polygon"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
Span AzureSpan `json:"span"`
|
||||
}
|
||||
|
||||
// AzureLine represents a line of text
|
||||
type AzureLine struct {
|
||||
Content string `json:"content"`
|
||||
Polygon []int `json:"polygon"`
|
||||
Spans []AzureSpan `json:"spans"`
|
||||
}
|
||||
|
||||
// AzureSpan represents a span of text with offset and length
|
||||
type AzureSpan struct {
|
||||
Offset int `json:"offset"`
|
||||
Length int `json:"length"`
|
||||
}
|
||||
|
||||
// AzureParagraph represents a paragraph of text
|
||||
type AzureParagraph struct {
|
||||
Content string `json:"content"`
|
||||
Spans []AzureSpan `json:"spans"`
|
||||
BoundingRegions []AzureBoundingBox `json:"boundingRegions"`
|
||||
}
|
||||
|
||||
// AzureBoundingBox represents the location of content on a page
|
||||
type AzureBoundingBox struct {
|
||||
PageNumber int `json:"pageNumber"`
|
||||
Polygon []int `json:"polygon"`
|
||||
}
|
||||
|
||||
// AzureStyle represents style information for text segments - changed to interface{} as per input
|
||||
type AzureStyle interface{}
|
|
@ -28,7 +28,7 @@ type Provider interface {
|
|||
|
||||
// Config holds the OCR provider configuration
|
||||
type Config struct {
|
||||
// Provider type (e.g., "llm", "google_docai")
|
||||
// Provider type (e.g., "llm", "google_docai", "azure")
|
||||
Provider string
|
||||
|
||||
// Google Document AI settings
|
||||
|
@ -40,6 +40,12 @@ type Config struct {
|
|||
VisionLLMProvider string
|
||||
VisionLLMModel string
|
||||
|
||||
// Azure Document Intelligence settings
|
||||
AzureEndpoint string
|
||||
AzureAPIKey string
|
||||
AzureModelID string // Optional, defaults to "prebuilt-read"
|
||||
AzureTimeout int // Optional, defaults to 120 seconds
|
||||
|
||||
// OCR output options
|
||||
EnableHOCR bool // Whether to request hOCR output if supported by the provider
|
||||
}
|
||||
|
@ -69,6 +75,12 @@ func NewProvider(config Config) (Provider, error) {
|
|||
}).Info("Using LLM OCR provider")
|
||||
return newLLMProvider(config)
|
||||
|
||||
case "azure":
|
||||
if config.AzureEndpoint == "" || config.AzureAPIKey == "" {
|
||||
return nil, fmt.Errorf("missing required Azure Document Intelligence configuration")
|
||||
}
|
||||
return newAzureProvider(config)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported OCR provider: %s", config.Provider)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue