Skip to content

Commit 670d691

Browse files
committed
OpenAiChatModel support internalToolExecutionMaxAttempts
Signed-off-by: lambochen <[email protected]>
1 parent e243d42 commit 670d691

File tree

7 files changed

+42
-14
lines changed

7 files changed

+42
-14
lines changed

models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) {
180180
.toolCallbacks(fromOptions.getToolCallbacks())
181181
.toolNames(fromOptions.getToolNames())
182182
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
183-
.internalToolExecutionMaxAttempts(fromOptions.getInternalToolExecutionMaxAttempts())
183+
.internalToolExecutionMaxAttempts(fromOptions.getInternalToolExecutionMaxAttempts())
184184
.toolContext(fromOptions.getToolContext())
185185
.build();
186186
}
@@ -395,8 +395,8 @@ public int hashCode() {
395395
result = prime * result + ((this.toolNames == null) ? 0 : this.toolNames.hashCode());
396396
result = prime * result
397397
+ ((this.internalToolExecutionEnabled == null) ? 0 : this.internalToolExecutionEnabled.hashCode());
398-
result = prime * result
399-
+ ((this.internalToolExecutionMaxAttempts == null) ? 0 : this.internalToolExecutionMaxAttempts.hashCode());
398+
result = prime * result + ((this.internalToolExecutionMaxAttempts == null) ? 0
399+
: this.internalToolExecutionMaxAttempts.hashCode());
400400
result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode());
401401
return result;
402402
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,11 @@
105105
* @author Alexandros Pappas
106106
* @author Soby Chacko
107107
* @author Jonghoon Park
108+
* @author lambochen
108109
* @see ChatModel
109110
* @see StreamingChatModel
110111
* @see OpenAiApi
112+
* @see ToolCallingChatOptions
111113
*/
112114
public class OpenAiChatModel implements ChatModel {
113115

@@ -178,10 +180,10 @@ public ChatResponse call(Prompt prompt) {
178180
// Before moving any further, build the final request Prompt,
179181
// merging runtime and default options.
180182
Prompt requestPrompt = buildRequestPrompt(prompt);
181-
return this.internalCall(requestPrompt, null);
183+
return this.internalCall(requestPrompt, null, 0);
182184
}
183185

184-
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
186+
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int attempts) {
185187

186188
ChatCompletionRequest request = createRequest(prompt, false);
187189

@@ -240,7 +242,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
240242

241243
});
242244

243-
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
245+
attempts++;
246+
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, attempts)) {
244247
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
245248
if (toolExecutionResult.returnDirect()) {
246249
// Return tool execution result directly to the client.
@@ -252,7 +255,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
252255
else {
253256
// Send the tool execution result back to the model.
254257
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
255-
response);
258+
response, attempts);
256259
}
257260
}
258261

@@ -520,6 +523,9 @@ Prompt buildRequestPrompt(Prompt prompt) {
520523
requestOptions.setInternalToolExecutionEnabled(
521524
ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(),
522525
this.defaultOptions.getInternalToolExecutionEnabled()));
526+
requestOptions.setInternalToolExecutionMaxAttempts(
527+
ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionMaxAttempts(),
528+
this.defaultOptions.getInternalToolExecutionMaxAttempts()));
523529
requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
524530
this.defaultOptions.getToolNames()));
525531
requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),
@@ -530,6 +536,8 @@ Prompt buildRequestPrompt(Prompt prompt) {
530536
else {
531537
requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
532538
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
539+
requestOptions
540+
.setInternalToolExecutionMaxAttempts(this.defaultOptions.getInternalToolExecutionMaxAttempts());
533541
requestOptions.setToolNames(this.defaultOptions.getToolNames());
534542
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
535543
requestOptions.setToolContext(this.defaultOptions.getToolContext());

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,10 +492,10 @@ public ToolCallingChatOptions merge(ChatOptions options) {
492492
builder.internalToolExecutionEnabled(toolCallingChatOptions.getInternalToolExecutionEnabled() != null
493493
? (toolCallingChatOptions).getInternalToolExecutionEnabled()
494494
: this.getInternalToolExecutionEnabled());
495-
builder.internalToolExecutionMaxAttempts(
496-
toolCallingChatOptions.getInternalToolExecutionMaxAttempts() != null
497-
? toolCallingChatOptions.getInternalToolExecutionMaxAttempts()
498-
: this.getInternalToolExecutionMaxAttempts());
495+
builder
496+
.internalToolExecutionMaxAttempts(toolCallingChatOptions.getInternalToolExecutionMaxAttempts() != null
497+
? toolCallingChatOptions.getInternalToolExecutionMaxAttempts()
498+
: this.getInternalToolExecutionMaxAttempts());
499499

500500
Set<String> toolNames = new HashSet<>();
501501
if (this.toolNames != null) {

spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ public ToolCallingChatOptions.Builder internalToolExecutionEnabled(
293293
}
294294

295295
@Override
296-
public ToolCallingChatOptions.Builder internalToolExecutionMaxAttempts(Integer internalToolExecutionMaxAttempts) {
296+
public ToolCallingChatOptions.Builder internalToolExecutionMaxAttempts(
297+
Integer internalToolExecutionMaxAttempts) {
297298
this.options.setInternalToolExecutionMaxAttempts(internalToolExecutionMaxAttempts);
298299
return this;
299300
}

spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ public interface ToolCallingChatOptions extends ChatOptions {
4848
* No limit for tool execution attempts.
4949
*/
5050
int TOOL_EXECUTION_NO_LIMIT = Integer.MAX_VALUE;
51+
5152
int DEFAULT_TOOL_EXECUTION_MAX_ATTEMPTS = TOOL_EXECUTION_NO_LIMIT;
5253

5354
/**

spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityChecker.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ default boolean isToolExecutionRequired(ChatOptions promptOptions, ChatResponse
4545
}
4646

4747
/**
48-
* Determines if tool execution should be performed based on the prompt options and chat response and attempts.
48+
* Determines if tool execution should be performed based on the prompt options and
49+
* chat response and attempts.
4950
* @param promptOptions The options from the prompt
5051
* @param chatResponse The response from the chat model
5152
* @param attempts The number of attempts to execute the tool
@@ -104,7 +105,8 @@ default boolean isInternalToolExecutionEnabled(ChatOptions chatOptions, int atte
104105
if (chatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) {
105106
return toolCallingChatOptions.getInternalToolExecutionMaxAttempts() == null
106107
|| attempts <= toolCallingChatOptions.getInternalToolExecutionMaxAttempts();
107-
} else {
108+
}
109+
else {
108110
internalToolExecutionEnabled = true;
109111
}
110112
return internalToolExecutionEnabled;

spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicate.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,20 @@ default boolean isToolExecutionRequired(ChatOptions promptOptions, ChatResponse
4343
return test(promptOptions, chatResponse);
4444
}
4545

46+
default boolean isToolExecutionRequired(ChatOptions promptOptions, ChatResponse chatResponse, int attempts) {
47+
boolean isToolExecutionRequired = isToolExecutionRequired(promptOptions, chatResponse);
48+
if (!isToolExecutionRequired) {
49+
return true;
50+
}
51+
52+
if (promptOptions instanceof ToolCallingChatOptions toolCallingChatOptions) {
53+
return toolCallingChatOptions.getInternalToolExecutionMaxAttempts() == null
54+
|| attempts <= toolCallingChatOptions.getInternalToolExecutionMaxAttempts();
55+
}
56+
else {
57+
isToolExecutionRequired = true;
58+
}
59+
return isToolExecutionRequired;
60+
}
61+
4662
}

0 commit comments

Comments
 (0)