Skip to content

Commit 01632c5

Browse files
ThomasVitaletzolov
authored andcommitted
Add merge for missing Ollama options
When fields in OllamaOptions are marked as ignored in Jackson, they require explicit merge of runtime and default options. Added tests to validate the different merge combinations for all tool-related options. Signed-off-by: Thomas Vitale <[email protected]>
1 parent ecae493 commit 01632c5

File tree

5 files changed

+125
-3
lines changed

5 files changed

+125
-3
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,11 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
147147

148148
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
149149
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
150-
super(null, defaultOptions, List.of());
150+
// We do not pass the 'defaultOptions' to the AbstractToolSupport, because it
151+
// modifies them.
152+
// We are not using the AbstractToolSupport class in this path, so we just pass
153+
// empty options.
154+
super(null, OllamaOptions.builder().build(), List.of());
151155
Assert.notNull(ollamaApi, "ollamaApi must not be null");
152156
Assert.notNull(defaultOptions, "defaultOptions must not be null");
153157
Assert.notNull(toolCallingManager, "toolCallingManager must not be null");
@@ -395,17 +399,24 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOp
395399
// Define request options by merging runtime options and default options
396400
OllamaOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
397401
OllamaOptions.class);
398-
// Merge tool names and tool callbacks explicitly since they are ignored by
402+
// Merge @JsonIgnore-annotated options explicitly since they are ignored by
399403
// Jackson, used by ModelOptionsUtils.
400404
if (runtimeOptions != null) {
405+
requestOptions.setInternalToolExecutionEnabled(
406+
ModelOptionsUtils.mergeOption(runtimeOptions.isInternalToolExecutionEnabled(),
407+
this.defaultOptions.isInternalToolExecutionEnabled()));
401408
requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
402409
this.defaultOptions.getToolNames()));
403410
requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),
404411
this.defaultOptions.getToolCallbacks()));
412+
requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
413+
this.defaultOptions.getToolContext()));
405414
}
406415
else {
416+
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.isInternalToolExecutionEnabled());
407417
requestOptions.setToolNames(this.defaultOptions.getToolNames());
408418
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
419+
requestOptions.setToolContext(this.defaultOptions.getToolContext());
409420
}
410421

411422
// Validate request options

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.Arrays;
21+
import java.util.HashMap;
2122
import java.util.HashSet;
2223
import java.util.List;
2324
import java.util.Map;
@@ -331,7 +332,7 @@ public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions {
331332
private Set<String> toolNames = new HashSet<>();
332333

333334
@JsonIgnore
334-
private Map<String, Object> toolContext;
335+
private Map<String, Object> toolContext = new HashMap<>();
335336

336337
public static Builder builder() {
337338
return new Builder();

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,13 @@
2020

2121
import org.springframework.ai.chat.prompt.ChatOptions;
2222
import org.springframework.ai.chat.prompt.Prompt;
23+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
2324
import org.springframework.ai.ollama.api.OllamaApi;
2425
import org.springframework.ai.ollama.api.OllamaOptions;
26+
import org.springframework.ai.tool.ToolCallback;
27+
import org.springframework.ai.tool.definition.ToolDefinition;
28+
29+
import java.util.Map;
2530

2631
import static org.assertj.core.api.Assertions.assertThat;
2732

@@ -36,6 +41,37 @@ class OllamaChatRequestTests {
3641
.defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build())
3742
.build();
3843

44+
@Test
45+
void whenToolRuntimeOptionsThenMergeWithDefaults() {
46+
OllamaOptions defaultOptions = OllamaOptions.builder()
47+
.model("MODEL_NAME")
48+
.internalToolExecutionEnabled(true)
49+
.toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2"))
50+
.toolNames("tool1", "tool2")
51+
.toolContext(Map.of("key1", "value1"))
52+
.build();
53+
OllamaChatModel chatModel = OllamaChatModel.builder()
54+
.ollamaApi(new OllamaApi())
55+
.defaultOptions(defaultOptions)
56+
.build();
57+
58+
OllamaOptions runtimeOptions = OllamaOptions.builder()
59+
.internalToolExecutionEnabled(false)
60+
.toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4"))
61+
.toolNames("tool3")
62+
.toolContext(Map.of("key2", "value2"))
63+
.build();
64+
Prompt prompt = chatModel.buildRequestPrompt(new Prompt("Test message content", runtimeOptions));
65+
66+
assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull();
67+
assertThat(((ToolCallingChatOptions) prompt.getOptions()).isInternalToolExecutionEnabled()).isFalse();
68+
assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(4);
69+
assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool1",
70+
"tool2", "tool3");
71+
assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolContext()).containsEntry("key1", "value1")
72+
.containsEntry("key2", "value2");
73+
}
74+
3975
@Test
4076
void createRequestWithDefaultOptions() {
4177
var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content"));
@@ -124,4 +160,24 @@ public void createRequestWithDefaultOptionsModelOverride() {
124160
assertThat(request.model()).isEqualTo("PROMPT_MODEL");
125161
}
126162

163+
static class TestToolCallback implements ToolCallback {
164+
165+
private final ToolDefinition toolDefinition;
166+
167+
public TestToolCallback(String name) {
168+
this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build();
169+
}
170+
171+
@Override
172+
public ToolDefinition getToolDefinition() {
173+
return toolDefinition;
174+
}
175+
176+
@Override
177+
public String call(String toolInput) {
178+
return "Mission accomplished!";
179+
}
180+
181+
}
182+
127183
}

spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.springframework.util.Assert;
2525

2626
import java.util.ArrayList;
27+
import java.util.HashMap;
2728
import java.util.HashSet;
2829
import java.util.List;
2930
import java.util.Map;
@@ -204,4 +205,15 @@ static List<FunctionCallback> mergeToolCallbacks(List<FunctionCallback> runtimeT
204205
return mergedToolCallbacks;
205206
}
206207

208+
static Map<String, Object> mergeToolContext(Map<String, Object> runtimeToolContext,
209+
Map<String, Object> defaultToolContext) {
210+
Assert.notNull(runtimeToolContext, "runtimeToolContext cannot be null");
211+
Assert.noNullElements(runtimeToolContext.keySet(), "runtimeToolContext keys cannot be null");
212+
Assert.notNull(defaultToolContext, "defaultToolContext cannot be null");
213+
Assert.noNullElements(defaultToolContext.keySet(), "defaultToolContext keys cannot be null");
214+
var mergedToolContext = new HashMap<>(defaultToolContext);
215+
mergedToolContext.putAll(runtimeToolContext);
216+
return mergedToolContext;
217+
}
218+
207219
}

spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.springframework.ai.tool.definition.ToolDefinition;
2323

2424
import java.util.List;
25+
import java.util.Map;
2526
import java.util.Set;
2627

2728
import static org.assertj.core.api.Assertions.assertThat;
@@ -141,6 +142,47 @@ void whenMergeEmptyRuntimeAndEmptyDefaultToolCallbacks() {
141142
assertThat(mergedToolCallbacks).hasSize(0);
142143
}
143144

145+
@Test
146+
void whenMergeRuntimeAndDefaultToolContext() {
147+
Map<String, Object> runtimeToolContext = Map.of("key1", "value1", "key2", "value2");
148+
Map<String, Object> defaultToolContext = Map.of("key1", "valueA", "key3", "value3");
149+
Map<String, Object> mergedToolContext = ToolCallingChatOptions.mergeToolContext(runtimeToolContext,
150+
defaultToolContext);
151+
assertThat(mergedToolContext).hasSize(3);
152+
assertThat(mergedToolContext).containsEntry("key1", "value1")
153+
.containsEntry("key2", "value2")
154+
.containsEntry("key3", "value3");
155+
}
156+
157+
@Test
158+
void whenMergeRuntimeAndEmptyDefaultToolContext() {
159+
Map<String, Object> runtimeToolContext = Map.of("key1", "value1", "key2", "value2");
160+
Map<String, Object> defaultToolContext = Map.of();
161+
Map<String, Object> mergedToolContext = ToolCallingChatOptions.mergeToolContext(runtimeToolContext,
162+
defaultToolContext);
163+
assertThat(mergedToolContext).hasSize(2);
164+
assertThat(mergedToolContext).containsEntry("key1", "value1").containsEntry("key2", "value2");
165+
}
166+
167+
@Test
168+
void whenMergeEmptyRuntimeAndDefaultToolContext() {
169+
Map<String, Object> runtimeToolContext = Map.of();
170+
Map<String, Object> defaultToolContext = Map.of("key1", "value1", "key2", "value2");
171+
Map<String, Object> mergedToolContext = ToolCallingChatOptions.mergeToolContext(runtimeToolContext,
172+
defaultToolContext);
173+
assertThat(mergedToolContext).hasSize(2);
174+
assertThat(mergedToolContext).containsEntry("key1", "value1").containsEntry("key2", "value2");
175+
}
176+
177+
@Test
178+
void whenMergeEmptyRuntimeAndEmptyDefaultToolContext() {
179+
Map<String, Object> runtimeToolContext = Map.of();
180+
Map<String, Object> defaultToolContext = Map.of();
181+
Map<String, Object> mergedToolContext = ToolCallingChatOptions.mergeToolContext(runtimeToolContext,
182+
defaultToolContext);
183+
assertThat(mergedToolContext).hasSize(0);
184+
}
185+
144186
static class TestToolCallback implements ToolCallback {
145187

146188
private final ToolDefinition toolDefinition;

0 commit comments

Comments
 (0)