mirror of
https://github.com/icereed/paperless-gpt.git
synced 2025-03-12 12:58:02 -05:00
OCR via LLM (#29)
This commit is contained in:
parent
10df151525
commit
03364f2741
21 changed files with 1662 additions and 359 deletions
6
.github/workflows/docker-build-and-push.yml
vendored
6
.github/workflows/docker-build-and-push.yml
vendored
|
@ -22,6 +22,12 @@ jobs:
|
|||
with:
|
||||
go-version: 1.22
|
||||
|
||||
- name: Install mupdf
|
||||
run: sudo apt-get install -y mupdf
|
||||
|
||||
- name: Set library path
|
||||
run: echo "/usr/lib" | sudo tee -a /etc/ld.so.conf.d/mupdf.conf && sudo ldconfig
|
||||
|
||||
- name: Install dependencies
|
||||
run: go mod download
|
||||
|
||||
|
|
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,3 +1,5 @@
|
|||
.env
|
||||
.DS_Store
|
||||
prompts/
|
||||
tests/tmp
|
||||
tmp/
|
29
Dockerfile
29
Dockerfile
|
@ -1,9 +1,17 @@
|
|||
# Stage 1: Build the Go binary
|
||||
FROM golang:1.22 AS builder
|
||||
FROM golang:1.22-alpine AS builder
|
||||
|
||||
# Set the working directory inside the container
|
||||
WORKDIR /app
|
||||
|
||||
# Install necessary packages
|
||||
RUN apk add --no-cache \
|
||||
git \
|
||||
gcc \
|
||||
musl-dev \
|
||||
mupdf \
|
||||
mupdf-dev
|
||||
|
||||
# Copy go.mod and go.sum files
|
||||
COPY go.mod go.sum ./
|
||||
|
||||
|
@ -13,17 +21,19 @@ RUN go mod download
|
|||
# Copy the rest of the application code
|
||||
COPY . .
|
||||
|
||||
# Build the Go binary
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o paperless-gpt .
|
||||
# Build the Go binary with the musl build tag
|
||||
RUN go build -tags musl -o paperless-gpt .
|
||||
|
||||
# Stage 2: Build Vite frontend
|
||||
FROM node:20 AS frontend
|
||||
FROM node:20-alpine AS frontend
|
||||
|
||||
# Set the working directory inside the container
|
||||
WORKDIR /app
|
||||
|
||||
# Copy package.json and package-lock.json
|
||||
# Install necessary packages
|
||||
RUN apk add --no-cache git
|
||||
|
||||
# Copy package.json and package-lock.json
|
||||
COPY web-app/package.json web-app/package-lock.json ./
|
||||
|
||||
# Install dependencies
|
||||
|
@ -35,11 +45,12 @@ COPY web-app /app/
|
|||
# Build the frontend
|
||||
RUN npm run build
|
||||
|
||||
# Stage 3: Create a lightweight image with the Go binary
|
||||
# Stage 3: Create a lightweight image with the Go binary and frontend
|
||||
FROM alpine:latest
|
||||
|
||||
# Install necessary CA certificates
|
||||
RUN apk --no-cache add ca-certificates
|
||||
# Install necessary runtime dependencies
|
||||
RUN apk add --no-cache \
|
||||
ca-certificates
|
||||
|
||||
# Set the working directory inside the container
|
||||
WORKDIR /app/
|
||||
|
@ -54,4 +65,4 @@ COPY --from=frontend /app/dist /app/web-app/dist
|
|||
EXPOSE 8080
|
||||
|
||||
# Command to run the binary
|
||||
CMD ["./paperless-gpt"]
|
||||
CMD ["/app/paperless-gpt"]
|
||||
|
|
18
README.md
18
README.md
|
@ -18,7 +18,7 @@
|
|||
- **User-Friendly Interface**: Intuitive web interface for reviewing and applying suggested titles and tags.
|
||||
- **Dockerized Deployment**: Simple setup using Docker and Docker Compose.
|
||||
- **Automatic Document Processing**: Automatically apply generated suggestions for documents with the `paperless-gpt-auto` tag.
|
||||
|
||||
- **Experimental OCR Feature**: Send documents to a vision LLM for OCR processing.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
|
@ -40,6 +40,7 @@
|
|||
- [Usage](#usage)
|
||||
- [Contributing](#contributing)
|
||||
- [License](#license)
|
||||
- [Star History](#star-history)
|
||||
|
||||
## Getting Started
|
||||
|
||||
|
@ -74,6 +75,8 @@ services:
|
|||
OPENAI_API_KEY: 'your_openai_api_key' # Required if using OpenAI
|
||||
LLM_LANGUAGE: 'English' # Optional, default is 'English'
|
||||
OLLAMA_HOST: 'http://host.docker.internal:11434' # If using Ollama
|
||||
VISION_LLM_PROVIDER: 'ollama' # Optional, for OCR
|
||||
VISION_LLM_MODEL: 'minicpm-v' # Optional, for OCR
|
||||
volumes:
|
||||
- ./prompts:/app/prompts # Mount the prompts directory
|
||||
ports:
|
||||
|
@ -117,6 +120,8 @@ If you prefer to run the application manually:
|
|||
-e LLM_MODEL='gpt-4o' \
|
||||
-e OPENAI_API_KEY='your_openai_api_key' \
|
||||
-e LLM_LANGUAGE='English' \
|
||||
-e VISION_LLM_PROVIDER='ollama' \
|
||||
-e VISION_LLM_MODEL='minicpm-v' \
|
||||
-v $(pwd)/prompts:/app/prompts \ # Mount the prompts directory
|
||||
-p 8080:8080 \
|
||||
paperless-gpt
|
||||
|
@ -135,6 +140,8 @@ If you prefer to run the application manually:
|
|||
| `OPENAI_API_KEY` | Your OpenAI API key. Required if using OpenAI as the LLM provider. | Cond. |
|
||||
| `LLM_LANGUAGE` | The likely language of your documents (e.g., `English`, `German`). Default is `English`. | No |
|
||||
| `OLLAMA_HOST` | The URL of the Ollama server (e.g., `http://host.docker.internal:11434`). Useful if using Ollama. Default is `http://127.0.0.1:11434`. | No |
|
||||
| `VISION_LLM_PROVIDER` | The vision LLM provider to use for OCR (`openai` or `ollama`). | No |
|
||||
| `VISION_LLM_MODEL` | The model name to use for OCR (e.g., `minicpm-v`). | No |
|
||||
|
||||
**Note:** When using Ollama, ensure that the Ollama server is running and accessible from the paperless-gpt container.
|
||||
|
||||
|
@ -257,6 +264,15 @@ Be very selective and only choose the most relevant tags since too many tags wil
|
|||
- Review the suggested titles. You can edit them if necessary.
|
||||
- Click on **"Apply Suggestions"** to update the document titles in paperless-ngx.
|
||||
|
||||
5. **Experimental OCR Feature:**
|
||||
|
||||
- Send documents to a vision LLM for OCR processing.
|
||||
- Example configuration to enable OCR with Ollama:
|
||||
```env
|
||||
VISION_LLM_PROVIDER=ollama
|
||||
VISION_LLM_MODEL=minicpm-v
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please read the [contributing guidelines](CONTRIBUTING.md) before submitting a pull request.
|
||||
|
|
245
app_http_handlers.go
Normal file
245
app_http_handlers.go
Normal file
|
@ -0,0 +1,245 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// getPromptsHandler handles the GET /api/prompts endpoint
|
||||
func getPromptsHandler(c *gin.Context) {
|
||||
templateMutex.RLock()
|
||||
defer templateMutex.RUnlock()
|
||||
|
||||
// Read the templates from files or use default content
|
||||
titleTemplateContent, err := os.ReadFile("prompts/title_prompt.tmpl")
|
||||
if err != nil {
|
||||
titleTemplateContent = []byte(defaultTitleTemplate)
|
||||
}
|
||||
|
||||
tagTemplateContent, err := os.ReadFile("prompts/tag_prompt.tmpl")
|
||||
if err != nil {
|
||||
tagTemplateContent = []byte(defaultTagTemplate)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"title_template": string(titleTemplateContent),
|
||||
"tag_template": string(tagTemplateContent),
|
||||
})
|
||||
}
|
||||
|
||||
// updatePromptsHandler handles the POST /api/prompts endpoint
|
||||
func updatePromptsHandler(c *gin.Context) {
|
||||
var req struct {
|
||||
TitleTemplate string `json:"title_template"`
|
||||
TagTemplate string `json:"tag_template"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request payload"})
|
||||
return
|
||||
}
|
||||
|
||||
templateMutex.Lock()
|
||||
defer templateMutex.Unlock()
|
||||
|
||||
// Update title template
|
||||
if req.TitleTemplate != "" {
|
||||
t, err := template.New("title").Parse(req.TitleTemplate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid title template: %v", err)})
|
||||
return
|
||||
}
|
||||
titleTemplate = t
|
||||
err = os.WriteFile("prompts/title_prompt.tmpl", []byte(req.TitleTemplate), 0644)
|
||||
if err != nil {
|
||||
log.Printf("Failed to write title_prompt.tmpl: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update tag template
|
||||
if req.TagTemplate != "" {
|
||||
t, err := template.New("tag").Parse(req.TagTemplate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid tag template: %v", err)})
|
||||
return
|
||||
}
|
||||
tagTemplate = t
|
||||
err = os.WriteFile("prompts/tag_prompt.tmpl", []byte(req.TagTemplate), 0644)
|
||||
if err != nil {
|
||||
log.Printf("Failed to write tag_prompt.tmpl: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
// getAllTagsHandler handles the GET /api/tags endpoint
|
||||
func (app *App) getAllTagsHandler(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
tags, err := app.Client.GetAllTags(ctx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error fetching tags: %v", err)})
|
||||
log.Printf("Error fetching tags: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, tags)
|
||||
}
|
||||
|
||||
// documentsHandler handles the GET /api/documents endpoint
|
||||
func (app *App) documentsHandler(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
documents, err := app.Client.GetDocumentsByTags(ctx, []string{manualTag})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error fetching documents: %v", err)})
|
||||
log.Printf("Error fetching documents: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, documents)
|
||||
}
|
||||
|
||||
// generateSuggestionsHandler handles the POST /api/generate-suggestions endpoint
|
||||
func (app *App) generateSuggestionsHandler(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var suggestionRequest GenerateSuggestionsRequest
|
||||
if err := c.ShouldBindJSON(&suggestionRequest); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request payload: %v", err)})
|
||||
log.Printf("Invalid request payload: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
results, err := app.generateDocumentSuggestions(ctx, suggestionRequest)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error processing documents: %v", err)})
|
||||
log.Printf("Error processing documents: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, results)
|
||||
}
|
||||
|
||||
// updateDocumentsHandler handles the PATCH /api/update-documents endpoint
|
||||
func (app *App) updateDocumentsHandler(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
var documents []DocumentSuggestion
|
||||
if err := c.ShouldBindJSON(&documents); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request payload: %v", err)})
|
||||
log.Printf("Invalid request payload: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err := app.Client.UpdateDocuments(ctx, documents)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error updating documents: %v", err)})
|
||||
log.Printf("Error updating documents: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
func (app *App) submitOCRJobHandler(c *gin.Context) {
|
||||
documentIDStr := c.Param("id")
|
||||
documentID, err := strconv.Atoi(documentIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// Create a new job
|
||||
jobID := generateJobID() // Implement a function to generate unique job IDs
|
||||
job := &Job{
|
||||
ID: jobID,
|
||||
DocumentID: documentID,
|
||||
Status: "pending",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Add job to store and queue
|
||||
jobStore.addJob(job)
|
||||
jobQueue <- job
|
||||
|
||||
// Return the job ID to the client
|
||||
c.JSON(http.StatusAccepted, gin.H{"job_id": jobID})
|
||||
}
|
||||
|
||||
func (app *App) getJobStatusHandler(c *gin.Context) {
|
||||
jobID := c.Param("job_id")
|
||||
|
||||
job, exists := jobStore.getJob(jobID)
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Job not found"})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"job_id": job.ID,
|
||||
"status": job.Status,
|
||||
"created_at": job.CreatedAt,
|
||||
"updated_at": job.UpdatedAt,
|
||||
"pages_done": job.PagesDone,
|
||||
}
|
||||
|
||||
if job.Status == "completed" {
|
||||
response["result"] = job.Result
|
||||
} else if job.Status == "failed" {
|
||||
response["error"] = job.Result
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (app *App) getAllJobsHandler(c *gin.Context) {
|
||||
jobs := jobStore.GetAllJobs()
|
||||
|
||||
jobList := make([]gin.H, 0, len(jobs))
|
||||
for _, job := range jobs {
|
||||
response := gin.H{
|
||||
"job_id": job.ID,
|
||||
"status": job.Status,
|
||||
"created_at": job.CreatedAt,
|
||||
"updated_at": job.UpdatedAt,
|
||||
"pages_done": job.PagesDone,
|
||||
}
|
||||
|
||||
if job.Status == "completed" {
|
||||
response["result"] = job.Result
|
||||
} else if job.Status == "failed" {
|
||||
response["error"] = job.Result
|
||||
}
|
||||
|
||||
jobList = append(jobList, response)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, jobList)
|
||||
}
|
||||
|
||||
// getDocumentHandler handles the retrieval of a document by its ID
|
||||
func (app *App) getDocumentHandler() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
parsedID, err := strconv.Atoi(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
|
||||
return
|
||||
}
|
||||
document, err := app.Client.GetDocument(c, parsedID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, document)
|
||||
}
|
||||
}
|
233
app_llm.go
Normal file
233
app_llm.go
Normal file
|
@ -0,0 +1,233 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
)
|
||||
|
||||
// getSuggestedTags generates suggested tags for a document using the LLM
|
||||
func (app *App) getSuggestedTags(ctx context.Context, content string, suggestedTitle string, availableTags []string) ([]string, error) {
|
||||
likelyLanguage := getLikelyLanguage()
|
||||
|
||||
templateMutex.RLock()
|
||||
defer templateMutex.RUnlock()
|
||||
|
||||
var promptBuffer bytes.Buffer
|
||||
err := tagTemplate.Execute(&promptBuffer, map[string]interface{}{
|
||||
"Language": likelyLanguage,
|
||||
"AvailableTags": availableTags,
|
||||
"Title": suggestedTitle,
|
||||
"Content": content,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error executing tag template: %v", err)
|
||||
}
|
||||
|
||||
prompt := promptBuffer.String()
|
||||
log.Printf("Tag suggestion prompt: %s", prompt)
|
||||
|
||||
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
|
||||
{
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextContent{
|
||||
Text: prompt,
|
||||
},
|
||||
},
|
||||
Role: llms.ChatMessageTypeHuman,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting response from LLM: %v", err)
|
||||
}
|
||||
|
||||
response := strings.TrimSpace(completion.Choices[0].Content)
|
||||
suggestedTags := strings.Split(response, ",")
|
||||
for i, tag := range suggestedTags {
|
||||
suggestedTags[i] = strings.TrimSpace(tag)
|
||||
}
|
||||
|
||||
// Filter out tags that are not in the available tags list
|
||||
filteredTags := []string{}
|
||||
for _, tag := range suggestedTags {
|
||||
for _, availableTag := range availableTags {
|
||||
if strings.EqualFold(tag, availableTag) {
|
||||
filteredTags = append(filteredTags, availableTag)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return filteredTags, nil
|
||||
}
|
||||
|
||||
func (app *App) doOCRViaLLM(ctx context.Context, jpegBytes []byte) (string, error) {
|
||||
|
||||
templateMutex.RLock()
|
||||
defer templateMutex.RUnlock()
|
||||
likelyLanguage := getLikelyLanguage()
|
||||
|
||||
var promptBuffer bytes.Buffer
|
||||
err := ocrTemplate.Execute(&promptBuffer, map[string]interface{}{
|
||||
"Language": likelyLanguage,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error executing tag template: %v", err)
|
||||
}
|
||||
|
||||
prompt := promptBuffer.String()
|
||||
|
||||
// Convert the image to text
|
||||
completion, err := app.VisionLLM.GenerateContent(ctx, []llms.MessageContent{
|
||||
{
|
||||
Parts: []llms.ContentPart{
|
||||
llms.BinaryPart("image/jpeg", jpegBytes),
|
||||
llms.TextPart(prompt),
|
||||
},
|
||||
Role: llms.ChatMessageTypeHuman,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting response from LLM: %v", err)
|
||||
}
|
||||
|
||||
result := completion.Choices[0].Content
|
||||
fmt.Println(result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// getSuggestedTitle generates a suggested title for a document using the LLM
|
||||
func (app *App) getSuggestedTitle(ctx context.Context, content string) (string, error) {
|
||||
likelyLanguage := getLikelyLanguage()
|
||||
|
||||
templateMutex.RLock()
|
||||
defer templateMutex.RUnlock()
|
||||
|
||||
var promptBuffer bytes.Buffer
|
||||
err := titleTemplate.Execute(&promptBuffer, map[string]interface{}{
|
||||
"Language": likelyLanguage,
|
||||
"Content": content,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error executing title template: %v", err)
|
||||
}
|
||||
|
||||
prompt := promptBuffer.String()
|
||||
|
||||
log.Printf("Title suggestion prompt: %s", prompt)
|
||||
|
||||
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
|
||||
{
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextContent{
|
||||
Text: prompt,
|
||||
},
|
||||
},
|
||||
Role: llms.ChatMessageTypeHuman,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting response from LLM: %v", err)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(strings.Trim(completion.Choices[0].Content, "\"")), nil
|
||||
}
|
||||
|
||||
// generateDocumentSuggestions generates suggestions for a set of documents
|
||||
func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionRequest GenerateSuggestionsRequest) ([]DocumentSuggestion, error) {
|
||||
// Fetch all available tags from paperless-ngx
|
||||
availableTagsMap, err := app.Client.GetAllTags(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch available tags: %v", err)
|
||||
}
|
||||
|
||||
// Prepare a list of tag names
|
||||
availableTagNames := make([]string, 0, len(availableTagsMap))
|
||||
for tagName := range availableTagsMap {
|
||||
if tagName == manualTag {
|
||||
continue
|
||||
}
|
||||
availableTagNames = append(availableTagNames, tagName)
|
||||
}
|
||||
|
||||
documents := suggestionRequest.Documents
|
||||
documentSuggestions := []DocumentSuggestion{}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
errorsList := make([]error, 0)
|
||||
|
||||
for i := range documents {
|
||||
wg.Add(1)
|
||||
go func(doc Document) {
|
||||
defer wg.Done()
|
||||
documentID := doc.ID
|
||||
log.Printf("Processing Document ID %d...", documentID)
|
||||
|
||||
content := doc.Content
|
||||
if len(content) > 5000 {
|
||||
content = content[:5000]
|
||||
}
|
||||
|
||||
var suggestedTitle string
|
||||
var suggestedTags []string
|
||||
|
||||
if suggestionRequest.GenerateTitles {
|
||||
suggestedTitle, err = app.getSuggestedTitle(ctx, content)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err))
|
||||
mu.Unlock()
|
||||
log.Printf("Error processing document %d: %v", documentID, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if suggestionRequest.GenerateTags {
|
||||
suggestedTags, err = app.getSuggestedTags(ctx, content, suggestedTitle, availableTagNames)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err))
|
||||
mu.Unlock()
|
||||
log.Printf("Error generating tags for document %d: %v", documentID, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
suggestion := DocumentSuggestion{
|
||||
ID: documentID,
|
||||
OriginalDocument: doc,
|
||||
}
|
||||
// Titles
|
||||
if suggestionRequest.GenerateTitles {
|
||||
suggestion.SuggestedTitle = suggestedTitle
|
||||
} else {
|
||||
suggestion.SuggestedTitle = doc.Title
|
||||
}
|
||||
|
||||
// Tags
|
||||
if suggestionRequest.GenerateTags {
|
||||
suggestion.SuggestedTags = suggestedTags
|
||||
} else {
|
||||
suggestion.SuggestedTags = removeTagFromList(doc.Tags, manualTag)
|
||||
}
|
||||
documentSuggestions = append(documentSuggestions, suggestion)
|
||||
mu.Unlock()
|
||||
log.Printf("Document %d processed successfully.", documentID)
|
||||
}(documents[i])
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if len(errorsList) > 0 {
|
||||
return nil, errorsList[0] // Return the first error encountered
|
||||
}
|
||||
|
||||
return documentSuggestions, nil
|
||||
}
|
7
go.mod
7
go.mod
|
@ -6,8 +6,11 @@ toolchain go1.22.2
|
|||
|
||||
require (
|
||||
github.com/Masterminds/sprig/v3 v3.2.3
|
||||
github.com/gen2brain/go-fitz v1.24.14
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tmc/langchaingo v0.1.12
|
||||
golang.org/x/sync v0.7.0
|
||||
)
|
||||
|
||||
require (
|
||||
|
@ -17,7 +20,9 @@ require (
|
|||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/ebitengine/purego v0.8.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
|
@ -28,6 +33,7 @@ require (
|
|||
github.com/huandu/xstrings v1.3.3 // indirect
|
||||
github.com/imdario/mergo v0.3.13 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/jupiterrider/ffi v0.2.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
|
@ -37,6 +43,7 @@ require (
|
|||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pkoukk/tiktoken-go v0.1.6 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/shopspring/decimal v1.2.0 // indirect
|
||||
github.com/spf13/cast v1.3.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
|
|
8
go.sum
8
go.sum
|
@ -17,8 +17,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
|||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/ebitengine/purego v0.8.0 h1:JbqvnEzRvPpxhCJzJJ2y0RbiZ8nyjccVUrSM3q+GvvE=
|
||||
github.com/ebitengine/purego v0.8.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gen2brain/go-fitz v1.24.14 h1:09weRkjVtLYNGo7l0J7DyOwBExbwi8SJ9h8YPhw9WEo=
|
||||
github.com/gen2brain/go-fitz v1.24.14/go.mod h1:0KaZeQgASc20Yp5R/pFzyy7SmP01XcoHKNF842U2/S4=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
||||
|
@ -46,6 +50,8 @@ github.com/imdario/mergo v0.3.13 h1:lFzP57bqS/wsqKssCGmtLAb8A0wKjLGrve2q3PPVcBk=
|
|||
github.com/imdario/mergo v0.3.13/go.mod h1:4lJ1jqUDcsbIECGy0RUJAXNIhg+6ocWgb1ALK2O4oXg=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/jupiterrider/ffi v0.2.0 h1:tMM70PexgYNmV+WyaYhJgCvQAvtTCs3wXeILPutihnA=
|
||||
github.com/jupiterrider/ffi v0.2.0/go.mod h1:yqYqX5DdEccAsHeMn+6owkoI2llBLySVAF8dwCDZPVs=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
|
@ -111,6 +117,8 @@ golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
|||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
|
150
jobs.go
Normal file
150
jobs.go
Normal file
|
@ -0,0 +1,150 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Job represents an OCR job
|
||||
type Job struct {
|
||||
ID string
|
||||
DocumentID int
|
||||
Status string // "pending", "in_progress", "completed", "failed"
|
||||
Result string // OCR result or error message
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
PagesDone int // Number of pages processed
|
||||
}
|
||||
|
||||
// JobStore manages jobs and their statuses
|
||||
type JobStore struct {
|
||||
sync.RWMutex
|
||||
jobs map[string]*Job
|
||||
}
|
||||
|
||||
var (
|
||||
jobStore = &JobStore{
|
||||
jobs: make(map[string]*Job),
|
||||
}
|
||||
jobQueue = make(chan *Job, 100) // Buffered channel with capacity of 100 jobs
|
||||
logger = log.New(os.Stdout, "OCR_JOB: ", log.LstdFlags)
|
||||
)
|
||||
|
||||
func generateJobID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
func (store *JobStore) addJob(job *Job) {
|
||||
store.Lock()
|
||||
defer store.Unlock()
|
||||
job.PagesDone = 0 // Initialize PagesDone to 0
|
||||
store.jobs[job.ID] = job
|
||||
logger.Printf("Job added: %v", job)
|
||||
}
|
||||
|
||||
func (store *JobStore) getJob(jobID string) (*Job, bool) {
|
||||
store.RLock()
|
||||
defer store.RUnlock()
|
||||
job, exists := store.jobs[jobID]
|
||||
return job, exists
|
||||
}
|
||||
|
||||
func (store *JobStore) GetAllJobs() []*Job {
|
||||
store.RLock()
|
||||
defer store.RUnlock()
|
||||
|
||||
jobs := make([]*Job, 0, len(store.jobs))
|
||||
for _, job := range store.jobs {
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
sort.Slice(jobs, func(i, j int) bool {
|
||||
return jobs[i].CreatedAt.After(jobs[j].CreatedAt)
|
||||
})
|
||||
|
||||
return jobs
|
||||
}
|
||||
|
||||
func (store *JobStore) updateJobStatus(jobID, status, result string) {
|
||||
store.Lock()
|
||||
defer store.Unlock()
|
||||
if job, exists := store.jobs[jobID]; exists {
|
||||
job.Status = status
|
||||
if result != "" {
|
||||
job.Result = result
|
||||
}
|
||||
job.UpdatedAt = time.Now()
|
||||
logger.Printf("Job status updated: %v", job)
|
||||
}
|
||||
}
|
||||
|
||||
func (store *JobStore) updatePagesDone(jobID string, pagesDone int) {
|
||||
store.Lock()
|
||||
defer store.Unlock()
|
||||
if job, exists := store.jobs[jobID]; exists {
|
||||
job.PagesDone = pagesDone
|
||||
job.UpdatedAt = time.Now()
|
||||
logger.Printf("Job pages done updated: %v", job)
|
||||
}
|
||||
}
|
||||
|
||||
func startWorkerPool(app *App, numWorkers int) {
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
go func(workerID int) {
|
||||
logger.Printf("Worker %d started", workerID)
|
||||
for job := range jobQueue {
|
||||
logger.Printf("Worker %d processing job: %s", workerID, job.ID)
|
||||
processJob(app, job)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
}
|
||||
|
||||
func processJob(app *App, job *Job) {
|
||||
jobStore.updateJobStatus(job.ID, "in_progress", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Download images of the document
|
||||
imagePaths, err := app.Client.DownloadDocumentAsImages(ctx, job.DocumentID)
|
||||
if err != nil {
|
||||
logger.Printf("Error downloading document images for job %s: %v", job.ID, err)
|
||||
jobStore.updateJobStatus(job.ID, "failed", fmt.Sprintf("Error downloading document images: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
var ocrTexts []string
|
||||
for i, imagePath := range imagePaths {
|
||||
imageContent, err := os.ReadFile(imagePath)
|
||||
if err != nil {
|
||||
logger.Printf("Error reading image file for job %s: %v", job.ID, err)
|
||||
jobStore.updateJobStatus(job.ID, "failed", fmt.Sprintf("Error reading image file: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ocrText, err := app.doOCRViaLLM(ctx, imageContent)
|
||||
if err != nil {
|
||||
logger.Printf("Error performing OCR for job %s: %v", job.ID, err)
|
||||
jobStore.updateJobStatus(job.ID, "failed", fmt.Sprintf("Error performing OCR: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ocrTexts = append(ocrTexts, ocrText)
|
||||
jobStore.updatePagesDone(job.ID, i+1) // Update PagesDone after each page is processed
|
||||
}
|
||||
|
||||
// Combine the OCR texts
|
||||
fullOcrText := strings.Join(ocrTexts, "\n\n")
|
||||
|
||||
// Update job status and result
|
||||
jobStore.updateJobStatus(job.ID, "completed", fullOcrText)
|
||||
logger.Printf("Job completed: %s", job.ID)
|
||||
}
|
388
main.go
388
main.go
|
@ -1,7 +1,6 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
@ -29,10 +28,13 @@ var (
|
|||
autoTag = "paperless-gpt-auto"
|
||||
llmProvider = os.Getenv("LLM_PROVIDER")
|
||||
llmModel = os.Getenv("LLM_MODEL")
|
||||
visionLlmProvider = os.Getenv("VISION_LLM_PROVIDER")
|
||||
visionLlmModel = os.Getenv("VISION_LLM_MODEL")
|
||||
|
||||
// Templates
|
||||
titleTemplate *template.Template
|
||||
tagTemplate *template.Template
|
||||
ocrTemplate *template.Template
|
||||
templateMutex sync.RWMutex
|
||||
|
||||
// Default templates
|
||||
|
@ -58,12 +60,15 @@ Content:
|
|||
Please concisely select the {{.Language}} tags from the list above that best describe the document.
|
||||
Be very selective and only choose the most relevant tags since too many tags will make the document less discoverable.
|
||||
`
|
||||
|
||||
defaultOcrPrompt = `Just transcribe the text in this image and preserve the formatting and layout (high quality OCR). Do that for ALL the text in the image. Be thorough and pay attention. This is very important. The image is from a text document so be sure to continue until the bottom of the page. Thanks a lot! You tend to forget about some text in the image so please focus! Use markdown format.`
|
||||
)
|
||||
|
||||
// App struct to hold dependencies
|
||||
type App struct {
|
||||
Client *PaperlessClient
|
||||
LLM llms.Model
|
||||
VisionLLM llms.Model
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
@ -82,10 +87,17 @@ func main() {
|
|||
log.Fatalf("Failed to create LLM client: %v", err)
|
||||
}
|
||||
|
||||
// Initialize Vision LLM
|
||||
visionLlm, err := createVisionLLM()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create Vision LLM client: %v", err)
|
||||
}
|
||||
|
||||
// Initialize App with dependencies
|
||||
app := &App{
|
||||
Client: client,
|
||||
LLM: llm,
|
||||
VisionLLM: visionLlm,
|
||||
}
|
||||
|
||||
// Start background process for auto-tagging
|
||||
|
@ -119,6 +131,8 @@ func main() {
|
|||
api := router.Group("/api")
|
||||
{
|
||||
api.GET("/documents", app.documentsHandler)
|
||||
// http://localhost:8080/api/documents/544
|
||||
api.GET("/documents/:id", app.getDocumentHandler())
|
||||
api.POST("/generate-suggestions", app.generateSuggestionsHandler)
|
||||
api.PATCH("/update-documents", app.updateDocumentsHandler)
|
||||
api.GET("/filter-tag", func(c *gin.Context) {
|
||||
|
@ -128,6 +142,17 @@ func main() {
|
|||
api.GET("/tags", app.getAllTagsHandler)
|
||||
api.GET("/prompts", getPromptsHandler)
|
||||
api.POST("/prompts", updatePromptsHandler)
|
||||
|
||||
// OCR endpoints
|
||||
api.POST("/documents/:id/ocr", app.submitOCRJobHandler)
|
||||
api.GET("/jobs/ocr/:job_id", app.getJobStatusHandler)
|
||||
api.GET("/jobs/ocr", app.getAllJobsHandler)
|
||||
|
||||
// Endpoint to see if user enabled OCR
|
||||
api.GET("/experimental/ocr", func(c *gin.Context) {
|
||||
enabled := isOcrEnabled()
|
||||
c.JSON(http.StatusOK, gin.H{"enabled": enabled})
|
||||
})
|
||||
}
|
||||
|
||||
// Serve static files for the frontend under /assets
|
||||
|
@ -139,12 +164,20 @@ func main() {
|
|||
c.File("./web-app/dist/index.html")
|
||||
})
|
||||
|
||||
// Start OCR worker pool
|
||||
numWorkers := 1 // Number of workers to start
|
||||
startWorkerPool(app, numWorkers)
|
||||
|
||||
log.Println("Server started on port :8080")
|
||||
if err := router.Run(":8080"); err != nil {
|
||||
log.Fatalf("Failed to run server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func isOcrEnabled() bool {
|
||||
return visionLlmModel != "" && visionLlmProvider != ""
|
||||
}
|
||||
|
||||
// validateEnvVars ensures all necessary environment variables are set
|
||||
func validateEnvVars() {
|
||||
if paperlessBaseURL == "" {
|
||||
|
@ -200,169 +233,6 @@ func (app *App) processAutoTagDocuments() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// getAllTagsHandler handles the GET /api/tags endpoint
|
||||
func (app *App) getAllTagsHandler(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
tags, err := app.Client.GetAllTags(ctx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error fetching tags: %v", err)})
|
||||
log.Printf("Error fetching tags: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, tags)
|
||||
}
|
||||
|
||||
// documentsHandler handles the GET /api/documents endpoint
|
||||
func (app *App) documentsHandler(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
documents, err := app.Client.GetDocumentsByTags(ctx, []string{manualTag})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error fetching documents: %v", err)})
|
||||
log.Printf("Error fetching documents: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, documents)
|
||||
}
|
||||
|
||||
// generateSuggestionsHandler handles the POST /api/generate-suggestions endpoint
|
||||
func (app *App) generateSuggestionsHandler(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var suggestionRequest GenerateSuggestionsRequest
|
||||
if err := c.ShouldBindJSON(&suggestionRequest); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request payload: %v", err)})
|
||||
log.Printf("Invalid request payload: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
results, err := app.generateDocumentSuggestions(ctx, suggestionRequest)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error processing documents: %v", err)})
|
||||
log.Printf("Error processing documents: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, results)
|
||||
}
|
||||
|
||||
// updateDocumentsHandler handles the PATCH /api/update-documents endpoint
|
||||
func (app *App) updateDocumentsHandler(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
var documents []DocumentSuggestion
|
||||
if err := c.ShouldBindJSON(&documents); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request payload: %v", err)})
|
||||
log.Printf("Invalid request payload: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err := app.Client.UpdateDocuments(ctx, documents)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error updating documents: %v", err)})
|
||||
log.Printf("Error updating documents: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
// generateDocumentSuggestions generates suggestions for a set of documents
|
||||
func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionRequest GenerateSuggestionsRequest) ([]DocumentSuggestion, error) {
|
||||
// Fetch all available tags from paperless-ngx
|
||||
availableTagsMap, err := app.Client.GetAllTags(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch available tags: %v", err)
|
||||
}
|
||||
|
||||
// Prepare a list of tag names
|
||||
availableTagNames := make([]string, 0, len(availableTagsMap))
|
||||
for tagName := range availableTagsMap {
|
||||
if tagName == manualTag {
|
||||
continue
|
||||
}
|
||||
availableTagNames = append(availableTagNames, tagName)
|
||||
}
|
||||
|
||||
documents := suggestionRequest.Documents
|
||||
documentSuggestions := []DocumentSuggestion{}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
errorsList := make([]error, 0)
|
||||
|
||||
for i := range documents {
|
||||
wg.Add(1)
|
||||
go func(doc Document) {
|
||||
defer wg.Done()
|
||||
documentID := doc.ID
|
||||
log.Printf("Processing Document ID %d...", documentID)
|
||||
|
||||
content := doc.Content
|
||||
if len(content) > 5000 {
|
||||
content = content[:5000]
|
||||
}
|
||||
|
||||
var suggestedTitle string
|
||||
var suggestedTags []string
|
||||
|
||||
if suggestionRequest.GenerateTitles {
|
||||
suggestedTitle, err = app.getSuggestedTitle(ctx, content)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err))
|
||||
mu.Unlock()
|
||||
log.Printf("Error processing document %d: %v", documentID, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if suggestionRequest.GenerateTags {
|
||||
suggestedTags, err = app.getSuggestedTags(ctx, content, suggestedTitle, availableTagNames)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err))
|
||||
mu.Unlock()
|
||||
log.Printf("Error generating tags for document %d: %v", documentID, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
suggestion := DocumentSuggestion{
|
||||
ID: documentID,
|
||||
OriginalDocument: doc,
|
||||
}
|
||||
// Titles
|
||||
if suggestionRequest.GenerateTitles {
|
||||
suggestion.SuggestedTitle = suggestedTitle
|
||||
} else {
|
||||
suggestion.SuggestedTitle = doc.Title
|
||||
}
|
||||
|
||||
// Tags
|
||||
if suggestionRequest.GenerateTags {
|
||||
suggestion.SuggestedTags = suggestedTags
|
||||
} else {
|
||||
suggestion.SuggestedTags = removeTagFromList(doc.Tags, manualTag)
|
||||
}
|
||||
documentSuggestions = append(documentSuggestions, suggestion)
|
||||
mu.Unlock()
|
||||
log.Printf("Document %d processed successfully.", documentID)
|
||||
}(documents[i])
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if len(errorsList) > 0 {
|
||||
return nil, errorsList[0] // Return the first error encountered
|
||||
}
|
||||
|
||||
return documentSuggestions, nil
|
||||
}
|
||||
|
||||
// removeTagFromList removes a specific tag from a list of tags
|
||||
func removeTagFromList(tags []string, tagToRemove string) []string {
|
||||
filteredTags := []string{}
|
||||
|
@ -374,61 +244,6 @@ func removeTagFromList(tags []string, tagToRemove string) []string {
|
|||
return filteredTags
|
||||
}
|
||||
|
||||
// getSuggestedTags generates suggested tags for a document using the LLM
|
||||
func (app *App) getSuggestedTags(ctx context.Context, content string, suggestedTitle string, availableTags []string) ([]string, error) {
|
||||
likelyLanguage := getLikelyLanguage()
|
||||
|
||||
templateMutex.RLock()
|
||||
defer templateMutex.RUnlock()
|
||||
|
||||
var promptBuffer bytes.Buffer
|
||||
err := tagTemplate.Execute(&promptBuffer, map[string]interface{}{
|
||||
"Language": likelyLanguage,
|
||||
"AvailableTags": availableTags,
|
||||
"Title": suggestedTitle,
|
||||
"Content": content,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error executing tag template: %v", err)
|
||||
}
|
||||
|
||||
prompt := promptBuffer.String()
|
||||
log.Printf("Tag suggestion prompt: %s", prompt)
|
||||
|
||||
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
|
||||
{
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextContent{
|
||||
Text: prompt,
|
||||
},
|
||||
},
|
||||
Role: llms.ChatMessageTypeHuman,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting response from LLM: %v", err)
|
||||
}
|
||||
|
||||
response := strings.TrimSpace(completion.Choices[0].Content)
|
||||
suggestedTags := strings.Split(response, ",")
|
||||
for i, tag := range suggestedTags {
|
||||
suggestedTags[i] = strings.TrimSpace(tag)
|
||||
}
|
||||
|
||||
// Filter out tags that are not in the available tags list
|
||||
filteredTags := []string{}
|
||||
for _, tag := range suggestedTags {
|
||||
for _, availableTag := range availableTags {
|
||||
if strings.EqualFold(tag, availableTag) {
|
||||
filteredTags = append(filteredTags, availableTag)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return filteredTags, nil
|
||||
}
|
||||
|
||||
// getLikelyLanguage determines the likely language of the document content
|
||||
func getLikelyLanguage() string {
|
||||
likelyLanguage := os.Getenv("LLM_LANGUAGE")
|
||||
|
@ -438,43 +253,6 @@ func getLikelyLanguage() string {
|
|||
return strings.Title(strings.ToLower(likelyLanguage))
|
||||
}
|
||||
|
||||
// getSuggestedTitle generates a suggested title for a document using the LLM
|
||||
func (app *App) getSuggestedTitle(ctx context.Context, content string) (string, error) {
|
||||
likelyLanguage := getLikelyLanguage()
|
||||
|
||||
templateMutex.RLock()
|
||||
defer templateMutex.RUnlock()
|
||||
|
||||
var promptBuffer bytes.Buffer
|
||||
err := titleTemplate.Execute(&promptBuffer, map[string]interface{}{
|
||||
"Language": likelyLanguage,
|
||||
"Content": content,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error executing title template: %v", err)
|
||||
}
|
||||
|
||||
prompt := promptBuffer.String()
|
||||
|
||||
log.Printf("Title suggestion prompt: %s", prompt)
|
||||
|
||||
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
|
||||
{
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextContent{
|
||||
Text: prompt,
|
||||
},
|
||||
},
|
||||
Role: llms.ChatMessageTypeHuman,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting response from LLM: %v", err)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(strings.Trim(completion.Choices[0].Content, "\"")), nil
|
||||
}
|
||||
|
||||
// loadTemplates loads the title and tag templates from files or uses default templates
|
||||
func loadTemplates() {
|
||||
templateMutex.Lock()
|
||||
|
@ -515,6 +293,21 @@ func loadTemplates() {
|
|||
if err != nil {
|
||||
log.Fatalf("Failed to parse tag template: %v", err)
|
||||
}
|
||||
|
||||
// Load OCR template
|
||||
ocrTemplatePath := filepath.Join(promptsDir, "ocr_prompt.tmpl")
|
||||
ocrTemplateContent, err := os.ReadFile(ocrTemplatePath)
|
||||
if err != nil {
|
||||
log.Printf("Could not read %s, using default template: %v", ocrTemplatePath, err)
|
||||
ocrTemplateContent = []byte(defaultOcrPrompt)
|
||||
if err := os.WriteFile(ocrTemplatePath, ocrTemplateContent, os.ModePerm); err != nil {
|
||||
log.Fatalf("Failed to write default OCR template to disk: %v", err)
|
||||
}
|
||||
}
|
||||
ocrTemplate, err = template.New("ocr").Funcs(sprig.FuncMap()).Parse(string(ocrTemplateContent))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to parse OCR template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// createLLM creates the appropriate LLM client based on the provider
|
||||
|
@ -542,70 +335,27 @@ func createLLM() (llms.Model, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// getPromptsHandler handles the GET /api/prompts endpoint
|
||||
func getPromptsHandler(c *gin.Context) {
|
||||
templateMutex.RLock()
|
||||
defer templateMutex.RUnlock()
|
||||
|
||||
// Read the templates from files or use default content
|
||||
titleTemplateContent, err := os.ReadFile("prompts/title_prompt.tmpl")
|
||||
if err != nil {
|
||||
titleTemplateContent = []byte(defaultTitleTemplate)
|
||||
func createVisionLLM() (llms.Model, error) {
|
||||
switch strings.ToLower(visionLlmProvider) {
|
||||
case "openai":
|
||||
if openaiAPIKey == "" {
|
||||
return nil, fmt.Errorf("OpenAI API key is not set")
|
||||
}
|
||||
|
||||
tagTemplateContent, err := os.ReadFile("prompts/tag_prompt.tmpl")
|
||||
if err != nil {
|
||||
tagTemplateContent = []byte(defaultTagTemplate)
|
||||
return openai.New(
|
||||
openai.WithModel(visionLlmModel),
|
||||
openai.WithToken(openaiAPIKey),
|
||||
)
|
||||
case "ollama":
|
||||
host := os.Getenv("OLLAMA_HOST")
|
||||
if host == "" {
|
||||
host = "http://127.0.0.1:11434"
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"title_template": string(titleTemplateContent),
|
||||
"tag_template": string(tagTemplateContent),
|
||||
})
|
||||
}
|
||||
|
||||
// updatePromptsHandler handles the POST /api/prompts endpoint
|
||||
func updatePromptsHandler(c *gin.Context) {
|
||||
var req struct {
|
||||
TitleTemplate string `json:"title_template"`
|
||||
TagTemplate string `json:"tag_template"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request payload"})
|
||||
return
|
||||
}
|
||||
|
||||
templateMutex.Lock()
|
||||
defer templateMutex.Unlock()
|
||||
|
||||
// Update title template
|
||||
if req.TitleTemplate != "" {
|
||||
t, err := template.New("title").Parse(req.TitleTemplate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid title template: %v", err)})
|
||||
return
|
||||
}
|
||||
titleTemplate = t
|
||||
err = os.WriteFile("prompts/title_prompt.tmpl", []byte(req.TitleTemplate), 0644)
|
||||
if err != nil {
|
||||
log.Printf("Failed to write title_prompt.tmpl: %v", err)
|
||||
return ollama.New(
|
||||
ollama.WithModel(visionLlmModel),
|
||||
ollama.WithServerURL(host),
|
||||
)
|
||||
default:
|
||||
log.Printf("No Vision LLM provider created: %s", visionLlmProvider)
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Update tag template
|
||||
if req.TagTemplate != "" {
|
||||
t, err := template.New("tag").Parse(req.TagTemplate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid tag template: %v", err)})
|
||||
return
|
||||
}
|
||||
tagTemplate = t
|
||||
err = os.WriteFile("prompts/tag_prompt.tmpl", []byte(req.TagTemplate), 0644)
|
||||
if err != nil {
|
||||
log.Printf("Failed to write tag_prompt.tmpl: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
|
183
paperless.go
183
paperless.go
|
@ -5,10 +5,17 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image/jpeg"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gen2brain/go-fitz"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// PaperlessClient struct to interact with the Paperless-NGX API
|
||||
|
@ -16,14 +23,18 @@ type PaperlessClient struct {
|
|||
BaseURL string
|
||||
APIToken string
|
||||
HTTPClient *http.Client
|
||||
CacheFolder string
|
||||
}
|
||||
|
||||
// NewPaperlessClient creates a new instance of PaperlessClient with a default HTTP client
|
||||
func NewPaperlessClient(baseURL, apiToken string) *PaperlessClient {
|
||||
cacheFolder := os.Getenv("PAPERLESS_GPT_CACHE_DIR")
|
||||
|
||||
return &PaperlessClient{
|
||||
BaseURL: strings.TrimRight(baseURL, "/"),
|
||||
APIToken: apiToken,
|
||||
HTTPClient: &http.Client{},
|
||||
CacheFolder: cacheFolder,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -164,6 +175,48 @@ func (c *PaperlessClient) DownloadPDF(ctx context.Context, document Document) ([
|
|||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
func (c *PaperlessClient) GetDocument(ctx context.Context, documentID int) (Document, error) {
|
||||
path := fmt.Sprintf("api/documents/%d/", documentID)
|
||||
resp, err := c.Do(ctx, "GET", path, nil)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return Document{}, fmt.Errorf("error fetching document %d: %d, %s", documentID, resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var documentResponse GetDocumentApiResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&documentResponse)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
|
||||
allTags, err := c.GetAllTags(ctx)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
|
||||
tagNames := make([]string, len(documentResponse.Tags))
|
||||
for i, resultTagID := range documentResponse.Tags {
|
||||
for tagName, tagID := range allTags {
|
||||
if resultTagID == tagID {
|
||||
tagNames[i] = tagName
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Document{
|
||||
ID: documentResponse.ID,
|
||||
Title: documentResponse.Title,
|
||||
Content: documentResponse.Content,
|
||||
Tags: tagNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateDocuments updates the specified documents with suggested changes
|
||||
func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []DocumentSuggestion) error {
|
||||
// Fetch all available tags
|
||||
|
@ -211,6 +264,12 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum
|
|||
log.Printf("No valid title found for document %d, skipping.", documentID)
|
||||
}
|
||||
|
||||
// Suggested Content
|
||||
suggestedContent := document.SuggestedContent
|
||||
if suggestedContent != "" {
|
||||
updatedFields["content"] = suggestedContent
|
||||
}
|
||||
|
||||
// Marshal updated fields to JSON
|
||||
jsonData, err := json.Marshal(updatedFields)
|
||||
if err != nil {
|
||||
|
@ -239,6 +298,130 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum
|
|||
return nil
|
||||
}
|
||||
|
||||
// DownloadDocumentAsImages downloads the PDF file of the specified document and converts it to images
|
||||
func (c *PaperlessClient) DownloadDocumentAsImages(ctx context.Context, documentId int) ([]string, error) {
|
||||
// Create a directory named after the document ID
|
||||
docDir := filepath.Join(c.GetCacheFolder(), fmt.Sprintf("/document-%d", documentId))
|
||||
if _, err := os.Stat(docDir); os.IsNotExist(err) {
|
||||
err = os.MkdirAll(docDir, 0755)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Check if images already exist
|
||||
var imagePaths []string
|
||||
for n := 0; ; n++ {
|
||||
imagePath := filepath.Join(docDir, fmt.Sprintf("page%03d.jpg", n))
|
||||
if _, err := os.Stat(imagePath); os.IsNotExist(err) {
|
||||
break
|
||||
}
|
||||
imagePaths = append(imagePaths, imagePath)
|
||||
}
|
||||
|
||||
// If images exist, return them
|
||||
if len(imagePaths) > 0 {
|
||||
return imagePaths, nil
|
||||
}
|
||||
|
||||
// Proceed with downloading and converting the document to images
|
||||
path := fmt.Sprintf("api/documents/%d/download/", documentId)
|
||||
resp, err := c.Do(ctx, "GET", path, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("error downloading document %d: %d, %s", documentId, resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
pdfData, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "document-*.pdf")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
_, err = tmpFile.Write(pdfData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
doc, err := fitz.New(tmpFile.Name())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer doc.Close()
|
||||
|
||||
var mu sync.Mutex
|
||||
var g errgroup.Group
|
||||
|
||||
for n := 0; n < doc.NumPage(); n++ {
|
||||
n := n // capture loop variable
|
||||
g.Go(func() error {
|
||||
mu.Lock()
|
||||
// I assume the libmupdf library is not thread-safe
|
||||
img, err := doc.Image(n)
|
||||
mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
imagePath := filepath.Join(docDir, fmt.Sprintf("page%03d.jpg", n))
|
||||
f, err := os.Create(imagePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = jpeg.Encode(f, img, &jpeg.Options{Quality: jpeg.DefaultQuality})
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return err
|
||||
}
|
||||
f.Close()
|
||||
|
||||
// Verify the JPEG file
|
||||
file, err := os.Open(imagePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
_, err = jpeg.Decode(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid JPEG file: %s", imagePath)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
imagePaths = append(imagePaths, imagePath)
|
||||
mu.Unlock()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return imagePaths, nil
|
||||
}
|
||||
|
||||
// GetCacheFolder returns the cache folder for the PaperlessClient
|
||||
func (c *PaperlessClient) GetCacheFolder() string {
|
||||
if c.CacheFolder == "" {
|
||||
c.CacheFolder = filepath.Join(os.TempDir(), "paperless-gpt")
|
||||
}
|
||||
return c.CacheFolder
|
||||
}
|
||||
|
||||
// urlEncode encodes a string for safe URL usage
|
||||
func urlEncode(s string) string {
|
||||
return strings.ReplaceAll(s, " ", "+")
|
||||
|
|
412
paperless_test.go
Normal file
412
paperless_test.go
Normal file
|
@ -0,0 +1,412 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Helper struct to hold common test data and methods
|
||||
type testEnv struct {
|
||||
t *testing.T
|
||||
server *httptest.Server
|
||||
client *PaperlessClient
|
||||
requestCount int
|
||||
mockResponses map[string]http.HandlerFunc
|
||||
}
|
||||
|
||||
// newTestEnv initializes a new test environment
|
||||
func newTestEnv(t *testing.T) *testEnv {
|
||||
env := &testEnv{
|
||||
t: t,
|
||||
mockResponses: make(map[string]http.HandlerFunc),
|
||||
}
|
||||
|
||||
// Create a mock server with a handler that dispatches based on URL path
|
||||
env.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
env.requestCount++
|
||||
handler, exists := env.mockResponses[r.URL.Path]
|
||||
if !exists {
|
||||
t.Fatalf("Unexpected request URL: %s", r.URL.Path)
|
||||
}
|
||||
// Set common headers and invoke the handler
|
||||
assert.Equal(t, "Token test-token", r.Header.Get("Authorization"))
|
||||
handler(w, r)
|
||||
}))
|
||||
|
||||
// Initialize the PaperlessClient with the mock server URL
|
||||
env.client = NewPaperlessClient(env.server.URL, "test-token")
|
||||
env.client.HTTPClient = env.server.Client()
|
||||
|
||||
return env
|
||||
}
|
||||
|
||||
// teardown closes the mock server
|
||||
func (env *testEnv) teardown() {
|
||||
env.server.Close()
|
||||
}
|
||||
|
||||
// Helper method to set a mock response for a specific path
|
||||
func (env *testEnv) setMockResponse(path string, handler http.HandlerFunc) {
|
||||
env.mockResponses[path] = handler
|
||||
}
|
||||
|
||||
// TestNewPaperlessClient tests the creation of a new PaperlessClient instance
|
||||
func TestNewPaperlessClient(t *testing.T) {
|
||||
baseURL := "http://example.com"
|
||||
apiToken := "test-token"
|
||||
|
||||
client := NewPaperlessClient(baseURL, apiToken)
|
||||
|
||||
assert.Equal(t, "http://example.com", client.BaseURL)
|
||||
assert.Equal(t, apiToken, client.APIToken)
|
||||
assert.NotNil(t, client.HTTPClient)
|
||||
}
|
||||
|
||||
// TestDo tests the Do method of PaperlessClient
|
||||
func TestDo(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
defer env.teardown()
|
||||
|
||||
// Set mock response for "/test-path"
|
||||
env.setMockResponse("/test-path", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify the request method
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
// Send a mock response
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"message": "success"}`))
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
resp, err := env.client.Do(ctx, "GET", "/test-path", nil)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, `{"message": "success"}`, string(body))
|
||||
}
|
||||
|
||||
// TestGetAllTags tests the GetAllTags method, including pagination
|
||||
func TestGetAllTags(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
defer env.teardown()
|
||||
|
||||
// Mock data for paginated responses
|
||||
page1 := map[string]interface{}{
|
||||
"results": []map[string]interface{}{
|
||||
{"id": 1, "name": "tag1"},
|
||||
{"id": 2, "name": "tag2"},
|
||||
},
|
||||
"next": fmt.Sprintf("%s/api/tags/?page=2", env.server.URL),
|
||||
}
|
||||
page2 := map[string]interface{}{
|
||||
"results": []map[string]interface{}{
|
||||
{"id": 3, "name": "tag3"},
|
||||
},
|
||||
"next": nil,
|
||||
}
|
||||
|
||||
// Set mock responses for pagination
|
||||
env.setMockResponse("/api/tags/", func(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query().Get("page")
|
||||
if query == "2" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(page2)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(page1)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
tags, err := env.client.GetAllTags(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedTags := map[string]int{
|
||||
"tag1": 1,
|
||||
"tag2": 2,
|
||||
"tag3": 3,
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedTags, tags)
|
||||
}
|
||||
|
||||
// TestGetDocumentsByTags tests the GetDocumentsByTags method
|
||||
func TestGetDocumentsByTags(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
defer env.teardown()
|
||||
|
||||
// Mock data for documents
|
||||
documentsResponse := GetDocumentsApiResponse{
|
||||
Results: []struct {
|
||||
ID int `json:"id"`
|
||||
Correspondent interface{} `json:"correspondent"`
|
||||
DocumentType interface{} `json:"document_type"`
|
||||
StoragePath interface{} `json:"storage_path"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Tags []int `json:"tags"`
|
||||
Created time.Time `json:"created"`
|
||||
CreatedDate string `json:"created_date"`
|
||||
Modified time.Time `json:"modified"`
|
||||
Added time.Time `json:"added"`
|
||||
ArchiveSerialNumber interface{} `json:"archive_serial_number"`
|
||||
OriginalFileName string `json:"original_file_name"`
|
||||
ArchivedFileName string `json:"archived_file_name"`
|
||||
Owner int `json:"owner"`
|
||||
UserCanChange bool `json:"user_can_change"`
|
||||
Notes []interface{} `json:"notes"`
|
||||
SearchHit struct {
|
||||
Score float64 `json:"score"`
|
||||
Highlights string `json:"highlights"`
|
||||
NoteHighlights string `json:"note_highlights"`
|
||||
Rank int `json:"rank"`
|
||||
} `json:"__search_hit__"`
|
||||
}{
|
||||
{
|
||||
ID: 1,
|
||||
Title: "Document 1",
|
||||
Content: "Content 1",
|
||||
Tags: []int{1, 2},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Title: "Document 2",
|
||||
Content: "Content 2",
|
||||
Tags: []int{2, 3},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Mock data for tags
|
||||
tagsResponse := map[string]interface{}{
|
||||
"results": []map[string]interface{}{
|
||||
{"id": 1, "name": "tag1"},
|
||||
{"id": 2, "name": "tag2"},
|
||||
{"id": 3, "name": "tag3"},
|
||||
},
|
||||
"next": nil,
|
||||
}
|
||||
|
||||
// Set mock responses
|
||||
env.setMockResponse("/api/documents/", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify query parameters
|
||||
expectedQuery := "query=tag:tag1+tag:tag2"
|
||||
assert.Equal(t, expectedQuery, r.URL.RawQuery)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(documentsResponse)
|
||||
})
|
||||
|
||||
env.setMockResponse("/api/tags/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(tagsResponse)
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
tags := []string{"tag1", "tag2"}
|
||||
documents, err := env.client.GetDocumentsByTags(ctx, tags)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedDocuments := []Document{
|
||||
{
|
||||
ID: 1,
|
||||
Title: "Document 1",
|
||||
Content: "Content 1",
|
||||
Tags: []string{"tag1", "tag2"},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Title: "Document 2",
|
||||
Content: "Content 2",
|
||||
Tags: []string{"tag2", "tag3"},
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedDocuments, documents)
|
||||
}
|
||||
|
||||
// TestDownloadPDF tests the DownloadPDF method
|
||||
func TestDownloadPDF(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
defer env.teardown()
|
||||
|
||||
document := Document{
|
||||
ID: 123,
|
||||
}
|
||||
|
||||
// Get sample PDF from tests/pdf/sample.pdf
|
||||
pdfFile := "tests/pdf/sample.pdf"
|
||||
pdfContent, err := os.ReadFile(pdfFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set mock response
|
||||
downloadPath := fmt.Sprintf("/api/documents/%d/download/", document.ID)
|
||||
env.setMockResponse(downloadPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(pdfContent)
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
data, err := env.client.DownloadPDF(ctx, document)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pdfContent, data)
|
||||
}
|
||||
|
||||
// TestUpdateDocuments tests the UpdateDocuments method
|
||||
func TestUpdateDocuments(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
defer env.teardown()
|
||||
|
||||
// Mock data for documents to update
|
||||
documents := []DocumentSuggestion{
|
||||
{
|
||||
ID: 1,
|
||||
OriginalDocument: Document{
|
||||
ID: 1,
|
||||
Title: "Old Title",
|
||||
Tags: []string{"tag1"},
|
||||
},
|
||||
SuggestedTitle: "New Title",
|
||||
SuggestedTags: []string{"tag2"},
|
||||
},
|
||||
}
|
||||
// Mock data for tags
|
||||
tagsResponse := map[string]interface{}{
|
||||
"results": []map[string]interface{}{
|
||||
{"id": 1, "name": "tag1"},
|
||||
{"id": 2, "name": "tag2"},
|
||||
{"id": 3, "name": "manual"},
|
||||
},
|
||||
"next": nil,
|
||||
}
|
||||
|
||||
// Set the manual tag
|
||||
manualTag = "manual"
|
||||
|
||||
// Set mock responses
|
||||
env.setMockResponse("/api/tags/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(tagsResponse)
|
||||
})
|
||||
|
||||
updatePath := fmt.Sprintf("/api/documents/%d/", documents[0].ID)
|
||||
env.setMockResponse(updatePath, func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify the request method
|
||||
assert.Equal(t, "PATCH", r.Method)
|
||||
|
||||
// Read and parse the request body
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
defer r.Body.Close()
|
||||
|
||||
var updatedFields map[string]interface{}
|
||||
err = json.Unmarshal(bodyBytes, &updatedFields)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expected updated fields
|
||||
expectedFields := map[string]interface{}{
|
||||
"title": "New Title",
|
||||
"tags": []interface{}{float64(2)}, // tag2 ID
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedFields, updatedFields)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
err := env.client.UpdateDocuments(ctx, documents)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestUrlEncode tests the urlEncode function
|
||||
func TestUrlEncode(t *testing.T) {
|
||||
input := "tag:tag1 tag:tag2"
|
||||
expected := "tag:tag1+tag:tag2"
|
||||
result := urlEncode(input)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
// TestDownloadDocumentAsImages tests the DownloadDocumentAsImages method
|
||||
func TestDownloadDocumentAsImages(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
defer env.teardown()
|
||||
|
||||
document := Document{
|
||||
ID: 123,
|
||||
}
|
||||
|
||||
// Get sample PDF from tests/pdf/sample.pdf
|
||||
pdfFile := "tests/pdf/sample.pdf"
|
||||
pdfContent, err := os.ReadFile(pdfFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set mock response
|
||||
downloadPath := fmt.Sprintf("/api/documents/%d/download/", document.ID)
|
||||
env.setMockResponse(downloadPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(pdfContent)
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
imagePaths, err := env.client.DownloadDocumentAsImages(ctx, document.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify that exatly one page was extracted
|
||||
assert.Len(t, imagePaths, 1)
|
||||
// The path shall end with paperless-gpt/document-123/page000.jpg
|
||||
assert.Contains(t, imagePaths[0], "paperless-gpt/document-123/page000.jpg")
|
||||
for _, imagePath := range imagePaths {
|
||||
_, err := os.Stat(imagePath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadDocumentAsImages_ManyPages(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
defer env.teardown()
|
||||
|
||||
document := Document{
|
||||
ID: 321,
|
||||
}
|
||||
|
||||
// Get sample PDF from tests/pdf/sample.pdf
|
||||
pdfFile := "tests/pdf/many-pages.pdf"
|
||||
pdfContent, err := os.ReadFile(pdfFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set mock response
|
||||
downloadPath := fmt.Sprintf("/api/documents/%d/download/", document.ID)
|
||||
env.setMockResponse(downloadPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(pdfContent)
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
env.client.CacheFolder = "tests/tmp"
|
||||
// Clean the cache folder
|
||||
os.RemoveAll(env.client.CacheFolder)
|
||||
imagePaths, err := env.client.DownloadDocumentAsImages(ctx, document.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify that exatly 52 pages were extracted
|
||||
assert.Len(t, imagePaths, 52)
|
||||
// The path shall end with tests/tmp/document-321/page000.jpg
|
||||
for _, imagePath := range imagePaths {
|
||||
_, err := os.Stat(imagePath)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, imagePath, "tests/tmp/document-321/page")
|
||||
}
|
||||
}
|
BIN
tests/pdf/many-pages.pdf
Normal file
BIN
tests/pdf/many-pages.pdf
Normal file
Binary file not shown.
BIN
tests/pdf/sample.pdf
Normal file
BIN
tests/pdf/sample.pdf
Normal file
Binary file not shown.
21
types.go
21
types.go
|
@ -36,6 +36,26 @@ type GetDocumentsApiResponse struct {
|
|||
} `json:"results"`
|
||||
}
|
||||
|
||||
type GetDocumentApiResponse struct {
|
||||
ID int `json:"id"`
|
||||
Correspondent interface{} `json:"correspondent"`
|
||||
DocumentType interface{} `json:"document_type"`
|
||||
StoragePath interface{} `json:"storage_path"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Tags []int `json:"tags"`
|
||||
Created time.Time `json:"created"`
|
||||
CreatedDate string `json:"created_date"`
|
||||
Modified time.Time `json:"modified"`
|
||||
Added time.Time `json:"added"`
|
||||
ArchiveSerialNumber interface{} `json:"archive_serial_number"`
|
||||
OriginalFileName string `json:"original_file_name"`
|
||||
ArchivedFileName string `json:"archived_file_name"`
|
||||
Owner int `json:"owner"`
|
||||
UserCanChange bool `json:"user_can_change"`
|
||||
Notes []interface{} `json:"notes"`
|
||||
}
|
||||
|
||||
// Document is a stripped down version of the document object from paperless-ngx.
|
||||
// Response payload for /documents endpoint and part of request payload for /generate-suggestions endpoint
|
||||
type Document struct {
|
||||
|
@ -58,4 +78,5 @@ type DocumentSuggestion struct {
|
|||
OriginalDocument Document `json:"original_document"`
|
||||
SuggestedTitle string `json:"suggested_title,omitempty"`
|
||||
SuggestedTags []string `json:"suggested_tags,omitempty"`
|
||||
SuggestedContent string `json:"suggested_content,omitempty"`
|
||||
}
|
||||
|
|
48
web-app/package-lock.json
generated
48
web-app/package-lock.json
generated
|
@ -15,6 +15,8 @@
|
|||
"prop-types": "^15.8.1",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"react-icons": "^5.3.0",
|
||||
"react-router-dom": "^6.27.0",
|
||||
"react-tag-autocomplete": "^7.3.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
@ -843,6 +845,14 @@
|
|||
"react": "^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@remix-run/router": {
|
||||
"version": "1.20.0",
|
||||
"resolved": "https://registry.npmjs.org/@remix-run/router/-/router-1.20.0.tgz",
|
||||
"integrity": "sha512-mUnk8rPJBI9loFDZ+YzPGdeniYK+FTmRD1TMCz7ev2SNIozyKKpnGgsxO34u6Z4z/t0ITuu7voi/AshfsGsgFg==",
|
||||
"engines": {
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@rollup/rollup-android-arm-eabi": {
|
||||
"version": "4.22.4",
|
||||
"resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.22.4.tgz",
|
||||
|
@ -3298,11 +3308,49 @@
|
|||
"react": "^18.3.1"
|
||||
}
|
||||
},
|
||||
"node_modules/react-icons": {
|
||||
"version": "5.3.0",
|
||||
"resolved": "https://registry.npmjs.org/react-icons/-/react-icons-5.3.0.tgz",
|
||||
"integrity": "sha512-DnUk8aFbTyQPSkCfF8dbX6kQjXA9DktMeJqfjrg6cK9vwQVMxmcA3BfP4QoiztVmEHtwlTgLFsPuH2NskKT6eg==",
|
||||
"peerDependencies": {
|
||||
"react": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/react-is": {
|
||||
"version": "16.13.1",
|
||||
"resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz",
|
||||
"integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ=="
|
||||
},
|
||||
"node_modules/react-router": {
|
||||
"version": "6.27.0",
|
||||
"resolved": "https://registry.npmjs.org/react-router/-/react-router-6.27.0.tgz",
|
||||
"integrity": "sha512-YA+HGZXz4jaAkVoYBE98VQl+nVzI+cVI2Oj/06F5ZM+0u3TgedN9Y9kmMRo2mnkSK2nCpNQn0DVob4HCsY/WLw==",
|
||||
"dependencies": {
|
||||
"@remix-run/router": "1.20.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14.0.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": ">=16.8"
|
||||
}
|
||||
},
|
||||
"node_modules/react-router-dom": {
|
||||
"version": "6.27.0",
|
||||
"resolved": "https://registry.npmjs.org/react-router-dom/-/react-router-dom-6.27.0.tgz",
|
||||
"integrity": "sha512-+bvtFWMC0DgAFrfKXKG9Fc+BcXWRUO1aJIihbB79xaeq0v5UzfvnM5houGUm1Y461WVRcgAQ+Clh5rdb1eCx4g==",
|
||||
"dependencies": {
|
||||
"@remix-run/router": "1.20.0",
|
||||
"react-router": "6.27.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14.0.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": ">=16.8",
|
||||
"react-dom": ">=16.8"
|
||||
}
|
||||
},
|
||||
"node_modules/react-tag-autocomplete": {
|
||||
"version": "7.3.0",
|
||||
"resolved": "https://registry.npmjs.org/react-tag-autocomplete/-/react-tag-autocomplete-7.3.0.tgz",
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
"prop-types": "^15.8.1",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"react-icons": "^5.3.0",
|
||||
"react-router-dom": "^6.27.0",
|
||||
"react-tag-autocomplete": "^7.3.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
|
|
@ -1,12 +1,17 @@
|
|||
// App.tsx or App.jsx
|
||||
import React from 'react';
|
||||
import { Route, BrowserRouter as Router, Routes } from 'react-router-dom';
|
||||
import DocumentProcessor from './DocumentProcessor';
|
||||
import './index.css';
|
||||
import ExperimentalOCR from './ExperimentalOCR'; // New component
|
||||
|
||||
const App: React.FC = () => {
|
||||
return (
|
||||
<div className="App">
|
||||
<DocumentProcessor />
|
||||
</div>
|
||||
<Router>
|
||||
<Routes>
|
||||
<Route path="/" element={<DocumentProcessor />} />
|
||||
<Route path="/experimental-ocr" element={<ExperimentalOCR />} />
|
||||
</Routes>
|
||||
</Router>
|
||||
);
|
||||
};
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import axios from "axios";
|
||||
import React, { useCallback, useEffect, useState } from "react";
|
||||
import { Link } from "react-router-dom";
|
||||
import "react-tag-autocomplete/example/src/styles.css"; // Ensure styles are loaded
|
||||
import DocumentsToProcess from "./components/DocumentsToProcess";
|
||||
import NoDocuments from "./components/NoDocuments";
|
||||
|
@ -24,6 +25,7 @@ export interface DocumentSuggestion {
|
|||
original_document: Document;
|
||||
suggested_title?: string;
|
||||
suggested_tags?: string[];
|
||||
suggested_content?: string;
|
||||
}
|
||||
|
||||
export interface TagOption {
|
||||
|
@ -44,17 +46,22 @@ const DocumentProcessor: React.FC = () => {
|
|||
const [generateTags, setGenerateTags] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
// Temporary feature flags
|
||||
const [ocrEnabled, setOcrEnabled] = useState(false);
|
||||
|
||||
// Custom hook to fetch initial data
|
||||
const fetchInitialData = useCallback(async () => {
|
||||
try {
|
||||
const [filterTagRes, documentsRes, tagsRes] = await Promise.all([
|
||||
const [filterTagRes, documentsRes, tagsRes, ocrEnabledRes] = await Promise.all([
|
||||
axios.get<{ tag: string }>("/api/filter-tag"),
|
||||
axios.get<Document[]>("/api/documents"),
|
||||
axios.get<Record<string, number>>("/api/tags"),
|
||||
axios.get<{enabled: boolean}>("/api/experimental/ocr"),
|
||||
]);
|
||||
|
||||
setFilterTag(filterTagRes.data.tag);
|
||||
setDocuments(documentsRes.data);
|
||||
setOcrEnabled(ocrEnabledRes.data.enabled);
|
||||
const tags = Object.keys(tagsRes.data).map((tag) => ({
|
||||
id: tag,
|
||||
name: tag,
|
||||
|
@ -129,9 +136,7 @@ const DocumentProcessor: React.FC = () => {
|
|||
doc.id === docId
|
||||
? {
|
||||
...doc,
|
||||
suggested_tags: doc.suggested_tags?.filter(
|
||||
(_, i) => i !== index
|
||||
),
|
||||
suggested_tags: doc.suggested_tags?.filter((_, i) => i !== index),
|
||||
}
|
||||
: doc
|
||||
)
|
||||
|
@ -141,9 +146,7 @@ const DocumentProcessor: React.FC = () => {
|
|||
const handleTitleChange = (docId: number, title: string) => {
|
||||
setSuggestions((prevSuggestions) =>
|
||||
prevSuggestions.map((doc) =>
|
||||
doc.id === docId
|
||||
? { ...doc, suggested_title: title }
|
||||
: doc
|
||||
doc.id === docId ? { ...doc, suggested_title: title } : doc
|
||||
)
|
||||
);
|
||||
};
|
||||
|
@ -182,11 +185,12 @@ const DocumentProcessor: React.FC = () => {
|
|||
}
|
||||
}, [documents]);
|
||||
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center min-h-screen bg-white dark:bg-gray-900">
|
||||
<div className="text-xl font-semibold text-gray-800 dark:text-gray-200">Loading documents...</div>
|
||||
<div className="text-xl font-semibold text-gray-800 dark:text-gray-200">
|
||||
Loading documents...
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
@ -195,6 +199,16 @@ const DocumentProcessor: React.FC = () => {
|
|||
<div className="max-w-5xl mx-auto p-6 bg-white dark:bg-gray-900 text-gray-800 dark:text-gray-200">
|
||||
<header className="text-center">
|
||||
<h1 className="text-4xl font-bold mb-8">Paperless GPT</h1>
|
||||
{ocrEnabled && (
|
||||
<div>
|
||||
<Link
|
||||
to="/experimental-ocr"
|
||||
className="inline-block bg-blue-600 hover:bg-blue-700 text-white font-semibold py-2 px-4 rounded transition duration-200 dark:bg-blue-500 dark:hover:bg-blue-600"
|
||||
>
|
||||
OCR via LLMs (Experimental)
|
||||
</Link>
|
||||
</div>
|
||||
)}
|
||||
</header>
|
||||
|
||||
{error && (
|
||||
|
|
190
web-app/src/ExperimentalOCR.tsx
Normal file
190
web-app/src/ExperimentalOCR.tsx
Normal file
|
@ -0,0 +1,190 @@
|
|||
import axios from 'axios';
|
||||
import React, { useCallback, useEffect, useState } from 'react';
|
||||
import { FaSpinner } from 'react-icons/fa';
|
||||
import { Document, DocumentSuggestion } from './DocumentProcessor';
|
||||
|
||||
const ExperimentalOCR: React.FC = () => {
|
||||
const refreshInterval = 1000; // Refresh interval in milliseconds
|
||||
const [documentId, setDocumentId] = useState(0);
|
||||
const [jobId, setJobId] = useState('');
|
||||
const [ocrResult, setOcrResult] = useState('');
|
||||
const [status, setStatus] = useState('');
|
||||
const [error, setError] = useState<string | null>('');
|
||||
const [pagesDone, setPagesDone] = useState(0); // New state for pages done
|
||||
const [saving, setSaving] = useState(false); // New state for saving
|
||||
const [documentDetails, setDocumentDetails] = useState<Document | null>(null); // New state for document details
|
||||
|
||||
const submitOCRJob = async () => {
|
||||
setStatus('');
|
||||
setError('');
|
||||
setJobId('');
|
||||
setOcrResult('');
|
||||
setPagesDone(0); // Reset pages done
|
||||
|
||||
try {
|
||||
setStatus('Submitting OCR job...');
|
||||
const response = await axios.post(`/api/documents/${documentId}/ocr`);
|
||||
setJobId(response.data.job_id);
|
||||
setStatus('Job submitted. Processing...');
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
setError('Failed to submit OCR job.');
|
||||
}
|
||||
};
|
||||
|
||||
const checkJobStatus = async () => {
|
||||
if (!jobId) return;
|
||||
|
||||
try {
|
||||
const response = await axios.get(`/api/jobs/ocr/${jobId}`);
|
||||
const jobStatus = response.data.status;
|
||||
setPagesDone(response.data.pages_done); // Update pages done
|
||||
if (jobStatus === 'completed') {
|
||||
setOcrResult(response.data.result);
|
||||
setStatus('OCR completed successfully.');
|
||||
} else if (jobStatus === 'failed') {
|
||||
setError(response.data.error);
|
||||
setStatus('OCR failed.');
|
||||
} else {
|
||||
setStatus(`Job status: ${jobStatus}. This may take a few minutes.`);
|
||||
// Automatically check again after a delay
|
||||
setTimeout(checkJobStatus, refreshInterval);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
setError('Failed to check job status.');
|
||||
}
|
||||
};
|
||||
|
||||
const handleSaveContent = async () => {
|
||||
setSaving(true);
|
||||
setError(null);
|
||||
try {
|
||||
if (!documentDetails) {
|
||||
setError('Document details not fetched.');
|
||||
throw new Error('Document details not fetched.');
|
||||
}
|
||||
const requestPayload: DocumentSuggestion = {
|
||||
id: documentId,
|
||||
original_document: documentDetails, // Use fetched document details
|
||||
suggested_content: ocrResult,
|
||||
};
|
||||
|
||||
await axios.patch("/api/update-documents", [requestPayload]);
|
||||
setStatus('Content saved successfully.');
|
||||
} catch (err) {
|
||||
console.error("Error saving content:", err);
|
||||
setError("Failed to save content.");
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
const fetchDocumentDetails = useCallback(async () => {
|
||||
if (!documentId) return;
|
||||
|
||||
try {
|
||||
const response = await axios.get<Document>(`/api/documents/${documentId}`);
|
||||
setDocumentDetails(response.data);
|
||||
} catch (err) {
|
||||
console.error("Error fetching document details:", err);
|
||||
setError("Failed to fetch document details.");
|
||||
}
|
||||
}, [documentId]);
|
||||
|
||||
// Fetch document details when documentId changes
|
||||
useEffect(() => {
|
||||
fetchDocumentDetails();
|
||||
}, [documentId, fetchDocumentDetails]);
|
||||
|
||||
// Start checking job status when jobId is set
|
||||
useEffect(() => {
|
||||
if (jobId) {
|
||||
checkJobStatus();
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [jobId]);
|
||||
|
||||
return (
|
||||
<div className="max-w-3xl mx-auto p-6 bg-white dark:bg-gray-900 text-gray-800 dark:text-gray-200">
|
||||
<h1 className="text-4xl font-bold mb-6 text-center">OCR via LLMs (Experimental)</h1>
|
||||
<p className="mb-6 text-center text-yellow-600">
|
||||
This is an experimental feature. Results may vary, and processing may take some time.
|
||||
</p>
|
||||
<div className="bg-gray-100 dark:bg-gray-800 p-6 rounded-lg shadow-md">
|
||||
<div className="mb-4">
|
||||
<label htmlFor="documentId" className="block mb-2 font-semibold">
|
||||
Document ID:
|
||||
</label>
|
||||
<input
|
||||
type="number"
|
||||
id="documentId"
|
||||
value={documentId}
|
||||
onChange={(e) => setDocumentId(Number(e.target.value))}
|
||||
className="border border-gray-300 dark:border-gray-700 rounded w-full p-2 focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||
placeholder="Enter the document ID"
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
onClick={submitOCRJob}
|
||||
className="w-full bg-blue-600 hover:bg-blue-700 text-white font-semibold py-2 px-4 rounded transition duration-200"
|
||||
disabled={!documentId}
|
||||
>
|
||||
{status.startsWith('Submitting') ? (
|
||||
<span className="flex items-center justify-center">
|
||||
<FaSpinner className="animate-spin mr-2" />
|
||||
Submitting...
|
||||
</span>
|
||||
) : (
|
||||
'Submit OCR Job'
|
||||
)}
|
||||
</button>
|
||||
{status && (
|
||||
<div className="mt-4 text-center text-gray-700 dark:text-gray-300">
|
||||
{status.includes('in_progress') && (
|
||||
<span className="flex items-center justify-center">
|
||||
<FaSpinner className="animate-spin mr-2" />
|
||||
{status}
|
||||
</span>
|
||||
)}
|
||||
{!status.includes('in_progress') && status}
|
||||
{pagesDone > 0 && (
|
||||
<div className="mt-2">
|
||||
Pages processed: {pagesDone}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
{error && (
|
||||
<div className="mt-4 p-4 bg-red-100 dark:bg-red-800 text-red-700 dark:text-red-200 rounded">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
{ocrResult && (
|
||||
<div className="mt-6">
|
||||
<h2 className="text-2xl font-bold mb-4">OCR Result:</h2>
|
||||
<div className="bg-gray-50 dark:bg-gray-900 p-4 rounded border border-gray-200 dark:border-gray-700 overflow-auto max-h-96">
|
||||
<pre className="whitespace-pre-wrap">{ocrResult}</pre>
|
||||
</div>
|
||||
<button
|
||||
onClick={handleSaveContent}
|
||||
className="w-full bg-green-600 hover:bg-green-700 text-white font-semibold py-2 px-4 rounded transition duration-200 mt-4"
|
||||
disabled={saving}
|
||||
>
|
||||
{saving ? (
|
||||
<span className="flex items-center justify-center">
|
||||
<FaSpinner className="animate-spin mr-2" />
|
||||
Saving...
|
||||
</span>
|
||||
) : (
|
||||
'Save Content'
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ExperimentalOCR;
|
|
@ -1 +1 @@
|
|||
{"root":["./src/app.tsx","./src/documentprocessor.tsx","./src/main.tsx","./src/vite-env.d.ts","./src/components/documentcard.tsx","./src/components/documentstoprocess.tsx","./src/components/nodocuments.tsx","./src/components/successmodal.tsx","./src/components/suggestioncard.tsx","./src/components/suggestionsreview.tsx"],"version":"5.6.2"}
|
||||
{"root":["./src/app.tsx","./src/documentprocessor.tsx","./src/experimentalocr.tsx","./src/main.tsx","./src/vite-env.d.ts","./src/components/documentcard.tsx","./src/components/documentstoprocess.tsx","./src/components/nodocuments.tsx","./src/components/successmodal.tsx","./src/components/suggestioncard.tsx","./src/components/suggestionsreview.tsx"],"version":"5.6.2"}
|
Loading…
Reference in a new issue