Early working UI protyping + jobs engine

This commit is contained in:
Dominik Schröter 2024-10-28 10:36:13 +01:00
parent b607815803
commit f65241fc8b
12 changed files with 814 additions and 389 deletions

1
.gitignore vendored
View file

@ -2,3 +2,4 @@
.DS_Store
prompts/
tests/tmp
tmp/

225
app_http_handlers.go Normal file
View file

@ -0,0 +1,225 @@
package main
import (
"fmt"
"log"
"net/http"
"os"
"strconv"
"text/template"
"time"
"github.com/gin-gonic/gin"
)
// getPromptsHandler handles the GET /api/prompts endpoint
func getPromptsHandler(c *gin.Context) {
templateMutex.RLock()
defer templateMutex.RUnlock()
// Read the templates from files or use default content
titleTemplateContent, err := os.ReadFile("prompts/title_prompt.tmpl")
if err != nil {
titleTemplateContent = []byte(defaultTitleTemplate)
}
tagTemplateContent, err := os.ReadFile("prompts/tag_prompt.tmpl")
if err != nil {
tagTemplateContent = []byte(defaultTagTemplate)
}
c.JSON(http.StatusOK, gin.H{
"title_template": string(titleTemplateContent),
"tag_template": string(tagTemplateContent),
})
}
// updatePromptsHandler handles the POST /api/prompts endpoint
func updatePromptsHandler(c *gin.Context) {
var req struct {
TitleTemplate string `json:"title_template"`
TagTemplate string `json:"tag_template"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request payload"})
return
}
templateMutex.Lock()
defer templateMutex.Unlock()
// Update title template
if req.TitleTemplate != "" {
t, err := template.New("title").Parse(req.TitleTemplate)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid title template: %v", err)})
return
}
titleTemplate = t
err = os.WriteFile("prompts/title_prompt.tmpl", []byte(req.TitleTemplate), 0644)
if err != nil {
log.Printf("Failed to write title_prompt.tmpl: %v", err)
}
}
// Update tag template
if req.TagTemplate != "" {
t, err := template.New("tag").Parse(req.TagTemplate)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid tag template: %v", err)})
return
}
tagTemplate = t
err = os.WriteFile("prompts/tag_prompt.tmpl", []byte(req.TagTemplate), 0644)
if err != nil {
log.Printf("Failed to write tag_prompt.tmpl: %v", err)
}
}
c.Status(http.StatusOK)
}
// getAllTagsHandler handles the GET /api/tags endpoint
func (app *App) getAllTagsHandler(c *gin.Context) {
ctx := c.Request.Context()
tags, err := app.Client.GetAllTags(ctx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error fetching tags: %v", err)})
log.Printf("Error fetching tags: %v", err)
return
}
c.JSON(http.StatusOK, tags)
}
// documentsHandler handles the GET /api/documents endpoint
func (app *App) documentsHandler(c *gin.Context) {
ctx := c.Request.Context()
documents, err := app.Client.GetDocumentsByTags(ctx, []string{manualTag})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error fetching documents: %v", err)})
log.Printf("Error fetching documents: %v", err)
return
}
c.JSON(http.StatusOK, documents)
}
// generateSuggestionsHandler handles the POST /api/generate-suggestions endpoint
func (app *App) generateSuggestionsHandler(c *gin.Context) {
ctx := c.Request.Context()
var suggestionRequest GenerateSuggestionsRequest
if err := c.ShouldBindJSON(&suggestionRequest); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request payload: %v", err)})
log.Printf("Invalid request payload: %v", err)
return
}
results, err := app.generateDocumentSuggestions(ctx, suggestionRequest)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error processing documents: %v", err)})
log.Printf("Error processing documents: %v", err)
return
}
c.JSON(http.StatusOK, results)
}
// updateDocumentsHandler handles the PATCH /api/update-documents endpoint
func (app *App) updateDocumentsHandler(c *gin.Context) {
ctx := c.Request.Context()
var documents []DocumentSuggestion
if err := c.ShouldBindJSON(&documents); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request payload: %v", err)})
log.Printf("Invalid request payload: %v", err)
return
}
err := app.Client.UpdateDocuments(ctx, documents)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error updating documents: %v", err)})
log.Printf("Error updating documents: %v", err)
return
}
c.Status(http.StatusOK)
}
func (app *App) submitOCRJobHandler(c *gin.Context) {
documentIDStr := c.Param("id")
documentID, err := strconv.Atoi(documentIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
// Create a new job
jobID := generateJobID() // Implement a function to generate unique job IDs
job := &Job{
ID: jobID,
DocumentID: documentID,
Status: "pending",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// Add job to store and queue
jobStore.addJob(job)
jobQueue <- job
// Return the job ID to the client
c.JSON(http.StatusAccepted, gin.H{"job_id": jobID})
}
func (app *App) getJobStatusHandler(c *gin.Context) {
jobID := c.Param("job_id")
job, exists := jobStore.getJob(jobID)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "Job not found"})
return
}
response := gin.H{
"job_id": job.ID,
"status": job.Status,
"created_at": job.CreatedAt,
"updated_at": job.UpdatedAt,
}
if job.Status == "completed" {
response["result"] = job.Result
} else if job.Status == "failed" {
response["error"] = job.Result
}
c.JSON(http.StatusOK, response)
}
func (app *App) getAllJobsHandler(c *gin.Context) {
jobs := jobStore.GetAllJobs()
jobList := make([]gin.H, 0, len(jobs))
for _, job := range jobs {
response := gin.H{
"job_id": job.ID,
"status": job.Status,
"created_at": job.CreatedAt,
"updated_at": job.UpdatedAt,
}
if job.Status == "completed" {
response["result"] = job.Result
} else if job.Status == "failed" {
response["error"] = job.Result
}
jobList = append(jobList, response)
}
c.JSON(http.StatusOK, jobList)
}

218
app_llm.go Normal file
View file

@ -0,0 +1,218 @@
package main
import (
"bytes"
"context"
"fmt"
"log"
"strings"
"sync"
"github.com/tmc/langchaingo/llms"
)
// getSuggestedTags generates suggested tags for a document using the LLM
func (app *App) getSuggestedTags(ctx context.Context, content string, suggestedTitle string, availableTags []string) ([]string, error) {
likelyLanguage := getLikelyLanguage()
templateMutex.RLock()
defer templateMutex.RUnlock()
var promptBuffer bytes.Buffer
err := tagTemplate.Execute(&promptBuffer, map[string]interface{}{
"Language": likelyLanguage,
"AvailableTags": availableTags,
"Title": suggestedTitle,
"Content": content,
})
if err != nil {
return nil, fmt.Errorf("error executing tag template: %v", err)
}
prompt := promptBuffer.String()
log.Printf("Tag suggestion prompt: %s", prompt)
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
{
Parts: []llms.ContentPart{
llms.TextContent{
Text: prompt,
},
},
Role: llms.ChatMessageTypeHuman,
},
})
if err != nil {
return nil, fmt.Errorf("error getting response from LLM: %v", err)
}
response := strings.TrimSpace(completion.Choices[0].Content)
suggestedTags := strings.Split(response, ",")
for i, tag := range suggestedTags {
suggestedTags[i] = strings.TrimSpace(tag)
}
// Filter out tags that are not in the available tags list
filteredTags := []string{}
for _, tag := range suggestedTags {
for _, availableTag := range availableTags {
if strings.EqualFold(tag, availableTag) {
filteredTags = append(filteredTags, availableTag)
break
}
}
}
return filteredTags, nil
}
func (app *App) doOCRViaLLM(ctx context.Context, jpegBytes []byte) (string, error) {
// Convert the image to text
completion, err := app.VisionLLM.GenerateContent(ctx, []llms.MessageContent{
{
Parts: []llms.ContentPart{
llms.BinaryPart("image/jpeg", jpegBytes),
llms.TextPart("Just transcribe the text in this image and preserve the formatting and layout (high quality OCR). Do that for ALL the text in the image. Be thorough and pay attention. This is very important. The image is from a text document so be sure to continue until the bottom of the page. Thanks a lot! You tend to forget about some text in the image so please focus! Use markdown format."),
},
Role: llms.ChatMessageTypeHuman,
},
})
if err != nil {
return "", fmt.Errorf("error getting response from LLM: %v", err)
}
result := completion.Choices[0].Content
fmt.Println(result)
return result, nil
}
// getSuggestedTitle generates a suggested title for a document using the LLM
func (app *App) getSuggestedTitle(ctx context.Context, content string) (string, error) {
likelyLanguage := getLikelyLanguage()
templateMutex.RLock()
defer templateMutex.RUnlock()
var promptBuffer bytes.Buffer
err := titleTemplate.Execute(&promptBuffer, map[string]interface{}{
"Language": likelyLanguage,
"Content": content,
})
if err != nil {
return "", fmt.Errorf("error executing title template: %v", err)
}
prompt := promptBuffer.String()
log.Printf("Title suggestion prompt: %s", prompt)
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
{
Parts: []llms.ContentPart{
llms.TextContent{
Text: prompt,
},
},
Role: llms.ChatMessageTypeHuman,
},
})
if err != nil {
return "", fmt.Errorf("error getting response from LLM: %v", err)
}
return strings.TrimSpace(strings.Trim(completion.Choices[0].Content, "\"")), nil
}
// generateDocumentSuggestions generates suggestions for a set of documents
func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionRequest GenerateSuggestionsRequest) ([]DocumentSuggestion, error) {
// Fetch all available tags from paperless-ngx
availableTagsMap, err := app.Client.GetAllTags(ctx)
if err != nil {
return nil, fmt.Errorf("failed to fetch available tags: %v", err)
}
// Prepare a list of tag names
availableTagNames := make([]string, 0, len(availableTagsMap))
for tagName := range availableTagsMap {
if tagName == manualTag {
continue
}
availableTagNames = append(availableTagNames, tagName)
}
documents := suggestionRequest.Documents
documentSuggestions := []DocumentSuggestion{}
var wg sync.WaitGroup
var mu sync.Mutex
errorsList := make([]error, 0)
for i := range documents {
wg.Add(1)
go func(doc Document) {
defer wg.Done()
documentID := doc.ID
log.Printf("Processing Document ID %d...", documentID)
content := doc.Content
if len(content) > 5000 {
content = content[:5000]
}
var suggestedTitle string
var suggestedTags []string
if suggestionRequest.GenerateTitles {
suggestedTitle, err = app.getSuggestedTitle(ctx, content)
if err != nil {
mu.Lock()
errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err))
mu.Unlock()
log.Printf("Error processing document %d: %v", documentID, err)
return
}
}
if suggestionRequest.GenerateTags {
suggestedTags, err = app.getSuggestedTags(ctx, content, suggestedTitle, availableTagNames)
if err != nil {
mu.Lock()
errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err))
mu.Unlock()
log.Printf("Error generating tags for document %d: %v", documentID, err)
return
}
}
mu.Lock()
suggestion := DocumentSuggestion{
ID: documentID,
OriginalDocument: doc,
}
// Titles
if suggestionRequest.GenerateTitles {
suggestion.SuggestedTitle = suggestedTitle
} else {
suggestion.SuggestedTitle = doc.Title
}
// Tags
if suggestionRequest.GenerateTags {
suggestion.SuggestedTags = suggestedTags
} else {
suggestion.SuggestedTags = removeTagFromList(doc.Tags, manualTag)
}
documentSuggestions = append(documentSuggestions, suggestion)
mu.Unlock()
log.Printf("Document %d processed successfully.", documentID)
}(documents[i])
}
wg.Wait()
if len(errorsList) > 0 {
return nil, errorsList[0] // Return the first error encountered
}
return documentSuggestions, nil
}

137
jobs.go Normal file
View file

@ -0,0 +1,137 @@
package main
import (
"context"
"fmt"
"log"
"os"
"sort"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
// Job represents an OCR job
type Job struct {
ID string
DocumentID int
Status string // "pending", "in_progress", "completed", "failed"
Result string // OCR result or error message
CreatedAt time.Time
UpdatedAt time.Time
}
// JobStore manages jobs and their statuses
type JobStore struct {
sync.RWMutex
jobs map[string]*Job
}
var (
jobStore = &JobStore{
jobs: make(map[string]*Job),
}
jobQueue = make(chan *Job, 100) // Buffered channel with capacity of 100 jobs
logger = log.New(os.Stdout, "OCR_JOB: ", log.LstdFlags)
)
func generateJobID() string {
return uuid.New().String()
}
func (store *JobStore) addJob(job *Job) {
store.Lock()
defer store.Unlock()
store.jobs[job.ID] = job
logger.Printf("Job added: %v", job)
}
func (store *JobStore) getJob(jobID string) (*Job, bool) {
store.RLock()
defer store.RUnlock()
job, exists := store.jobs[jobID]
return job, exists
}
func (store *JobStore) GetAllJobs() []*Job {
store.RLock()
defer store.RUnlock()
jobs := make([]*Job, 0, len(store.jobs))
for _, job := range store.jobs {
jobs = append(jobs, job)
}
sort.Slice(jobs, func(i, j int) bool {
return jobs[i].CreatedAt.After(jobs[j].CreatedAt)
})
return jobs
}
func (store *JobStore) updateJobStatus(jobID, status, result string) {
store.Lock()
defer store.Unlock()
if job, exists := store.jobs[jobID]; exists {
job.Status = status
if result != "" {
job.Result = result
}
job.UpdatedAt = time.Now()
logger.Printf("Job status updated: %v", job)
}
}
func startWorkerPool(app *App, numWorkers int) {
for i := 0; i < numWorkers; i++ {
go func(workerID int) {
logger.Printf("Worker %d started", workerID)
for job := range jobQueue {
logger.Printf("Worker %d processing job: %s", workerID, job.ID)
processJob(app, job)
}
}(i)
}
}
func processJob(app *App, job *Job) {
jobStore.updateJobStatus(job.ID, "in_progress", "")
ctx := context.Background()
// Download images of the document
imagePaths, err := app.Client.DownloadDocumentAsImages(ctx, job.DocumentID)
if err != nil {
logger.Printf("Error downloading document images for job %s: %v", job.ID, err)
jobStore.updateJobStatus(job.ID, "failed", fmt.Sprintf("Error downloading document images: %v", err))
return
}
var ocrTexts []string
for _, imagePath := range imagePaths {
imageContent, err := os.ReadFile(imagePath)
if err != nil {
logger.Printf("Error reading image file for job %s: %v", job.ID, err)
jobStore.updateJobStatus(job.ID, "failed", fmt.Sprintf("Error reading image file: %v", err))
return
}
ocrText, err := app.doOCRViaLLM(ctx, imageContent)
if err != nil {
logger.Printf("Error performing OCR for job %s: %v", job.ID, err)
jobStore.updateJobStatus(job.ID, "failed", fmt.Sprintf("Error performing OCR: %v", err))
return
}
ocrTexts = append(ocrTexts, ocrText)
}
// Combine the OCR texts
fullOcrText := strings.Join(ocrTexts, "\n\n")
// Update job status and result
jobStore.updateJobStatus(job.ID, "completed", fullOcrText)
logger.Printf("Job completed: %s", job.ID)
}

376
main.go
View file

@ -1,7 +1,6 @@
package main
import (
"bytes"
"context"
"fmt"
"log"
@ -138,6 +137,11 @@ func main() {
api.GET("/tags", app.getAllTagsHandler)
api.GET("/prompts", getPromptsHandler)
api.POST("/prompts", updatePromptsHandler)
// OCR endpoints
api.POST("/documents/:id/ocr", app.submitOCRJobHandler)
api.GET("/jobs/ocr/:job_id", app.getJobStatusHandler)
api.GET("/jobs/ocr", app.getAllJobsHandler)
}
// Serve static files for the frontend under /assets
@ -149,26 +153,13 @@ func main() {
c.File("./web-app/dist/index.html")
})
// log.Println("Server started on port :8080")
// if err := router.Run(":8080"); err != nil {
// log.Fatalf("Failed to run server: %v", err)
// }
images, err := client.DownloadDocumentAsImages(context.Background(), Document{
// Insert the document ID here to test OCR
ID: 531,
})
if err != nil {
log.Fatalf("Failed to download document: %v", err)
}
for _, image := range images {
content, err := os.ReadFile(image)
if err != nil {
log.Fatalf("Failed to read image: %v", err)
}
_, err = app.doOCRViaLLM(context.Background(), content)
if err != nil {
log.Fatalf("Failed to OCR image: %v", err)
}
// Start OCR worker pool
numWorkers := 1 // Number of workers to start
startWorkerPool(app, numWorkers)
log.Println("Server started on port :8080")
if err := router.Run(":8080"); err != nil {
log.Fatalf("Failed to run server: %v", err)
}
}
@ -227,169 +218,6 @@ func (app *App) processAutoTagDocuments() error {
return nil
}
// getAllTagsHandler handles the GET /api/tags endpoint
func (app *App) getAllTagsHandler(c *gin.Context) {
ctx := c.Request.Context()
tags, err := app.Client.GetAllTags(ctx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error fetching tags: %v", err)})
log.Printf("Error fetching tags: %v", err)
return
}
c.JSON(http.StatusOK, tags)
}
// documentsHandler handles the GET /api/documents endpoint
func (app *App) documentsHandler(c *gin.Context) {
ctx := c.Request.Context()
documents, err := app.Client.GetDocumentsByTags(ctx, []string{manualTag})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error fetching documents: %v", err)})
log.Printf("Error fetching documents: %v", err)
return
}
c.JSON(http.StatusOK, documents)
}
// generateSuggestionsHandler handles the POST /api/generate-suggestions endpoint
func (app *App) generateSuggestionsHandler(c *gin.Context) {
ctx := c.Request.Context()
var suggestionRequest GenerateSuggestionsRequest
if err := c.ShouldBindJSON(&suggestionRequest); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request payload: %v", err)})
log.Printf("Invalid request payload: %v", err)
return
}
results, err := app.generateDocumentSuggestions(ctx, suggestionRequest)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error processing documents: %v", err)})
log.Printf("Error processing documents: %v", err)
return
}
c.JSON(http.StatusOK, results)
}
// updateDocumentsHandler handles the PATCH /api/update-documents endpoint
func (app *App) updateDocumentsHandler(c *gin.Context) {
ctx := c.Request.Context()
var documents []DocumentSuggestion
if err := c.ShouldBindJSON(&documents); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request payload: %v", err)})
log.Printf("Invalid request payload: %v", err)
return
}
err := app.Client.UpdateDocuments(ctx, documents)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error updating documents: %v", err)})
log.Printf("Error updating documents: %v", err)
return
}
c.Status(http.StatusOK)
}
// generateDocumentSuggestions generates suggestions for a set of documents
func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionRequest GenerateSuggestionsRequest) ([]DocumentSuggestion, error) {
// Fetch all available tags from paperless-ngx
availableTagsMap, err := app.Client.GetAllTags(ctx)
if err != nil {
return nil, fmt.Errorf("failed to fetch available tags: %v", err)
}
// Prepare a list of tag names
availableTagNames := make([]string, 0, len(availableTagsMap))
for tagName := range availableTagsMap {
if tagName == manualTag {
continue
}
availableTagNames = append(availableTagNames, tagName)
}
documents := suggestionRequest.Documents
documentSuggestions := []DocumentSuggestion{}
var wg sync.WaitGroup
var mu sync.Mutex
errorsList := make([]error, 0)
for i := range documents {
wg.Add(1)
go func(doc Document) {
defer wg.Done()
documentID := doc.ID
log.Printf("Processing Document ID %d...", documentID)
content := doc.Content
if len(content) > 5000 {
content = content[:5000]
}
var suggestedTitle string
var suggestedTags []string
if suggestionRequest.GenerateTitles {
suggestedTitle, err = app.getSuggestedTitle(ctx, content)
if err != nil {
mu.Lock()
errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err))
mu.Unlock()
log.Printf("Error processing document %d: %v", documentID, err)
return
}
}
if suggestionRequest.GenerateTags {
suggestedTags, err = app.getSuggestedTags(ctx, content, suggestedTitle, availableTagNames)
if err != nil {
mu.Lock()
errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err))
mu.Unlock()
log.Printf("Error generating tags for document %d: %v", documentID, err)
return
}
}
mu.Lock()
suggestion := DocumentSuggestion{
ID: documentID,
OriginalDocument: doc,
}
// Titles
if suggestionRequest.GenerateTitles {
suggestion.SuggestedTitle = suggestedTitle
} else {
suggestion.SuggestedTitle = doc.Title
}
// Tags
if suggestionRequest.GenerateTags {
suggestion.SuggestedTags = suggestedTags
} else {
suggestion.SuggestedTags = removeTagFromList(doc.Tags, manualTag)
}
documentSuggestions = append(documentSuggestions, suggestion)
mu.Unlock()
log.Printf("Document %d processed successfully.", documentID)
}(documents[i])
}
wg.Wait()
if len(errorsList) > 0 {
return nil, errorsList[0] // Return the first error encountered
}
return documentSuggestions, nil
}
// removeTagFromList removes a specific tag from a list of tags
func removeTagFromList(tags []string, tagToRemove string) []string {
filteredTags := []string{}
@ -401,61 +229,6 @@ func removeTagFromList(tags []string, tagToRemove string) []string {
return filteredTags
}
// getSuggestedTags generates suggested tags for a document using the LLM
func (app *App) getSuggestedTags(ctx context.Context, content string, suggestedTitle string, availableTags []string) ([]string, error) {
likelyLanguage := getLikelyLanguage()
templateMutex.RLock()
defer templateMutex.RUnlock()
var promptBuffer bytes.Buffer
err := tagTemplate.Execute(&promptBuffer, map[string]interface{}{
"Language": likelyLanguage,
"AvailableTags": availableTags,
"Title": suggestedTitle,
"Content": content,
})
if err != nil {
return nil, fmt.Errorf("error executing tag template: %v", err)
}
prompt := promptBuffer.String()
log.Printf("Tag suggestion prompt: %s", prompt)
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
{
Parts: []llms.ContentPart{
llms.TextContent{
Text: prompt,
},
},
Role: llms.ChatMessageTypeHuman,
},
})
if err != nil {
return nil, fmt.Errorf("error getting response from LLM: %v", err)
}
response := strings.TrimSpace(completion.Choices[0].Content)
suggestedTags := strings.Split(response, ",")
for i, tag := range suggestedTags {
suggestedTags[i] = strings.TrimSpace(tag)
}
// Filter out tags that are not in the available tags list
filteredTags := []string{}
for _, tag := range suggestedTags {
for _, availableTag := range availableTags {
if strings.EqualFold(tag, availableTag) {
filteredTags = append(filteredTags, availableTag)
break
}
}
}
return filteredTags, nil
}
// getLikelyLanguage determines the likely language of the document content
func getLikelyLanguage() string {
likelyLanguage := os.Getenv("LLM_LANGUAGE")
@ -465,63 +238,6 @@ func getLikelyLanguage() string {
return strings.Title(strings.ToLower(likelyLanguage))
}
func (app *App) doOCRViaLLM(ctx context.Context, jpegBytes []byte) (string, error) {
// Convert the image to text
completion, err := app.VisionLLM.GenerateContent(ctx, []llms.MessageContent{
{
Parts: []llms.ContentPart{
llms.BinaryPart("image/jpeg", jpegBytes),
llms.TextPart("Just transcribe the text in this image and preserve the formatting and layout (high quality OCR). Do that for ALL the text in the image. Be thorough and pay attention. This is very important. The image is from a text document so be sure to continue until the bottom of the page. Thanks a lot! You tend to forget about some text in the image so please focus! Use markdown format."),
},
Role: llms.ChatMessageTypeHuman,
},
})
if err != nil {
return "", fmt.Errorf("error getting response from LLM: %v", err)
}
result := completion.Choices[0].Content
fmt.Println(result)
return result, nil
}
// getSuggestedTitle generates a suggested title for a document using the LLM
func (app *App) getSuggestedTitle(ctx context.Context, content string) (string, error) {
likelyLanguage := getLikelyLanguage()
templateMutex.RLock()
defer templateMutex.RUnlock()
var promptBuffer bytes.Buffer
err := titleTemplate.Execute(&promptBuffer, map[string]interface{}{
"Language": likelyLanguage,
"Content": content,
})
if err != nil {
return "", fmt.Errorf("error executing title template: %v", err)
}
prompt := promptBuffer.String()
log.Printf("Title suggestion prompt: %s", prompt)
completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{
{
Parts: []llms.ContentPart{
llms.TextContent{
Text: prompt,
},
},
Role: llms.ChatMessageTypeHuman,
},
})
if err != nil {
return "", fmt.Errorf("error getting response from LLM: %v", err)
}
return strings.TrimSpace(strings.Trim(completion.Choices[0].Content, "\"")), nil
}
// loadTemplates loads the title and tag templates from files or uses default templates
func loadTemplates() {
templateMutex.Lock()
@ -612,71 +328,3 @@ func createVisionLLM() (llms.Model, error) {
return nil, fmt.Errorf("unsupported LLM provider: %s", llmProvider)
}
}
// getPromptsHandler handles the GET /api/prompts endpoint
func getPromptsHandler(c *gin.Context) {
templateMutex.RLock()
defer templateMutex.RUnlock()
// Read the templates from files or use default content
titleTemplateContent, err := os.ReadFile("prompts/title_prompt.tmpl")
if err != nil {
titleTemplateContent = []byte(defaultTitleTemplate)
}
tagTemplateContent, err := os.ReadFile("prompts/tag_prompt.tmpl")
if err != nil {
tagTemplateContent = []byte(defaultTagTemplate)
}
c.JSON(http.StatusOK, gin.H{
"title_template": string(titleTemplateContent),
"tag_template": string(tagTemplateContent),
})
}
// updatePromptsHandler handles the POST /api/prompts endpoint
func updatePromptsHandler(c *gin.Context) {
var req struct {
TitleTemplate string `json:"title_template"`
TagTemplate string `json:"tag_template"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request payload"})
return
}
templateMutex.Lock()
defer templateMutex.Unlock()
// Update title template
if req.TitleTemplate != "" {
t, err := template.New("title").Parse(req.TitleTemplate)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid title template: %v", err)})
return
}
titleTemplate = t
err = os.WriteFile("prompts/title_prompt.tmpl", []byte(req.TitleTemplate), 0644)
if err != nil {
log.Printf("Failed to write title_prompt.tmpl: %v", err)
}
}
// Update tag template
if req.TagTemplate != "" {
t, err := template.New("tag").Parse(req.TagTemplate)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid tag template: %v", err)})
return
}
tagTemplate = t
err = os.WriteFile("prompts/tag_prompt.tmpl", []byte(req.TagTemplate), 0644)
if err != nil {
log.Printf("Failed to write tag_prompt.tmpl: %v", err)
}
}
c.Status(http.StatusOK)
}

View file

@ -28,10 +28,13 @@ type PaperlessClient struct {
// NewPaperlessClient creates a new instance of PaperlessClient with a default HTTP client
func NewPaperlessClient(baseURL, apiToken string) *PaperlessClient {
cacheFolder := os.Getenv("PAPERLESS_GPT_CACHE_DIR")
return &PaperlessClient{
BaseURL: strings.TrimRight(baseURL, "/"),
APIToken: apiToken,
HTTPClient: &http.Client{},
BaseURL: strings.TrimRight(baseURL, "/"),
APIToken: apiToken,
HTTPClient: &http.Client{},
CacheFolder: cacheFolder,
}
}
@ -248,9 +251,9 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum
}
// DownloadDocumentAsImages downloads the PDF file of the specified document and converts it to images
func (c *PaperlessClient) DownloadDocumentAsImages(ctx context.Context, document Document) ([]string, error) {
func (c *PaperlessClient) DownloadDocumentAsImages(ctx context.Context, documentId int) ([]string, error) {
// Create a directory named after the document ID
docDir := filepath.Join(c.GetCacheFolder(), fmt.Sprintf("/document-%d", document.ID))
docDir := filepath.Join(c.GetCacheFolder(), fmt.Sprintf("/document-%d", documentId))
if _, err := os.Stat(docDir); os.IsNotExist(err) {
err = os.MkdirAll(docDir, 0755)
if err != nil {
@ -274,7 +277,7 @@ func (c *PaperlessClient) DownloadDocumentAsImages(ctx context.Context, document
}
// Proceed with downloading and converting the document to images
path := fmt.Sprintf("api/documents/%d/download/", document.ID)
path := fmt.Sprintf("api/documents/%d/download/", documentId)
resp, err := c.Do(ctx, "GET", path, nil)
if err != nil {
return nil, err
@ -283,7 +286,7 @@ func (c *PaperlessClient) DownloadDocumentAsImages(ctx context.Context, document
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("error downloading document %d: %d, %s", document.ID, resp.StatusCode, string(bodyBytes))
return nil, fmt.Errorf("error downloading document %d: %d, %s", documentId, resp.StatusCode, string(bodyBytes))
}
pdfData, err := io.ReadAll(resp.Body)
@ -315,7 +318,10 @@ func (c *PaperlessClient) DownloadDocumentAsImages(ctx context.Context, document
for n := 0; n < doc.NumPage(); n++ {
n := n // capture loop variable
g.Go(func() error {
mu.Lock()
// I assume the libmupdf library is not thread-safe
img, err := doc.Image(n)
mu.Unlock()
if err != nil {
return err
}

View file

@ -361,7 +361,7 @@ func TestDownloadDocumentAsImages(t *testing.T) {
})
ctx := context.Background()
imagePaths, err := env.client.DownloadDocumentAsImages(ctx, document)
imagePaths, err := env.client.DownloadDocumentAsImages(ctx, document.ID)
require.NoError(t, err)
// Verify that exatly one page was extracted
@ -398,7 +398,7 @@ func TestDownloadDocumentAsImages_ManyPages(t *testing.T) {
env.client.CacheFolder = "tests/tmp"
// Clean the cache folder
os.RemoveAll(env.client.CacheFolder)
imagePaths, err := env.client.DownloadDocumentAsImages(ctx, document)
imagePaths, err := env.client.DownloadDocumentAsImages(ctx, document.ID)
require.NoError(t, err)
// Verify that exatly 52 pages were extracted

View file

@ -15,6 +15,8 @@
"prop-types": "^15.8.1",
"react": "^18.3.1",
"react-dom": "^18.3.1",
"react-icons": "^5.3.0",
"react-router-dom": "^6.27.0",
"react-tag-autocomplete": "^7.3.0"
},
"devDependencies": {
@ -843,6 +845,14 @@
"react": "^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0"
}
},
"node_modules/@remix-run/router": {
"version": "1.20.0",
"resolved": "https://registry.npmjs.org/@remix-run/router/-/router-1.20.0.tgz",
"integrity": "sha512-mUnk8rPJBI9loFDZ+YzPGdeniYK+FTmRD1TMCz7ev2SNIozyKKpnGgsxO34u6Z4z/t0ITuu7voi/AshfsGsgFg==",
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@rollup/rollup-android-arm-eabi": {
"version": "4.22.4",
"resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.22.4.tgz",
@ -3298,11 +3308,49 @@
"react": "^18.3.1"
}
},
"node_modules/react-icons": {
"version": "5.3.0",
"resolved": "https://registry.npmjs.org/react-icons/-/react-icons-5.3.0.tgz",
"integrity": "sha512-DnUk8aFbTyQPSkCfF8dbX6kQjXA9DktMeJqfjrg6cK9vwQVMxmcA3BfP4QoiztVmEHtwlTgLFsPuH2NskKT6eg==",
"peerDependencies": {
"react": "*"
}
},
"node_modules/react-is": {
"version": "16.13.1",
"resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz",
"integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ=="
},
"node_modules/react-router": {
"version": "6.27.0",
"resolved": "https://registry.npmjs.org/react-router/-/react-router-6.27.0.tgz",
"integrity": "sha512-YA+HGZXz4jaAkVoYBE98VQl+nVzI+cVI2Oj/06F5ZM+0u3TgedN9Y9kmMRo2mnkSK2nCpNQn0DVob4HCsY/WLw==",
"dependencies": {
"@remix-run/router": "1.20.0"
},
"engines": {
"node": ">=14.0.0"
},
"peerDependencies": {
"react": ">=16.8"
}
},
"node_modules/react-router-dom": {
"version": "6.27.0",
"resolved": "https://registry.npmjs.org/react-router-dom/-/react-router-dom-6.27.0.tgz",
"integrity": "sha512-+bvtFWMC0DgAFrfKXKG9Fc+BcXWRUO1aJIihbB79xaeq0v5UzfvnM5houGUm1Y461WVRcgAQ+Clh5rdb1eCx4g==",
"dependencies": {
"@remix-run/router": "1.20.0",
"react-router": "6.27.0"
},
"engines": {
"node": ">=14.0.0"
},
"peerDependencies": {
"react": ">=16.8",
"react-dom": ">=16.8"
}
},
"node_modules/react-tag-autocomplete": {
"version": "7.3.0",
"resolved": "https://registry.npmjs.org/react-tag-autocomplete/-/react-tag-autocomplete-7.3.0.tgz",

View file

@ -18,6 +18,8 @@
"prop-types": "^15.8.1",
"react": "^18.3.1",
"react-dom": "^18.3.1",
"react-icons": "^5.3.0",
"react-router-dom": "^6.27.0",
"react-tag-autocomplete": "^7.3.0"
},
"devDependencies": {

View file

@ -1,13 +1,18 @@
// App.tsx or App.jsx
import React from 'react';
import { Route, BrowserRouter as Router, Routes } from 'react-router-dom';
import DocumentProcessor from './DocumentProcessor';
import './index.css';
import ExperimentalOCR from './ExperimentalOCR'; // New component
const App: React.FC = () => {
return (
<div className="App">
<DocumentProcessor />
</div>
);
return (
<Router>
<Routes>
<Route path="/" element={<DocumentProcessor />} />
<Route path="/experimental-ocr" element={<ExperimentalOCR />} />
</Routes>
</Router>
);
};
export default App;

View file

@ -1,5 +1,6 @@
import axios from "axios";
import React, { useCallback, useEffect, useState } from "react";
import { Link } from "react-router-dom";
import "react-tag-autocomplete/example/src/styles.css"; // Ensure styles are loaded
import DocumentsToProcess from "./components/DocumentsToProcess";
import NoDocuments from "./components/NoDocuments";
@ -129,9 +130,7 @@ const DocumentProcessor: React.FC = () => {
doc.id === docId
? {
...doc,
suggested_tags: doc.suggested_tags?.filter(
(_, i) => i !== index
),
suggested_tags: doc.suggested_tags?.filter((_, i) => i !== index),
}
: doc
)
@ -141,9 +140,7 @@ const DocumentProcessor: React.FC = () => {
const handleTitleChange = (docId: number, title: string) => {
setSuggestions((prevSuggestions) =>
prevSuggestions.map((doc) =>
doc.id === docId
? { ...doc, suggested_title: title }
: doc
doc.id === docId ? { ...doc, suggested_title: title } : doc
)
);
};
@ -182,11 +179,12 @@ const DocumentProcessor: React.FC = () => {
}
}, [documents]);
if (loading) {
return (
<div className="flex items-center justify-center min-h-screen bg-white dark:bg-gray-900">
<div className="text-xl font-semibold text-gray-800 dark:text-gray-200">Loading documents...</div>
<div className="text-xl font-semibold text-gray-800 dark:text-gray-200">
Loading documents...
</div>
</div>
);
}
@ -195,6 +193,14 @@ const DocumentProcessor: React.FC = () => {
<div className="max-w-5xl mx-auto p-6 bg-white dark:bg-gray-900 text-gray-800 dark:text-gray-200">
<header className="text-center">
<h1 className="text-4xl font-bold mb-8">Paperless GPT</h1>
<div>
<Link
to="/experimental-ocr"
className="text-blue-500 hover:underline"
>
OCR via LLMs (Experimental)
</Link>
</div>
</header>
{error && (

View file

@ -0,0 +1,129 @@
// ExperimentalOCR.tsx
import axios from 'axios';
import React, { useState } from 'react';
import { FaSpinner } from 'react-icons/fa';
const ExperimentalOCR: React.FC = () => {
const [documentId, setDocumentId] = useState('');
const [jobId, setJobId] = useState('');
const [ocrResult, setOcrResult] = useState('');
const [status, setStatus] = useState('');
const [error, setError] = useState('');
const [isCheckingStatus, setIsCheckingStatus] = useState(false);
const submitOCRJob = async () => {
setStatus('');
setError('');
setJobId('');
setOcrResult('');
try {
setStatus('Submitting OCR job...');
const response = await axios.post(`/api/documents/${documentId}/ocr`);
setJobId(response.data.job_id);
setStatus('Job submitted. Processing...');
} catch (err) {
console.error(err);
setError('Failed to submit OCR job.');
}
};
const checkJobStatus = async () => {
if (!jobId) return;
setIsCheckingStatus(true);
try {
const response = await axios.get(`/api/jobs/ocr/${jobId}`);
const jobStatus = response.data.status;
if (jobStatus === 'completed') {
setOcrResult(response.data.result);
setStatus('OCR completed successfully.');
} else if (jobStatus === 'failed') {
setError(response.data.error);
setStatus('OCR failed.');
} else {
setStatus(`Job status: ${jobStatus}. This may take a few minutes.`);
// Automatically check again after a delay
setTimeout(checkJobStatus, 5000);
}
} catch (err) {
console.error(err);
setError('Failed to check job status.');
} finally {
setIsCheckingStatus(false);
}
};
// Start checking job status when jobId is set
React.useEffect(() => {
if (jobId) {
checkJobStatus();
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [jobId]);
return (
<div className="max-w-3xl mx-auto p-6 bg-white dark:bg-gray-900 text-gray-800 dark:text-gray-200">
<h1 className="text-4xl font-bold mb-6 text-center">OCR via LLMs (Experimental)</h1>
<p className="mb-6 text-center text-yellow-600">
This is an experimental feature. Results may vary, and processing may take some time.
</p>
<div className="bg-gray-100 dark:bg-gray-800 p-6 rounded-lg shadow-md">
<div className="mb-4">
<label htmlFor="documentId" className="block mb-2 font-semibold">
Document ID:
</label>
<input
type="text"
id="documentId"
value={documentId}
onChange={(e) => setDocumentId(e.target.value)}
className="border border-gray-300 dark:border-gray-700 rounded w-full p-2 focus:outline-none focus:ring-2 focus:ring-blue-500"
placeholder="Enter the document ID"
/>
</div>
<button
onClick={submitOCRJob}
className="w-full bg-blue-600 hover:bg-blue-700 text-white font-semibold py-2 px-4 rounded transition duration-200"
disabled={!documentId}
>
{status.startsWith('Submitting') ? (
<span className="flex items-center justify-center">
<FaSpinner className="animate-spin mr-2" />
Submitting...
</span>
) : (
'Submit OCR Job'
)}
</button>
{status && (
<div className="mt-4 text-center text-gray-700 dark:text-gray-300">
{status.includes('in_progress') && (
<span className="flex items-center justify-center">
<FaSpinner className="animate-spin mr-2" />
{status}
</span>
)}
{!status.includes('in_progress') && status}
</div>
)}
{error && (
<div className="mt-4 p-4 bg-red-100 dark:bg-red-800 text-red-700 dark:text-red-200 rounded">
{error}
</div>
)}
{ocrResult && (
<div className="mt-6">
<h2 className="text-2xl font-bold mb-4">OCR Result:</h2>
<div className="bg-gray-50 dark:bg-gray-900 p-4 rounded border border-gray-200 dark:border-gray-700 overflow-auto max-h-96">
<pre className="whitespace-pre-wrap">{ocrResult}</pre>
</div>
</div>
)}
</div>
</div>
);
};
export default ExperimentalOCR;