Skip to content

Commit b55b494

Browse files
committed
OpenAI - Adopt ToolCallingManager API
- Update OpenAiChatModel to use the new ToolCallingManager API, while ensuring full API backward compatibility. - Introduce Builder to instantiate a new OpenAiChatModel since the number of overloaded constructors is growing too big. - Update documentation about tool calling and OpenAI support for that. - Add extra validation to ensure the uniqueness of tool names when aggregated from different sources. Signed-off-by: Thomas Vitale <[email protected]>
1 parent 1fee082 commit b55b494

File tree

18 files changed

+582
-191
lines changed

18 files changed

+582
-191
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,10 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
129129
@Nullable FunctionCallbackResolver functionCallbackResolver,
130130
@Nullable List<FunctionCallback> toolFunctionCallbacks, ObservationRegistry observationRegistry,
131131
ModelManagementOptions modelManagementOptions) {
132-
super(functionCallbackResolver, defaultOptions, toolFunctionCallbacks);
132+
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
133+
// because it modifies them. We are using ToolCallingManager instead,
134+
// so we just pass empty options here.
135+
super(functionCallbackResolver, OllamaOptions.builder().build(), toolFunctionCallbacks);
133136
Assert.notNull(ollamaApi, "ollamaApi must not be null");
134137
Assert.notNull(defaultOptions, "defaultOptions must not be null");
135138
Assert.notNull(observationRegistry, "observationRegistry must not be null");
@@ -147,10 +150,9 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
147150

148151
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
149152
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
150-
// We do not pass the 'defaultOptions' to the AbstractToolSupport, because it
151-
// modifies them.
152-
// We are not using the AbstractToolSupport class in this path, so we just pass
153-
// empty options.
153+
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
154+
// because it modifies them. We are using ToolCallingManager instead,
155+
// so we just pass empty options here.
154156
super(null, OllamaOptions.builder().build(), List.of());
155157
Assert.notNull(ollamaApi, "ollamaApi must not be null");
156158
Assert.notNull(defaultOptions, "defaultOptions must not be null");
@@ -424,6 +426,8 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOp
424426
throw new IllegalArgumentException("model cannot be null or empty");
425427
}
426428

429+
ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());
430+
427431
return new Prompt(prompt.getInstructions(), requestOptions);
428432
}
429433

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 243 additions & 67 deletions
Large diffs are not rendered by default.

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java

Lines changed: 122 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.ai.openai;
1818

1919
import java.util.ArrayList;
20+
import java.util.Arrays;
2021
import java.util.HashMap;
2122
import java.util.HashSet;
2223
import java.util.List;
@@ -31,12 +32,14 @@
3132

3233
import org.springframework.ai.model.ModelOptionsUtils;
3334
import org.springframework.ai.model.function.FunctionCallback;
34-
import org.springframework.ai.model.function.FunctionCallingOptions;
35+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3536
import org.springframework.ai.openai.api.OpenAiApi;
3637
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters;
3738
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions;
3839
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder;
3940
import org.springframework.ai.openai.api.ResponseFormat;
41+
import org.springframework.ai.tool.ToolCallback;
42+
import org.springframework.lang.Nullable;
4043
import org.springframework.util.Assert;
4144

4245
/**
@@ -49,7 +52,7 @@
4952
* @since 0.8.0
5053
*/
5154
@JsonInclude(Include.NON_NULL)
52-
public class OpenAiChatOptions implements FunctionCallingOptions {
55+
public class OpenAiChatOptions implements ToolCallingChatOptions {
5356

5457
// @formatter:off
5558
/**
@@ -192,33 +195,22 @@ public class OpenAiChatOptions implements FunctionCallingOptions {
192195
private @JsonProperty("reasoning_effort") String reasoningEffort;
193196

194197
/**
195-
* OpenAI Tool Function Callbacks to register with the ChatModel.
196-
* For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution.
197-
* For Default Options the functionCallbacks are registered but disabled by default. Use the enableFunctions to set the functions
198-
* from the registry to be used by the ChatModel chat completion requests.
198+
* Collection of {@link ToolCallback}s to be used for tool calling in the chat completion requests.
199199
*/
200200
@JsonIgnore
201-
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
201+
private List<FunctionCallback> toolCallbacks = new ArrayList<>();
202202

203203
/**
204-
* List of functions, identified by their names, to configure for function calling in
205-
* the chat completion requests.
206-
* Functions with those names must exist in the functionCallbacks registry.
207-
* The {@link #functionCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution.
208-
*
209-
* Note that function enabled with the default options are enabled for all chat completion requests. This could impact the token count and the billing.
210-
* If the functions is set in a prompt options, then the enabled functions are only active for the duration of this prompt execution.
204+
* Collection of tool names to be resolved at runtime and used for tool calling in the chat completion requests.
211205
*/
212206
@JsonIgnore
213-
private Set<String> functions = new HashSet<>();
207+
private Set<String> toolNames = new HashSet<>();
214208

215209
/**
216-
* If true, the Spring AI will not handle the function calls internally, but will proxy them to the client.
217-
* It is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results.
218-
* If false, the Spring AI will handle the function calls internally.
210+
* Whether to enable the tool execution lifecycle internally in ChatModel.
219211
*/
220212
@JsonIgnore
221-
private Boolean proxyToolCalls;
213+
private Boolean internalToolExecutionEnabled;
222214

223215
/**
224216
* Optional HTTP headers to be added to the chat completion request.
@@ -227,7 +219,7 @@ public class OpenAiChatOptions implements FunctionCallingOptions {
227219
private Map<String, String> httpHeaders = new HashMap<>();
228220

229221
@JsonIgnore
230-
private Map<String, Object> toolContext;
222+
private Map<String, Object> toolContext = new HashMap<>();
231223

232224
// @formatter:on
233225

@@ -258,10 +250,10 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) {
258250
.toolChoice(fromOptions.getToolChoice())
259251
.user(fromOptions.getUser())
260252
.parallelToolCalls(fromOptions.getParallelToolCalls())
261-
.functionCallbacks(fromOptions.getFunctionCallbacks())
262-
.functions(fromOptions.getFunctions())
253+
.toolCallbacks(fromOptions.getFunctionCallbacks())
254+
.toolNames(fromOptions.getFunctions())
263255
.httpHeaders(fromOptions.getHttpHeaders())
264-
.proxyToolCalls(fromOptions.getProxyToolCalls())
256+
.internalToolExecutionEnabled(fromOptions.getProxyToolCalls())
265257
.toolContext(fromOptions.getToolContext())
266258
.store(fromOptions.getStore())
267259
.metadata(fromOptions.getMetadata())
@@ -447,12 +439,16 @@ public void setToolChoice(Object toolChoice) {
447439
}
448440

449441
@Override
442+
@Deprecated
443+
@JsonIgnore
450444
public Boolean getProxyToolCalls() {
451-
return this.proxyToolCalls;
445+
return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null;
452446
}
453447

448+
@Deprecated
449+
@JsonIgnore
454450
public void setProxyToolCalls(Boolean proxyToolCalls) {
455-
this.proxyToolCalls = proxyToolCalls;
451+
this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null;
456452
}
457453

458454
public String getUser() {
@@ -472,22 +468,73 @@ public void setParallelToolCalls(Boolean parallelToolCalls) {
472468
}
473469

474470
@Override
471+
@JsonIgnore
472+
public List<FunctionCallback> getToolCallbacks() {
473+
return this.toolCallbacks;
474+
}
475+
476+
@Override
477+
@JsonIgnore
478+
public void setToolCallbacks(List<FunctionCallback> toolCallbacks) {
479+
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
480+
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
481+
this.toolCallbacks = toolCallbacks;
482+
}
483+
484+
@Override
485+
@JsonIgnore
486+
public Set<String> getToolNames() {
487+
return this.toolNames;
488+
}
489+
490+
@Override
491+
@JsonIgnore
492+
public void setToolNames(Set<String> toolNames) {
493+
Assert.notNull(toolNames, "toolNames cannot be null");
494+
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
495+
toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements"));
496+
this.toolNames = toolNames;
497+
}
498+
499+
@Override
500+
@Nullable
501+
@JsonIgnore
502+
public Boolean isInternalToolExecutionEnabled() {
503+
return internalToolExecutionEnabled;
504+
}
505+
506+
@Override
507+
@JsonIgnore
508+
public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
509+
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
510+
}
511+
512+
@Override
513+
@Deprecated
514+
@JsonIgnore
475515
public List<FunctionCallback> getFunctionCallbacks() {
476-
return this.functionCallbacks;
516+
return this.getToolCallbacks();
477517
}
478518

479519
@Override
520+
@Deprecated
521+
@JsonIgnore
480522
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
481-
this.functionCallbacks = functionCallbacks;
523+
this.setToolCallbacks(functionCallbacks);
482524
}
483525

484526
@Override
527+
@Deprecated
528+
@JsonIgnore
485529
public Set<String> getFunctions() {
486-
return this.functions;
530+
return this.getToolNames();
487531
}
488532

533+
@Override
534+
@Deprecated
535+
@JsonIgnore
489536
public void setFunctions(Set<String> functionNames) {
490-
this.functions = functionNames;
537+
this.setToolNames(functionNames);
491538
}
492539

493540
public Map<String, String> getHttpHeaders() {
@@ -505,11 +552,13 @@ public Integer getTopK() {
505552
}
506553

507554
@Override
555+
@JsonIgnore
508556
public Map<String, Object> getToolContext() {
509557
return this.toolContext;
510558
}
511559

512560
@Override
561+
@JsonIgnore
513562
public void setToolContext(Map<String, Object> toolContext) {
514563
this.toolContext = toolContext;
515564
}
@@ -548,9 +597,9 @@ public int hashCode() {
548597
return Objects.hash(this.model, this.frequencyPenalty, this.logitBias, this.logprobs, this.topLogprobs,
549598
this.maxTokens, this.maxCompletionTokens, this.n, this.presencePenalty, this.responseFormat,
550599
this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice,
551-
this.user, this.parallelToolCalls, this.functionCallbacks, this.functions, this.httpHeaders,
552-
this.proxyToolCalls, this.toolContext, this.outputModalities, this.outputAudio, this.store,
553-
this.metadata, this.reasoningEffort);
600+
this.user, this.parallelToolCalls, this.toolCallbacks, this.toolNames, this.httpHeaders,
601+
this.internalToolExecutionEnabled, this.toolContext, this.outputModalities, this.outputAudio,
602+
this.store, this.metadata, this.reasoningEffort);
554603
}
555604

556605
@Override
@@ -574,11 +623,11 @@ public boolean equals(Object o) {
574623
&& Objects.equals(this.topP, other.topP) && Objects.equals(this.tools, other.tools)
575624
&& Objects.equals(this.toolChoice, other.toolChoice) && Objects.equals(this.user, other.user)
576625
&& Objects.equals(this.parallelToolCalls, other.parallelToolCalls)
577-
&& Objects.equals(this.functionCallbacks, other.functionCallbacks)
578-
&& Objects.equals(this.functions, other.functions)
626+
&& Objects.equals(this.toolCallbacks, other.toolCallbacks)
627+
&& Objects.equals(this.toolNames, other.toolNames)
579628
&& Objects.equals(this.httpHeaders, other.httpHeaders)
580629
&& Objects.equals(this.toolContext, other.toolContext)
581-
&& Objects.equals(this.proxyToolCalls, other.proxyToolCalls)
630+
&& Objects.equals(this.internalToolExecutionEnabled, other.internalToolExecutionEnabled)
582631
&& Objects.equals(this.outputModalities, other.outputModalities)
583632
&& Objects.equals(this.outputAudio, other.outputAudio) && Objects.equals(this.store, other.store)
584633
&& Objects.equals(this.metadata, other.metadata)
@@ -712,25 +761,54 @@ public Builder parallelToolCalls(Boolean parallelToolCalls) {
712761
return this;
713762
}
714763

715-
public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
716-
this.options.functionCallbacks = functionCallbacks;
764+
public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
765+
this.options.setToolCallbacks(toolCallbacks);
717766
return this;
718767
}
719768

720-
public Builder functions(Set<String> functionNames) {
721-
Assert.notNull(functionNames, "Function names must not be null");
722-
this.options.functions = functionNames;
769+
public Builder toolCallbacks(FunctionCallback... toolCallbacks) {
770+
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
771+
this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks));
723772
return this;
724773
}
725774

726-
public Builder function(String functionName) {
727-
Assert.hasText(functionName, "Function name must not be empty");
728-
this.options.functions.add(functionName);
775+
public Builder toolNames(Set<String> toolNames) {
776+
Assert.notNull(toolNames, "toolNames cannot be null");
777+
this.options.setToolNames(toolNames);
778+
return this;
779+
}
780+
781+
public Builder toolNames(String... toolNames) {
782+
Assert.notNull(toolNames, "toolNames cannot be null");
783+
this.options.toolNames.addAll(Set.of(toolNames));
729784
return this;
730785
}
731786

787+
public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
788+
this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled);
789+
return this;
790+
}
791+
792+
@Deprecated
793+
public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
794+
return toolCallbacks(functionCallbacks);
795+
}
796+
797+
@Deprecated
798+
public Builder functions(Set<String> functionNames) {
799+
return toolNames(functionNames);
800+
}
801+
802+
@Deprecated
803+
public Builder function(String functionName) {
804+
return toolNames(functionName);
805+
}
806+
807+
@Deprecated
732808
public Builder proxyToolCalls(Boolean proxyToolCalls) {
733-
this.options.proxyToolCalls = proxyToolCalls;
809+
if (proxyToolCalls != null) {
810+
this.options.setInternalToolExecutionEnabled(!proxyToolCalls);
811+
}
734812
return this;
735813
}
736814

0 commit comments

Comments
 (0)