diff --git a/main.go b/main.go index a300412..638637e 100644 --- a/main.go +++ b/main.go @@ -8,9 +8,9 @@ import ( "paperless-gpt/ocr" "path/filepath" "runtime" - "strconv" "slices" - "strings" + "strconv" + "strings" "sync" "text/template" "time" @@ -639,9 +639,23 @@ func createLLM() (llms.Model, error) { if openaiAPIKey == "" { return nil, fmt.Errorf("OpenAI API key is not set") } + + // Create custom transport that adds headers + customTransport := &headerTransport{ + transport: http.DefaultTransport, + headers: map[string]string{ + "X-Title": "paperless-gpt", + }, + } + + // Create custom client with the transport + httpClient := http.DefaultClient + httpClient.Transport = customTransport + return openai.New( openai.WithModel(llmModel), openai.WithToken(openaiAPIKey), + openai.WithHTTPClient(httpClient), ) case "ollama": host := os.Getenv("OLLAMA_HOST") @@ -663,9 +677,23 @@ func createVisionLLM() (llms.Model, error) { if openaiAPIKey == "" { return nil, fmt.Errorf("OpenAI API key is not set") } + + // Create custom transport that adds headers + customTransport := &headerTransport{ + transport: http.DefaultTransport, + headers: map[string]string{ + "X-Title": "paperless-gpt", + }, + } + + // Create custom client with the transport + httpClient := http.DefaultClient + httpClient.Transport = customTransport + return openai.New( openai.WithModel(visionLlmModel), openai.WithToken(openaiAPIKey), + openai.WithHTTPClient(httpClient), ) case "ollama": host := os.Getenv("OLLAMA_HOST") @@ -681,3 +709,17 @@ func createVisionLLM() (llms.Model, error) { return nil, nil } } + +// headerTransport is a custom http.RoundTripper that adds custom headers to requests +type headerTransport struct { + transport http.RoundTripper + headers map[string]string +} + +// RoundTrip implements the http.RoundTripper interface +func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + for key, value := range t.headers { + req.Header.Add(key, value) + } + return t.transport.RoundTrip(req) +}