Skip to content

Commit 33c997c

Browse files
committed
Add reasoningContent support to OpenAiChatModel and related classes (For deepseek-reasoner https://api-docs.deepseek.com/api/create-chat-completion)
- Added reasoningContent field to metadata in OpenAiChatModel - Updated ChatCompletionMessage to include reasoningContent - Modified OpenAiStreamFunctionCallingHelper to handle reasoningContent - Updated tests to verify reasoningContent functionality Signed-off-by: Alexandros Pappas <[email protected]>
1 parent 171b758 commit 33c997c

File tree

5 files changed

+74
-11
lines changed

5 files changed

+74
-11
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
258258
"role", choice.message().role() != null ? choice.message().role().name() : "",
259259
"index", choice.index(),
260260
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
261-
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
261+
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
262+
"reasoningContent", StringUtils.hasText(choice.message().reasoningContent()) ? choice.message().reasoningContent() : "");
262263
// @formatter:on
263264
return buildGeneration(choice, metadata, request);
264265
}).toList();
@@ -346,7 +347,8 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
346347
"role", roleMap.getOrDefault(id, ""),
347348
"index", choice.index(),
348349
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
349-
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
350+
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
351+
"reasoningContent", StringUtils.hasText(choice.message().reasoningContent()) ? choice.message().reasoningContent() : "");
350352

351353
return buildGeneration(choice, metadata, request);
352354
}).toList();
@@ -543,7 +545,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
543545

544546
}
545547
return List.of(new ChatCompletionMessage(assistantMessage.getText(),
546-
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput));
548+
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null));
547549
}
548550
else if (message.getMessageType() == MessageType.TOOL) {
549551
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
@@ -553,7 +555,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
553555
return toolMessage.getResponses()
554556
.stream()
555557
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
556-
tr.id(), null, null, null))
558+
tr.id(), null, null, null, null))
557559
.toList();
558560
}
559561
else {

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,8 +1138,11 @@ public record StreamOptions(
11381138
* {@link Role#ASSISTANT} role and null otherwise.
11391139
* @param audioOutput Audio response from the model. >>>>>>> bdb66e577 (OpenAI -
11401140
* Support audio input modality)
1141+
* @param reasoningContent For deepseek-reasoner model only. The reasoning contents of
1142+
* the assistant message, before the final answer.
11411143
*/
11421144
@JsonInclude(Include.NON_NULL)
1145+
@JsonIgnoreProperties(ignoreUnknown = true)
11431146
public record ChatCompletionMessage(// @formatter:off
11441147
@JsonProperty("content") Object rawContent,
11451148
@JsonProperty("role") Role role,
@@ -1148,7 +1151,8 @@ public record ChatCompletionMessage(// @formatter:off
11481151
@JsonProperty("tool_calls")
11491152
@JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List<ToolCall> toolCalls,
11501153
@JsonProperty("refusal") String refusal,
1151-
@JsonProperty("audio") AudioOutput audioOutput) { // @formatter:on
1154+
@JsonProperty("audio") AudioOutput audioOutput,
1155+
@JsonProperty("reasoning_content") String reasoningContent) { // @formatter:on
11521156

11531157
/**
11541158
* Create a chat completion message with the given content and role. All other
@@ -1157,7 +1161,7 @@ public record ChatCompletionMessage(// @formatter:off
11571161
* @param role The role of the author of this message.
11581162
*/
11591163
public ChatCompletionMessage(Object content, Role role) {
1160-
this(content, role, null, null, null, null, null);
1164+
this(content, role, null, null, null, null, null, null);
11611165

11621166
}
11631167

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
9797
String refusal = (current.refusal() != null ? current.refusal() : previous.refusal());
9898
ChatCompletionMessage.AudioOutput audioOutput = (current.audioOutput() != null ? current.audioOutput()
9999
: previous.audioOutput());
100+
String reasoningContent = (current.reasoningContent() != null ? current.reasoningContent()
101+
: previous.reasoningContent());
100102

101103
List<ToolCall> toolCalls = new ArrayList<>();
102104
ToolCall lastPreviousTooCall = null;
@@ -126,7 +128,8 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
126128
toolCalls.add(lastPreviousTooCall);
127129
}
128130
}
129-
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput);
131+
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput,
132+
reasoningContent);
130133
}
131134

132135
private ToolCall merge(ToolCall previous, ToolCall current) {

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ public void toolFunctionCall() {
129129

130130
// extend conversation with function response.
131131
messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL,
132-
functionName, toolCall.id(), null, null, null));
132+
functionName, toolCall.id(), null, null, null, null));
133133
}
134134
}
135135

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DeepSeekWithOpenAiChatModelIT.java

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.springframework.context.annotation.Bean;
5757
import org.springframework.core.convert.support.DefaultConversionService;
5858
import org.springframework.core.io.Resource;
59+
import org.springframework.http.ResponseEntity;
5960

6061
import static org.assertj.core.api.Assertions.assertThat;
6162

@@ -82,6 +83,9 @@ class DeepSeekWithOpenAiChatModelIT {
8283
@Value("classpath:/prompts/system-message.st")
8384
private Resource systemResource;
8485

86+
@Autowired
87+
private OpenAiApi openAiApi;
88+
8589
@Autowired
8690
private OpenAiChatModel chatModel;
8791

@@ -128,9 +132,9 @@ void streamingWithTokenUsage() {
128132
var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage();
129133
var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage();
130134

131-
assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0);
132-
assertThat(streamingTokenUsage.getCompletionTokens()).isGreaterThan(0);
133-
assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0);
135+
assertThat(streamingTokenUsage.getPromptTokens()).isPositive();
136+
assertThat(streamingTokenUsage.getCompletionTokens()).isPositive();
137+
assertThat(streamingTokenUsage.getTotalTokens()).isPositive();
134138

135139
assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens());
136140
assertThat(streamingTokenUsage.getCompletionTokens()).isEqualTo(referenceTokenUsage.getCompletionTokens());
@@ -325,6 +329,56 @@ record ActorsFilmsRecord(String actor, List<String> movies) {
325329

326330
}
327331

332+
@Test
333+
void chatCompletionEntityWithReasoning() {
334+
OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage(
335+
"Explain the theory of relativity", OpenAiApi.ChatCompletionMessage.Role.USER);
336+
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(List.of(chatCompletionMessage),
337+
"deepseek-reasoner", 0.8, false);
338+
ResponseEntity<OpenAiApi.ChatCompletion> response = this.openAiApi.chatCompletionEntity(request);
339+
340+
assertThat(response).isNotNull();
341+
assertThat(response.getBody().choices().get(0).message().reasoningContent()).isNotBlank();
342+
}
343+
344+
@Test
345+
void chatCompletionStreamWithReasoning() {
346+
OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage(
347+
"Explain the theory of relativity", OpenAiApi.ChatCompletionMessage.Role.USER);
348+
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(List.of(chatCompletionMessage),
349+
"deepseek-reasoner", 0.8, true);
350+
Flux<OpenAiApi.ChatCompletionChunk> response = this.openAiApi.chatCompletionStream(request);
351+
352+
assertThat(response).isNotNull();
353+
List<OpenAiApi.ChatCompletionChunk> chunks = response.collectList().block();
354+
assertThat(chunks).isNotNull();
355+
assertThat(chunks.stream().anyMatch(chunk -> !chunk.choices().get(0).delta().reasoningContent().isBlank()))
356+
.isTrue();
357+
}
358+
359+
@Test
360+
void chatModelCallWithReasoning() {
361+
OpenAiChatModel deepReasoner = new OpenAiChatModel(this.openAiApi,
362+
OpenAiChatOptions.builder().model("deepseek-reasoner").build());
363+
ChatResponse chatResponse = deepReasoner.call(new Prompt("Explain the theory of relativity"));
364+
assertThat(chatResponse.getResults()).isNotEmpty();
365+
assertThat(chatResponse.getResults().get(0).getOutput().getMetadata().get("reasoningContent").toString())
366+
.isNotBlank();
367+
}
368+
369+
@Test
370+
void chatModelStreamWithReasoning() {
371+
OpenAiChatModel deepReasoner = new OpenAiChatModel(this.openAiApi,
372+
OpenAiChatOptions.builder().model("deepseek-reasoner").build());
373+
Flux<ChatResponse> flux = deepReasoner.stream(new Prompt("Explain the theory of relativity"));
374+
List<ChatResponse> responses = flux.collectList().block();
375+
assertThat(responses).isNotEmpty();
376+
assertThat(responses.stream()
377+
.flatMap(response -> response.getResults().stream())
378+
.map(result -> result.getOutput().getMetadata().get("reasoningContent").toString())
379+
.noneMatch(String::isBlank)).isTrue();
380+
}
381+
328382
@SpringBootConfiguration
329383
static class Config {
330384

0 commit comments

Comments
 (0)