pretty ui

This commit is contained in:
Dominik Schröter 2024-10-28 15:02:18 +01:00
parent de8dd90cbb
commit 2b436a2ab2
8 changed files with 161 additions and 22 deletions

View file

@ -189,6 +189,7 @@ func (app *App) getJobStatusHandler(c *gin.Context) {
"status": job.Status, "status": job.Status,
"created_at": job.CreatedAt, "created_at": job.CreatedAt,
"updated_at": job.UpdatedAt, "updated_at": job.UpdatedAt,
"pages_done": job.PagesDone,
} }
if job.Status == "completed" { if job.Status == "completed" {
@ -210,6 +211,7 @@ func (app *App) getAllJobsHandler(c *gin.Context) {
"status": job.Status, "status": job.Status,
"created_at": job.CreatedAt, "created_at": job.CreatedAt,
"updated_at": job.UpdatedAt, "updated_at": job.UpdatedAt,
"pages_done": job.PagesDone,
} }
if job.Status == "completed" { if job.Status == "completed" {

View file

@ -67,12 +67,27 @@ func (app *App) getSuggestedTags(ctx context.Context, content string, suggestedT
} }
func (app *App) doOCRViaLLM(ctx context.Context, jpegBytes []byte) (string, error) { func (app *App) doOCRViaLLM(ctx context.Context, jpegBytes []byte) (string, error) {
templateMutex.RLock()
defer templateMutex.RUnlock()
likelyLanguage := getLikelyLanguage()
var promptBuffer bytes.Buffer
err := ocrTemplate.Execute(&promptBuffer, map[string]interface{}{
"Language": likelyLanguage,
})
if err != nil {
return "", fmt.Errorf("error executing tag template: %v", err)
}
prompt := promptBuffer.String()
// Convert the image to text // Convert the image to text
completion, err := app.VisionLLM.GenerateContent(ctx, []llms.MessageContent{ completion, err := app.VisionLLM.GenerateContent(ctx, []llms.MessageContent{
{ {
Parts: []llms.ContentPart{ Parts: []llms.ContentPart{
llms.BinaryPart("image/jpeg", jpegBytes), llms.BinaryPart("image/jpeg", jpegBytes),
llms.TextPart("Just transcribe the text in this image and preserve the formatting and layout (high quality OCR). Do that for ALL the text in the image. Be thorough and pay attention. This is very important. The image is from a text document so be sure to continue until the bottom of the page. Thanks a lot! You tend to forget about some text in the image so please focus! Use markdown format."), llms.TextPart(prompt),
}, },
Role: llms.ChatMessageTypeHuman, Role: llms.ChatMessageTypeHuman,
}, },

15
jobs.go
View file

@ -21,6 +21,7 @@ type Job struct {
Result string // OCR result or error message Result string // OCR result or error message
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
PagesDone int // Number of pages processed
} }
// JobStore manages jobs and their statuses // JobStore manages jobs and their statuses
@ -44,6 +45,7 @@ func generateJobID() string {
func (store *JobStore) addJob(job *Job) { func (store *JobStore) addJob(job *Job) {
store.Lock() store.Lock()
defer store.Unlock() defer store.Unlock()
job.PagesDone = 0 // Initialize PagesDone to 0
store.jobs[job.ID] = job store.jobs[job.ID] = job
logger.Printf("Job added: %v", job) logger.Printf("Job added: %v", job)
} }
@ -84,6 +86,16 @@ func (store *JobStore) updateJobStatus(jobID, status, result string) {
} }
} }
func (store *JobStore) updatePagesDone(jobID string, pagesDone int) {
store.Lock()
defer store.Unlock()
if job, exists := store.jobs[jobID]; exists {
job.PagesDone = pagesDone
job.UpdatedAt = time.Now()
logger.Printf("Job pages done updated: %v", job)
}
}
func startWorkerPool(app *App, numWorkers int) { func startWorkerPool(app *App, numWorkers int) {
for i := 0; i < numWorkers; i++ { for i := 0; i < numWorkers; i++ {
go func(workerID int) { go func(workerID int) {
@ -110,7 +122,7 @@ func processJob(app *App, job *Job) {
} }
var ocrTexts []string var ocrTexts []string
for _, imagePath := range imagePaths { for i, imagePath := range imagePaths {
imageContent, err := os.ReadFile(imagePath) imageContent, err := os.ReadFile(imagePath)
if err != nil { if err != nil {
logger.Printf("Error reading image file for job %s: %v", job.ID, err) logger.Printf("Error reading image file for job %s: %v", job.ID, err)
@ -126,6 +138,7 @@ func processJob(app *App, job *Job) {
} }
ocrTexts = append(ocrTexts, ocrText) ocrTexts = append(ocrTexts, ocrText)
jobStore.updatePagesDone(job.ID, i+1) // Update PagesDone after each page is processed
} }
// Combine the OCR texts // Combine the OCR texts

31
main.go
View file

@ -34,6 +34,7 @@ var (
// Templates // Templates
titleTemplate *template.Template titleTemplate *template.Template
tagTemplate *template.Template tagTemplate *template.Template
ocrTemplate *template.Template
templateMutex sync.RWMutex templateMutex sync.RWMutex
// Default templates // Default templates
@ -59,6 +60,8 @@ Content:
Please concisely select the {{.Language}} tags from the list above that best describe the document. Please concisely select the {{.Language}} tags from the list above that best describe the document.
Be very selective and only choose the most relevant tags since too many tags will make the document less discoverable. Be very selective and only choose the most relevant tags since too many tags will make the document less discoverable.
` `
defaultOcrPrompt = `Just transcribe the text in this image and preserve the formatting and layout (high quality OCR). Do that for ALL the text in the image. Be thorough and pay attention. This is very important. The image is from a text document so be sure to continue until the bottom of the page. Thanks a lot! You tend to forget about some text in the image so please focus! Use markdown format.`
) )
// App struct to hold dependencies // App struct to hold dependencies
@ -142,6 +145,12 @@ func main() {
api.POST("/documents/:id/ocr", app.submitOCRJobHandler) api.POST("/documents/:id/ocr", app.submitOCRJobHandler)
api.GET("/jobs/ocr/:job_id", app.getJobStatusHandler) api.GET("/jobs/ocr/:job_id", app.getJobStatusHandler)
api.GET("/jobs/ocr", app.getAllJobsHandler) api.GET("/jobs/ocr", app.getAllJobsHandler)
// Endpoint to see if user enabled OCR
api.GET("/experimental/ocr", func(c *gin.Context) {
enabled := isOcrEnabled()
c.JSON(http.StatusOK, gin.H{"enabled": enabled})
})
} }
// Serve static files for the frontend under /assets // Serve static files for the frontend under /assets
@ -163,6 +172,10 @@ func main() {
} }
} }
func isOcrEnabled() bool {
return visionLlmModel != "" && visionLlmProvider != ""
}
// validateEnvVars ensures all necessary environment variables are set // validateEnvVars ensures all necessary environment variables are set
func validateEnvVars() { func validateEnvVars() {
if paperlessBaseURL == "" { if paperlessBaseURL == "" {
@ -278,6 +291,21 @@ func loadTemplates() {
if err != nil { if err != nil {
log.Fatalf("Failed to parse tag template: %v", err) log.Fatalf("Failed to parse tag template: %v", err)
} }
// Load OCR template
ocrTemplatePath := filepath.Join(promptsDir, "ocr_prompt.tmpl")
ocrTemplateContent, err := os.ReadFile(ocrTemplatePath)
if err != nil {
log.Printf("Could not read %s, using default template: %v", ocrTemplatePath, err)
ocrTemplateContent = []byte(defaultOcrPrompt)
if err := os.WriteFile(ocrTemplatePath, ocrTemplateContent, os.ModePerm); err != nil {
log.Fatalf("Failed to write default OCR template to disk: %v", err)
}
}
ocrTemplate, err = template.New("ocr").Funcs(sprig.FuncMap()).Parse(string(ocrTemplateContent))
if err != nil {
log.Fatalf("Failed to parse OCR template: %v", err)
}
} }
// createLLM creates the appropriate LLM client based on the provider // createLLM creates the appropriate LLM client based on the provider
@ -325,6 +353,7 @@ func createVisionLLM() (llms.Model, error) {
ollama.WithServerURL(host), ollama.WithServerURL(host),
) )
default: default:
return nil, fmt.Errorf("unsupported LLM provider: %s", llmProvider) log.Printf("No Vision LLM provider created: %s", visionLlmProvider)
return nil, nil
} }
} }

View file

@ -222,6 +222,12 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum
log.Printf("No valid title found for document %d, skipping.", documentID) log.Printf("No valid title found for document %d, skipping.", documentID)
} }
// Suggested Content
suggestedContent := document.SuggestedContent
if suggestedContent != "" {
updatedFields["content"] = suggestedContent
}
// Marshal updated fields to JSON // Marshal updated fields to JSON
jsonData, err := json.Marshal(updatedFields) jsonData, err := json.Marshal(updatedFields)
if err != nil { if err != nil {

View file

@ -58,4 +58,5 @@ type DocumentSuggestion struct {
OriginalDocument Document `json:"original_document"` OriginalDocument Document `json:"original_document"`
SuggestedTitle string `json:"suggested_title,omitempty"` SuggestedTitle string `json:"suggested_title,omitempty"`
SuggestedTags []string `json:"suggested_tags,omitempty"` SuggestedTags []string `json:"suggested_tags,omitempty"`
SuggestedContent string `json:"suggested_content,omitempty"`
} }

View file

@ -25,6 +25,7 @@ export interface DocumentSuggestion {
original_document: Document; original_document: Document;
suggested_title?: string; suggested_title?: string;
suggested_tags?: string[]; suggested_tags?: string[];
suggested_content?: string;
} }
export interface TagOption { export interface TagOption {
@ -45,17 +46,22 @@ const DocumentProcessor: React.FC = () => {
const [generateTags, setGenerateTags] = useState(true); const [generateTags, setGenerateTags] = useState(true);
const [error, setError] = useState<string | null>(null); const [error, setError] = useState<string | null>(null);
// Temporary feature flags
const [ocrEnabled, setOcrEnabled] = useState(false);
// Custom hook to fetch initial data // Custom hook to fetch initial data
const fetchInitialData = useCallback(async () => { const fetchInitialData = useCallback(async () => {
try { try {
const [filterTagRes, documentsRes, tagsRes] = await Promise.all([ const [filterTagRes, documentsRes, tagsRes, ocrEnabledRes] = await Promise.all([
axios.get<{ tag: string }>("/api/filter-tag"), axios.get<{ tag: string }>("/api/filter-tag"),
axios.get<Document[]>("/api/documents"), axios.get<Document[]>("/api/documents"),
axios.get<Record<string, number>>("/api/tags"), axios.get<Record<string, number>>("/api/tags"),
axios.get<{enabled: boolean}>("/api/experimental/ocr"),
]); ]);
setFilterTag(filterTagRes.data.tag); setFilterTag(filterTagRes.data.tag);
setDocuments(documentsRes.data); setDocuments(documentsRes.data);
setOcrEnabled(ocrEnabledRes.data.enabled);
const tags = Object.keys(tagsRes.data).map((tag) => ({ const tags = Object.keys(tagsRes.data).map((tag) => ({
id: tag, id: tag,
name: tag, name: tag,
@ -193,14 +199,16 @@ const DocumentProcessor: React.FC = () => {
<div className="max-w-5xl mx-auto p-6 bg-white dark:bg-gray-900 text-gray-800 dark:text-gray-200"> <div className="max-w-5xl mx-auto p-6 bg-white dark:bg-gray-900 text-gray-800 dark:text-gray-200">
<header className="text-center"> <header className="text-center">
<h1 className="text-4xl font-bold mb-8">Paperless GPT</h1> <h1 className="text-4xl font-bold mb-8">Paperless GPT</h1>
<div> {ocrEnabled && (
<Link <div>
to="/experimental-ocr" <Link
className="text-blue-500 hover:underline" to="/experimental-ocr"
> className="inline-block bg-blue-600 hover:bg-blue-700 text-white font-semibold py-2 px-4 rounded transition duration-200 dark:bg-blue-500 dark:hover:bg-blue-600"
OCR via LLMs (Experimental) >
</Link> OCR via LLMs (Experimental)
</div> </Link>
</div>
)}
</header> </header>
{error && ( {error && (

View file

@ -1,21 +1,25 @@
// ExperimentalOCR.tsx
import axios from 'axios'; import axios from 'axios';
import React, { useState } from 'react'; import React, { useCallback, useEffect, useState } from 'react';
import { FaSpinner } from 'react-icons/fa'; import { FaSpinner } from 'react-icons/fa';
import { Document, DocumentSuggestion } from './DocumentProcessor';
const ExperimentalOCR: React.FC = () => { const ExperimentalOCR: React.FC = () => {
const [documentId, setDocumentId] = useState(''); const refreshInterval = 1000; // Refresh interval in milliseconds
const [documentId, setDocumentId] = useState(0);
const [jobId, setJobId] = useState(''); const [jobId, setJobId] = useState('');
const [ocrResult, setOcrResult] = useState(''); const [ocrResult, setOcrResult] = useState('');
const [status, setStatus] = useState(''); const [status, setStatus] = useState('');
const [error, setError] = useState(''); const [error, setError] = useState<string | null>('');
const [pagesDone, setPagesDone] = useState(0); // New state for pages done
const [saving, setSaving] = useState(false); // New state for saving
const [documentDetails, setDocumentDetails] = useState<Document | null>(null); // New state for document details
const submitOCRJob = async () => { const submitOCRJob = async () => {
setStatus(''); setStatus('');
setError(''); setError('');
setJobId(''); setJobId('');
setOcrResult(''); setOcrResult('');
setPagesDone(0); // Reset pages done
try { try {
setStatus('Submitting OCR job...'); setStatus('Submitting OCR job...');
@ -34,6 +38,7 @@ const ExperimentalOCR: React.FC = () => {
try { try {
const response = await axios.get(`/api/jobs/ocr/${jobId}`); const response = await axios.get(`/api/jobs/ocr/${jobId}`);
const jobStatus = response.data.status; const jobStatus = response.data.status;
setPagesDone(response.data.pages_done); // Update pages done
if (jobStatus === 'completed') { if (jobStatus === 'completed') {
setOcrResult(response.data.result); setOcrResult(response.data.result);
setStatus('OCR completed successfully.'); setStatus('OCR completed successfully.');
@ -43,7 +48,7 @@ const ExperimentalOCR: React.FC = () => {
} else { } else {
setStatus(`Job status: ${jobStatus}. This may take a few minutes.`); setStatus(`Job status: ${jobStatus}. This may take a few minutes.`);
// Automatically check again after a delay // Automatically check again after a delay
setTimeout(checkJobStatus, 5000); setTimeout(checkJobStatus, refreshInterval);
} }
} catch (err) { } catch (err) {
console.error(err); console.error(err);
@ -51,8 +56,49 @@ const ExperimentalOCR: React.FC = () => {
} }
}; };
const handleSaveContent = async () => {
setSaving(true);
setError(null);
try {
if (!documentDetails) {
setError('Document details not fetched.');
throw new Error('Document details not fetched.');
}
const requestPayload: DocumentSuggestion = {
id: documentId,
original_document: documentDetails, // Use fetched document details
suggested_content: ocrResult,
};
await axios.post("/api/save-content", requestPayload);
setStatus('Content saved successfully.');
} catch (err) {
console.error("Error saving content:", err);
setError("Failed to save content.");
} finally {
setSaving(false);
}
};
const fetchDocumentDetails = useCallback(async () => {
if (!documentId) return;
try {
const response = await axios.get<Document>(`/api/documents/${documentId}`);
setDocumentDetails(response.data);
} catch (err) {
console.error("Error fetching document details:", err);
setError("Failed to fetch document details.");
}
}, [documentId]);
// Fetch document details when documentId changes
useEffect(() => {
fetchDocumentDetails();
}, [documentId, fetchDocumentDetails]);
// Start checking job status when jobId is set // Start checking job status when jobId is set
React.useEffect(() => { useEffect(() => {
if (jobId) { if (jobId) {
checkJobStatus(); checkJobStatus();
} }
@ -71,10 +117,10 @@ const ExperimentalOCR: React.FC = () => {
Document ID: Document ID:
</label> </label>
<input <input
type="text" type="number"
id="documentId" id="documentId"
value={documentId} value={documentId}
onChange={(e) => setDocumentId(e.target.value)} onChange={(e) => setDocumentId(Number(e.target.value))}
className="border border-gray-300 dark:border-gray-700 rounded w-full p-2 focus:outline-none focus:ring-2 focus:ring-blue-500" className="border border-gray-300 dark:border-gray-700 rounded w-full p-2 focus:outline-none focus:ring-2 focus:ring-blue-500"
placeholder="Enter the document ID" placeholder="Enter the document ID"
/> />
@ -102,6 +148,11 @@ const ExperimentalOCR: React.FC = () => {
</span> </span>
)} )}
{!status.includes('in_progress') && status} {!status.includes('in_progress') && status}
{pagesDone > 0 && (
<div className="mt-2">
Pages processed: {pagesDone}
</div>
)}
</div> </div>
)} )}
{error && ( {error && (
@ -115,6 +166,20 @@ const ExperimentalOCR: React.FC = () => {
<div className="bg-gray-50 dark:bg-gray-900 p-4 rounded border border-gray-200 dark:border-gray-700 overflow-auto max-h-96"> <div className="bg-gray-50 dark:bg-gray-900 p-4 rounded border border-gray-200 dark:border-gray-700 overflow-auto max-h-96">
<pre className="whitespace-pre-wrap">{ocrResult}</pre> <pre className="whitespace-pre-wrap">{ocrResult}</pre>
</div> </div>
<button
onClick={handleSaveContent}
className="w-full bg-green-600 hover:bg-green-700 text-white font-semibold py-2 px-4 rounded transition duration-200 mt-4"
disabled={saving}
>
{saving ? (
<span className="flex items-center justify-center">
<FaSpinner className="animate-spin mr-2" />
Saving...
</span>
) : (
'Save Content'
)}
</button>
</div> </div>
)} )}
</div> </div>
@ -122,4 +187,4 @@ const ExperimentalOCR: React.FC = () => {
); );
}; };
export default ExperimentalOCR; export default ExperimentalOCR;