Skip to content

Commit b902ca2

Browse files
ThomasVitaletzolov
authored andcommitted
Advancing Tool Support - Part 4
* Adopted new tool calling logic in OllamaChatModel, while maintaining full API backward compatibility thanks to the LegacyToolCallingManager. * Improved efficiency and robustness of merging options in prompts for Ollama. * Update Ollama Autoconfiguration to use the new ToolCallingManager. * Improved troubleshooting for new tool calling APIs and finalised changes for full backward compatibility. * Updated Ollama Testcontainers dependency to 0.5.7. Relates to gh-2049 Signed-off-by: Thomas Vitale <[email protected]>
1 parent 76ab91f commit b902ca2

File tree

31 files changed

+1269
-151
lines changed

31 files changed

+1269
-151
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 141 additions & 70 deletions
Large diffs are not rendered by default.

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java

Lines changed: 117 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
1717
package org.springframework.ai.ollama.api;
1818

1919
import java.util.ArrayList;
20+
import java.util.Arrays;
2021
import java.util.HashSet;
2122
import java.util.List;
2223
import java.util.Map;
@@ -32,7 +33,8 @@
3233
import org.springframework.ai.embedding.EmbeddingOptions;
3334
import org.springframework.ai.model.ModelOptionsUtils;
3435
import org.springframework.ai.model.function.FunctionCallback;
35-
import org.springframework.ai.model.function.FunctionCallingOptions;
36+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
37+
import org.springframework.lang.Nullable;
3638
import org.springframework.util.Assert;
3739

3840
/**
@@ -48,7 +50,7 @@
4850
* @see <a href="https://github.com/ollama/ollama/blob/main/api/types.go">Ollama Types</a>
4951
*/
5052
@JsonInclude(Include.NON_NULL)
51-
public class OllamaOptions implements FunctionCallingOptions, EmbeddingOptions {
53+
public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions {
5254

5355
private static final List<String> NON_SUPPORTED_FIELDS = List.of("model", "format", "keep_alive", "truncate");
5456

@@ -305,28 +307,28 @@ public class OllamaOptions implements FunctionCallingOptions, EmbeddingOptions {
305307
@JsonProperty("truncate")
306308
private Boolean truncate;
307309

310+
@JsonIgnore
311+
private Boolean internalToolExecutionEnabled;
312+
308313
/**
309314
* Tool Function Callbacks to register with the ChatModel.
310315
* For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution.
311316
* For Default Options the functionCallbacks are registered but disabled by default. Use the enableFunctions to set the functions
312317
* from the registry to be used by the ChatModel chat completion requests.
313318
*/
314319
@JsonIgnore
315-
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
320+
private List<FunctionCallback> toolCallbacks = new ArrayList<>();
316321

317322
/**
318323
* List of functions, identified by their names, to configure for function calling in
319324
* the chat completion requests.
320325
* Functions with those names must exist in the functionCallbacks registry.
321-
* The {@link #functionCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution.
326+
* The {@link #toolCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution.
322327
* Note that function enabled with the default options are enabled for all chat completion requests. This could impact the token count and the billing.
323328
* If the functions is set in a prompt options, then the enabled functions are only active for the duration of this prompt execution.
324329
*/
325330
@JsonIgnore
326-
private Set<String> functions = new HashSet<>();
327-
328-
@JsonIgnore
329-
private Boolean proxyToolCalls;
331+
private Set<String> toolNames = new HashSet<>();
330332

331333
@JsonIgnore
332334
private Map<String, Object> toolContext;
@@ -381,9 +383,9 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
381383
.mirostatEta(fromOptions.getMirostatEta())
382384
.penalizeNewline(fromOptions.getPenalizeNewline())
383385
.stop(fromOptions.getStop())
384-
.functions(fromOptions.getFunctions())
385-
.proxyToolCalls(fromOptions.getProxyToolCalls())
386-
.functionCallbacks(fromOptions.getFunctionCallbacks())
386+
.tools(fromOptions.getTools())
387+
.internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled())
388+
.toolCallbacks(fromOptions.getToolCallbacks())
387389
.toolContext(fromOptions.getToolContext()).build();
388390
}
389391

@@ -683,23 +685,73 @@ public void setTruncate(Boolean truncate) {
683685
}
684686

685687
@Override
688+
@JsonIgnore
689+
public List<FunctionCallback> getToolCallbacks() {
690+
return this.toolCallbacks;
691+
}
692+
693+
@Override
694+
@JsonIgnore
695+
public void setToolCallbacks(List<FunctionCallback> toolCallbacks) {
696+
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
697+
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
698+
this.toolCallbacks = toolCallbacks;
699+
}
700+
701+
@Override
702+
@JsonIgnore
703+
public Set<String> getTools() {
704+
return this.toolNames;
705+
}
706+
707+
@Override
708+
@JsonIgnore
709+
public void setTools(Set<String> toolNames) {
710+
Assert.notNull(toolNames, "toolNames cannot be null");
711+
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
712+
toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements"));
713+
this.toolNames = toolNames;
714+
}
715+
716+
@Override
717+
@Nullable
718+
@JsonIgnore
719+
public Boolean isInternalToolExecutionEnabled() {
720+
return internalToolExecutionEnabled;
721+
}
722+
723+
@Override
724+
@JsonIgnore
725+
public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
726+
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
727+
}
728+
729+
@Override
730+
@Deprecated
731+
@JsonIgnore
686732
public List<FunctionCallback> getFunctionCallbacks() {
687-
return this.functionCallbacks;
733+
return this.getToolCallbacks();
688734
}
689735

690736
@Override
737+
@Deprecated
738+
@JsonIgnore
691739
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
692-
this.functionCallbacks = functionCallbacks;
740+
this.setToolCallbacks(functionCallbacks);
693741
}
694742

695743
@Override
744+
@Deprecated
745+
@JsonIgnore
696746
public Set<String> getFunctions() {
697-
return this.functions;
747+
return this.getTools();
698748
}
699749

700750
@Override
751+
@Deprecated
752+
@JsonIgnore
701753
public void setFunctions(Set<String> functions) {
702-
this.functions = functions;
754+
this.setTools(functions);
703755
}
704756

705757
@Override
@@ -709,20 +761,26 @@ public Integer getDimensions() {
709761
}
710762

711763
@Override
764+
@Deprecated
765+
@JsonIgnore
712766
public Boolean getProxyToolCalls() {
713-
return this.proxyToolCalls;
767+
return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null;
714768
}
715769

770+
@Deprecated
771+
@JsonIgnore
716772
public void setProxyToolCalls(Boolean proxyToolCalls) {
717-
this.proxyToolCalls = proxyToolCalls;
773+
this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null;
718774
}
719775

720776
@Override
777+
@JsonIgnore
721778
public Map<String, Object> getToolContext() {
722779
return this.toolContext;
723780
}
724781

725782
@Override
783+
@JsonIgnore
726784
public void setToolContext(Map<String, Object> toolContext) {
727785
this.toolContext = toolContext;
728786
}
@@ -769,9 +827,9 @@ public boolean equals(Object o) {
769827
&& Objects.equals(this.mirostat, that.mirostat) && Objects.equals(this.mirostatTau, that.mirostatTau)
770828
&& Objects.equals(this.mirostatEta, that.mirostatEta)
771829
&& Objects.equals(this.penalizeNewline, that.penalizeNewline) && Objects.equals(this.stop, that.stop)
772-
&& Objects.equals(this.functionCallbacks, that.functionCallbacks)
773-
&& Objects.equals(this.proxyToolCalls, that.proxyToolCalls)
774-
&& Objects.equals(this.functions, that.functions) && Objects.equals(this.toolContext, that.toolContext);
830+
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
831+
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
832+
&& Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.toolContext, that.toolContext);
775833
}
776834

777835
@Override
@@ -781,7 +839,7 @@ public int hashCode() {
781839
this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK,
782840
this.topP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty,
783841
this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta,
784-
this.penalizeNewline, this.stop, this.functionCallbacks, this.functions, this.proxyToolCalls,
842+
this.penalizeNewline, this.stop, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled,
785843
this.toolContext);
786844
}
787845

@@ -959,25 +1017,53 @@ public Builder stop(List<String> stop) {
9591017
return this;
9601018
}
9611019

962-
public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
963-
this.options.functionCallbacks = functionCallbacks;
1020+
public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
1021+
this.options.setToolCallbacks(toolCallbacks);
9641022
return this;
9651023
}
9661024

967-
public Builder functions(Set<String> functions) {
968-
Assert.notNull(functions, "Function names must not be null");
969-
this.options.functions = functions;
1025+
public Builder toolCallbacks(FunctionCallback... toolCallbacks) {
1026+
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
1027+
this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks));
9701028
return this;
9711029
}
9721030

973-
public Builder function(String functionName) {
974-
Assert.hasText(functionName, "Function name must not be empty");
975-
this.options.functions.add(functionName);
1031+
public Builder tools(Set<String> toolNames) {
1032+
this.options.setTools(toolNames);
1033+
return this;
1034+
}
1035+
1036+
public Builder tools(String... toolNames) {
1037+
Assert.notNull(toolNames, "toolNames cannot be null");
1038+
this.options.toolNames.addAll(Set.of(toolNames));
1039+
return this;
1040+
}
1041+
1042+
public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
1043+
this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled);
9761044
return this;
9771045
}
9781046

1047+
@Deprecated
1048+
public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
1049+
return toolCallbacks(functionCallbacks);
1050+
}
1051+
1052+
@Deprecated
1053+
public Builder functions(Set<String> functions) {
1054+
return tools(functions);
1055+
}
1056+
1057+
@Deprecated
1058+
public Builder function(String functionName) {
1059+
return tools(functionName);
1060+
}
1061+
1062+
@Deprecated
9791063
public Builder proxyToolCalls(Boolean proxyToolCalls) {
980-
this.options.proxyToolCalls = proxyToolCalls;
1064+
if (proxyToolCalls != null) {
1065+
this.options.setInternalToolExecutionEnabled(!proxyToolCalls);
1066+
}
9811067
return this;
9821068
}
9831069

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -29,17 +29,18 @@
2929
* @author Christian Tzolov
3030
* @author Thomas Vitale
3131
*/
32-
public class OllamaChatRequestTests {
32+
class OllamaChatRequestTests {
3333

3434
OllamaChatModel chatModel = OllamaChatModel.builder()
3535
.ollamaApi(new OllamaApi())
3636
.defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build())
3737
.build();
3838

3939
@Test
40-
public void createRequestWithDefaultOptions() {
40+
void createRequestWithDefaultOptions() {
41+
var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content"));
4142

42-
var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content"), false);
43+
var request = this.chatModel.ollamaChatRequest(prompt, false);
4344

4445
assertThat(request.messages()).hasSize(1);
4546
assertThat(request.stream()).isFalse();
@@ -52,12 +53,12 @@ public void createRequestWithDefaultOptions() {
5253
}
5354

5455
@Test
55-
public void createRequestWithPromptOllamaOptions() {
56-
56+
void createRequestWithPromptOllamaOptions() {
5757
// Runtime options should override the default options.
5858
OllamaOptions promptOptions = OllamaOptions.builder().temperature(0.8).topP(0.5).numGPU(2).build();
59+
var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content", promptOptions));
5960

60-
var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
61+
var request = this.chatModel.ollamaChatRequest(prompt, true);
6162

6263
assertThat(request.messages()).hasSize(1);
6364
assertThat(request.stream()).isTrue();
@@ -74,11 +75,11 @@ public void createRequestWithPromptOllamaOptions() {
7475

7576
@Test
7677
public void createRequestWithPromptPortableChatOptions() {
77-
7878
// Ollama runtime options.
7979
ChatOptions portablePromptOptions = ChatOptions.builder().temperature(0.9).topK(100).topP(0.6).build();
80+
var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content", portablePromptOptions));
8081

81-
var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", portablePromptOptions), true);
82+
var request = this.chatModel.ollamaChatRequest(prompt, true);
8283

8384
assertThat(request.messages()).hasSize(1);
8485
assertThat(request.stream()).isTrue();
@@ -92,31 +93,33 @@ public void createRequestWithPromptPortableChatOptions() {
9293

9394
@Test
9495
public void createRequestWithPromptOptionsModelOverride() {
95-
9696
// Ollama runtime options.
9797
OllamaOptions promptOptions = OllamaOptions.builder().model("PROMPT_MODEL").build();
98+
var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content", promptOptions));
9899

99-
var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
100+
var request = this.chatModel.ollamaChatRequest(prompt, true);
100101

101102
assertThat(request.model()).isEqualTo("PROMPT_MODEL");
102103
}
103104

104105
@Test
105106
public void createRequestWithDefaultOptionsModelOverride() {
106-
107107
OllamaChatModel chatModel = OllamaChatModel.builder()
108108
.ollamaApi(new OllamaApi())
109109
.defaultOptions(OllamaOptions.builder().model("DEFAULT_OPTIONS_MODEL").build())
110110
.build();
111111

112-
var request = chatModel.ollamaChatRequest(new Prompt("Test message content"), true);
112+
var prompt1 = chatModel.buildRequestPrompt(new Prompt("Test message content"));
113+
114+
var request = chatModel.ollamaChatRequest(prompt1, true);
113115

114116
assertThat(request.model()).isEqualTo("DEFAULT_OPTIONS_MODEL");
115117

116118
// Prompt options should override the default options.
117119
OllamaOptions promptOptions = OllamaOptions.builder().model("PROMPT_MODEL").build();
120+
var prompt2 = chatModel.buildRequestPrompt(new Prompt("Test message content", promptOptions));
118121

119-
request = chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
122+
request = chatModel.ollamaChatRequest(prompt2, true);
120123

121124
assertThat(request.model()).isEqualTo("PROMPT_MODEL");
122125
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class OllamaWithOpenAiChatModelIT {
7676
private static final String DEFAULT_OLLAMA_MODEL = "mistral";
7777

7878
@Container
79-
static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.5.1");
79+
static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.5.7");
8080

8181
static String baseUrl = "http://localhost:11434";
8282

spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,7 @@ interface Builder {
290290
Builder defaultFunctions(String... functionNames);
291291

292292
/**
293-
* @deprecated in favor of {@link #defaultTools(FunctionCallback...)} or
294-
* {@link #defaultToolCallbacks(FunctionCallback...)}
293+
* @deprecated in favor of {@link #defaultTools(Object...)}
295294
*/
296295
@Deprecated
297296
Builder defaultFunctions(FunctionCallback... functionCallbacks);

0 commit comments

Comments
 (0)