diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 246b7893c4a..85b06a5a4d6 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -319,7 +319,8 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha "index", choice.index() != null ? choice.index() : 0, "finishReason", getFinishReasonJson(choice.finishReason()), "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "", - "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of()); + "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of(), + "reasoningContent", choice.message().reasoningContent() != null ? choice.message().reasoningContent() : ""); return buildGeneration(choice, metadata, request); }).toList(); // @formatter:on @@ -606,7 +607,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { } return List.of(new ChatCompletionMessage(assistantMessage.getText(), - ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null)); + ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null, null)); } else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage) message; @@ -616,7 +617,7 @@ else if (message.getMessageType() == MessageType.TOOL) { return toolMessage.getResponses() .stream() .map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(), - tr.id(), null, null, null, null)) + tr.id(), null, null, null, null, null)) .toList(); } else { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 14b5ba42536..c7c2159d988 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -137,7 +137,7 @@ public class OpenAiChatOptions implements ToolCallingChatOptions { * modalities: ["audio"] * Note: that the audio modality is only available for the gpt-4o-audio-preview model * and is not supported for streaming completions. - + * */ private @JsonProperty("audio") AudioParameters outputAudio; @@ -264,6 +264,8 @@ public class OpenAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Map toolContext = new HashMap<>(); + private @JsonProperty("extra_body") Map extraBody; + // @formatter:on public static Builder builder() { @@ -306,6 +308,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { .webSearchOptions(fromOptions.getWebSearchOptions()) .verbosity(fromOptions.getVerbosity()) .serviceTier(fromOptions.getServiceTier()) + .extraBody(fromOptions.getExtraBody()) .build(); } @@ -502,6 +505,14 @@ public void setParallelToolCalls(Boolean parallelToolCalls) { this.parallelToolCalls = parallelToolCalls; } + public Map getExtraBody() { + return this.extraBody; + } + + public void setExtraBody(Map extraBody) { + this.extraBody = extraBody; + } + @Override @JsonIgnore public List getToolCallbacks() { @@ -630,7 +641,8 @@ public int hashCode() { this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice, this.user, this.parallelToolCalls, this.toolCallbacks, this.toolNames, this.httpHeaders, this.internalToolExecutionEnabled, this.toolContext, this.outputModalities, this.outputAudio, - this.store, this.metadata, this.reasoningEffort, this.webSearchOptions, this.serviceTier); + this.store, this.metadata, this.reasoningEffort, this.webSearchOptions, this.serviceTier, + this.extraBody); } @Override @@ -665,7 +677,8 @@ public boolean equals(Object o) { && Objects.equals(this.reasoningEffort, other.reasoningEffort) && Objects.equals(this.webSearchOptions, other.webSearchOptions) && Objects.equals(this.verbosity, other.verbosity) - && Objects.equals(this.serviceTier, other.serviceTier); + && Objects.equals(this.serviceTier, other.serviceTier) + && Objects.equals(this.extraBody, other.extraBody); } @Override @@ -933,6 +946,11 @@ public Builder serviceTier(OpenAiApi.ServiceTier serviceTier) { return this; } + public Builder extraBody(Map extraBody) { + this.options.extraBody = extraBody; + return this; + } + public OpenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index f7e72246e97..2a58a79c2ef 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -29,7 +29,10 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.node.ObjectNode; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -187,6 +190,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null."); + Object dynamicRequestBody = createDynamicRequestBody(chatRequest); // @formatter:off return this.restClient.post() .uri(this.completionsPath) @@ -194,7 +198,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest headers.addAll(additionalHttpHeader); addDefaultHeadersIfMissing(headers); }) - .body(chatRequest) + .body(dynamicRequestBody) .retrieve() .toEntity(ChatCompletion.class); // @formatter:on @@ -210,6 +214,29 @@ public Flux chatCompletionStream(ChatCompletionRequest chat return chatCompletionStream(chatRequest, new LinkedMultiValueMap<>()); } + private Object createDynamicRequestBody(ChatCompletionRequest baseRequest) { + ObjectMapper mapper = new ObjectMapper(); + ObjectNode requestNode = mapper.valueToTree(baseRequest); + if (null == baseRequest.extraBody) { + return requestNode; + } + + // 添加额外字段 + baseRequest.extraBody().forEach((key, value) -> { + if (value instanceof Map) { + requestNode.set(key, mapper.valueToTree(value)); + } + else if (value instanceof List) { + requestNode.set(key, mapper.valueToTree(value)); + } + else { + requestNode.putPOJO(key, value); + } + }); + + return requestNode; + } + /** * Creates a streaming chat response for the given chat conversation. * @param chatRequest The chat completion request. Must have the stream property set @@ -226,6 +253,15 @@ public Flux chatCompletionStream(ChatCompletionRequest chat AtomicBoolean isInsideTool = new AtomicBoolean(false); + ObjectMapper objectMapper = new ObjectMapper(); + try { + var s = objectMapper.writeValueAsString(chatRequest); + System.out.println("aaaaaaaa:" + s); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + Object dynamicBody = createDynamicRequestBody(chatRequest); // @formatter:off return this.webClient.post() .uri(this.completionsPath) @@ -233,7 +269,7 @@ public Flux chatCompletionStream(ChatCompletionRequest chat headers.addAll(additionalHttpHeader); addDefaultHeadersIfMissing(headers); }) // @formatter:on - .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .bodyValue(dynamicBody) .retrieve() .bodyToFlux(String.class) // cancels the flux stream after the "[DONE]" is received. @@ -1116,7 +1152,8 @@ public record ChatCompletionRequest(// @formatter:off @JsonProperty("user") String user, @JsonProperty("reasoning_effort") String reasoningEffort, @JsonProperty("web_search_options") WebSearchOptions webSearchOptions, - @JsonProperty("verbosity") String verbosity) { + @JsonProperty("verbosity") String verbosity, + @JsonProperty("extra_body") Map extraBody) { /** * Shortcut constructor for a chat completion request with the given messages, model and temperature. @@ -1128,7 +1165,7 @@ public record ChatCompletionRequest(// @formatter:off public ChatCompletionRequest(List messages, String model, Double temperature) { this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false, null, temperature, null, - null, null, null, null, null, null, null); + null, null, null, null, null, null, null, null); } /** @@ -1142,7 +1179,7 @@ public ChatCompletionRequest(List messages, String model, this(messages, model, null, null, null, null, null, null, null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null, null, null, null, stream, null, null, null, - null, null, null, null, null, null, null); + null, null, null, null, null, null, null, null); } /** @@ -1157,7 +1194,7 @@ public ChatCompletionRequest(List messages, String model, public ChatCompletionRequest(List messages, String model, Double temperature, boolean stream) { this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, stream, null, temperature, null, - null, null, null, null, null, null, null); + null, null, null, null, null, null, null, null); } /** @@ -1173,7 +1210,7 @@ public ChatCompletionRequest(List messages, String model, List tools, Object toolChoice) { this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false, null, 0.8, null, - tools, toolChoice, null, null, null, null, null); + tools, toolChoice, null, null, null, null, null, null); } /** @@ -1184,9 +1221,9 @@ public ChatCompletionRequest(List messages, String model, * as they become available, with the stream terminated by a data: [DONE] message. */ public ChatCompletionRequest(List messages, Boolean stream) { - this(messages, null, null, null, null, null, null, null, null, null, null, - null, null, null, null, null, null, null, stream, null, null, null, - null, null, null, null, null, null, null); + this(messages, null, null, null, null, null, null, null, null, null, null, null, null, null, + null, null, null, null, stream, null, null, null, null, null, null, null, null, null, + null, null); } /** @@ -1197,9 +1234,9 @@ public ChatCompletionRequest(List messages, Boolean strea */ public ChatCompletionRequest streamOptions(StreamOptions streamOptions) { return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs, - this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty, - this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP, - this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions, this.verbosity); + this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty, + this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP, + this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions, this.verbosity, this.extraBody); } /** @@ -1411,7 +1448,8 @@ public record ChatCompletionMessage(// @formatter:off @JsonProperty("tool_calls") @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List toolCalls, @JsonProperty("refusal") String refusal, @JsonProperty("audio") AudioOutput audioOutput, - @JsonProperty("annotations") List annotations + @JsonProperty("annotations") List annotations, + @JsonProperty("reasoning_content") String reasoningContent ) { // @formatter:on /** @@ -1421,7 +1459,7 @@ public record ChatCompletionMessage(// @formatter:off * @param role The role of the author of this message. */ public ChatCompletionMessage(Object content, Role role) { - this(content, role, null, null, null, null, null, null); + this(content, role, null, null, null, null, null, null, null); } /** diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java index 55d74fa7b52..3558760a54f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java @@ -36,6 +36,7 @@ /** * Helper class to support Streaming function calling. * + *

* It can merge the streamed ChatCompletionChunk in case of function calling message. * * @author Christian Tzolov @@ -100,6 +101,8 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) { private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) { String content = (current.content() != null ? current.content() : "" + ((previous.content() != null) ? previous.content() : "")); + String reasoningContent = (current.reasoningContent() != null ? current.reasoningContent() + : "" + ((previous.reasoningContent() != null) ? previous.reasoningContent() : "")); Role role = (current.role() != null ? current.role() : previous.role()); role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null String name = (current.name() != null ? current.name() : previous.name()); @@ -138,7 +141,8 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti toolCalls.add(lastPreviousTooCall); } } - return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput, annotations); + return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput, annotations, + reasoningContent); } private ToolCall merge(ToolCall previous, ToolCall current) { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java index e2858c9e46d..01b7635418c 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java @@ -412,6 +412,7 @@ void dynamicApiKeyWebClient() throws InterruptedException { "role": "assistant", "content": "Hello world" }, + "reasoning_content": "test", "finish_reason": "stop" } ], diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java index d050a621034..55c223d818a 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java @@ -77,7 +77,7 @@ void validateReasoningTokens() { "If a train travels 100 miles in 2 hours, what is its average speed?", ChatCompletionMessage.Role.USER); ChatCompletionRequest request = new ChatCompletionRequest(List.of(userMessage), "o1", null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false, null, null, null, null, - null, null, null, "low", null, null); + null, null, null, "low", null, null, null); ResponseEntity response = this.openAiApi.chatCompletionEntity(request); assertThat(response).isNotNull(); @@ -180,7 +180,7 @@ void chatCompletionEntityWithNewModelsAndLowVerbosity(OpenAiApi.ChatModel modelN ChatCompletionRequest request = new ChatCompletionRequest(List.of(chatCompletionMessage), // messages modelName.getValue(), null, null, null, null, null, null, null, null, null, null, null, null, null, - null, null, null, false, null, 1.0, null, null, null, null, null, null, null, "low"); + null, null, null, false, null, 1.0, null, null, null, null, null, null, null, "low", null); ResponseEntity response = this.openAiApi.chatCompletionEntity(request); @@ -227,7 +227,7 @@ void chatCompletionEntityWithServiceTier(OpenAiApi.ServiceTier serviceTier) { ChatCompletionRequest request = new ChatCompletionRequest(List.of(chatCompletionMessage), // messages OpenAiApi.ChatModel.GPT_4_O.value, null, null, null, null, null, null, null, null, null, null, null, null, null, null, serviceTier.getValue(), null, false, null, 1.0, null, null, null, null, null, null, - null, null); + null, null, null); ResponseEntity response = this.openAiApi.chatCompletionEntity(request); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java index d67245dc49e..6e9967908e5 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java @@ -88,8 +88,8 @@ public void isStreamingToolFunctionCall_whenChatCompletionChunkFirstChoiceDeltaT // Test for null. assertion.accept(new OpenAiApi.ChatCompletionMessage(null, null)); // Test for empty. - assertion.accept( - new OpenAiApi.ChatCompletionMessage(null, null, null, null, Collections.emptyList(), null, null, null)); + assertion.accept(new OpenAiApi.ChatCompletionMessage(null, null, null, null, Collections.emptyList(), null, + null, null, null)); } @Test @@ -102,7 +102,7 @@ public void isStreamingToolFunctionCall_whenChatCompletionChunkFirstChoiceDeltaT }; assertion.accept(new OpenAiApi.ChatCompletionMessage(null, null, null, null, List.of(Mockito.mock(org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall.class)), - null, null, null)); + null, null, null, null)); } @Test @@ -191,7 +191,8 @@ public void isStreamingToolFunctionCallReturnsFalseForNullOrEmptyChunks() { @Test public void isStreamingToolFunctionCall_returnsTrueForValidToolCalls() { var toolCall = Mockito.mock(OpenAiApi.ChatCompletionMessage.ToolCall.class); - var delta = new OpenAiApi.ChatCompletionMessage(null, null, null, null, List.of(toolCall), null, null, null); + var delta = new OpenAiApi.ChatCompletionMessage(null, null, null, null, List.of(toolCall), null, null, null, + null); var choice = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, delta, null); var chunk = new OpenAiApi.ChatCompletionChunk(null, List.of(choice), null, null, null, null, null, null); @@ -255,7 +256,7 @@ public void merge_partialFieldsFromEachChunk() { public void isStreamingToolFunctionCall_withMultipleChoicesAndOnlyFirstHasToolCalls() { var toolCall = Mockito.mock(OpenAiApi.ChatCompletionMessage.ToolCall.class); var deltaWithToolCalls = new OpenAiApi.ChatCompletionMessage(null, null, null, null, List.of(toolCall), null, - null, null); + null, null, null); var deltaWithoutToolCalls = new OpenAiApi.ChatCompletionMessage(null, null); var choice1 = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, deltaWithToolCalls, null); @@ -327,7 +328,7 @@ public void edgeCases_emptyStringFields() { @Test public void isStreamingToolFunctionCall_withNullToolCallsList() { - var delta = new OpenAiApi.ChatCompletionMessage(null, null, null, null, null, null, null, null); + var delta = new OpenAiApi.ChatCompletionMessage(null, null, null, null, null, null, null, null, null); var choice = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, delta, null); var chunk = new OpenAiApi.ChatCompletionChunk(null, List.of(choice), null, null, null, null, null, null); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java index e655de46421..3ce6e3f24e6 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java @@ -130,7 +130,7 @@ public void toolFunctionCall() { // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL, - functionName, toolCall.id(), null, null, null, null)); + functionName, toolCall.id(), null, null, null, null, null)); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiStreamingFinishReasonTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiStreamingFinishReasonTests.java index 3dc59444e82..1664782314b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiStreamingFinishReasonTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiStreamingFinishReasonTests.java @@ -124,7 +124,8 @@ void testJsonDeserializationWithEmptyStringFinishReason() throws JsonProcessingE "index": 0, "delta": { "role": "assistant", - "content": "" + "content": "", + "reasoning_content": "" }, "finish_reason": "" }] @@ -161,7 +162,8 @@ void testJsonDeserializationWithNullFinishReason() throws JsonProcessingExceptio "index": 0, "delta": { "role": "assistant", - "content": "Hello" + "content": "Hello", + "reasoning_content": "test" }, "finish_reason": null }] @@ -176,6 +178,7 @@ void testJsonDeserializationWithNullFinishReason() throws JsonProcessingExceptio var choice = chunk.choices().get(0); assertThat(choice.finishReason()).isNull(); assertThat(choice.delta().content()).isEqualTo("Hello"); + assertThat(choice.delta().reasoningContent()).isEqualTo("test"); } @Test