feat: add custom HTTP transport with headers for OpenAI client

Closes #237
This commit is contained in:
Dominik Schröter 2025-02-17 09:46:24 +01:00 committed by Icereed
parent 1647219fa8
commit 23632f6fcd

46
main.go
View file

@ -8,9 +8,9 @@ import (
"paperless-gpt/ocr"
"path/filepath"
"runtime"
"strconv"
"slices"
"strings"
"strconv"
"strings"
"sync"
"text/template"
"time"
@ -639,9 +639,23 @@ func createLLM() (llms.Model, error) {
if openaiAPIKey == "" {
return nil, fmt.Errorf("OpenAI API key is not set")
}
// Create custom transport that adds headers
customTransport := &headerTransport{
transport: http.DefaultTransport,
headers: map[string]string{
"X-Title": "paperless-gpt",
},
}
// Create custom client with the transport
httpClient := http.DefaultClient
httpClient.Transport = customTransport
return openai.New(
openai.WithModel(llmModel),
openai.WithToken(openaiAPIKey),
openai.WithHTTPClient(httpClient),
)
case "ollama":
host := os.Getenv("OLLAMA_HOST")
@ -663,9 +677,23 @@ func createVisionLLM() (llms.Model, error) {
if openaiAPIKey == "" {
return nil, fmt.Errorf("OpenAI API key is not set")
}
// Create custom transport that adds headers
customTransport := &headerTransport{
transport: http.DefaultTransport,
headers: map[string]string{
"X-Title": "paperless-gpt",
},
}
// Create custom client with the transport
httpClient := http.DefaultClient
httpClient.Transport = customTransport
return openai.New(
openai.WithModel(visionLlmModel),
openai.WithToken(openaiAPIKey),
openai.WithHTTPClient(httpClient),
)
case "ollama":
host := os.Getenv("OLLAMA_HOST")
@ -681,3 +709,17 @@ func createVisionLLM() (llms.Model, error) {
return nil, nil
}
}
// headerTransport is a custom http.RoundTripper that adds custom headers to requests
type headerTransport struct {
transport http.RoundTripper
headers map[string]string
}
// RoundTrip implements the http.RoundTripper interface
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
for key, value := range t.headers {
req.Header.Add(key, value)
}
return t.transport.RoundTrip(req)
}