Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
* backed by {@link DeepSeekApi}.
*
* @author Geng Rong
* @last Updated By : Kuntal Maity
*/
public class DeepSeekChatModel implements ChatModel {

Expand Down Expand Up @@ -193,8 +194,11 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
}).toList();

// Current usage
DeepSeekApi.Usage usage = completionEntity.getBody().usage();
Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage();
DeepSeekApi.Usage usage = null;
if (completionEntity != null && completionEntity.getBody() != null) {
usage = chatCompletion.usage();
}
Usage currentChatResponseUsage = toUsageOrEmpty(usage);
Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage,
previousChatResponse);
ChatResponse chatResponse = new ChatResponse(generations,
Expand All @@ -216,6 +220,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
.build();
}
else {
// Reset tool choice to AUTO to prevent forcing repeated tool calls.
resetToolChoiceToAuto(prompt);
// Send the tool execution result back to the model.
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
response);
Expand Down Expand Up @@ -272,7 +278,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
return buildGeneration(choice, metadata);
}).toList();
DeepSeekApi.Usage usage = chatCompletion2.usage();
Usage currentUsage = (usage != null) ? getDefaultUsage(usage) : new EmptyUsage();
Usage currentUsage = toUsageOrEmpty(usage);
Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse);

return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage));
Expand Down Expand Up @@ -305,6 +311,8 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
.build());
}
else {
// Reset tool choice to AUTO to prevent forcing repeated tool calls.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated code with 223. you can extract these duplicated line to seperate method.

resetToolChoiceToAuto(prompt);
// Send the tool execution result back to the model.
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
response);
Expand Down Expand Up @@ -390,6 +398,28 @@ private DefaultUsage getDefaultUsage(DeepSeekApi.Usage usage) {
return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage);
}

/**
* Convert {@link DeepSeekApi.Usage} to a non-null {@link Usage} instance. Returns
* {@link EmptyUsage} when the given usage is null.
* @param usage the API usage, can be null
* @return non-null {@link Usage}
* @author Kuntal Maity
*/
private Usage toUsageOrEmpty(DeepSeekApi.Usage usage) {
return (usage != null) ? getDefaultUsage(usage) : new EmptyUsage();
}

/**
* Reset tool choice to AUTO to prevent forcing repeated tool calls.
* @param prompt the prompt that carries the options
* @author Kuntal Maity
*/
private void resetToolChoiceToAuto(Prompt prompt) {
if (prompt.getOptions() instanceof DeepSeekChatOptions options) {
options.setToolChoice(ChatCompletionRequest.ToolChoiceBuilder.AUTO);
}
}

Prompt buildRequestPrompt(Prompt prompt) {
DeepSeekChatOptions runtimeOptions = null;
if (prompt.getOptions() != null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.deepseek;

import java.time.Instant;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;

import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.deepseek.api.DeepSeekApi;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletion;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletion.Choice;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionFinishReason;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.ChatCompletionFunction;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.ToolCall;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionRequest;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.http.ResponseEntity;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

/**
* Verifies that DeepSeekChatModel resets tool_choice to AUTO when resubmitting tool
* results (returnDirect=false) to avoid infinite tool call loops.
*
* @author : kuntal maity
*/
class DeepSeekChatModelToolChoiceResetTests {

@Test
void resetsToolChoiceToAutoOnToolResultPushback() {
// Arrange: mock API to return a tool call first, then a normal assistant message
DeepSeekApi api = mock(DeepSeekApi.class);

// Capture requests to verify tool_choice on the second call
ArgumentCaptor<ChatCompletionRequest> reqCaptor = ArgumentCaptor.forClass(ChatCompletionRequest.class);

AtomicInteger apiCalls = new AtomicInteger(0);
when(api.chatCompletionEntity(reqCaptor.capture())).thenAnswer(invocation -> {
int call = apiCalls.incrementAndGet();
if (call == 1) {
// First response: model requests tool call
ChatCompletionMessage msg = new ChatCompletionMessage("", // content
ChatCompletionMessage.Role.ASSISTANT, null, null, List.of(new ToolCall("call_1", "function",
new ChatCompletionFunction("getMarineYetiDescription", "{}"))),
null, null);
ChatCompletion cc = new ChatCompletion("id-1",
List.of(new Choice(ChatCompletionFinishReason.TOOL_CALLS, 0, msg, null)),
Instant.now().getEpochSecond(), DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getName(), null,
"chat.completion", null);
return ResponseEntity.ok(cc);
}
else {
// Second response: normal assistant message
ChatCompletionMessage msg = new ChatCompletionMessage("Marine yeti is orange.",
ChatCompletionMessage.Role.ASSISTANT, null, null, null, null, null);
ChatCompletion cc = new ChatCompletion("id-2",
List.of(new Choice(ChatCompletionFinishReason.STOP, 0, msg, null)),
Instant.now().getEpochSecond(), DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getName(), null,
"chat.completion", null);
return ResponseEntity.ok(cc);
}
});

// Tool callback increments counter; returnDirect defaults to false
AtomicInteger toolInvocations = new AtomicInteger(0);
var tool = FunctionToolCallback.builder("getMarineYetiDescription", () -> {
toolInvocations.incrementAndGet();
return "Marine yeti is orange";
}).build();

DeepSeekChatOptions options = DeepSeekChatOptions.builder()
.model(DeepSeekApi.ChatModel.DEEPSEEK_CHAT)
.toolCallbacks(List.of(tool))
.toolChoice(ChatCompletionRequest.ToolChoiceBuilder.FUNCTION("getMarineYetiDescription"))
.build();

DeepSeekChatModel model = DeepSeekChatModel.builder().deepSeekApi(api).defaultOptions(options).build();

// Act
ChatResponse response = model.call(new Prompt("What is the color of a marine yeti?"));

// Assert: API was called twice (tool call, then final text)
assertThat(apiCalls.get()).isEqualTo(2);
// Second request tool_choice should be AUTO
assertThat(reqCaptor.getAllValues()).hasSize(2);
Object secondToolChoice = reqCaptor.getAllValues().get(1).toolChoice();
assertThat(secondToolChoice).isEqualTo(ChatCompletionRequest.ToolChoiceBuilder.AUTO);
// Tool executes exactly once
assertThat(toolInvocations.get()).isEqualTo(1);
// And final content is normal text
assertThat(response.getResult().getOutput().getText()).containsIgnoringCase("orange");
}

}