diff --git a/Dockerfile b/Dockerfile index 13b8366..b4b81ec 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,7 +14,7 @@ RUN go mod download COPY . . # Build the Go binary -RUN CGO_ENABLED=0 GOOS=linux go build -o paperless-gpt main.go +RUN CGO_ENABLED=0 GOOS=linux go build -o paperless-gpt . # Stage 2: Build Vite frontend FROM node:20 AS frontend diff --git a/http_client_bearer.go b/http_client_bearer.go new file mode 100644 index 0000000..09c479e --- /dev/null +++ b/http_client_bearer.go @@ -0,0 +1,34 @@ +package main + +import ( + "fmt" + "net/http" +) + +// HttpTransportWithBearer wraps the default RoundTripper to add the Authorization header. +type HttpTransportWithBearer struct { + BaseTransport http.RoundTripper + Token string +} + +// RoundTrip implements the RoundTripper interface to modify the request. +func (t *HttpTransportWithBearer) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone the request to avoid side effects + reqClone := req.Clone(req.Context()) + + // Add the Authorization header + reqClone.Header.Set("Authorization", fmt.Sprintf("Bearer %s", t.Token)) + + // Use the base RoundTripper to perform the request + return t.BaseTransport.RoundTrip(reqClone) +} + +func NewHttpClientWithBearerTransport(token string) *http.Client { + // Create a new HTTP client with the custom transport + return &http.Client{ + Transport: &HttpTransportWithBearer{ + BaseTransport: http.DefaultTransport, + Token: token, + }, + } +} diff --git a/http_client_bearer_test.go b/http_client_bearer_test.go new file mode 100644 index 0000000..2c90081 --- /dev/null +++ b/http_client_bearer_test.go @@ -0,0 +1,54 @@ +package main + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +// TestHttpClientWithBearerTransport tests the addition of the Authorization header. +func TestHttpClientWithBearerTransport(t *testing.T) { + // Define the expected Bearer token + token := "test_bearer_token" + + // Set up a test HTTP server + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Retrieve the Authorization header from the request + authHeader := r.Header.Get("Authorization") + expectedHeader := fmt.Sprintf("Bearer %s", token) + + // Check if the Authorization header matches the expected value + if authHeader != expectedHeader { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Return a success response + w.WriteHeader(http.StatusOK) + io.WriteString(w, "Success") + })) + defer testServer.Close() + + // Create an HTTP client with the custom transport + client := NewHttpClientWithBearerTransport(token) + + // Create a new HTTP request to the test server + req, err := http.NewRequest("GET", testServer.URL, nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + // Perform the request using the custom client + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + // Check if the status code is 200 OK + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code 200 OK, got %d", resp.StatusCode) + } +} diff --git a/main.go b/main.go index b670f46..19703ff 100644 --- a/main.go +++ b/main.go @@ -14,7 +14,6 @@ import ( "time" "github.com/gin-gonic/gin" - retryablehttp "github.com/hashicorp/go-retryablehttp" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/ollama" "github.com/tmc/langchaingo/llms/openai" @@ -141,22 +140,23 @@ func createLLM() (llms.Model, error) { if host == "" { host = "http://127.0.0.1:11434" } - // custom http client (retryable http client) if bearer token is wanted - retryClient := retryablehttp.NewClient() - retryClient.RetryMax = 10 + ollamaOptions := []ollama.Option{ + ollama.WithModel(llmModel), + ollama.WithServerURL(host), + } bearerToken := os.Getenv("OLLAMA_BEARER_TOKEN") if bearerToken != "" { - retryClient.RequestLogHook = func(l retryablehttp.Logger, r *http.Request, i int) { - r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", bearerToken)) - shortenedBearerToken := fmt.Sprintf("%s...", r.Header.Get("Authorization")[:5]) - log.Printf("Request with bearer %s token to %s %s", shortenedBearerToken, r.Method, r.URL) - } + log.Println("Using bearer token for OLLAMA authentication") + ollamaOptions = append( + ollamaOptions, + ollama.WithHTTPClient( + NewHttpClientWithBearerTransport(bearerToken), + ), + ) } return ollama.New( - ollama.WithModel(llmModel), - ollama.WithServerURL(host), - ollama.WithHTTPClient(retryClient.StandardClient()), + ollamaOptions..., ) default: return nil, fmt.Errorf("unsupported LLM provider: %s", llmProvider)