Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
"index", choice.index() != null ? choice.index() : 0,
"finishReason", getFinishReasonJson(choice.finishReason()),
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
"annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of());
"annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of(),
"reasoningContent", choice.message().reasoningContent() != null ? choice.message().reasoningContent() : "");
return buildGeneration(choice, metadata, request);
}).toList();
// @formatter:on
Expand Down Expand Up @@ -606,7 +607,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {

}
return List.of(new ChatCompletionMessage(assistantMessage.getText(),
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null));
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null, null));
}
else if (message.getMessageType() == MessageType.TOOL) {
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
Expand All @@ -616,7 +617,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
return toolMessage.getResponses()
.stream()
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
tr.id(), null, null, null, null))
tr.id(), null, null, null, null, null))
.toList();
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ public class OpenAiChatOptions implements ToolCallingChatOptions {
* modalities: ["audio"]
* Note: that the audio modality is only available for the gpt-4o-audio-preview model
* and is not supported for streaming completions.

*
*/
private @JsonProperty("audio") AudioParameters outputAudio;

Expand Down Expand Up @@ -264,6 +264,8 @@ public class OpenAiChatOptions implements ToolCallingChatOptions {
@JsonIgnore
private Map<String, Object> toolContext = new HashMap<>();

private @JsonProperty("extra_body") Map<String, Object> extraBody;

// @formatter:on

public static Builder builder() {
Expand Down Expand Up @@ -306,6 +308,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) {
.webSearchOptions(fromOptions.getWebSearchOptions())
.verbosity(fromOptions.getVerbosity())
.serviceTier(fromOptions.getServiceTier())
.extraBody(fromOptions.getExtraBody())
.build();
}

Expand Down Expand Up @@ -502,6 +505,14 @@ public void setParallelToolCalls(Boolean parallelToolCalls) {
this.parallelToolCalls = parallelToolCalls;
}

public Map<String, Object> getExtraBody() {
return this.extraBody;
}

public void setExtraBody(Map<String, Object> extraBody) {
this.extraBody = extraBody;
}

@Override
@JsonIgnore
public List<ToolCallback> getToolCallbacks() {
Expand Down Expand Up @@ -630,7 +641,8 @@ public int hashCode() {
this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice,
this.user, this.parallelToolCalls, this.toolCallbacks, this.toolNames, this.httpHeaders,
this.internalToolExecutionEnabled, this.toolContext, this.outputModalities, this.outputAudio,
this.store, this.metadata, this.reasoningEffort, this.webSearchOptions, this.serviceTier);
this.store, this.metadata, this.reasoningEffort, this.webSearchOptions, this.serviceTier,
this.extraBody);
}

@Override
Expand Down Expand Up @@ -665,7 +677,8 @@ public boolean equals(Object o) {
&& Objects.equals(this.reasoningEffort, other.reasoningEffort)
&& Objects.equals(this.webSearchOptions, other.webSearchOptions)
&& Objects.equals(this.verbosity, other.verbosity)
&& Objects.equals(this.serviceTier, other.serviceTier);
&& Objects.equals(this.serviceTier, other.serviceTier)
&& Objects.equals(this.extraBody, other.extraBody);
}

@Override
Expand Down Expand Up @@ -933,6 +946,11 @@ public Builder serviceTier(OpenAiApi.ServiceTier serviceTier) {
return this;
}

public Builder extraBody(Map<String, Object> extraBody) {
this.options.extraBody = extraBody;
return this;
}

public OpenAiChatOptions build() {
return this.options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.node.ObjectNode;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

Expand Down Expand Up @@ -187,14 +190,15 @@ public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest
Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false.");
Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null.");

Object dynamicRequestBody = createDynamicRequestBody(chatRequest);
// @formatter:off
return this.restClient.post()
.uri(this.completionsPath)
.headers(headers -> {
headers.addAll(additionalHttpHeader);
addDefaultHeadersIfMissing(headers);
})
.body(chatRequest)
.body(dynamicRequestBody)
.retrieve()
.toEntity(ChatCompletion.class);
// @formatter:on
Expand All @@ -210,6 +214,29 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
return chatCompletionStream(chatRequest, new LinkedMultiValueMap<>());
}

private Object createDynamicRequestBody(ChatCompletionRequest baseRequest) {
ObjectMapper mapper = new ObjectMapper();
ObjectNode requestNode = mapper.valueToTree(baseRequest);
if (null == baseRequest.extraBody) {
return requestNode;
}

// 添加额外字段
baseRequest.extraBody().forEach((key, value) -> {
if (value instanceof Map) {
requestNode.set(key, mapper.valueToTree(value));
}
else if (value instanceof List) {
requestNode.set(key, mapper.valueToTree(value));
}
else {
requestNode.putPOJO(key, value);
}
});

return requestNode;
}

/**
* Creates a streaming chat response for the given chat conversation.
* @param chatRequest The chat completion request. Must have the stream property set
Expand All @@ -226,14 +253,23 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat

AtomicBoolean isInsideTool = new AtomicBoolean(false);

ObjectMapper objectMapper = new ObjectMapper();
try {
var s = objectMapper.writeValueAsString(chatRequest);
System.out.println("aaaaaaaa:" + s);
}
catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
Object dynamicBody = createDynamicRequestBody(chatRequest);
// @formatter:off
return this.webClient.post()
.uri(this.completionsPath)
.headers(headers -> {
headers.addAll(additionalHttpHeader);
addDefaultHeadersIfMissing(headers);
}) // @formatter:on
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
.bodyValue(dynamicBody)
.retrieve()
.bodyToFlux(String.class)
// cancels the flux stream after the "[DONE]" is received.
Expand Down Expand Up @@ -1116,7 +1152,8 @@ public record ChatCompletionRequest(// @formatter:off
@JsonProperty("user") String user,
@JsonProperty("reasoning_effort") String reasoningEffort,
@JsonProperty("web_search_options") WebSearchOptions webSearchOptions,
@JsonProperty("verbosity") String verbosity) {
@JsonProperty("verbosity") String verbosity,
@JsonProperty("extra_body") Map<String, Object> extraBody) {

/**
* Shortcut constructor for a chat completion request with the given messages, model and temperature.
Expand All @@ -1128,7 +1165,7 @@ public record ChatCompletionRequest(// @formatter:off
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature) {
this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null,
null, null, null, false, null, temperature, null,
null, null, null, null, null, null, null);
null, null, null, null, null, null, null, null);
}

/**
Expand All @@ -1142,7 +1179,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
this(messages, model, null, null, null, null, null, null,
null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null,
null, null, null, stream, null, null, null,
null, null, null, null, null, null, null);
null, null, null, null, null, null, null, null);
}

/**
Expand All @@ -1157,7 +1194,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature, boolean stream) {
this(messages, model, null, null, null, null, null, null, null, null, null,
null, null, null, null, null, null, null, stream, null, temperature, null,
null, null, null, null, null, null, null);
null, null, null, null, null, null, null, null);
}

/**
Expand All @@ -1173,7 +1210,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
List<FunctionTool> tools, Object toolChoice) {
this(messages, model, null, null, null, null, null, null, null, null, null,
null, null, null, null, null, null, null, false, null, 0.8, null,
tools, toolChoice, null, null, null, null, null);
tools, toolChoice, null, null, null, null, null, null);
}

/**
Expand All @@ -1184,9 +1221,9 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
* as they become available, with the stream terminated by a data: [DONE] message.
*/
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
this(messages, null, null, null, null, null, null, null, null, null, null,
null, null, null, null, null, null, null, stream, null, null, null,
null, null, null, null, null, null, null);
this(messages, null, null, null, null, null, null, null, null, null, null, null, null, null,
null, null, null, null, stream, null, null, null, null, null, null, null, null, null,
null, null);
}

/**
Expand All @@ -1197,9 +1234,9 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean strea
*/
public ChatCompletionRequest streamOptions(StreamOptions streamOptions) {
return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs,
this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty,
this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP,
this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions, this.verbosity);
this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty,
this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP,
this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions, this.verbosity, this.extraBody);
}

/**
Expand Down Expand Up @@ -1411,7 +1448,8 @@ public record ChatCompletionMessage(// @formatter:off
@JsonProperty("tool_calls") @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List<ToolCall> toolCalls,
@JsonProperty("refusal") String refusal,
@JsonProperty("audio") AudioOutput audioOutput,
@JsonProperty("annotations") List<Annotation> annotations
@JsonProperty("annotations") List<Annotation> annotations,
@JsonProperty("reasoning_content") String reasoningContent
) { // @formatter:on

/**
Expand All @@ -1421,7 +1459,7 @@ public record ChatCompletionMessage(// @formatter:off
* @param role The role of the author of this message.
*/
public ChatCompletionMessage(Object content, Role role) {
this(content, role, null, null, null, null, null, null);
this(content, role, null, null, null, null, null, null, null);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
/**
* Helper class to support Streaming function calling.
*
* <p>
* It can merge the streamed ChatCompletionChunk in case of function calling message.
*
* @author Christian Tzolov
Expand Down Expand Up @@ -100,6 +101,8 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) {
String content = (current.content() != null ? current.content()
: "" + ((previous.content() != null) ? previous.content() : ""));
String reasoningContent = (current.reasoningContent() != null ? current.reasoningContent()
: "" + ((previous.reasoningContent() != null) ? previous.reasoningContent() : ""));
Role role = (current.role() != null ? current.role() : previous.role());
role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null
String name = (current.name() != null ? current.name() : previous.name());
Expand Down Expand Up @@ -138,7 +141,8 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
toolCalls.add(lastPreviousTooCall);
}
}
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput, annotations);
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput, annotations,
reasoningContent);
}

private ToolCall merge(ToolCall previous, ToolCall current) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ void dynamicApiKeyWebClient() throws InterruptedException {
"role": "assistant",
"content": "Hello world"
},
"reasoning_content": "test",
"finish_reason": "stop"
}
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void validateReasoningTokens() {
"If a train travels 100 miles in 2 hours, what is its average speed?", ChatCompletionMessage.Role.USER);
ChatCompletionRequest request = new ChatCompletionRequest(List.of(userMessage), "o1", null, null, null, null,
null, null, null, null, null, null, null, null, null, null, null, null, false, null, null, null, null,
null, null, null, "low", null, null);
null, null, null, "low", null, null, null);
ResponseEntity<ChatCompletion> response = this.openAiApi.chatCompletionEntity(request);

assertThat(response).isNotNull();
Expand Down Expand Up @@ -180,7 +180,7 @@ void chatCompletionEntityWithNewModelsAndLowVerbosity(OpenAiApi.ChatModel modelN

ChatCompletionRequest request = new ChatCompletionRequest(List.of(chatCompletionMessage), // messages
modelName.getValue(), null, null, null, null, null, null, null, null, null, null, null, null, null,
null, null, null, false, null, 1.0, null, null, null, null, null, null, null, "low");
null, null, null, false, null, 1.0, null, null, null, null, null, null, null, "low", null);

ResponseEntity<ChatCompletion> response = this.openAiApi.chatCompletionEntity(request);

Expand Down Expand Up @@ -227,7 +227,7 @@ void chatCompletionEntityWithServiceTier(OpenAiApi.ServiceTier serviceTier) {
ChatCompletionRequest request = new ChatCompletionRequest(List.of(chatCompletionMessage), // messages
OpenAiApi.ChatModel.GPT_4_O.value, null, null, null, null, null, null, null, null, null, null, null,
null, null, null, serviceTier.getValue(), null, false, null, 1.0, null, null, null, null, null, null,
null, null);
null, null, null);

ResponseEntity<ChatCompletion> response = this.openAiApi.chatCompletionEntity(request);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ public void isStreamingToolFunctionCall_whenChatCompletionChunkFirstChoiceDeltaT
// Test for null.
assertion.accept(new OpenAiApi.ChatCompletionMessage(null, null));
// Test for empty.
assertion.accept(
new OpenAiApi.ChatCompletionMessage(null, null, null, null, Collections.emptyList(), null, null, null));
assertion.accept(new OpenAiApi.ChatCompletionMessage(null, null, null, null, Collections.emptyList(), null,
null, null, null));
}

@Test
Expand All @@ -102,7 +102,7 @@ public void isStreamingToolFunctionCall_whenChatCompletionChunkFirstChoiceDeltaT
};
assertion.accept(new OpenAiApi.ChatCompletionMessage(null, null, null, null,
List.of(Mockito.mock(org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall.class)),
null, null, null));
null, null, null, null));
}

@Test
Expand Down Expand Up @@ -191,7 +191,8 @@ public void isStreamingToolFunctionCallReturnsFalseForNullOrEmptyChunks() {
@Test
public void isStreamingToolFunctionCall_returnsTrueForValidToolCalls() {
var toolCall = Mockito.mock(OpenAiApi.ChatCompletionMessage.ToolCall.class);
var delta = new OpenAiApi.ChatCompletionMessage(null, null, null, null, List.of(toolCall), null, null, null);
var delta = new OpenAiApi.ChatCompletionMessage(null, null, null, null, List.of(toolCall), null, null, null,
null);
var choice = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, delta, null);
var chunk = new OpenAiApi.ChatCompletionChunk(null, List.of(choice), null, null, null, null, null, null);

Expand Down Expand Up @@ -255,7 +256,7 @@ public void merge_partialFieldsFromEachChunk() {
public void isStreamingToolFunctionCall_withMultipleChoicesAndOnlyFirstHasToolCalls() {
var toolCall = Mockito.mock(OpenAiApi.ChatCompletionMessage.ToolCall.class);
var deltaWithToolCalls = new OpenAiApi.ChatCompletionMessage(null, null, null, null, List.of(toolCall), null,
null, null);
null, null, null);
var deltaWithoutToolCalls = new OpenAiApi.ChatCompletionMessage(null, null);

var choice1 = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, deltaWithToolCalls, null);
Expand Down Expand Up @@ -327,7 +328,7 @@ public void edgeCases_emptyStringFields() {

@Test
public void isStreamingToolFunctionCall_withNullToolCallsList() {
var delta = new OpenAiApi.ChatCompletionMessage(null, null, null, null, null, null, null, null);
var delta = new OpenAiApi.ChatCompletionMessage(null, null, null, null, null, null, null, null, null);
var choice = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, delta, null);
var chunk = new OpenAiApi.ChatCompletionChunk(null, List.of(choice), null, null, null, null, null, null);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public void toolFunctionCall() {

// extend conversation with function response.
messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL,
functionName, toolCall.id(), null, null, null, null));
functionName, toolCall.id(), null, null, null, null, null));
}
}

Expand Down
Loading