Skip to content

Commit bf9df71

Browse files
committed
VertexAiGeminiChatModel support InternalToolExecutionMaxAttempts
Signed-off-by: lambochen <[email protected]>
1 parent f808731 commit bf9df71

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,12 @@
136136
* @author Jihoon Kim
137137
* @author Alexandros Pappas
138138
* @author Ilayaperumal Gopinathan
139+
* @author lambochen
139140
* @since 0.8.1
140141
* @see VertexAiGeminiChatOptions
141142
* @see ToolCallingManager
142143
* @see ChatModel
144+
* @see ToolCallingChatOptions
143145
*/
144146
public class VertexAiGeminiChatModel implements ChatModel, DisposableBean {
145147

@@ -389,10 +391,10 @@ private static Schema jsonToSchema(String json) {
389391
@Override
390392
public ChatResponse call(Prompt prompt) {
391393
var requestPrompt = this.buildRequestPrompt(prompt);
392-
return this.internalCall(requestPrompt, null);
394+
return this.internalCall(requestPrompt, null, 1);
393395
}
394396

395-
private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
397+
private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int attempts) {
396398

397399
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
398400
.prompt(prompt)
@@ -425,7 +427,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon
425427
return chatResponse;
426428
}));
427429

428-
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
430+
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, attempts)) {
429431
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
430432
if (toolExecutionResult.returnDirect()) {
431433
// Return tool execution result directly to the client.
@@ -437,7 +439,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon
437439
else {
438440
// Send the tool execution result back to the model.
439441
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
440-
response);
442+
response, attempts + 1);
441443
}
442444
}
443445

@@ -469,6 +471,11 @@ Prompt buildRequestPrompt(Prompt prompt) {
469471
requestOptions.setInternalToolExecutionEnabled(
470472
ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(),
471473
this.defaultOptions.getInternalToolExecutionEnabled()));
474+
requestOptions.setInternalToolExecutionMaxAttempts(
475+
ModelOptionsUtils.mergeOption(
476+
runtimeOptions.getInternalToolExecutionMaxAttempts(),
477+
this.defaultOptions.getInternalToolExecutionMaxAttempts())
478+
);
472479
requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
473480
this.defaultOptions.getToolNames()));
474481
requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),
@@ -483,6 +490,7 @@ Prompt buildRequestPrompt(Prompt prompt) {
483490
}
484491
else {
485492
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
493+
requestOptions.setInternalToolExecutionMaxAttempts(this.defaultOptions.getInternalToolExecutionMaxAttempts());
486494
requestOptions.setToolNames(this.defaultOptions.getToolNames());
487495
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
488496
requestOptions.setToolContext(this.defaultOptions.getToolContext());
@@ -499,10 +507,10 @@ Prompt buildRequestPrompt(Prompt prompt) {
499507
@Override
500508
public Flux<ChatResponse> stream(Prompt prompt) {
501509
var requestPrompt = this.buildRequestPrompt(prompt);
502-
return this.internalStream(requestPrompt, null);
510+
return this.internalStream(requestPrompt, null, 1);
503511
}
504512

505-
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
513+
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse, int attempts) {
506514
return Flux.deferContextual(contextView -> {
507515

508516
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
@@ -538,7 +546,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
538546

539547
// @formatter:off
540548
Flux<ChatResponse> flux = chatResponseFlux.flatMap(response -> {
541-
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
549+
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, attempts)) {
542550
// FIXME: bounded elastic needs to be used since tool calling
543551
// is currently only synchronous
544552
return Flux.defer(() -> {
@@ -551,7 +559,10 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
551559
}
552560
else {
553561
// Send the tool execution result back to the model.
554-
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response);
562+
return this.internalStream(
563+
new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
564+
response,
565+
attempts + 1);
555566
}
556567
}).subscribeOn(Schedulers.boundedElastic());
557568
}

0 commit comments

Comments
 (0)