feat: add custom HTTP transport with headers for OpenAI client

Closes #237
This commit is contained in:
Dominik Schröter 2025-02-17 09:46:24 +01:00 committed by Icereed
parent 1647219fa8
commit 23632f6fcd

44
main.go
View file

@ -8,8 +8,8 @@ import (
"paperless-gpt/ocr" "paperless-gpt/ocr"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv"
"slices" "slices"
"strconv"
"strings" "strings"
"sync" "sync"
"text/template" "text/template"
@ -639,9 +639,23 @@ func createLLM() (llms.Model, error) {
if openaiAPIKey == "" { if openaiAPIKey == "" {
return nil, fmt.Errorf("OpenAI API key is not set") return nil, fmt.Errorf("OpenAI API key is not set")
} }
// Create custom transport that adds headers
customTransport := &headerTransport{
transport: http.DefaultTransport,
headers: map[string]string{
"X-Title": "paperless-gpt",
},
}
// Create custom client with the transport
httpClient := http.DefaultClient
httpClient.Transport = customTransport
return openai.New( return openai.New(
openai.WithModel(llmModel), openai.WithModel(llmModel),
openai.WithToken(openaiAPIKey), openai.WithToken(openaiAPIKey),
openai.WithHTTPClient(httpClient),
) )
case "ollama": case "ollama":
host := os.Getenv("OLLAMA_HOST") host := os.Getenv("OLLAMA_HOST")
@ -663,9 +677,23 @@ func createVisionLLM() (llms.Model, error) {
if openaiAPIKey == "" { if openaiAPIKey == "" {
return nil, fmt.Errorf("OpenAI API key is not set") return nil, fmt.Errorf("OpenAI API key is not set")
} }
// Create custom transport that adds headers
customTransport := &headerTransport{
transport: http.DefaultTransport,
headers: map[string]string{
"X-Title": "paperless-gpt",
},
}
// Create custom client with the transport
httpClient := http.DefaultClient
httpClient.Transport = customTransport
return openai.New( return openai.New(
openai.WithModel(visionLlmModel), openai.WithModel(visionLlmModel),
openai.WithToken(openaiAPIKey), openai.WithToken(openaiAPIKey),
openai.WithHTTPClient(httpClient),
) )
case "ollama": case "ollama":
host := os.Getenv("OLLAMA_HOST") host := os.Getenv("OLLAMA_HOST")
@ -681,3 +709,17 @@ func createVisionLLM() (llms.Model, error) {
return nil, nil return nil, nil
} }
} }
// headerTransport is a custom http.RoundTripper that adds custom headers to requests
type headerTransport struct {
transport http.RoundTripper
headers map[string]string
}
// RoundTrip implements the http.RoundTripper interface
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
for key, value := range t.headers {
req.Header.Add(key, value)
}
return t.transport.RoundTrip(req)
}