Skip to content

Commit 2d128ba

Browse files
committed
OpenAiChatModel support InternalToolExecutionMaxAttempts
Signed-off-by: lambochen <[email protected]>
1 parent 092eb6c commit 2d128ba

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ public ChatResponse call(Prompt prompt) {
180180
// Before moving any further, build the final request Prompt,
181181
// merging runtime and default options.
182182
Prompt requestPrompt = buildRequestPrompt(prompt);
183-
return this.internalCall(requestPrompt, null, 0);
183+
return this.internalCall(requestPrompt, null, 1);
184184
}
185185

186186
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int attempts) {
@@ -242,7 +242,6 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
242242

243243
});
244244

245-
attempts++;
246245
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, attempts)) {
247246
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
248247
if (toolExecutionResult.returnDirect()) {
@@ -255,7 +254,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
255254
else {
256255
// Send the tool execution result back to the model.
257256
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
258-
response, attempts);
257+
response, attempts + 1);
259258
}
260259
}
261260

@@ -267,10 +266,10 @@ public Flux<ChatResponse> stream(Prompt prompt) {
267266
// Before moving any further, build the final request Prompt,
268267
// merging runtime and default options.
269268
Prompt requestPrompt = buildRequestPrompt(prompt);
270-
return internalStream(requestPrompt, null);
269+
return internalStream(requestPrompt, null, 1);
271270
}
272271

273-
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
272+
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse, int attempts) {
274273
return Flux.deferContextual(contextView -> {
275274
ChatCompletionRequest request = createRequest(prompt, true);
276275

@@ -365,7 +364,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
365364

366365
// @formatter:off
367366
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
368-
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
367+
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, attempts)) {
369368
return Flux.defer(() -> {
370369
// FIXME: bounded elastic needs to be used since tool calling
371370
// is currently only synchronous
@@ -379,7 +378,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
379378
else {
380379
// Send the tool execution result back to the model.
381380
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
382-
response);
381+
response, attempts + 1);
383382
}
384383
}).subscribeOn(Schedulers.boundedElastic());
385384
}

0 commit comments

Comments
 (0)