diff --git a/main.go b/main.go index 638637e..d382901 100644 --- a/main.go +++ b/main.go @@ -640,22 +640,10 @@ func createLLM() (llms.Model, error) { return nil, fmt.Errorf("OpenAI API key is not set") } - // Create custom transport that adds headers - customTransport := &headerTransport{ - transport: http.DefaultTransport, - headers: map[string]string{ - "X-Title": "paperless-gpt", - }, - } - - // Create custom client with the transport - httpClient := http.DefaultClient - httpClient.Transport = customTransport - return openai.New( openai.WithModel(llmModel), openai.WithToken(openaiAPIKey), - openai.WithHTTPClient(httpClient), + openai.WithHTTPClient(createCustomHTTPClient()), ) case "ollama": host := os.Getenv("OLLAMA_HOST") @@ -678,22 +666,10 @@ func createVisionLLM() (llms.Model, error) { return nil, fmt.Errorf("OpenAI API key is not set") } - // Create custom transport that adds headers - customTransport := &headerTransport{ - transport: http.DefaultTransport, - headers: map[string]string{ - "X-Title": "paperless-gpt", - }, - } - - // Create custom client with the transport - httpClient := http.DefaultClient - httpClient.Transport = customTransport - return openai.New( openai.WithModel(visionLlmModel), openai.WithToken(openaiAPIKey), - openai.WithHTTPClient(httpClient), + openai.WithHTTPClient(createCustomHTTPClient()), ) case "ollama": host := os.Getenv("OLLAMA_HOST") @@ -710,6 +686,22 @@ func createVisionLLM() (llms.Model, error) { } } +func createCustomHTTPClient() *http.Client { + // Create custom transport that adds headers + customTransport := &headerTransport{ + transport: http.DefaultTransport, + headers: map[string]string{ + "X-Title": "paperless-gpt", + }, + } + + // Create custom client with the transport + httpClient := http.DefaultClient + httpClient.Transport = customTransport + + return httpClient +} + // headerTransport is a custom http.RoundTripper that adds custom headers to requests type headerTransport struct { transport http.RoundTripper diff --git a/main_test.go b/main_test.go index b3fd0ad..cf4e3d1 100644 --- a/main_test.go +++ b/main_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/http" + "net/http/httptest" "slices" "testing" "text/template" @@ -175,3 +176,24 @@ func TestProcessAutoTagDocuments(t *testing.T) { }) } } + +func TestCreateCustomHTTPClient(t *testing.T) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify custom header + assert.Equal(t, "paperless-gpt", r.Header.Get("X-Title"), "Expected X-Title header") + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Get custom client + client := createCustomHTTPClient() + require.NotNil(t, client, "HTTP client should not be nil") + + // Make a request + resp, err := client.Get(server.URL) + require.NoError(t, err, "Request should not fail") + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode, "Expected 200 OK response") +}