Skip to content

Commit 253d19e

Browse files
committed
AnthropicChatModel support InternalToolExecutionMaxAttempts
Signed-off-by: lambochen <[email protected]>
1 parent 2d128ba commit 253d19e

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
* @author Alexandros Pappas
9191
* @author Jonghoon Park
9292
* @author Soby Chacko
93+
* @author lambochen
9394
* @since 1.0.0
9495
*/
9596
public class AnthropicChatModel implements ChatModel {
@@ -170,10 +171,10 @@ public ChatResponse call(Prompt prompt) {
170171
// Before moving any further, build the final request Prompt,
171172
// merging runtime and default options.
172173
Prompt requestPrompt = buildRequestPrompt(prompt);
173-
return this.internalCall(requestPrompt, null);
174+
return this.internalCall(requestPrompt, null, 1);
174175
}
175176

176-
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
177+
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int attempts) {
177178
ChatCompletionRequest request = createRequest(prompt, false);
178179

179180
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
@@ -203,7 +204,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
203204
return chatResponse;
204205
});
205206

206-
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
207+
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, attempts)) {
207208
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
208209
if (toolExecutionResult.returnDirect()) {
209210
// Return tool execution result directly to the client.
@@ -215,7 +216,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
215216
else {
216217
// Send the tool execution result back to the model.
217218
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
218-
response);
219+
response, attempts + 1);
219220
}
220221
}
221222

@@ -232,10 +233,10 @@ public Flux<ChatResponse> stream(Prompt prompt) {
232233
// Before moving any further, build the final request Prompt,
233234
// merging runtime and default options.
234235
Prompt requestPrompt = buildRequestPrompt(prompt);
235-
return this.internalStream(requestPrompt, null);
236+
return this.internalStream(requestPrompt, null, 1);
236237
}
237238

238-
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
239+
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse, int attempts) {
239240
return Flux.deferContextual(contextView -> {
240241
ChatCompletionRequest request = createRequest(prompt, true);
241242

@@ -260,7 +261,8 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
260261
Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
261262
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);
262263

263-
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) && chatResponse.hasFinishReasons(Set.of("tool_use"))) {
264+
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse, attempts)
265+
&& chatResponse.hasFinishReasons(Set.of("tool_use"))) {
264266
// FIXME: bounded elastic needs to be used since tool calling
265267
// is currently only synchronous
266268
return Flux.defer(() -> {
@@ -274,7 +276,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
274276
else {
275277
// Send the tool execution result back to the model.
276278
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
277-
chatResponse);
279+
chatResponse, attempts + 1);
278280
}
279281
}).subscribeOn(Schedulers.boundedElastic());
280282
}
@@ -437,6 +439,11 @@ Prompt buildRequestPrompt(Prompt prompt) {
437439
requestOptions.setInternalToolExecutionEnabled(
438440
ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(),
439441
this.defaultOptions.getInternalToolExecutionEnabled()));
442+
requestOptions.setInternalToolExecutionMaxAttempts(
443+
ModelOptionsUtils.mergeOption(
444+
runtimeOptions.getInternalToolExecutionMaxAttempts(),
445+
defaultOptions.getInternalToolExecutionMaxAttempts())
446+
);
440447
requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
441448
this.defaultOptions.getToolNames()));
442449
requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),
@@ -447,6 +454,7 @@ Prompt buildRequestPrompt(Prompt prompt) {
447454
else {
448455
requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
449456
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
457+
requestOptions.setInternalToolExecutionMaxAttempts(this.defaultOptions.getInternalToolExecutionMaxAttempts());
450458
requestOptions.setToolNames(this.defaultOptions.getToolNames());
451459
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
452460
requestOptions.setToolContext(this.defaultOptions.getToolContext());

0 commit comments

Comments
 (0)