Support reasoning models (deepseek-r1) in Ollama (#194)

This commit is contained in:
Icereed 2025-02-05 20:59:08 +01:00 committed by GitHub
parent 47b0f07b15
commit caa27c4d53
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 47 additions and 5 deletions

View file

@ -67,7 +67,7 @@ func (app *App) getSuggestedCorrespondent(ctx context.Context, content string, s
return "", fmt.Errorf("error getting response from LLM: %v", err) return "", fmt.Errorf("error getting response from LLM: %v", err)
} }
response := strings.TrimSpace(completion.Choices[0].Content) response := stripReasoning(strings.TrimSpace(completion.Choices[0].Content))
return response, nil return response, nil
} }
@ -137,7 +137,8 @@ func (app *App) getSuggestedTags(
return nil, fmt.Errorf("error getting response from LLM: %v", err) return nil, fmt.Errorf("error getting response from LLM: %v", err)
} }
response := strings.TrimSpace(completion.Choices[0].Content) response := stripReasoning(completion.Choices[0].Content)
suggestedTags := strings.Split(response, ",") suggestedTags := strings.Split(response, ",")
for i, tag := range suggestedTags { for i, tag := range suggestedTags {
suggestedTags[i] = strings.TrimSpace(tag) suggestedTags[i] = strings.TrimSpace(tag)
@ -252,7 +253,7 @@ func (app *App) getSuggestedTitle(ctx context.Context, content string, originalT
var promptBuffer bytes.Buffer var promptBuffer bytes.Buffer
templateData["Content"] = truncatedContent templateData["Content"] = truncatedContent
err = titleTemplate.Execute(&promptBuffer, templateData) err = titleTemplate.Execute(&promptBuffer, templateData)
if err != nil { if err != nil {
return "", fmt.Errorf("error executing title template: %v", err) return "", fmt.Errorf("error executing title template: %v", err)
} }
@ -273,8 +274,8 @@ func (app *App) getSuggestedTitle(ctx context.Context, content string, originalT
if err != nil { if err != nil {
return "", fmt.Errorf("error getting response from LLM: %v", err) return "", fmt.Errorf("error getting response from LLM: %v", err)
} }
result := stripReasoning(completion.Choices[0].Content)
return strings.TrimSpace(strings.Trim(completion.Choices[0].Content, "\"")), nil return strings.TrimSpace(strings.Trim(result, "\"")), nil
} }
// generateDocumentSuggestions generates suggestions for a set of documents // generateDocumentSuggestions generates suggestions for a set of documents
@ -404,3 +405,19 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque
return documentSuggestions, nil return documentSuggestions, nil
} }
// stripReasoning removes the reasoning from the content indicated by <think> and </think> tags.
func stripReasoning(content string) string {
// Remove reasoning from the content
reasoningStart := strings.Index(content, "<think>")
if reasoningStart != -1 {
reasoningEnd := strings.Index(content, "</think>")
if reasoningEnd != -1 {
content = content[:reasoningStart] + content[reasoningEnd+len("</think>"):]
}
}
// Trim whitespace
content = strings.TrimSpace(content)
return content
}

View file

@ -266,3 +266,28 @@ func TestTokenLimitInTitleGeneration(t *testing.T) {
// Final prompt should be within token limit // Final prompt should be within token limit
assert.LessOrEqual(t, len(tokens), 50, "Final prompt should be within token limit") assert.LessOrEqual(t, len(tokens), 50, "Final prompt should be within token limit")
} }
func TestStripReasoning(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "No reasoning tags",
input: "This is a test content without reasoning tags.",
expected: "This is a test content without reasoning tags.",
},
{
name: "Reasoning tags at the start",
input: "<think>Start reasoning</think>\n\nContent \n\n",
expected: "Content",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := stripReasoning(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}