Skip to content

Commit 029da4f

Browse files
committed
Fix support for FunctionCallingOptions across AI models
Add support for PortableFunctionCallingOptions across AI models - Modify FunctionCallingOptions interface to extend ChatOptions for better integration - Refactor option handling in chat models to accommodate both ChatOptions and FunctionCallingOptions - Implement handling of FunctionCallingOptions in Anthropic, Azure OpenAI, MistralAI, Ollama, OpenAI, VertexAI Gemini, and other models - Update existing function calling tests to use new FunctionCallingOptions. Resolves #624
1 parent 6fc76b7 commit 029da4f

File tree

17 files changed

+242
-40
lines changed

17 files changed

+242
-40
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import org.springframework.ai.model.ModelOptionsUtils;
5454
import org.springframework.ai.model.function.FunctionCallback;
5555
import org.springframework.ai.model.function.FunctionCallbackContext;
56+
import org.springframework.ai.model.function.FunctionCallingOptions;
5657
import org.springframework.ai.retry.RetryUtils;
5758
import org.springframework.http.ResponseEntity;
5859
import org.springframework.retry.support.RetryTemplate;
@@ -413,8 +414,15 @@ else if (message.getMessageType() == MessageType.TOOL) {
413414
systemPrompt, this.defaultOptions.getMaxTokens(), this.defaultOptions.getTemperature(), stream);
414415

415416
if (prompt.getOptions() != null) {
416-
AnthropicChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
417-
ChatOptions.class, AnthropicChatOptions.class);
417+
AnthropicChatOptions updatedRuntimeOptions;
418+
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
419+
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
420+
FunctionCallingOptions.class, AnthropicChatOptions.class);
421+
}
422+
else {
423+
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
424+
AnthropicChatOptions.class);
425+
}
418426

419427
functionsForThisRequest.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));
420428

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.Optional;
2626
import java.util.Set;
2727
import java.util.concurrent.atomic.AtomicBoolean;
28+
import java.util.function.Function;
2829

2930
import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage;
3031
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -47,6 +48,7 @@
4748
import org.springframework.ai.model.ModelOptionsUtils;
4849
import org.springframework.ai.model.function.FunctionCallback;
4950
import org.springframework.ai.model.function.FunctionCallbackContext;
51+
import org.springframework.ai.model.function.FunctionCallingOptions;
5052
import org.springframework.util.Assert;
5153
import org.springframework.util.CollectionUtils;
5254

@@ -286,8 +288,15 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
286288
functionsForThisRequest.addAll(this.defaultOptions.getFunctions());
287289

288290
if (prompt.getOptions() != null) {
289-
AzureOpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
290-
ChatOptions.class, AzureOpenAiChatOptions.class);
291+
AzureOpenAiChatOptions updatedRuntimeOptions;
292+
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
293+
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
294+
FunctionCallingOptions.class, AzureOpenAiChatOptions.class);
295+
}
296+
else {
297+
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
298+
AzureOpenAiChatOptions.class);
299+
}
291300
options = this.merge(updatedRuntimeOptions, options);
292301

293302
functionsForThisRequest.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));

models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.springframework.ai.model.ModelOptionsUtils;
4545
import org.springframework.ai.model.function.FunctionCallback;
4646
import org.springframework.ai.model.function.FunctionCallbackContext;
47+
import org.springframework.ai.model.function.FunctionCallingOptions;
4748
import org.springframework.ai.retry.RetryUtils;
4849
import org.springframework.http.ResponseEntity;
4950
import org.springframework.retry.support.RetryTemplate;
@@ -391,8 +392,16 @@ else if (message.getMessageType() == MessageType.TOOL) {
391392
Set<String> enabledToolsToUse = new HashSet<>();
392393

393394
if (prompt.getOptions() != null) {
394-
MiniMaxChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
395-
ChatOptions.class, MiniMaxChatOptions.class);
395+
MiniMaxChatOptions updatedRuntimeOptions;
396+
397+
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
398+
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
399+
FunctionCallingOptions.class, MiniMaxChatOptions.class);
400+
}
401+
else {
402+
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
403+
MiniMaxChatOptions.class);
404+
}
396405

397406
enabledToolsToUse.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));
398407

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import org.springframework.ai.model.ModelOptionsUtils;
5353
import org.springframework.ai.model.function.FunctionCallback;
5454
import org.springframework.ai.model.function.FunctionCallbackContext;
55+
import org.springframework.ai.model.function.FunctionCallingOptions;
5556
import org.springframework.ai.retry.RetryUtils;
5657
import org.springframework.http.ResponseEntity;
5758
import org.springframework.retry.support.RetryTemplate;
@@ -367,8 +368,16 @@ else if (message instanceof ToolResponseMessage toolResponseMessage) {
367368
request = ModelOptionsUtils.merge(request, this.defaultOptions, MistralAiApi.ChatCompletionRequest.class);
368369

369370
if (prompt.getOptions() != null) {
370-
var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
371-
MistralAiChatOptions.class);
371+
MistralAiChatOptions updatedRuntimeOptions;
372+
373+
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
374+
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
375+
FunctionCallingOptions.class, MistralAiChatOptions.class);
376+
}
377+
else {
378+
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
379+
MistralAiChatOptions.class);
380+
}
372381

373382
functionsForThisRequest.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));
374383

models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.ai.model.ModelOptionsUtils;
3434
import org.springframework.ai.model.function.FunctionCallback;
3535
import org.springframework.ai.model.function.FunctionCallbackContext;
36+
import org.springframework.ai.model.function.FunctionCallingOptions;
3637
import org.springframework.ai.moonshot.api.MoonshotApi;
3738
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletion;
3839
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletion.Choice;
@@ -341,9 +342,16 @@ else if (message.getMessageType() == MessageType.TOOL) {
341342
Set<String> enabledToolsToUse = new HashSet<>();
342343

343344
if (prompt.getOptions() != null) {
344-
MoonshotChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
345-
ChatOptions.class, MoonshotChatOptions.class);
345+
MoonshotChatOptions updatedRuntimeOptions;
346346

347+
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
348+
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
349+
FunctionCallingOptions.class, MoonshotChatOptions.class);
350+
}
351+
else {
352+
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
353+
MoonshotChatOptions.class);
354+
}
347355
enabledToolsToUse.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));
348356

349357
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.springframework.ai.model.ModelOptionsUtils;
4242
import org.springframework.ai.model.function.FunctionCallback;
4343
import org.springframework.ai.model.function.FunctionCallbackContext;
44+
import org.springframework.ai.model.function.FunctionCallingOptions;
4445
import org.springframework.ai.ollama.api.OllamaApi;
4546
import org.springframework.ai.ollama.api.OllamaApi.ChatRequest;
4647
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
@@ -297,8 +298,14 @@ else if (message instanceof ToolResponseMessage toolMessage) {
297298
// runtime options
298299
OllamaOptions runtimeOptions = null;
299300
if (prompt.getOptions() != null) {
300-
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
301-
OllamaOptions.class);
301+
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
302+
runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class,
303+
OllamaOptions.class);
304+
}
305+
else {
306+
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
307+
OllamaOptions.class);
308+
}
302309
functionsForThisRequest.addAll(this.runtimeFunctionCallbackConfigurations(runtimeOptions));
303310
}
304311

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import org.springframework.ai.model.ModelOptionsUtils;
5252
import org.springframework.ai.model.function.FunctionCallback;
5353
import org.springframework.ai.model.function.FunctionCallbackContext;
54+
import org.springframework.ai.model.function.FunctionCallingOptions;
5455
import org.springframework.ai.openai.api.OpenAiApi;
5556
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
5657
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice;
@@ -477,8 +478,16 @@ else if (message.getMessageType() == MessageType.TOOL) {
477478
Set<String> enabledToolsToUse = new HashSet<>();
478479

479480
if (prompt.getOptions() != null) {
480-
OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
481-
ChatOptions.class, OpenAiChatOptions.class);
481+
OpenAiChatOptions updatedRuntimeOptions = null;
482+
483+
if (prompt.getOptions() instanceof FunctionCallingOptions) {
484+
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(((FunctionCallingOptions) prompt.getOptions()),
485+
FunctionCallingOptions.class, OpenAiChatOptions.class);
486+
}
487+
else if (prompt.getOptions() instanceof OpenAiChatOptions) {
488+
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
489+
OpenAiChatOptions.class);
490+
}
482491

483492
enabledToolsToUse.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));
484493

models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,9 @@ public ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
181181
}
182182

183183
if (prompt.getOptions() != null) {
184-
if (prompt.getOptions() != null) {
185-
var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
186-
QianFanChatOptions.class);
187-
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
188-
}
189-
else {
190-
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
191-
+ prompt.getOptions().getClass().getSimpleName());
192-
}
184+
var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
185+
QianFanChatOptions.class);
186+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
193187
}
194188
return request;
195189
}

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.springframework.ai.model.ModelOptionsUtils;
4343
import org.springframework.ai.model.function.FunctionCallback;
4444
import org.springframework.ai.model.function.FunctionCallbackContext;
45+
import org.springframework.ai.model.function.FunctionCallingOptions;
4546
import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage;
4647
import org.springframework.beans.factory.DisposableBean;
4748
import org.springframework.lang.NonNull;
@@ -80,8 +81,6 @@
8081
*/
8182
public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements ChatModel, DisposableBean {
8283

83-
private final static boolean IS_RUNTIME_CALL = true;
84-
8584
private final VertexAI vertexAI;
8685

8786
private final VertexAiGeminiChatOptions defaultOptions;
@@ -292,9 +291,15 @@ GeminiRequest createGeminiRequest(Prompt prompt) {
292291
VertexAiGeminiChatOptions updatedRuntimeOptions = VertexAiGeminiChatOptions.builder().build();
293292

294293
if (prompt.getOptions() != null) {
295-
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
296-
VertexAiGeminiChatOptions.class);
294+
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
295+
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
296+
FunctionCallingOptions.class, VertexAiGeminiChatOptions.class);
297297

298+
}
299+
else {
300+
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
301+
VertexAiGeminiChatOptions.class);
302+
}
298303
functionsForThisRequest.addAll(runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));
299304
}
300305

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.ai.model.ModelOptionsUtils;
3434
import org.springframework.ai.model.function.FunctionCallback;
3535
import org.springframework.ai.model.function.FunctionCallbackContext;
36+
import org.springframework.ai.model.function.FunctionCallingOptions;
3637
import org.springframework.ai.retry.RetryUtils;
3738
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
3839
import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletion;
@@ -358,8 +359,15 @@ else if (message.getMessageType() == MessageType.TOOL) {
358359
Set<String> enabledToolsToUse = new HashSet<>();
359360

360361
if (prompt.getOptions() != null) {
361-
ZhiPuAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
362-
ChatOptions.class, ZhiPuAiChatOptions.class);
362+
ZhiPuAiChatOptions updatedRuntimeOptions;
363+
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
364+
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
365+
FunctionCallingOptions.class, ZhiPuAiChatOptions.class);
366+
}
367+
else {
368+
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
369+
ZhiPuAiChatOptions.class);
370+
}
363371

364372
enabledToolsToUse.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));
365373

0 commit comments

Comments
 (0)