diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index b5743d80bfd..6dcb48d28f2 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -147,7 +147,11 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) { - super(null, defaultOptions, List.of()); + // We do not pass the 'defaultOptions' to the AbstractToolSupport, because it + // modifies them. + // We are not using the AbstractToolSupport class in this path, so we just pass + // empty options. + super(null, OllamaOptions.builder().build(), List.of()); Assert.notNull(ollamaApi, "ollamaApi must not be null"); Assert.notNull(defaultOptions, "defaultOptions must not be null"); Assert.notNull(toolCallingManager, "toolCallingManager must not be null"); @@ -395,17 +399,24 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOp // Define request options by merging runtime options and default options OllamaOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class); - // Merge tool names and tool callbacks explicitly since they are ignored by + // Merge @JsonIgnore-annotated options explicitly since they are ignored by // Jackson, used by ModelOptionsUtils. if (runtimeOptions != null) { + requestOptions.setInternalToolExecutionEnabled( + ModelOptionsUtils.mergeOption(runtimeOptions.isInternalToolExecutionEnabled(), + this.defaultOptions.isInternalToolExecutionEnabled())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), this.defaultOptions.getToolCallbacks())); + requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), + this.defaultOptions.getToolContext())); } else { + requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.isInternalToolExecutionEnabled()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); + requestOptions.setToolContext(this.defaultOptions.getToolContext()); } // Validate request options diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index fc815202b3f..c6706e337d0 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -331,7 +332,7 @@ public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions { private Set toolNames = new HashSet<>(); @JsonIgnore - private Map toolContext; + private Map toolContext = new HashMap<>(); public static Builder builder() { return new Builder(); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index 9f38c6fa06a..740c299ab8a 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -20,8 +20,13 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; + +import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -36,6 +41,37 @@ class OllamaChatRequestTests { .defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build()) .build(); + @Test + void whenToolRuntimeOptionsThenMergeWithDefaults() { + OllamaOptions defaultOptions = OllamaOptions.builder() + .model("MODEL_NAME") + .internalToolExecutionEnabled(true) + .toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2")) + .toolNames("tool1", "tool2") + .toolContext(Map.of("key1", "value1")) + .build(); + OllamaChatModel chatModel = OllamaChatModel.builder() + .ollamaApi(new OllamaApi()) + .defaultOptions(defaultOptions) + .build(); + + OllamaOptions runtimeOptions = OllamaOptions.builder() + .internalToolExecutionEnabled(false) + .toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4")) + .toolNames("tool3") + .toolContext(Map.of("key2", "value2")) + .build(); + Prompt prompt = chatModel.buildRequestPrompt(new Prompt("Test message content", runtimeOptions)); + + assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).isInternalToolExecutionEnabled()).isFalse(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(4); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool1", + "tool2", "tool3"); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolContext()).containsEntry("key1", "value1") + .containsEntry("key2", "value2"); + } + @Test void createRequestWithDefaultOptions() { var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content")); @@ -124,4 +160,24 @@ public void createRequestWithDefaultOptionsModelOverride() { assertThat(request.model()).isEqualTo("PROMPT_MODEL"); } + static class TestToolCallback implements ToolCallback { + + private final ToolDefinition toolDefinition; + + public TestToolCallback(String name) { + this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build(); + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + return "Mission accomplished!"; + } + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java index 6be2cf37d37..b4c25f91172 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java @@ -24,6 +24,7 @@ import org.springframework.util.Assert; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -204,4 +205,15 @@ static List mergeToolCallbacks(List runtimeT return mergedToolCallbacks; } + static Map mergeToolContext(Map runtimeToolContext, + Map defaultToolContext) { + Assert.notNull(runtimeToolContext, "runtimeToolContext cannot be null"); + Assert.noNullElements(runtimeToolContext.keySet(), "runtimeToolContext keys cannot be null"); + Assert.notNull(defaultToolContext, "defaultToolContext cannot be null"); + Assert.noNullElements(defaultToolContext.keySet(), "defaultToolContext keys cannot be null"); + var mergedToolContext = new HashMap<>(defaultToolContext); + mergedToolContext.putAll(runtimeToolContext); + return mergedToolContext; + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java index c3f92df2580..134151ab0b9 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java @@ -22,6 +22,7 @@ import org.springframework.ai.tool.definition.ToolDefinition; import java.util.List; +import java.util.Map; import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; @@ -141,6 +142,47 @@ void whenMergeEmptyRuntimeAndEmptyDefaultToolCallbacks() { assertThat(mergedToolCallbacks).hasSize(0); } + @Test + void whenMergeRuntimeAndDefaultToolContext() { + Map runtimeToolContext = Map.of("key1", "value1", "key2", "value2"); + Map defaultToolContext = Map.of("key1", "valueA", "key3", "value3"); + Map mergedToolContext = ToolCallingChatOptions.mergeToolContext(runtimeToolContext, + defaultToolContext); + assertThat(mergedToolContext).hasSize(3); + assertThat(mergedToolContext).containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("key3", "value3"); + } + + @Test + void whenMergeRuntimeAndEmptyDefaultToolContext() { + Map runtimeToolContext = Map.of("key1", "value1", "key2", "value2"); + Map defaultToolContext = Map.of(); + Map mergedToolContext = ToolCallingChatOptions.mergeToolContext(runtimeToolContext, + defaultToolContext); + assertThat(mergedToolContext).hasSize(2); + assertThat(mergedToolContext).containsEntry("key1", "value1").containsEntry("key2", "value2"); + } + + @Test + void whenMergeEmptyRuntimeAndDefaultToolContext() { + Map runtimeToolContext = Map.of(); + Map defaultToolContext = Map.of("key1", "value1", "key2", "value2"); + Map mergedToolContext = ToolCallingChatOptions.mergeToolContext(runtimeToolContext, + defaultToolContext); + assertThat(mergedToolContext).hasSize(2); + assertThat(mergedToolContext).containsEntry("key1", "value1").containsEntry("key2", "value2"); + } + + @Test + void whenMergeEmptyRuntimeAndEmptyDefaultToolContext() { + Map runtimeToolContext = Map.of(); + Map defaultToolContext = Map.of(); + Map mergedToolContext = ToolCallingChatOptions.mergeToolContext(runtimeToolContext, + defaultToolContext); + assertThat(mergedToolContext).hasSize(0); + } + static class TestToolCallback implements ToolCallback { private final ToolDefinition toolDefinition;