Skip to content

Commit 93efce5

Browse files
committed
AzureOpenAiChatModel support InternalToolExecutionMaxAttempts
Signed-off-by: lambochen <[email protected]>
1 parent 253d19e commit 93efce5

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
* @author Berjan Jonker
123123
* @author Andres da Silva Santos
124124
* @author Bart Veenstra
125+
* @author lambochen
125126
* @see ChatModel
126127
* @see com.azure.ai.openai.OpenAIClient
127128
* @since 1.0.0
@@ -247,10 +248,10 @@ public ChatResponse call(Prompt prompt) {
247248
// Before moving any further, build the final request Prompt,
248249
// merging runtime and default options.
249250
Prompt requestPrompt = buildRequestPrompt(prompt);
250-
return this.internalCall(requestPrompt, null);
251+
return this.internalCall(requestPrompt, null, 1);
251252
}
252253

253-
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
254+
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int attempts) {
254255

255256
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
256257
.prompt(prompt)
@@ -270,7 +271,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
270271
return chatResponse;
271272
});
272273

273-
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
274+
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, attempts)) {
274275
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
275276
if (toolExecutionResult.returnDirect()) {
276277
// Return tool execution result directly to the client.
@@ -282,7 +283,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
282283
else {
283284
// Send the tool execution result back to the model.
284285
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
285-
response);
286+
response, attempts + 1);
286287
}
287288
}
288289

@@ -294,10 +295,10 @@ public Flux<ChatResponse> stream(Prompt prompt) {
294295
// Before moving any further, build the final request Prompt,
295296
// merging runtime and default options.
296297
Prompt requestPrompt = buildRequestPrompt(prompt);
297-
return this.internalStream(requestPrompt, null);
298+
return this.internalStream(requestPrompt, null, 1);
298299
}
299300

300-
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
301+
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse, int attempts) {
301302

302303
return Flux.deferContextual(contextView -> {
303304
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
@@ -377,7 +378,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
377378
});
378379

379380
return chatResponseFlux.flatMap(chatResponse -> {
380-
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) {
381+
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse, attempts)) {
381382
// FIXME: bounded elastic needs to be used since tool calling
382383
// is currently only synchronous
383384
return Flux.defer(() -> {
@@ -393,7 +394,8 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
393394
// Send the tool execution result back to the model.
394395
return this.internalStream(
395396
new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
396-
chatResponse);
397+
chatResponse,
398+
attempts + 1);
397399
}
398400
}).subscribeOn(Schedulers.boundedElastic());
399401
}
@@ -666,6 +668,12 @@ Prompt buildRequestPrompt(Prompt prompt) {
666668
requestOptions.setInternalToolExecutionEnabled(
667669
ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(),
668670
this.defaultOptions.getInternalToolExecutionEnabled()));
671+
runtimeOptions.setInternalToolExecutionMaxAttempts(
672+
ModelOptionsUtils.mergeOption(
673+
runtimeOptions.getInternalToolExecutionMaxAttempts(),
674+
this.defaultOptions.getInternalToolExecutionMaxAttempts()
675+
)
676+
);
669677
requestOptions.setStreamUsage(ModelOptionsUtils.mergeOption(runtimeOptions.getStreamUsage(),
670678
this.defaultOptions.getStreamUsage()));
671679
requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
@@ -677,6 +685,7 @@ Prompt buildRequestPrompt(Prompt prompt) {
677685
}
678686
else {
679687
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
688+
requestOptions.setInternalToolExecutionMaxAttempts(this.defaultOptions.getInternalToolExecutionMaxAttempts());
680689
requestOptions.setStreamUsage(this.defaultOptions.getStreamUsage());
681690
requestOptions.setToolNames(this.defaultOptions.getToolNames());
682691
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());

0 commit comments

Comments
 (0)