Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,15 @@ public class OpenAiChatOptions implements FunctionCallingOptions {
* Developer-defined tags and values used for filtering completions in the <a href="https://platform.openai.com/chat-completions">dashboard</a>.
*/
private @JsonProperty("metadata") Map<String, String> metadata;

/**
* Constrains effort on reasoning for reasoning models. Currently supported values are low, medium, and high.
* Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response.
* Optional. Defaults to medium.
* Only for 'o1' models.
*/
private @JsonProperty("reasoning_effort") String reasoningEffort;

/**
* OpenAI Tool Function Callbacks to register with the ChatModel.
* For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution.
Expand Down Expand Up @@ -256,6 +265,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) {
.toolContext(fromOptions.getToolContext())
.store(fromOptions.getStore())
.metadata(fromOptions.getMetadata())
.reasoningEffort(fromOptions.getReasoningEffort())
.build();
}

Expand Down Expand Up @@ -520,6 +530,14 @@ public void setMetadata(Map<String, String> metadata) {
this.metadata = metadata;
}

public String getReasoningEffort() {
return this.reasoningEffort;
}

public void setReasoningEffort(String reasoningEffort) {
this.reasoningEffort = reasoningEffort;
}

@Override
public OpenAiChatOptions copy() {
return OpenAiChatOptions.fromOptions(this);
Expand All @@ -532,7 +550,7 @@ public int hashCode() {
this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice,
this.user, this.parallelToolCalls, this.functionCallbacks, this.functions, this.httpHeaders,
this.proxyToolCalls, this.toolContext, this.outputModalities, this.outputAudio, this.store,
this.metadata);
this.metadata, this.reasoningEffort);
}

@Override
Expand Down Expand Up @@ -563,7 +581,8 @@ public boolean equals(Object o) {
&& Objects.equals(this.proxyToolCalls, other.proxyToolCalls)
&& Objects.equals(this.outputModalities, other.outputModalities)
&& Objects.equals(this.outputAudio, other.outputAudio) && Objects.equals(this.store, other.store)
&& Objects.equals(this.metadata, other.metadata);
&& Objects.equals(this.metadata, other.metadata)
&& Objects.equals(this.reasoningEffort, other.reasoningEffort);
}

@Override
Expand Down Expand Up @@ -740,6 +759,11 @@ public Builder metadata(Map<String, String> metadata) {
return this;
}

public Builder reasoningEffort(String reasoningEffort) {
this.options.reasoningEffort = reasoningEffort;
return this;
}

public OpenAiChatOptions build() {
return this.options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
* @author Mariusz Bernacki
* @author Thomas Vitale
* @author David Frizelle
* @author Alexandros Pappas
*/
public class OpenAiApi {

Expand Down Expand Up @@ -826,7 +827,8 @@ public record ChatCompletionRequest(// @formatter:off
@JsonProperty("tools") List<FunctionTool> tools,
@JsonProperty("tool_choice") Object toolChoice,
@JsonProperty("parallel_tool_calls") Boolean parallelToolCalls,
@JsonProperty("user") String user) {
@JsonProperty("user") String user,
@JsonProperty("reasoning_effort") String reasoningEffort) {

/**
* Shortcut constructor for a chat completion request with the given messages, model and temperature.
Expand All @@ -838,7 +840,7 @@ public record ChatCompletionRequest(// @formatter:off
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature) {
this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null,
null, null, null, false, null, temperature, null,
null, null, null, null);
null, null, null, null, null);
}

/**
Expand All @@ -852,7 +854,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
this(messages, model, null, null, null, null, null, null,
null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null,
null, null, null, stream, null, null, null,
null, null, null, null);
null, null, null, null, null);
}

/**
Expand All @@ -867,7 +869,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature, boolean stream) {
this(messages, model, null, null, null, null, null, null, null, null, null,
null, null, null, null, null, null, null, stream, null, temperature, null,
null, null, null, null);
null, null, null, null, null);
}

/**
Expand All @@ -883,7 +885,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
List<FunctionTool> tools, Object toolChoice) {
this(messages, model, null, null, null, null, null, null, null, null, null,
null, null, null, null, null, null, null, false, null, 0.8, null,
tools, toolChoice, null, null);
tools, toolChoice, null, null, null);
}

/**
Expand All @@ -896,7 +898,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
this(messages, null, null, null, null, null, null, null, null, null, null,
null, null, null, null, null, null, null, stream, null, null, null,
null, null, null, null);
null, null, null, null, null);
}

/**
Expand All @@ -909,7 +911,7 @@ public ChatCompletionRequest streamOptions(StreamOptions streamOptions) {
return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs,
this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty,
this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP,
this.tools, this.toolChoice, this.parallelToolCalls, this.user);
this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
/**
* @author Christian Tzolov
* @author Thomas Vitale
* @author Alexandros Pappas
*/
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class OpenAiApiIT {
Expand All @@ -66,6 +67,25 @@ void chatCompletionStream() {
assertThat(response.collectList().block()).isNotNull();
}

@Test
void validateReasoningTokens() {
ChatCompletionMessage userMessage = new ChatCompletionMessage(
"If a train travels 100 miles in 2 hours, what is its average speed?", ChatCompletionMessage.Role.USER);
ChatCompletionRequest request = new ChatCompletionRequest(List.of(userMessage), "o1", null, null, null, null,
null, null, null, null, null, null, null, null, null, null, null, null, false, null, null, null, null,
null, null, null, "low");
ResponseEntity<ChatCompletion> response = this.openAiApi.chatCompletionEntity(request);

assertThat(response).isNotNull();
assertThat(response.getBody()).isNotNull();

OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = response.getBody()
.usage()
.completionTokenDetails();
assertThat(completionTokenDetails).isNotNull();
assertThat(completionTokenDetails.reasoningTokens()).isPositive();
}

@Test
void embeddings() {
ResponseEntity<EmbeddingList<Embedding>> response = this.openAiApi
Expand Down