diff --git a/models/spring-ai-openai/pom.xml b/models/spring-ai-openai/pom.xml
index ad238b9a55f..8e0279ff969 100644
--- a/models/spring-ai-openai/pom.xml
+++ b/models/spring-ai-openai/pom.xml
@@ -34,6 +34,12 @@
2.0.4
+
+
+ org.springframework.boot
+ spring-boot
+
+
io.rest-assured
json-path
diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java
index 3e18671e0ca..15e24f580ba 100644
--- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java
+++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java
@@ -28,14 +28,16 @@
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
-import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException;
+import org.springframework.ai.openai.api.OpenAiChatOptions;
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
import org.springframework.http.ResponseEntity;
@@ -57,12 +59,13 @@
*/
public class OpenAiChatClient implements ChatClient, StreamingChatClient {
- private Double temperature = 0.7;
-
- private String model = "gpt-3.5-turbo";
-
private final Logger logger = LoggerFactory.getLogger(getClass());
+ private OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder()
+ .withModel("gpt-3.5-turbo")
+ .withTemperature(0.7f)
+ .build();
+
public final RetryTemplate retryTemplate = RetryTemplate.builder()
.maxAttempts(10)
.retryOn(OpenAiApiException.class)
@@ -76,40 +79,46 @@ public OpenAiChatClient(OpenAiApi openAiApi) {
this.openAiApi = openAiApi;
}
- public String getModel() {
- return this.model;
- }
-
- public void setModel(String model) {
- this.model = model;
- }
-
- public Double getTemperature() {
- return this.temperature;
- }
-
- public void setTemperature(Double temperature) {
- this.temperature = temperature;
+ public OpenAiChatClient withDefaultOptions(OpenAiChatOptions options) {
+ this.defaultOptions = options;
+ return this;
}
@Override
public ChatResponse call(Prompt prompt) {
return this.retryTemplate.execute(ctx -> {
- List messages = prompt.getInstructions();
- List chatCompletionMessages = messages.stream()
- .map(m -> new ChatCompletionMessage(m.getContent(),
- ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
- .toList();
+ ChatCompletionRequest request = createRequest(prompt, false);
+
+ // List messages = prompt.getInstructions();
+
+ // List chatCompletionMessages = messages.stream()
+ // .map(m -> new ChatCompletionMessage(m.getContent(),
+ // ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
+ // .toList();
+
+ // ChatCompletionRequest request =
+ // ChatCompletionRequest.from(chatCompletionMessages, this.defaultOptions,
+ // false);
- ResponseEntity completionEntity = this.openAiApi
- .chatCompletionEntity(new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, this.model,
- this.temperature.floatValue()));
+ // if (prompt.getOptions() != null) {
+ // if (prompt.getOptions() instanceof OpenAiChatOptions runtimeOptions) {
+ // request = ModelOptionsUtils.merge(runtimeOptions, request,
+ // ChatCompletionRequest.class);
+ // }
+ // else {
+ // throw new IllegalArgumentException("Prompt options are not of type
+ // ChatCompletionRequest:"
+ // + prompt.getOptions().getClass().getSimpleName());
+ // }
+ // }
+
+ ResponseEntity completionEntity = this.openAiApi.chatCompletionEntity(request);
var chatCompletion = completionEntity.getBody();
if (chatCompletion == null) {
- logger.warn("No chat completion returned for request: {}", chatCompletionMessages);
+ logger.warn("No chat completion returned for request: {}", prompt);
return new ChatResponse(List.of());
}
@@ -128,16 +137,32 @@ public ChatResponse call(Prompt prompt) {
@Override
public Flux stream(Prompt prompt) {
return this.retryTemplate.execute(ctx -> {
- List messages = prompt.getInstructions();
+ ChatCompletionRequest request = createRequest(prompt, true);
+
+ // List messages = prompt.getInstructions();
+
+ // List chatCompletionMessages = messages.stream()
+ // .map(m -> new ChatCompletionMessage(m.getContent(),
+ // ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
+ // .toList();
- List chatCompletionMessages = messages.stream()
- .map(m -> new ChatCompletionMessage(m.getContent(),
- ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
- .toList();
+ // ChatCompletionRequest request =
+ // ChatCompletionRequest.from(chatCompletionMessages, this.defaultOptions,
+ // true);
- Flux completionChunks = this.openAiApi
- .chatCompletionStream(new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, this.model,
- this.temperature.floatValue(), true));
+ // if (prompt.getOptions() != null) {
+ // if (prompt.getOptions() instanceof OpenAiChatOptions runtimeOptions) {
+ // request = ModelOptionsUtils.merge(runtimeOptions, request,
+ // ChatCompletionRequest.class);
+ // }
+ // else {
+ // throw new IllegalArgumentException("Prompt options are not of type
+ // ChatCompletionRequest:"
+ // + prompt.getOptions().getClass().getSimpleName());
+ // }
+ // }
+
+ Flux completionChunks = this.openAiApi.chatCompletionStream(request);
// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
@@ -161,4 +186,30 @@ public Flux stream(Prompt prompt) {
});
}
+ /**
+ * Accessible for testing.
+ */
+ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
+
+ List chatCompletionMessages = prompt.getInstructions()
+ .stream()
+ .map(m -> new ChatCompletionMessage(m.getContent(),
+ ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
+ .toList();
+
+ ChatCompletionRequest request = ChatCompletionRequest.from(chatCompletionMessages, this.defaultOptions, stream);
+
+ if (prompt.getOptions() != null) {
+ if (prompt.getOptions() instanceof OpenAiChatOptions runtimeOptions) {
+ request = ModelOptionsUtils.merge(runtimeOptions, request, ChatCompletionRequest.class);
+ }
+ else {
+ throw new IllegalArgumentException("Prompt options are not of type ChatCompletionRequest:"
+ + prompt.getOptions().getClass().getSimpleName());
+ }
+ }
+
+ return request;
+ }
+
}
diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java
index 989bd302fea..fec77a0107d 100644
--- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java
+++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java
@@ -31,6 +31,8 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
+import org.springframework.ai.model.ModelOptionsUtils;
+import org.springframework.boot.context.properties.bind.ConstructorBinding;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
@@ -183,6 +185,7 @@ public record FunctionTool(
* Create a tool of type 'function' and the given function definition.
* @param function function definition.
*/
+ @ConstructorBinding
public FunctionTool(Function function) {
this(Type.FUNCTION, function);
}
@@ -219,146 +222,57 @@ public record Function(
* @param name tool function name.
* @param jsonSchema tool function schema as json.
*/
+ @ConstructorBinding
public Function(String description, String name, String jsonSchema) {
this(description, name, parseJson(jsonSchema));
}
}
}
- /**
- * Creates a model response for the given chat conversation.
- *
- * @param messages A list of messages comprising the conversation so far.
- * @param model ID of the model to use.
- * @param frequencyPenalty Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing
- * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
- * @param logitBias Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object
- * that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100.
- * Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will
- * vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100
- * or 100 should result in a ban or exclusive selection of the relevant token.
- * @param maxTokens The maximum number of tokens to generate in the chat completion. The total length of input
- * tokens and generated tokens is limited by the model's context length.
- * @param n How many chat completion choices to generate for each input message. Note that you will be charged based
- * on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
- * @param presencePenalty Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they
- * appear in the text so far, increasing the model's likelihood to talk about new topics.
- * @param responseFormat An object specifying the format that the model must output. Setting to { "type":
- * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON.
- * @param seed This feature is in Beta. If specified, our system will make a best effort to sample
- * deterministically, such that repeated requests with the same seed and parameters should return the same result.
- * Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor
- * changes in the backend.
- * @param stop Up to 4 sequences where the API will stop generating further tokens.
- * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events as
- * they become available, with the stream terminated by a data: [DONE] message.
- * @param temperature What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output
- * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend
- * altering this or top_p but not both.
- * @param topP An alternative to sampling with temperature, called nucleus sampling, where the model considers the
- * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%
- * probability mass are considered. We generally recommend altering this or temperature but not both.
- * @param tools A list of tools the model may call. Currently, only functions are supported as a tool. Use this to
- * provide a list of functions the model may generate JSON inputs for.
- * @param toolChoice Controls which (if any) function is called by the model. none means the model will not call a
- * function and instead generates a message. auto means the model can pick between generating a message or calling a
- * function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces
- * the model to call that function. none is the default when no functions are present. auto is the default if
- * functions are present.
- * @param user A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
- *
- */
@JsonInclude(Include.NON_NULL)
- public record ChatCompletionRequest(
- @JsonProperty("messages") List messages,
- @JsonProperty("model") String model,
- @JsonProperty("frequency_penalty") Float frequencyPenalty,
- @JsonProperty("logit_bias") Map logitBias,
- @JsonProperty("max_tokens") Integer maxTokens,
- @JsonProperty("n") Integer n,
- @JsonProperty("presence_penalty") Float presencePenalty,
- @JsonProperty("response_format") ResponseFormat responseFormat,
- @JsonProperty("seed") Integer seed,
- @JsonProperty("stop") String stop,
- @JsonProperty("stream") Boolean stream,
- @JsonProperty("temperature") Float temperature,
- @JsonProperty("top_p") Float topP,
- @JsonProperty("tools") List tools,
- @JsonProperty("tool_choice") ToolChoice toolChoice,
- @JsonProperty("user") String user) {
+ public static class ChatCompletionRequest extends OpenAiChatOptions {
/**
- * Shortcut constructor for a chat completion request with the given messages and model.
- *
- * @param messages A list of messages comprising the conversation so far.
- * @param model ID of the model to use.
- * @param temperature What sampling temperature to use, between 0 and 1.
+ * A list of chat completion messages.
*/
- public ChatCompletionRequest(List messages, String model, Float temperature) {
- this(messages, model, 0.0f, null, null, 1, 0.0f,
- null, null, null, false, temperature, null,
- null, null, null);
- }
+ private @JsonProperty("messages") List messages;
/**
- * Shortcut constructor for a chat completion request with the given messages, model and control for streaming.
- *
- * @param messages A list of messages comprising the conversation so far.
- * @param model ID of the model to use.
- * @param temperature What sampling temperature to use, between 0 and 1.
- * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events
- * as they become available, with the stream terminated by a data: [DONE] message.
+ * If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events as
+ * they become available, with the stream terminated by a data: [DONE] message.
*/
- public ChatCompletionRequest(List messages, String model, Float temperature, boolean stream) {
- this(messages, model, 0.0f, null, null, 1, 0.0f,
- null, null, null, stream, temperature, null,
- null, null, null);
+ private @JsonProperty("stream") Boolean stream = false;
+
+
+ public static ChatCompletionRequest from(List messages, OpenAiChatOptions options, boolean stream) {
+ ChatCompletionRequest request = new ChatCompletionRequest();
+ request.setMessages(messages);
+ request.setStream(stream);
+
+ if ( options != null) {
+ request = ModelOptionsUtils.merge(request, options, ChatCompletionRequest.class);
+ }
+
+ return request;
}
- /**
- * Shortcut constructor for a chat completion request with the given messages, model, tools and tool choice.
- * Streaming is set to false, temperature to 0.8 and all other parameters are null.
- *
- * @param messages A list of messages comprising the conversation so far.
- * @param model ID of the model to use.
- * @param tools A list of tools the model may call. Currently, only functions are supported as a tool.
- * @param toolChoice Controls which (if any) function is called by the model.
- */
- public ChatCompletionRequest(List messages, String model,
- List tools, ToolChoice toolChoice) {
- this(messages, model, 0.0f, null, null, 1, 0.0f,
- null, null, null, false, 0.8f, null,
- tools, toolChoice, null);
+ public List getMessages() {
+ return messages;
}
- /**
- * Specifies a tool the model should use. Use to force the model to call a specific function.
- *
- * @param type The type of the tool. Currently, only 'function' is supported.
- * @param function single field map for type 'name':'your function name'.
- */
- @JsonInclude(Include.NON_NULL)
- public record ToolChoice(
- @JsonProperty("type") String type,
- @JsonProperty("function") Map function) {
+ public void setMessages(List messages) {
+ this.messages = messages;
+ }
- /**
- * Create a tool choice of type 'function' and name 'functionName'.
- * @param functionName Function name of the tool.
- */
- public ToolChoice(String functionName) {
- this("function", Map.of("name", functionName));
- }
+ public Boolean getStream() {
+ return stream;
}
- /**
- * An object specifying the format that the model must output.
- * @param type Must be one of 'text' or 'json_object'.
- */
- @JsonInclude(Include.NON_NULL)
- public record ResponseFormat(
- @JsonProperty("type") String type) {
+ public void setStream(Boolean stream) {
+ this.stream = stream;
}
+
+
}
/**
@@ -621,7 +535,7 @@ public record ChunkChoice(
public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) {
Assert.notNull(chatRequest, "The request body can not be null.");
- Assert.isTrue(!chatRequest.stream(), "Request must set the steam property to false.");
+ Assert.isTrue(!chatRequest.getStream(), "Request must set the steam property to false.");
return this.restClient.post()
.uri("/v1/chat/completions")
@@ -639,7 +553,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest
public Flux chatCompletionStream(ChatCompletionRequest chatRequest) {
Assert.notNull(chatRequest, "The request body can not be null.");
- Assert.isTrue(chatRequest.stream(), "Request must set the steam property to true.");
+ Assert.isTrue(chatRequest.getStream(), "Request must set the steam property to true.");
return this.webClient.post()
.uri("/v1/chat/completions")
diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiChatOptions.java
new file mode 100644
index 00000000000..77e3a3754a5
--- /dev/null
+++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiChatOptions.java
@@ -0,0 +1,483 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.openai.api;
+
+import java.util.List;
+import java.util.Map;
+
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonInclude.Include;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import org.springframework.ai.chat.ChatOptions;
+import org.springframework.ai.openai.api.OpenAiApi.FunctionTool;
+import org.springframework.boot.context.properties.bind.ConstructorBinding;
+
+/**
+ * @author Christian Tzolov
+ * @since 0.8.0
+ */
+@JsonInclude(Include.NON_NULL)
+public class OpenAiChatOptions implements ChatOptions {
+
+ // @formatter:off
+ /**
+ * ID of the model to use.
+ */
+ @JsonProperty("model") String model;
+ /**
+ * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing
+ * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
+ */
+ @JsonProperty("frequency_penalty") Float frequencyPenalty;
+ /**
+ * Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object
+ * that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100.
+ * Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will
+ * vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100
+ * or 100 should result in a ban or exclusive selection of the relevant token.
+ */
+ @JsonProperty("logit_bias") Map logitBias;
+ /**
+ * The maximum number of tokens to generate in the chat completion. The total length of input
+ * tokens and generated tokens is limited by the model's context length.
+ */
+ @JsonProperty("max_tokens") Integer maxTokens;
+ /**
+ * How many chat completion choices to generate for each input message. Note that you will be charged based
+ * on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
+ */
+ @JsonProperty("n") Integer n;
+ /**
+ * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they
+ * appear in the text so far, increasing the model's likelihood to talk about new topics.
+ */
+ @JsonProperty("presence_penalty") Float presencePenalty;
+ /**
+ * An object specifying the format that the model must output. Setting to { "type":
+ * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON.
+ */
+ @JsonProperty("response_format") ResponseFormat responseFormat;
+ /**
+ * This feature is in Beta. If specified, our system will make a best effort to sample
+ * deterministically, such that repeated requests with the same seed and parameters should return the same result.
+ * Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor
+ * changes in the backend.
+ */
+ @JsonProperty("seed") Integer seed;
+ /**
+ * Up to 4 sequences where the API will stop generating further tokens.
+ */
+ @JsonProperty("stop") List stop;
+ /**
+ * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output
+ * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend
+ * altering this or top_p but not both.
+ */
+ @JsonProperty("temperature") Float temperature;
+ /**
+ * An alternative to sampling with temperature, called nucleus sampling, where the model considers the
+ * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%
+ * probability mass are considered. We generally recommend altering this or temperature but not both.
+ */
+ @JsonProperty("top_p") Float topP;
+ /**
+ * A list of tools the model may call. Currently, only functions are supported as a tool. Use this to
+ * provide a list of functions the model may generate JSON inputs for.
+ */
+ @JsonProperty("tools") List tools;
+ /**
+ * Controls which (if any) function is called by the model. none means the model will not call a
+ * function and instead generates a message. auto means the model can pick between generating a message or calling a
+ * function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces
+ * the model to call that function. none is the default when no functions are present. auto is the default if
+ * functions are present.
+ */
+ @JsonProperty("tool_choice") ToolChoice toolChoice;
+ /**
+ * A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
+ */
+ @JsonProperty("user") String user;
+ // @formatter:on
+
+ /**
+ * Specifies a tool the model should use. Use to force the model to call a specific
+ * function.
+ *
+ * @param type The type of the tool. Currently, only 'function' is supported.
+ * @param function single field map for type 'name':'your function name'.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record ToolChoice(@JsonProperty("type") String type,
+ @JsonProperty("function") Map function) {
+
+ /**
+ * Create a tool choice of type 'function' and name 'functionName'.
+ * @param functionName Function name of the tool.
+ */
+ @ConstructorBinding
+ public ToolChoice(String functionName) {
+ this("function", Map.of("name", functionName));
+ }
+ }
+
+ /**
+ * An object specifying the format that the model must output.
+ *
+ * @param type Must be one of 'text' or 'json_object'.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record ResponseFormat(@JsonProperty("type") String type) {
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static class Builder {
+
+ protected OpenAiChatOptions options;
+
+ public Builder() {
+ this.options = new OpenAiChatOptions();
+ }
+
+ public Builder(OpenAiChatOptions options) {
+ this.options = options;
+ }
+
+ public Builder withModel(String model) {
+ this.options.model = model;
+ return this;
+ }
+
+ public Builder withFrequencyPenalty(Float frequencyPenalty) {
+ options.frequencyPenalty = frequencyPenalty;
+ return this;
+ }
+
+ public Builder withLogitBias(Map logitBias) {
+ options.logitBias = logitBias;
+ return this;
+ }
+
+ public Builder withMaxTokens(Integer maxTokens) {
+ options.maxTokens = maxTokens;
+ return this;
+ }
+
+ public Builder withN(Integer n) {
+ options.n = n;
+ return this;
+ }
+
+ public Builder withPresencePenalty(Float presencePenalty) {
+ options.presencePenalty = presencePenalty;
+ return this;
+ }
+
+ public Builder withResponseFormat(ResponseFormat responseFormat) {
+ options.responseFormat = responseFormat;
+ return this;
+ }
+
+ public Builder withSeed(Integer seed) {
+ options.seed = seed;
+ return this;
+ }
+
+ public Builder withStop(List stop) {
+ options.stop = stop;
+ return this;
+ }
+
+ public Builder withTemperature(Float temperature) {
+ options.temperature = temperature;
+ return this;
+ }
+
+ public Builder withTopP(Float topP) {
+ options.topP = topP;
+ return this;
+ }
+
+ public Builder withTools(List tools) {
+ options.tools = tools;
+ return this;
+ }
+
+ public Builder withToolChoice(ToolChoice toolChoice) {
+ options.toolChoice = toolChoice;
+ return this;
+ }
+
+ public Builder withUser(String user) {
+ options.user = user;
+ return this;
+ }
+
+ public OpenAiChatOptions build() {
+ return options;
+ }
+
+ }
+
+ public String getModel() {
+ return model;
+ }
+
+ public void setModel(String model) {
+ this.model = model;
+ }
+
+ public Float getFrequencyPenalty() {
+ return frequencyPenalty;
+ }
+
+ public void setFrequencyPenalty(Float frequencyPenalty) {
+ this.frequencyPenalty = frequencyPenalty;
+ }
+
+ public Map getLogitBias() {
+ return logitBias;
+ }
+
+ public void setLogitBias(Map logitBias) {
+ this.logitBias = logitBias;
+ }
+
+ public Integer getMaxTokens() {
+ return maxTokens;
+ }
+
+ public void setMaxTokens(Integer maxTokens) {
+ this.maxTokens = maxTokens;
+ }
+
+ public Integer getN() {
+ return n;
+ }
+
+ public void setN(Integer n) {
+ this.n = n;
+ }
+
+ public Float getPresencePenalty() {
+ return presencePenalty;
+ }
+
+ public void setPresencePenalty(Float presencePenalty) {
+ this.presencePenalty = presencePenalty;
+ }
+
+ public ResponseFormat getResponseFormat() {
+ return responseFormat;
+ }
+
+ public void setResponseFormat(ResponseFormat responseFormat) {
+ this.responseFormat = responseFormat;
+ }
+
+ public Integer getSeed() {
+ return seed;
+ }
+
+ public void setSeed(Integer seed) {
+ this.seed = seed;
+ }
+
+ public List getStop() {
+ return stop;
+ }
+
+ public void setStop(List stop) {
+ this.stop = stop;
+ }
+
+ public Float getTemperature() {
+ return temperature;
+ }
+
+ public void setTemperature(Float temperature) {
+ this.temperature = temperature;
+ }
+
+ public Float getTopP() {
+ return topP;
+ }
+
+ public void setTopP(Float topP) {
+ this.topP = topP;
+ }
+
+ public List getTools() {
+ return tools;
+ }
+
+ public void setTools(List tools) {
+ this.tools = tools;
+ }
+
+ public ToolChoice getToolChoice() {
+ return toolChoice;
+ }
+
+ public void setToolChoice(ToolChoice toolChoice) {
+ this.toolChoice = toolChoice;
+ }
+
+ public String getUser() {
+ return user;
+ }
+
+ public void setUser(String user) {
+ this.user = user;
+ }
+
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ int result = 1;
+ result = prime * result + ((model == null) ? 0 : model.hashCode());
+ result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode());
+ result = prime * result + ((logitBias == null) ? 0 : logitBias.hashCode());
+ result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode());
+ result = prime * result + ((n == null) ? 0 : n.hashCode());
+ result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode());
+ result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode());
+ result = prime * result + ((seed == null) ? 0 : seed.hashCode());
+ result = prime * result + ((stop == null) ? 0 : stop.hashCode());
+ result = prime * result + ((temperature == null) ? 0 : temperature.hashCode());
+ result = prime * result + ((topP == null) ? 0 : topP.hashCode());
+ result = prime * result + ((tools == null) ? 0 : tools.hashCode());
+ result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode());
+ result = prime * result + ((user == null) ? 0 : user.hashCode());
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj)
+ return true;
+ if (obj == null)
+ return false;
+ if (getClass() != obj.getClass())
+ return false;
+ OpenAiChatOptions other = (OpenAiChatOptions) obj;
+ if (model == null) {
+ if (other.model != null)
+ return false;
+ }
+ else if (!model.equals(other.model))
+ return false;
+ if (frequencyPenalty == null) {
+ if (other.frequencyPenalty != null)
+ return false;
+ }
+ else if (!frequencyPenalty.equals(other.frequencyPenalty))
+ return false;
+ if (logitBias == null) {
+ if (other.logitBias != null)
+ return false;
+ }
+ else if (!logitBias.equals(other.logitBias))
+ return false;
+ if (maxTokens == null) {
+ if (other.maxTokens != null)
+ return false;
+ }
+ else if (!maxTokens.equals(other.maxTokens))
+ return false;
+ if (n == null) {
+ if (other.n != null)
+ return false;
+ }
+ else if (!n.equals(other.n))
+ return false;
+ if (presencePenalty == null) {
+ if (other.presencePenalty != null)
+ return false;
+ }
+ else if (!presencePenalty.equals(other.presencePenalty))
+ return false;
+ if (responseFormat == null) {
+ if (other.responseFormat != null)
+ return false;
+ }
+ else if (!responseFormat.equals(other.responseFormat))
+ return false;
+ if (seed == null) {
+ if (other.seed != null)
+ return false;
+ }
+ else if (!seed.equals(other.seed))
+ return false;
+ if (stop == null) {
+ if (other.stop != null)
+ return false;
+ }
+ else if (!stop.equals(other.stop))
+ return false;
+ if (temperature == null) {
+ if (other.temperature != null)
+ return false;
+ }
+ else if (!temperature.equals(other.temperature))
+ return false;
+ if (topP == null) {
+ if (other.topP != null)
+ return false;
+ }
+ else if (!topP.equals(other.topP))
+ return false;
+ if (tools == null) {
+ if (other.tools != null)
+ return false;
+ }
+ else if (!tools.equals(other.tools))
+ return false;
+ if (toolChoice == null) {
+ if (other.toolChoice != null)
+ return false;
+ }
+ else if (!toolChoice.equals(other.toolChoice))
+ return false;
+ if (user == null) {
+ if (other.user != null)
+ return false;
+ }
+ else if (!user.equals(other.user))
+ return false;
+ return true;
+ }
+
+ @Override
+ @JsonIgnore
+ public Integer getTopK() {
+ // TODO Auto-generated method stub
+ throw new UnsupportedOperationException("Unimplemented method 'getTopK'");
+ }
+
+ @Override
+ @JsonIgnore
+ public void setTopK(Integer topK) {
+ // TODO Auto-generated method stub
+ throw new UnsupportedOperationException("Unimplemented method 'setTopK'");
+ }
+
+}
diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestIT.java
new file mode 100644
index 00000000000..4e67182f58b
--- /dev/null
+++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestIT.java
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.openai;
+
+import org.junit.jupiter.api.Test;
+
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.ai.openai.api.OpenAiChatOptions;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * @author Christian Tzolov
+ */
+public class ChatCompletionRequestIT {
+
+ @Test
+ public void chatOptions() {
+
+ var client = new OpenAiChatClient(new OpenAiApi("TEST"))
+ .withDefaultOptions(OpenAiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build());
+
+ var request = client.createRequest(new Prompt("Test message content"), false);
+
+ assertThat(request.getMessages()).hasSize(1);
+ assertThat(request.getStream()).isFalse();
+
+ assertThat(request.getModel()).isEqualTo("DEFAULT_MODEL");
+ assertThat(request.getTemperature()).isEqualTo(66.6f);
+
+ request = client.createRequest(new Prompt("Test message content",
+ OpenAiChatOptions.builder().withModel("PROMPT_MODEL").withTemperature(99.9f).build()), true);
+
+ assertThat(request.getMessages()).hasSize(1);
+ assertThat(request.getStream()).isTrue();
+
+ assertThat(request.getModel()).isEqualTo("PROMPT_MODEL");
+ assertThat(request.getTemperature()).isEqualTo(99.9f);
+ }
+
+}
diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java
index b14b0ba9391..d43a25dde1d 100644
--- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java
+++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java
@@ -32,7 +32,6 @@ private String getApiKey() {
@Bean
public OpenAiChatClient openAiChatClient(OpenAiApi api) {
OpenAiChatClient openAiChatClient = new OpenAiChatClient(api);
- openAiChatClient.setTemperature(0.3);
return openAiChatClient;
}
diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java
index 120b32a7be2..4f655320615 100644
--- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java
+++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java
@@ -166,7 +166,6 @@ public OpenAiApi openAiApi() throws IOException {
@Bean
public OpenAiChatClient openAiChatClient(OpenAiApi openAiApi) {
OpenAiChatClient openAiChatClient = new OpenAiChatClient(openAiApi);
- openAiChatClient.setTemperature(0.3);
return openAiChatClient;
}
diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java
index 85c47554ef7..717abbb4001 100644
--- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java
+++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java
@@ -24,7 +24,12 @@
import java.util.Map;
import java.util.stream.Collectors;
-public abstract class ModelOptionsUtils {
+/**
+ * Utility class for manipulating {@link ModelOptions} objects.
+ *
+ * @author Christian Tzolov
+ */
+public final class ModelOptionsUtils {
private final static ObjectMapper OBJECT_MAPPER = new ObjectMapper();
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java
index 669856a81bc..1e7edfba235 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java
@@ -54,9 +54,8 @@ public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProper
var openAiApi = new OpenAiApi(baseUrl, apiKey, RestClient.builder());
- OpenAiChatClient openAiChatClient = new OpenAiChatClient(openAiApi);
- openAiChatClient.setTemperature(chatProperties.getTemperature());
- openAiChatClient.setModel(chatProperties.getModel());
+ OpenAiChatClient openAiChatClient = new OpenAiChatClient(openAiApi)
+ .withDefaultOptions(chatProperties.getOptions());
return openAiChatClient;
}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java
index 9f3f387f153..b6e8608f104 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java
@@ -16,7 +16,9 @@
package org.springframework.ai.autoconfigure.openai;
+import org.springframework.ai.openai.api.OpenAiChatOptions;
import org.springframework.boot.context.properties.ConfigurationProperties;
+import org.springframework.boot.context.properties.NestedConfigurationProperty;
@ConfigurationProperties(OpenAiChatProperties.CONFIG_PREFIX)
public class OpenAiChatProperties extends OpenAiParentProperties {
@@ -25,24 +27,20 @@ public class OpenAiChatProperties extends OpenAiParentProperties {
public static final String DEFAULT_CHAT_MODEL = "gpt-3.5-turbo";
- private Double temperature = 0.7;
+ private static final Double DEFAULT_TEMPERATURE = 0.7;
- private String model = DEFAULT_CHAT_MODEL;
+ @NestedConfigurationProperty
+ private OpenAiChatOptions options = OpenAiChatOptions.builder()
+ .withModel(DEFAULT_CHAT_MODEL)
+ .withTemperature(DEFAULT_TEMPERATURE.floatValue())
+ .build();
- public String getModel() {
- return model;
+ public OpenAiChatOptions getOptions() {
+ return options;
}
- public void setModel(String model) {
- this.model = model;
- }
-
- public Double getTemperature() {
- return temperature;
- }
-
- public void setTemperature(Double temperature) {
- this.temperature = temperature;
+ public void setOptions(OpenAiChatOptions options) {
+ this.options = options;
}
}
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java
index bf8fc8208f4..0c5a51c2d47 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java
@@ -16,8 +16,13 @@
package org.springframework.ai.autoconfigure.openai;
+import java.util.Map;
+
import org.junit.jupiter.api.Test;
+import org.springframework.ai.openai.api.OpenAiApi.FunctionTool.Type;
+import org.springframework.ai.openai.api.OpenAiChatOptions.ResponseFormat;
+import org.springframework.ai.openai.api.OpenAiChatOptions.ToolChoice;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
@@ -39,8 +44,8 @@ public void chatProperties() {
// @formatter:off
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.api-key=abc123",
- "spring.ai.openai.chat.model=MODEL_XYZ",
- "spring.ai.openai.chat.temperature=0.55")
+ "spring.ai.openai.chat.options.model=MODEL_XYZ",
+ "spring.ai.openai.chat.options.temperature=0.55")
// @formatter:on
.withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class))
.run(context -> {
@@ -53,8 +58,8 @@ public void chatProperties() {
assertThat(chatProperties.getApiKey()).isNull();
assertThat(chatProperties.getBaseUrl()).isNull();
- assertThat(chatProperties.getModel()).isEqualTo("MODEL_XYZ");
- assertThat(chatProperties.getTemperature()).isEqualTo(0.55);
+ assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
+ assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f);
});
}
@@ -67,8 +72,8 @@ public void chatOverrideConnectionProperties() {
"spring.ai.openai.api-key=abc123",
"spring.ai.openai.chat.base-url=TEST_BASE_URL2",
"spring.ai.openai.chat.api-key=456",
- "spring.ai.openai.chat.model=MODEL_XYZ",
- "spring.ai.openai.chat.temperature=0.55")
+ "spring.ai.openai.chat.options.model=MODEL_XYZ",
+ "spring.ai.openai.chat.options.temperature=0.55")
// @formatter:on
.withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class))
.run(context -> {
@@ -81,8 +86,8 @@ public void chatOverrideConnectionProperties() {
assertThat(chatProperties.getApiKey()).isEqualTo("456");
assertThat(chatProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2");
- assertThat(chatProperties.getModel()).isEqualTo("MODEL_XYZ");
- assertThat(chatProperties.getTemperature()).isEqualTo(0.55);
+ assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
+ assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f);
});
}
@@ -136,4 +141,92 @@ public void embeddingOverrideConnectionProperties() {
});
}
+ @Test
+ public void optionsTest() {
+
+ new ApplicationContextRunner().withPropertyValues(
+ // @formatter:off
+ "spring.ai.openai.api-key=API_KEY",
+ "spring.ai.openai.base-url=TEST_BASE_URL",
+
+ "spring.ai.openai.chat.options.model=MODEL_XYZ",
+ "spring.ai.openai.chat.options.frequencyPenalty=-1.5",
+ "spring.ai.openai.chat.options.logitBias.myTokenId=-5",
+ "spring.ai.openai.chat.options.maxTokens=123",
+ "spring.ai.openai.chat.options.n=10",
+ "spring.ai.openai.chat.options.presencePenalty=0",
+ "spring.ai.openai.chat.options.responseFormat.type=json",
+ "spring.ai.openai.chat.options.seed=66",
+ "spring.ai.openai.chat.options.stop=boza,koza",
+ "spring.ai.openai.chat.options.temperature=0.55",
+ "spring.ai.openai.chat.options.topP=0.56",
+
+ "spring.ai.openai.chat.options.toolChoice.functionName=toolChoiceFunctionName",
+
+ "spring.ai.openai.chat.options.tools[0].function.name=myFunction1",
+ "spring.ai.openai.chat.options.tools[0].function.description=function description",
+ "spring.ai.openai.chat.options.tools[0].function.jsonSchema=" + """
+ {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state e.g. San Francisco, CA"
+ },
+ "lat": {
+ "type": "number",
+ "description": "The city latitude"
+ },
+ "lon": {
+ "type": "number",
+ "description": "The city longitude"
+ },
+ "unit": {
+ "type": "string",
+ "enum": ["c", "f"]
+ }
+ },
+ "required": ["location", "lat", "lon", "unit"]
+ }
+ """,
+ "spring.ai.openai.chat.options.user=userXYZ"
+ )
+ // @formatter:on
+ .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class))
+ .run(context -> {
+ var chatProperties = context.getBean(OpenAiChatProperties.class);
+ var connectionProperties = context.getBean(OpenAiConnectionProperties.class);
+ var embeddingProperties = context.getBean(OpenAiEmbeddingProperties.class);
+
+ assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
+ assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY");
+
+ assertThat(embeddingProperties.getModel()).isEqualTo("text-embedding-ada-002");
+
+ assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
+ assertThat(chatProperties.getOptions().getFrequencyPenalty()).isEqualTo(-1.5f);
+ assertThat(chatProperties.getOptions().getLogitBias().get("myTokenId")).isEqualTo(-5);
+ assertThat(chatProperties.getOptions().getMaxTokens()).isEqualTo(123);
+ assertThat(chatProperties.getOptions().getN()).isEqualTo(10);
+ assertThat(chatProperties.getOptions().getPresencePenalty()).isEqualTo(0);
+ assertThat(chatProperties.getOptions().getResponseFormat()).isEqualTo(new ResponseFormat("json"));
+ assertThat(chatProperties.getOptions().getSeed()).isEqualTo(66);
+ assertThat(chatProperties.getOptions().getStop()).contains("boza", "koza");
+ assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f);
+ assertThat(chatProperties.getOptions().getTopP()).isEqualTo(0.56f);
+
+ assertThat(chatProperties.getOptions().getToolChoice())
+ .isEqualTo(new ToolChoice("function", Map.of("name", "toolChoiceFunctionName")));
+ assertThat(chatProperties.getOptions().getUser()).isEqualTo("userXYZ");
+
+ assertThat(chatProperties.getOptions().getTools()).hasSize(1);
+ var tool = chatProperties.getOptions().getTools().get(0);
+ assertThat(tool.type()).isEqualTo(Type.FUNCTION);
+ var function = tool.function();
+ assertThat(function.name()).isEqualTo("myFunction1");
+ assertThat(function.description()).isEqualTo("function description");
+ assertThat(function.parameters()).isNotEmpty();
+ });
+ }
+
}