From b7fab1af8a0dad81b030c23a13591f52dae65fb8 Mon Sep 17 00:00:00 2001 From: Icereed Date: Mon, 17 Feb 2025 11:39:45 +0100 Subject: [PATCH] feat: add custom HTTP transport with headers for OpenAI client (#245) * feat: add custom HTTP transport with headers for OpenAI client Closes #237 --- main.go | 38 ++++++++++++++++++++++++++++++++++++-- main_test.go | 22 ++++++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/main.go b/main.go index a300412..d382901 100644 --- a/main.go +++ b/main.go @@ -8,9 +8,9 @@ import ( "paperless-gpt/ocr" "path/filepath" "runtime" - "strconv" "slices" - "strings" + "strconv" + "strings" "sync" "text/template" "time" @@ -639,9 +639,11 @@ func createLLM() (llms.Model, error) { if openaiAPIKey == "" { return nil, fmt.Errorf("OpenAI API key is not set") } + return openai.New( openai.WithModel(llmModel), openai.WithToken(openaiAPIKey), + openai.WithHTTPClient(createCustomHTTPClient()), ) case "ollama": host := os.Getenv("OLLAMA_HOST") @@ -663,9 +665,11 @@ func createVisionLLM() (llms.Model, error) { if openaiAPIKey == "" { return nil, fmt.Errorf("OpenAI API key is not set") } + return openai.New( openai.WithModel(visionLlmModel), openai.WithToken(openaiAPIKey), + openai.WithHTTPClient(createCustomHTTPClient()), ) case "ollama": host := os.Getenv("OLLAMA_HOST") @@ -681,3 +685,33 @@ func createVisionLLM() (llms.Model, error) { return nil, nil } } + +func createCustomHTTPClient() *http.Client { + // Create custom transport that adds headers + customTransport := &headerTransport{ + transport: http.DefaultTransport, + headers: map[string]string{ + "X-Title": "paperless-gpt", + }, + } + + // Create custom client with the transport + httpClient := http.DefaultClient + httpClient.Transport = customTransport + + return httpClient +} + +// headerTransport is a custom http.RoundTripper that adds custom headers to requests +type headerTransport struct { + transport http.RoundTripper + headers map[string]string +} + +// RoundTrip implements the http.RoundTripper interface +func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + for key, value := range t.headers { + req.Header.Add(key, value) + } + return t.transport.RoundTrip(req) +} diff --git a/main_test.go b/main_test.go index b3fd0ad..cf4e3d1 100644 --- a/main_test.go +++ b/main_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/http" + "net/http/httptest" "slices" "testing" "text/template" @@ -175,3 +176,24 @@ func TestProcessAutoTagDocuments(t *testing.T) { }) } } + +func TestCreateCustomHTTPClient(t *testing.T) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify custom header + assert.Equal(t, "paperless-gpt", r.Header.Get("X-Title"), "Expected X-Title header") + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Get custom client + client := createCustomHTTPClient() + require.NotNil(t, client, "HTTP client should not be nil") + + // Make a request + resp, err := client.Get(server.URL) + require.NoError(t, err, "Request should not fail") + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode, "Expected 200 OK response") +}