feat: add custom HTTP transport with headers for OpenAI client (#245)

* feat: add custom HTTP transport with headers for OpenAI client

Closes #237
This commit is contained in:
Icereed 2025-02-17 11:39:45 +01:00 committed by GitHub
parent 1647219fa8
commit b7fab1af8a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 58 additions and 2 deletions

38
main.go
View file

@ -8,9 +8,9 @@ import (
"paperless-gpt/ocr" "paperless-gpt/ocr"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv"
"slices" "slices"
"strings" "strconv"
"strings"
"sync" "sync"
"text/template" "text/template"
"time" "time"
@ -639,9 +639,11 @@ func createLLM() (llms.Model, error) {
if openaiAPIKey == "" { if openaiAPIKey == "" {
return nil, fmt.Errorf("OpenAI API key is not set") return nil, fmt.Errorf("OpenAI API key is not set")
} }
return openai.New( return openai.New(
openai.WithModel(llmModel), openai.WithModel(llmModel),
openai.WithToken(openaiAPIKey), openai.WithToken(openaiAPIKey),
openai.WithHTTPClient(createCustomHTTPClient()),
) )
case "ollama": case "ollama":
host := os.Getenv("OLLAMA_HOST") host := os.Getenv("OLLAMA_HOST")
@ -663,9 +665,11 @@ func createVisionLLM() (llms.Model, error) {
if openaiAPIKey == "" { if openaiAPIKey == "" {
return nil, fmt.Errorf("OpenAI API key is not set") return nil, fmt.Errorf("OpenAI API key is not set")
} }
return openai.New( return openai.New(
openai.WithModel(visionLlmModel), openai.WithModel(visionLlmModel),
openai.WithToken(openaiAPIKey), openai.WithToken(openaiAPIKey),
openai.WithHTTPClient(createCustomHTTPClient()),
) )
case "ollama": case "ollama":
host := os.Getenv("OLLAMA_HOST") host := os.Getenv("OLLAMA_HOST")
@ -681,3 +685,33 @@ func createVisionLLM() (llms.Model, error) {
return nil, nil return nil, nil
} }
} }
func createCustomHTTPClient() *http.Client {
// Create custom transport that adds headers
customTransport := &headerTransport{
transport: http.DefaultTransport,
headers: map[string]string{
"X-Title": "paperless-gpt",
},
}
// Create custom client with the transport
httpClient := http.DefaultClient
httpClient.Transport = customTransport
return httpClient
}
// headerTransport is a custom http.RoundTripper that adds custom headers to requests
type headerTransport struct {
transport http.RoundTripper
headers map[string]string
}
// RoundTrip implements the http.RoundTripper interface
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
for key, value := range t.headers {
req.Header.Add(key, value)
}
return t.transport.RoundTrip(req)
}

View file

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest"
"slices" "slices"
"testing" "testing"
"text/template" "text/template"
@ -175,3 +176,24 @@ func TestProcessAutoTagDocuments(t *testing.T) {
}) })
} }
} }
func TestCreateCustomHTTPClient(t *testing.T) {
// Create a test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify custom header
assert.Equal(t, "paperless-gpt", r.Header.Get("X-Title"), "Expected X-Title header")
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Get custom client
client := createCustomHTTPClient()
require.NotNil(t, client, "HTTP client should not be nil")
// Make a request
resp, err := client.Get(server.URL)
require.NoError(t, err, "Request should not fail")
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode, "Expected 200 OK response")
}