Skip to content

Commit 52198ed

Browse files
ThomasVitaletzolov
authored andcommitted
OpenAI - Adopt ToolCallingManager API
- Update OpenAiChatModel to use the new ToolCallingManager API, while ensuring full API backward compatibility. - Introduce Builder to instantiate a new OpenAiChatModel since the number of overloaded constructors is growing too big. - Update documentation about tool calling and OpenAI support for that. - Add extra validation to ensure the uniqueness of tool names when aggregated from different sources. - Ensure consistent merging of options, following Spring Boot strategy. Signed-off-by: Thomas Vitale <[email protected]>
1 parent 9710f78 commit 52198ed

File tree

20 files changed

+613
-235
lines changed

20 files changed

+613
-235
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -129,28 +129,18 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
129129
@Nullable FunctionCallbackResolver functionCallbackResolver,
130130
@Nullable List<FunctionCallback> toolFunctionCallbacks, ObservationRegistry observationRegistry,
131131
ModelManagementOptions modelManagementOptions) {
132-
super(functionCallbackResolver, defaultOptions, toolFunctionCallbacks);
133-
Assert.notNull(ollamaApi, "ollamaApi must not be null");
134-
Assert.notNull(defaultOptions, "defaultOptions must not be null");
135-
Assert.notNull(observationRegistry, "observationRegistry must not be null");
136-
Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
137-
this.chatApi = ollamaApi;
138-
this.defaultOptions = defaultOptions;
139-
this.toolCallingManager = new LegacyToolCallingManager(functionCallbackResolver, toolFunctionCallbacks);
140-
this.observationRegistry = observationRegistry;
141-
this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions);
142-
initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
132+
this(ollamaApi, defaultOptions, new LegacyToolCallingManager(functionCallbackResolver, toolFunctionCallbacks),
133+
observationRegistry, modelManagementOptions);
143134

144135
logger.warn("This constructor is deprecated and will be removed in the next milestone. "
145-
+ "Please use the new constructor accepting ToolCallingManager instead.");
136+
+ "Please use the OllamaChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
146137
}
147138

148139
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
149140
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
150-
// We do not pass the 'defaultOptions' to the AbstractToolSupport, because it
151-
// modifies them.
152-
// We are not using the AbstractToolSupport class in this path, so we just pass
153-
// empty options.
141+
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
142+
// because it modifies them. We are using ToolCallingManager instead,
143+
// so we just pass empty options here.
154144
super(null, OllamaOptions.builder().build(), List.of());
155145
Assert.notNull(ollamaApi, "ollamaApi must not be null");
156146
Assert.notNull(defaultOptions, "defaultOptions must not be null");
@@ -424,6 +414,8 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOp
424414
throw new IllegalArgumentException("model cannot be null or empty");
425415
}
426416

417+
ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());
418+
427419
return new Prompt(prompt.getInstructions(), requestOptions);
428420
}
429421

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import org.springframework.ai.chat.prompt.ChatOptions;
2222
import org.springframework.ai.chat.prompt.Prompt;
23+
import org.springframework.ai.model.function.FunctionCallback;
2324
import org.springframework.ai.model.tool.ToolCallingChatOptions;
2425
import org.springframework.ai.ollama.api.OllamaApi;
2526
import org.springframework.ai.ollama.api.OllamaOptions;
@@ -48,7 +49,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() {
4849
.internalToolExecutionEnabled(true)
4950
.toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2"))
5051
.toolNames("tool1", "tool2")
51-
.toolContext(Map.of("key1", "value1"))
52+
.toolContext(Map.of("key1", "value1", "key2", "valueA"))
5253
.build();
5354
OllamaChatModel chatModel = OllamaChatModel.builder()
5455
.ollamaApi(new OllamaApi())
@@ -59,17 +60,19 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() {
5960
.internalToolExecutionEnabled(false)
6061
.toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4"))
6162
.toolNames("tool3")
62-
.toolContext(Map.of("key2", "value2"))
63+
.toolContext(Map.of("key2", "valueB"))
6364
.build();
6465
Prompt prompt = chatModel.buildRequestPrompt(new Prompt("Test message content", runtimeOptions));
6566

6667
assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull();
6768
assertThat(((ToolCallingChatOptions) prompt.getOptions()).isInternalToolExecutionEnabled()).isFalse();
68-
assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(4);
69-
assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool1",
70-
"tool2", "tool3");
69+
assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2);
70+
assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()
71+
.stream()
72+
.map(FunctionCallback::getName)).containsExactlyInAnyOrder("tool3", "tool4");
73+
assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool3");
7174
assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolContext()).containsEntry("key1", "value1")
72-
.containsEntry("key2", "value2");
75+
.containsEntry("key2", "valueB");
7376
}
7477

7578
@Test

0 commit comments

Comments
 (0)