Skip to content

Commit 9d3f39a

Browse files
committed
feat: support reasoning content for deepseek-reasoner and add docs
Signed-off-by: GR <[email protected]>
1 parent ffc7fb3 commit 9d3f39a

File tree

8 files changed

+226
-53
lines changed

8 files changed

+226
-53
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package org.springframework.ai.deepseek;
2+
3+
import org.springframework.ai.chat.messages.AssistantMessage;
4+
import org.springframework.ai.content.Media;
5+
6+
import java.util.List;
7+
import java.util.Map;
8+
import java.util.Objects;
9+
10+
public class DeepSeekAssistantMessage extends AssistantMessage {
11+
12+
private Boolean prefix;
13+
14+
private String reasoningContent;
15+
16+
public DeepSeekAssistantMessage(String content) {
17+
super(content);
18+
}
19+
20+
public DeepSeekAssistantMessage(String content, String reasoningContent) {
21+
super(content);
22+
this.reasoningContent = reasoningContent;
23+
}
24+
25+
public DeepSeekAssistantMessage(String content, Map<String, Object> properties) {
26+
super(content, properties);
27+
}
28+
29+
public DeepSeekAssistantMessage(String content, Map<String, Object> properties, List<ToolCall> toolCalls) {
30+
super(content, properties, toolCalls);
31+
}
32+
33+
public DeepSeekAssistantMessage(String content, String reasoningContent, Map<String, Object> properties,
34+
List<ToolCall> toolCalls) {
35+
this(content, reasoningContent, properties, toolCalls, List.of());
36+
}
37+
38+
public DeepSeekAssistantMessage(String content, String reasoningContent, Map<String, Object> properties,
39+
List<ToolCall> toolCalls, List<Media> media) {
40+
super(content, properties, toolCalls, media);
41+
this.reasoningContent = reasoningContent;
42+
}
43+
44+
public static DeepSeekAssistantMessage prefixAssistantMessage(String context) {
45+
return prefixAssistantMessage(context, null);
46+
}
47+
48+
public static DeepSeekAssistantMessage prefixAssistantMessage(String context, String reasoningContent) {
49+
return new DeepSeekAssistantMessage(context, reasoningContent);
50+
}
51+
52+
public Boolean getPrefix() {
53+
return prefix;
54+
}
55+
56+
public void setPrefix(Boolean prefix) {
57+
this.prefix = prefix;
58+
}
59+
60+
public String getReasoningContent() {
61+
return reasoningContent;
62+
}
63+
64+
public void setReasoningContent(String reasoningContent) {
65+
this.reasoningContent = reasoningContent;
66+
}
67+
68+
@Override
69+
public boolean equals(Object o) {
70+
if (this == o) {
71+
return true;
72+
}
73+
if (!(o instanceof DeepSeekAssistantMessage that)) {
74+
return false;
75+
}
76+
if (!super.equals(o)) {
77+
return false;
78+
}
79+
return Objects.equals(this.reasoningContent, that.reasoningContent) && Objects.equals(this.prefix, that.prefix);
80+
}
81+
82+
@Override
83+
public int hashCode() {
84+
return Objects.hash(super.hashCode(), this.prefix, this.reasoningContent);
85+
}
86+
87+
@Override
88+
public String toString() {
89+
return "AssistantMessage [messageType=" + this.messageType + ", toolCalls=" + super.getToolCalls()
90+
+ ", textContent=" + this.textContent + ", reasoningContent=" + this.reasoningContent + ", prefix="
91+
+ this.prefix + ", metadata=" + this.metadata + "]";
92+
}
93+
94+
}

models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,10 @@ private Generation buildGeneration(Choice choice, Map<String, Object> metadata)
314314
var generationMetadataBuilder = ChatGenerationMetadata.builder().finishReason(finishReason);
315315

316316
String textContent = choice.message().content();
317+
String reasoningContent = choice.message().reasoningContent();
317318

318-
AssistantMessage assistantMessage = new AssistantMessage(textContent, metadata, toolCalls);
319+
DeepSeekAssistantMessage assistantMessage = new DeepSeekAssistantMessage(textContent, reasoningContent,
320+
metadata, toolCalls);
319321
return new Generation(assistantMessage, generationMetadataBuilder.build());
320322
}
321323

@@ -416,9 +418,13 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
416418
return new ToolCall(toolCall.id(), toolCall.type(), function);
417419
}).toList();
418420
}
419-
boolean isPrefixAssistantMessage = message instanceof PrefixCompletionAssistantMessage;
421+
Boolean isPrefixAssistantMessage = null;
422+
if (message instanceof DeepSeekAssistantMessage
423+
&& Boolean.TRUE.equals(((DeepSeekAssistantMessage) message).getPrefix())) {
424+
isPrefixAssistantMessage = true;
425+
}
420426
return List.of(new ChatCompletionMessage(assistantMessage.getText(),
421-
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, isPrefixAssistantMessage));
427+
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, isPrefixAssistantMessage, null));
422428
}
423429
else if (message.getMessageType() == MessageType.TOOL) {
424430
ToolResponseMessage toolMessage = (ToolResponseMessage) message;

models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/PrefixCompletionAssistantMessage.java

Lines changed: 0 additions & 28 deletions
This file was deleted.

models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
import com.fasterxml.jackson.annotation.JsonInclude;
2222
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2323
import com.fasterxml.jackson.annotation.JsonProperty;
24-
import org.springframework.ai.deepseek.PrefixCompletionAssistantMessage;
25-
import org.springframework.ai.deepseek.api.common.DeepSeekConstants;
2624
import org.springframework.ai.model.ApiKey;
2725
import org.springframework.ai.model.ChatModelDescription;
2826
import org.springframework.ai.model.ModelOptionsUtils;
@@ -32,7 +30,6 @@
3230
import org.springframework.http.MediaType;
3331
import org.springframework.http.ResponseEntity;
3432
import org.springframework.util.Assert;
35-
import org.springframework.util.CollectionUtils;
3633
import org.springframework.util.LinkedMultiValueMap;
3734
import org.springframework.util.MultiValueMap;
3835
import org.springframework.web.client.ResponseErrorHandler;
@@ -597,7 +594,8 @@ public record ChatCompletionMessage(// @formatter:off
597594
@JsonProperty("tool_call_id") String toolCallId,
598595
@JsonProperty("tool_calls")
599596
@JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List<ToolCall> toolCalls,
600-
@JsonProperty("prefix") Boolean prefix
597+
@JsonProperty("prefix") Boolean prefix,
598+
@JsonProperty("reasoning_content") String reasoningContent
601599
) { // @formatter:on
602600

603601
/**
@@ -607,7 +605,7 @@ public record ChatCompletionMessage(// @formatter:off
607605
* @param role The role of the author of this message.
608606
*/
609607
public ChatCompletionMessage(Object content, Role role) {
610-
this(content, role, null, null, null, null);
608+
this(content, role, null, null, null, null, null);
611609
}
612610

613611
/**
@@ -621,7 +619,7 @@ public ChatCompletionMessage(Object content, Role role) {
621619
*/
622620
public ChatCompletionMessage(Object content, Role role, String name, String toolCallId,
623621
List<ToolCall> toolCalls) {
624-
this(content, role, name, toolCallId, toolCalls, null);
622+
this(content, role, name, toolCallId, toolCalls, null, null);
625623
}
626624

627625
/**

models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelFunctionCallingIT.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
package org.springframework.ai.deepseek.chat;
1818

19-
import org.junit.jupiter.api.Disabled;
2019
import org.junit.jupiter.api.Test;
20+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2121
import org.slf4j.Logger;
2222
import org.slf4j.LoggerFactory;
2323
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -48,8 +48,9 @@
4848
* @author Geng Rong
4949
*/
5050
@SpringBootTest(classes = DeepSeekTestConfiguration.class)
51-
@Disabled("the deepseek-chat model's Function Calling capability is unstable see: https://api-docs.deepseek.com/guides/function_calling")
52-
// @EnabledIfEnvironmentVariable(named = "DEEPSEEK_API_KEY", matches = ".+")
51+
// @Disabled("the deepseek-chat model's Function Calling capability is unstable see:
52+
// https://api-docs.deepseek.com/guides/function_calling")
53+
@EnabledIfEnvironmentVariable(named = "DEEPSEEK_API_KEY", matches = ".+")
5354
class DeepSeekChatModelFunctionCallingIT {
5455

5556
private static final Logger logger = LoggerFactory.getLogger(DeepSeekChatModelFunctionCallingIT.class);
@@ -155,8 +156,11 @@ public void toolFunctionCallWithUsage() {
155156
assertThat(chatResponse).isNotNull();
156157
assertThat(chatResponse.getResult().getOutput());
157158
assertThat(chatResponse.getResult().getOutput().getText()).contains("San Francisco");
158-
assertThat(chatResponse.getResult().getOutput().getText()).contains("30.0");
159-
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280);
159+
assertThat(chatResponse.getResult().getOutput().getText()).contains("30");
160+
// 这个 total token 是第一次 chat 以及 tool call 之后的两次请求 token 总和
161+
162+
// the total token is first chat and tool call request
163+
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(700).isGreaterThan(280);
160164
}
161165

162166
@Test
@@ -176,7 +180,7 @@ public void testStreamFunctionCallUsage() {
176180
assertThat(chatResponse).isNotNull();
177181
assertThat(chatResponse.getMetadata()).isNotNull();
178182
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
179-
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280);
183+
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(700).isGreaterThan(280);
180184
}
181185

182186
}

models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelIT.java

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,19 @@
3232
import org.springframework.ai.converter.BeanOutputConverter;
3333
import org.springframework.ai.converter.ListOutputConverter;
3434
import org.springframework.ai.converter.MapOutputConverter;
35+
import org.springframework.ai.deepseek.DeepSeekChatOptions;
3536
import org.springframework.ai.deepseek.DeepSeekTestConfiguration;
36-
import org.springframework.ai.deepseek.PrefixCompletionAssistantMessage;
37+
import org.springframework.ai.deepseek.DeepSeekAssistantMessage;
38+
import org.springframework.ai.deepseek.api.DeepSeekApi;
39+
import org.springframework.ai.deepseek.api.MockWeatherService;
40+
import org.springframework.ai.tool.function.FunctionToolCallback;
3741
import org.springframework.beans.factory.annotation.Autowired;
3842
import org.springframework.beans.factory.annotation.Value;
3943
import org.springframework.boot.test.context.SpringBootTest;
4044
import org.springframework.core.convert.support.DefaultConversionService;
4145
import org.springframework.core.io.Resource;
4246

43-
import java.util.Arrays;
44-
import java.util.List;
45-
import java.util.Map;
47+
import java.util.*;
4648
import java.util.stream.Collectors;
4749

4850
import static org.assertj.core.api.Assertions.assertThat;
@@ -71,7 +73,7 @@ void roleTest() {
7173
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
7274
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
7375
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
74-
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
76+
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
7577
ChatResponse response = chatModel.call(prompt);
7678
assertThat(response.getResults()).hasSize(1);
7779
assertThat(response.getResults().get(0).getOutput().getText()).contains("Blackbeard");
@@ -181,7 +183,7 @@ void beanStreamOutputConverterRecords() {
181183
.map(ChatResponse::getResults)
182184
.flatMap(List::stream)
183185
.map(Generation::getOutput)
184-
.map(AssistantMessage::getText)
186+
.map(m -> m.getText() != null ? m.getText() : "")
185187
.collect(Collectors.joining());
186188

187189
ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream);
@@ -207,11 +209,56 @@ void prefixCompletionTest() {
207209
```
208210
""";
209211
UserMessage userMessage = new UserMessage(userMessageContent);
210-
Message assistantMessage = new PrefixCompletionAssistantMessage(
211-
"{\"code\":200,\"result\":{\"total\":1,\"data\":[1");
212+
Message assistantMessage = new DeepSeekAssistantMessage("{\"code\":200,\"result\":{\"total\":1,\"data\":[1");
212213
Prompt prompt = new Prompt(List.of(userMessage, assistantMessage));
213214
ChatResponse response = chatModel.call(prompt);
214215
assertThat(response.getResult().getOutput().getText().equals(",2,3]}}"));
215216
}
216217

218+
/**
219+
* For deepseek-reasoner model only. The reasoning contents of the assistant message,
220+
* before the final answer.
221+
*/
222+
@Test
223+
void reasonerModelTest() {
224+
var promptOptions = DeepSeekChatOptions.builder()
225+
.model(DeepSeekApi.ChatModel.DEEPSEEK_REASONER.getValue())
226+
.build();
227+
Prompt prompt = new Prompt("9.11 and 9.8, which is greater?", promptOptions);
228+
ChatResponse response = chatModel.call(prompt);
229+
230+
DeepSeekAssistantMessage deepSeekAssistantMessage = (DeepSeekAssistantMessage) response.getResult().getOutput();
231+
assertThat(deepSeekAssistantMessage.getReasoningContent()).isNotEmpty();
232+
assertThat(deepSeekAssistantMessage.getText()).isNotEmpty();
233+
}
234+
235+
/**
236+
* the deepseek-reasoner model Multi-round Conversation.
237+
*/
238+
@Test
239+
void reasonerModelMultiRoundTest() {
240+
List<Message> messages = new ArrayList<>();
241+
messages.add(new UserMessage("9.11 and 9.8, which is greater?"));
242+
var promptOptions = DeepSeekChatOptions.builder()
243+
.model(DeepSeekApi.ChatModel.DEEPSEEK_REASONER.getValue())
244+
.build();
245+
246+
Prompt prompt = new Prompt(messages, promptOptions);
247+
ChatResponse response = chatModel.call(prompt);
248+
249+
DeepSeekAssistantMessage deepSeekAssistantMessage = (DeepSeekAssistantMessage) response.getResult().getOutput();
250+
assertThat(deepSeekAssistantMessage.getReasoningContent()).isNotEmpty();
251+
assertThat(deepSeekAssistantMessage.getText()).isNotEmpty();
252+
253+
messages.add(new AssistantMessage(Objects.requireNonNull(deepSeekAssistantMessage.getText())));
254+
messages.add(new UserMessage("How many Rs are there in the word 'strawberry'?"));
255+
Prompt prompt2 = new Prompt(messages, promptOptions);
256+
ChatResponse response2 = chatModel.call(prompt2);
257+
258+
DeepSeekAssistantMessage deepSeekAssistantMessage2 = (DeepSeekAssistantMessage) response2.getResult()
259+
.getOutput();
260+
assertThat(deepSeekAssistantMessage2.getReasoningContent()).isNotEmpty();
261+
assertThat(deepSeekAssistantMessage2.getText()).isNotEmpty();
262+
}
263+
217264
}
50.3 KB
Loading

0 commit comments

Comments
 (0)