From caa27c4d53b5c7e4eb07f0a73fca242d4780cbb8 Mon Sep 17 00:00:00 2001 From: Icereed Date: Wed, 5 Feb 2025 20:59:08 +0100 Subject: [PATCH] Support reasoning models (deepseek-r1) in Ollama (#194) --- app_llm.go | 27 ++++++++++++++++++++++----- app_llm_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/app_llm.go b/app_llm.go index 8a56b6d..8a38528 100644 --- a/app_llm.go +++ b/app_llm.go @@ -67,7 +67,7 @@ func (app *App) getSuggestedCorrespondent(ctx context.Context, content string, s return "", fmt.Errorf("error getting response from LLM: %v", err) } - response := strings.TrimSpace(completion.Choices[0].Content) + response := stripReasoning(strings.TrimSpace(completion.Choices[0].Content)) return response, nil } @@ -137,7 +137,8 @@ func (app *App) getSuggestedTags( return nil, fmt.Errorf("error getting response from LLM: %v", err) } - response := strings.TrimSpace(completion.Choices[0].Content) + response := stripReasoning(completion.Choices[0].Content) + suggestedTags := strings.Split(response, ",") for i, tag := range suggestedTags { suggestedTags[i] = strings.TrimSpace(tag) @@ -252,7 +253,7 @@ func (app *App) getSuggestedTitle(ctx context.Context, content string, originalT var promptBuffer bytes.Buffer templateData["Content"] = truncatedContent err = titleTemplate.Execute(&promptBuffer, templateData) - + if err != nil { return "", fmt.Errorf("error executing title template: %v", err) } @@ -273,8 +274,8 @@ func (app *App) getSuggestedTitle(ctx context.Context, content string, originalT if err != nil { return "", fmt.Errorf("error getting response from LLM: %v", err) } - - return strings.TrimSpace(strings.Trim(completion.Choices[0].Content, "\"")), nil + result := stripReasoning(completion.Choices[0].Content) + return strings.TrimSpace(strings.Trim(result, "\"")), nil } // generateDocumentSuggestions generates suggestions for a set of documents @@ -404,3 +405,19 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque return documentSuggestions, nil } + +// stripReasoning removes the reasoning from the content indicated by and tags. +func stripReasoning(content string) string { + // Remove reasoning from the content + reasoningStart := strings.Index(content, "") + if reasoningStart != -1 { + reasoningEnd := strings.Index(content, "") + if reasoningEnd != -1 { + content = content[:reasoningStart] + content[reasoningEnd+len(""):] + } + } + + // Trim whitespace + content = strings.TrimSpace(content) + return content +} diff --git a/app_llm_test.go b/app_llm_test.go index 953204b..da94e11 100644 --- a/app_llm_test.go +++ b/app_llm_test.go @@ -266,3 +266,28 @@ func TestTokenLimitInTitleGeneration(t *testing.T) { // Final prompt should be within token limit assert.LessOrEqual(t, len(tokens), 50, "Final prompt should be within token limit") } +func TestStripReasoning(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "No reasoning tags", + input: "This is a test content without reasoning tags.", + expected: "This is a test content without reasoning tags.", + }, + { + name: "Reasoning tags at the start", + input: "Start reasoning\n\nContent \n\n", + expected: "Content", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := stripReasoning(tc.input) + assert.Equal(t, tc.expected, result) + }) + } +}