diff --git a/models/spring-ai-openai/pom.xml b/models/spring-ai-openai/pom.xml index ad238b9a55f..8e0279ff969 100644 --- a/models/spring-ai-openai/pom.xml +++ b/models/spring-ai-openai/pom.xml @@ -34,6 +34,12 @@ 2.0.4 + + + org.springframework.boot + spring-boot + + io.rest-assured json-path diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java index 3e18671e0ca..15e24f580ba 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java @@ -28,14 +28,16 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.StreamingChatClient; -import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException; +import org.springframework.ai.openai.api.OpenAiChatOptions; import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata; import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; import org.springframework.http.ResponseEntity; @@ -57,12 +59,13 @@ */ public class OpenAiChatClient implements ChatClient, StreamingChatClient { - private Double temperature = 0.7; - - private String model = "gpt-3.5-turbo"; - private final Logger logger = LoggerFactory.getLogger(getClass()); + private OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder() + .withModel("gpt-3.5-turbo") + .withTemperature(0.7f) + .build(); + public final RetryTemplate retryTemplate = RetryTemplate.builder() .maxAttempts(10) .retryOn(OpenAiApiException.class) @@ -76,40 +79,46 @@ public OpenAiChatClient(OpenAiApi openAiApi) { this.openAiApi = openAiApi; } - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - public Double getTemperature() { - return this.temperature; - } - - public void setTemperature(Double temperature) { - this.temperature = temperature; + public OpenAiChatClient withDefaultOptions(OpenAiChatOptions options) { + this.defaultOptions = options; + return this; } @Override public ChatResponse call(Prompt prompt) { return this.retryTemplate.execute(ctx -> { - List messages = prompt.getInstructions(); - List chatCompletionMessages = messages.stream() - .map(m -> new ChatCompletionMessage(m.getContent(), - ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))) - .toList(); + ChatCompletionRequest request = createRequest(prompt, false); + + // List messages = prompt.getInstructions(); + + // List chatCompletionMessages = messages.stream() + // .map(m -> new ChatCompletionMessage(m.getContent(), + // ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))) + // .toList(); + + // ChatCompletionRequest request = + // ChatCompletionRequest.from(chatCompletionMessages, this.defaultOptions, + // false); - ResponseEntity completionEntity = this.openAiApi - .chatCompletionEntity(new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, this.model, - this.temperature.floatValue())); + // if (prompt.getOptions() != null) { + // if (prompt.getOptions() instanceof OpenAiChatOptions runtimeOptions) { + // request = ModelOptionsUtils.merge(runtimeOptions, request, + // ChatCompletionRequest.class); + // } + // else { + // throw new IllegalArgumentException("Prompt options are not of type + // ChatCompletionRequest:" + // + prompt.getOptions().getClass().getSimpleName()); + // } + // } + + ResponseEntity completionEntity = this.openAiApi.chatCompletionEntity(request); var chatCompletion = completionEntity.getBody(); if (chatCompletion == null) { - logger.warn("No chat completion returned for request: {}", chatCompletionMessages); + logger.warn("No chat completion returned for request: {}", prompt); return new ChatResponse(List.of()); } @@ -128,16 +137,32 @@ public ChatResponse call(Prompt prompt) { @Override public Flux stream(Prompt prompt) { return this.retryTemplate.execute(ctx -> { - List messages = prompt.getInstructions(); + ChatCompletionRequest request = createRequest(prompt, true); + + // List messages = prompt.getInstructions(); + + // List chatCompletionMessages = messages.stream() + // .map(m -> new ChatCompletionMessage(m.getContent(), + // ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))) + // .toList(); - List chatCompletionMessages = messages.stream() - .map(m -> new ChatCompletionMessage(m.getContent(), - ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))) - .toList(); + // ChatCompletionRequest request = + // ChatCompletionRequest.from(chatCompletionMessages, this.defaultOptions, + // true); - Flux completionChunks = this.openAiApi - .chatCompletionStream(new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, this.model, - this.temperature.floatValue(), true)); + // if (prompt.getOptions() != null) { + // if (prompt.getOptions() instanceof OpenAiChatOptions runtimeOptions) { + // request = ModelOptionsUtils.merge(runtimeOptions, request, + // ChatCompletionRequest.class); + // } + // else { + // throw new IllegalArgumentException("Prompt options are not of type + // ChatCompletionRequest:" + // + prompt.getOptions().getClass().getSimpleName()); + // } + // } + + Flux completionChunks = this.openAiApi.chatCompletionStream(request); // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. @@ -161,4 +186,30 @@ public Flux stream(Prompt prompt) { }); } + /** + * Accessible for testing. + */ + ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + + List chatCompletionMessages = prompt.getInstructions() + .stream() + .map(m -> new ChatCompletionMessage(m.getContent(), + ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))) + .toList(); + + ChatCompletionRequest request = ChatCompletionRequest.from(chatCompletionMessages, this.defaultOptions, stream); + + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof OpenAiChatOptions runtimeOptions) { + request = ModelOptionsUtils.merge(runtimeOptions, request, ChatCompletionRequest.class); + } + else { + throw new IllegalArgumentException("Prompt options are not of type ChatCompletionRequest:" + + prompt.getOptions().getClass().getSimpleName()); + } + } + + return request; + } + } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 989bd302fea..fec77a0107d 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -31,6 +31,8 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.boot.context.properties.bind.ConstructorBinding; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; @@ -183,6 +185,7 @@ public record FunctionTool( * Create a tool of type 'function' and the given function definition. * @param function function definition. */ + @ConstructorBinding public FunctionTool(Function function) { this(Type.FUNCTION, function); } @@ -219,146 +222,57 @@ public record Function( * @param name tool function name. * @param jsonSchema tool function schema as json. */ + @ConstructorBinding public Function(String description, String name, String jsonSchema) { this(description, name, parseJson(jsonSchema)); } } } - /** - * Creates a model response for the given chat conversation. - * - * @param messages A list of messages comprising the conversation so far. - * @param model ID of the model to use. - * @param frequencyPenalty Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing - * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. - * @param logitBias Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object - * that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. - * Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will - * vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 - * or 100 should result in a ban or exclusive selection of the relevant token. - * @param maxTokens The maximum number of tokens to generate in the chat completion. The total length of input - * tokens and generated tokens is limited by the model's context length. - * @param n How many chat completion choices to generate for each input message. Note that you will be charged based - * on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs. - * @param presencePenalty Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they - * appear in the text so far, increasing the model's likelihood to talk about new topics. - * @param responseFormat An object specifying the format that the model must output. Setting to { "type": - * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. - * @param seed This feature is in Beta. If specified, our system will make a best effort to sample - * deterministically, such that repeated requests with the same seed and parameters should return the same result. - * Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor - * changes in the backend. - * @param stop Up to 4 sequences where the API will stop generating further tokens. - * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events as - * they become available, with the stream terminated by a data: [DONE] message. - * @param temperature What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output - * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend - * altering this or top_p but not both. - * @param topP An alternative to sampling with temperature, called nucleus sampling, where the model considers the - * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% - * probability mass are considered. We generally recommend altering this or temperature but not both. - * @param tools A list of tools the model may call. Currently, only functions are supported as a tool. Use this to - * provide a list of functions the model may generate JSON inputs for. - * @param toolChoice Controls which (if any) function is called by the model. none means the model will not call a - * function and instead generates a message. auto means the model can pick between generating a message or calling a - * function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces - * the model to call that function. none is the default when no functions are present. auto is the default if - * functions are present. - * @param user A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. - * - */ @JsonInclude(Include.NON_NULL) - public record ChatCompletionRequest( - @JsonProperty("messages") List messages, - @JsonProperty("model") String model, - @JsonProperty("frequency_penalty") Float frequencyPenalty, - @JsonProperty("logit_bias") Map logitBias, - @JsonProperty("max_tokens") Integer maxTokens, - @JsonProperty("n") Integer n, - @JsonProperty("presence_penalty") Float presencePenalty, - @JsonProperty("response_format") ResponseFormat responseFormat, - @JsonProperty("seed") Integer seed, - @JsonProperty("stop") String stop, - @JsonProperty("stream") Boolean stream, - @JsonProperty("temperature") Float temperature, - @JsonProperty("top_p") Float topP, - @JsonProperty("tools") List tools, - @JsonProperty("tool_choice") ToolChoice toolChoice, - @JsonProperty("user") String user) { + public static class ChatCompletionRequest extends OpenAiChatOptions { /** - * Shortcut constructor for a chat completion request with the given messages and model. - * - * @param messages A list of messages comprising the conversation so far. - * @param model ID of the model to use. - * @param temperature What sampling temperature to use, between 0 and 1. + * A list of chat completion messages. */ - public ChatCompletionRequest(List messages, String model, Float temperature) { - this(messages, model, 0.0f, null, null, 1, 0.0f, - null, null, null, false, temperature, null, - null, null, null); - } + private @JsonProperty("messages") List messages; /** - * Shortcut constructor for a chat completion request with the given messages, model and control for streaming. - * - * @param messages A list of messages comprising the conversation so far. - * @param model ID of the model to use. - * @param temperature What sampling temperature to use, between 0 and 1. - * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events - * as they become available, with the stream terminated by a data: [DONE] message. + * If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events as + * they become available, with the stream terminated by a data: [DONE] message. */ - public ChatCompletionRequest(List messages, String model, Float temperature, boolean stream) { - this(messages, model, 0.0f, null, null, 1, 0.0f, - null, null, null, stream, temperature, null, - null, null, null); + private @JsonProperty("stream") Boolean stream = false; + + + public static ChatCompletionRequest from(List messages, OpenAiChatOptions options, boolean stream) { + ChatCompletionRequest request = new ChatCompletionRequest(); + request.setMessages(messages); + request.setStream(stream); + + if ( options != null) { + request = ModelOptionsUtils.merge(request, options, ChatCompletionRequest.class); + } + + return request; } - /** - * Shortcut constructor for a chat completion request with the given messages, model, tools and tool choice. - * Streaming is set to false, temperature to 0.8 and all other parameters are null. - * - * @param messages A list of messages comprising the conversation so far. - * @param model ID of the model to use. - * @param tools A list of tools the model may call. Currently, only functions are supported as a tool. - * @param toolChoice Controls which (if any) function is called by the model. - */ - public ChatCompletionRequest(List messages, String model, - List tools, ToolChoice toolChoice) { - this(messages, model, 0.0f, null, null, 1, 0.0f, - null, null, null, false, 0.8f, null, - tools, toolChoice, null); + public List getMessages() { + return messages; } - /** - * Specifies a tool the model should use. Use to force the model to call a specific function. - * - * @param type The type of the tool. Currently, only 'function' is supported. - * @param function single field map for type 'name':'your function name'. - */ - @JsonInclude(Include.NON_NULL) - public record ToolChoice( - @JsonProperty("type") String type, - @JsonProperty("function") Map function) { + public void setMessages(List messages) { + this.messages = messages; + } - /** - * Create a tool choice of type 'function' and name 'functionName'. - * @param functionName Function name of the tool. - */ - public ToolChoice(String functionName) { - this("function", Map.of("name", functionName)); - } + public Boolean getStream() { + return stream; } - /** - * An object specifying the format that the model must output. - * @param type Must be one of 'text' or 'json_object'. - */ - @JsonInclude(Include.NON_NULL) - public record ResponseFormat( - @JsonProperty("type") String type) { + public void setStream(Boolean stream) { + this.stream = stream; } + + } /** @@ -621,7 +535,7 @@ public record ChunkChoice( public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(!chatRequest.stream(), "Request must set the steam property to false."); + Assert.isTrue(!chatRequest.getStream(), "Request must set the steam property to false."); return this.restClient.post() .uri("/v1/chat/completions") @@ -639,7 +553,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(chatRequest.stream(), "Request must set the steam property to true."); + Assert.isTrue(chatRequest.getStream(), "Request must set the steam property to true."); return this.webClient.post() .uri("/v1/chat/completions") diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiChatOptions.java new file mode 100644 index 00000000000..77e3a3754a5 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiChatOptions.java @@ -0,0 +1,483 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.api; + +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.chat.ChatOptions; +import org.springframework.ai.openai.api.OpenAiApi.FunctionTool; +import org.springframework.boot.context.properties.bind.ConstructorBinding; + +/** + * @author Christian Tzolov + * @since 0.8.0 + */ +@JsonInclude(Include.NON_NULL) +public class OpenAiChatOptions implements ChatOptions { + + // @formatter:off + /** + * ID of the model to use. + */ + @JsonProperty("model") String model; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + */ + @JsonProperty("frequency_penalty") Float frequencyPenalty; + /** + * Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object + * that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. + * Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will + * vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 + * or 100 should result in a ban or exclusive selection of the relevant token. + */ + @JsonProperty("logit_bias") Map logitBias; + /** + * The maximum number of tokens to generate in the chat completion. The total length of input + * tokens and generated tokens is limited by the model's context length. + */ + @JsonProperty("max_tokens") Integer maxTokens; + /** + * How many chat completion choices to generate for each input message. Note that you will be charged based + * on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs. + */ + @JsonProperty("n") Integer n; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + * appear in the text so far, increasing the model's likelihood to talk about new topics. + */ + @JsonProperty("presence_penalty") Float presencePenalty; + /** + * An object specifying the format that the model must output. Setting to { "type": + * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. + */ + @JsonProperty("response_format") ResponseFormat responseFormat; + /** + * This feature is in Beta. If specified, our system will make a best effort to sample + * deterministically, such that repeated requests with the same seed and parameters should return the same result. + * Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor + * changes in the backend. + */ + @JsonProperty("seed") Integer seed; + /** + * Up to 4 sequences where the API will stop generating further tokens. + */ + @JsonProperty("stop") List stop; + /** + * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output + * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend + * altering this or top_p but not both. + */ + @JsonProperty("temperature") Float temperature; + /** + * An alternative to sampling with temperature, called nucleus sampling, where the model considers the + * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + * probability mass are considered. We generally recommend altering this or temperature but not both. + */ + @JsonProperty("top_p") Float topP; + /** + * A list of tools the model may call. Currently, only functions are supported as a tool. Use this to + * provide a list of functions the model may generate JSON inputs for. + */ + @JsonProperty("tools") List tools; + /** + * Controls which (if any) function is called by the model. none means the model will not call a + * function and instead generates a message. auto means the model can pick between generating a message or calling a + * function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces + * the model to call that function. none is the default when no functions are present. auto is the default if + * functions are present. + */ + @JsonProperty("tool_choice") ToolChoice toolChoice; + /** + * A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + */ + @JsonProperty("user") String user; + // @formatter:on + + /** + * Specifies a tool the model should use. Use to force the model to call a specific + * function. + * + * @param type The type of the tool. Currently, only 'function' is supported. + * @param function single field map for type 'name':'your function name'. + */ + @JsonInclude(Include.NON_NULL) + public record ToolChoice(@JsonProperty("type") String type, + @JsonProperty("function") Map function) { + + /** + * Create a tool choice of type 'function' and name 'functionName'. + * @param functionName Function name of the tool. + */ + @ConstructorBinding + public ToolChoice(String functionName) { + this("function", Map.of("name", functionName)); + } + } + + /** + * An object specifying the format that the model must output. + * + * @param type Must be one of 'text' or 'json_object'. + */ + @JsonInclude(Include.NON_NULL) + public record ResponseFormat(@JsonProperty("type") String type) { + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + protected OpenAiChatOptions options; + + public Builder() { + this.options = new OpenAiChatOptions(); + } + + public Builder(OpenAiChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withFrequencyPenalty(Float frequencyPenalty) { + options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withLogitBias(Map logitBias) { + options.logitBias = logitBias; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + options.maxTokens = maxTokens; + return this; + } + + public Builder withN(Integer n) { + options.n = n; + return this; + } + + public Builder withPresencePenalty(Float presencePenalty) { + options.presencePenalty = presencePenalty; + return this; + } + + public Builder withResponseFormat(ResponseFormat responseFormat) { + options.responseFormat = responseFormat; + return this; + } + + public Builder withSeed(Integer seed) { + options.seed = seed; + return this; + } + + public Builder withStop(List stop) { + options.stop = stop; + return this; + } + + public Builder withTemperature(Float temperature) { + options.temperature = temperature; + return this; + } + + public Builder withTopP(Float topP) { + options.topP = topP; + return this; + } + + public Builder withTools(List tools) { + options.tools = tools; + return this; + } + + public Builder withToolChoice(ToolChoice toolChoice) { + options.toolChoice = toolChoice; + return this; + } + + public Builder withUser(String user) { + options.user = user; + return this; + } + + public OpenAiChatOptions build() { + return options; + } + + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public Float getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(Float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public Map getLogitBias() { + return logitBias; + } + + public void setLogitBias(Map logitBias) { + this.logitBias = logitBias; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Integer getN() { + return n; + } + + public void setN(Integer n) { + this.n = n; + } + + public Float getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(Float presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public ResponseFormat getResponseFormat() { + return responseFormat; + } + + public void setResponseFormat(ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + public Integer getSeed() { + return seed; + } + + public void setSeed(Integer seed) { + this.seed = seed; + } + + public List getStop() { + return stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + public Float getTemperature() { + return temperature; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + public Float getTopP() { + return topP; + } + + public void setTopP(Float topP) { + this.topP = topP; + } + + public List getTools() { + return tools; + } + + public void setTools(List tools) { + this.tools = tools; + } + + public ToolChoice getToolChoice() { + return toolChoice; + } + + public void setToolChoice(ToolChoice toolChoice) { + this.toolChoice = toolChoice; + } + + public String getUser() { + return user; + } + + public void setUser(String user) { + this.user = user; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((model == null) ? 0 : model.hashCode()); + result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode()); + result = prime * result + ((logitBias == null) ? 0 : logitBias.hashCode()); + result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); + result = prime * result + ((n == null) ? 0 : n.hashCode()); + result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode()); + result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); + result = prime * result + ((seed == null) ? 0 : seed.hashCode()); + result = prime * result + ((stop == null) ? 0 : stop.hashCode()); + result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); + result = prime * result + ((topP == null) ? 0 : topP.hashCode()); + result = prime * result + ((tools == null) ? 0 : tools.hashCode()); + result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); + result = prime * result + ((user == null) ? 0 : user.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + OpenAiChatOptions other = (OpenAiChatOptions) obj; + if (model == null) { + if (other.model != null) + return false; + } + else if (!model.equals(other.model)) + return false; + if (frequencyPenalty == null) { + if (other.frequencyPenalty != null) + return false; + } + else if (!frequencyPenalty.equals(other.frequencyPenalty)) + return false; + if (logitBias == null) { + if (other.logitBias != null) + return false; + } + else if (!logitBias.equals(other.logitBias)) + return false; + if (maxTokens == null) { + if (other.maxTokens != null) + return false; + } + else if (!maxTokens.equals(other.maxTokens)) + return false; + if (n == null) { + if (other.n != null) + return false; + } + else if (!n.equals(other.n)) + return false; + if (presencePenalty == null) { + if (other.presencePenalty != null) + return false; + } + else if (!presencePenalty.equals(other.presencePenalty)) + return false; + if (responseFormat == null) { + if (other.responseFormat != null) + return false; + } + else if (!responseFormat.equals(other.responseFormat)) + return false; + if (seed == null) { + if (other.seed != null) + return false; + } + else if (!seed.equals(other.seed)) + return false; + if (stop == null) { + if (other.stop != null) + return false; + } + else if (!stop.equals(other.stop)) + return false; + if (temperature == null) { + if (other.temperature != null) + return false; + } + else if (!temperature.equals(other.temperature)) + return false; + if (topP == null) { + if (other.topP != null) + return false; + } + else if (!topP.equals(other.topP)) + return false; + if (tools == null) { + if (other.tools != null) + return false; + } + else if (!tools.equals(other.tools)) + return false; + if (toolChoice == null) { + if (other.toolChoice != null) + return false; + } + else if (!toolChoice.equals(other.toolChoice)) + return false; + if (user == null) { + if (other.user != null) + return false; + } + else if (!user.equals(other.user)) + return false; + return true; + } + + @Override + @JsonIgnore + public Integer getTopK() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'getTopK'"); + } + + @Override + @JsonIgnore + public void setTopK(Integer topK) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'setTopK'"); + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestIT.java new file mode 100644 index 00000000000..4e67182f58b --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestIT.java @@ -0,0 +1,56 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiChatOptions; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + */ +public class ChatCompletionRequestIT { + + @Test + public void chatOptions() { + + var client = new OpenAiChatClient(new OpenAiApi("TEST")) + .withDefaultOptions(OpenAiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build()); + + var request = client.createRequest(new Prompt("Test message content"), false); + + assertThat(request.getMessages()).hasSize(1); + assertThat(request.getStream()).isFalse(); + + assertThat(request.getModel()).isEqualTo("DEFAULT_MODEL"); + assertThat(request.getTemperature()).isEqualTo(66.6f); + + request = client.createRequest(new Prompt("Test message content", + OpenAiChatOptions.builder().withModel("PROMPT_MODEL").withTemperature(99.9f).build()), true); + + assertThat(request.getMessages()).hasSize(1); + assertThat(request.getStream()).isTrue(); + + assertThat(request.getModel()).isEqualTo("PROMPT_MODEL"); + assertThat(request.getTemperature()).isEqualTo(99.9f); + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index b14b0ba9391..d43a25dde1d 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -32,7 +32,6 @@ private String getApiKey() { @Bean public OpenAiChatClient openAiChatClient(OpenAiApi api) { OpenAiChatClient openAiChatClient = new OpenAiChatClient(api); - openAiChatClient.setTemperature(0.3); return openAiChatClient; } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java index 120b32a7be2..4f655320615 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java @@ -166,7 +166,6 @@ public OpenAiApi openAiApi() throws IOException { @Bean public OpenAiChatClient openAiChatClient(OpenAiApi openAiApi) { OpenAiChatClient openAiChatClient = new OpenAiChatClient(openAiApi); - openAiChatClient.setTemperature(0.3); return openAiChatClient; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java index 85c47554ef7..717abbb4001 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -24,7 +24,12 @@ import java.util.Map; import java.util.stream.Collectors; -public abstract class ModelOptionsUtils { +/** + * Utility class for manipulating {@link ModelOptions} objects. + * + * @author Christian Tzolov + */ +public final class ModelOptionsUtils { private final static ObjectMapper OBJECT_MAPPER = new ObjectMapper(); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index 669856a81bc..1e7edfba235 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -54,9 +54,8 @@ public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProper var openAiApi = new OpenAiApi(baseUrl, apiKey, RestClient.builder()); - OpenAiChatClient openAiChatClient = new OpenAiChatClient(openAiApi); - openAiChatClient.setTemperature(chatProperties.getTemperature()); - openAiChatClient.setModel(chatProperties.getModel()); + OpenAiChatClient openAiChatClient = new OpenAiChatClient(openAiApi) + .withDefaultOptions(chatProperties.getOptions()); return openAiChatClient; } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java index 9f3f387f153..b6e8608f104 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java @@ -16,7 +16,9 @@ package org.springframework.ai.autoconfigure.openai; +import org.springframework.ai.openai.api.OpenAiChatOptions; import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; @ConfigurationProperties(OpenAiChatProperties.CONFIG_PREFIX) public class OpenAiChatProperties extends OpenAiParentProperties { @@ -25,24 +27,20 @@ public class OpenAiChatProperties extends OpenAiParentProperties { public static final String DEFAULT_CHAT_MODEL = "gpt-3.5-turbo"; - private Double temperature = 0.7; + private static final Double DEFAULT_TEMPERATURE = 0.7; - private String model = DEFAULT_CHAT_MODEL; + @NestedConfigurationProperty + private OpenAiChatOptions options = OpenAiChatOptions.builder() + .withModel(DEFAULT_CHAT_MODEL) + .withTemperature(DEFAULT_TEMPERATURE.floatValue()) + .build(); - public String getModel() { - return model; + public OpenAiChatOptions getOptions() { + return options; } - public void setModel(String model) { - this.model = model; - } - - public Double getTemperature() { - return temperature; - } - - public void setTemperature(Double temperature) { - this.temperature = temperature; + public void setOptions(OpenAiChatOptions options) { + this.options = options; } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java index bf8fc8208f4..0c5a51c2d47 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java @@ -16,8 +16,13 @@ package org.springframework.ai.autoconfigure.openai; +import java.util.Map; + import org.junit.jupiter.api.Test; +import org.springframework.ai.openai.api.OpenAiApi.FunctionTool.Type; +import org.springframework.ai.openai.api.OpenAiChatOptions.ResponseFormat; +import org.springframework.ai.openai.api.OpenAiChatOptions.ToolChoice; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -39,8 +44,8 @@ public void chatProperties() { // @formatter:off "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.api-key=abc123", - "spring.ai.openai.chat.model=MODEL_XYZ", - "spring.ai.openai.chat.temperature=0.55") + "spring.ai.openai.chat.options.model=MODEL_XYZ", + "spring.ai.openai.chat.options.temperature=0.55") // @formatter:on .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)) .run(context -> { @@ -53,8 +58,8 @@ public void chatProperties() { assertThat(chatProperties.getApiKey()).isNull(); assertThat(chatProperties.getBaseUrl()).isNull(); - assertThat(chatProperties.getModel()).isEqualTo("MODEL_XYZ"); - assertThat(chatProperties.getTemperature()).isEqualTo(0.55); + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); }); } @@ -67,8 +72,8 @@ public void chatOverrideConnectionProperties() { "spring.ai.openai.api-key=abc123", "spring.ai.openai.chat.base-url=TEST_BASE_URL2", "spring.ai.openai.chat.api-key=456", - "spring.ai.openai.chat.model=MODEL_XYZ", - "spring.ai.openai.chat.temperature=0.55") + "spring.ai.openai.chat.options.model=MODEL_XYZ", + "spring.ai.openai.chat.options.temperature=0.55") // @formatter:on .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)) .run(context -> { @@ -81,8 +86,8 @@ public void chatOverrideConnectionProperties() { assertThat(chatProperties.getApiKey()).isEqualTo("456"); assertThat(chatProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); - assertThat(chatProperties.getModel()).isEqualTo("MODEL_XYZ"); - assertThat(chatProperties.getTemperature()).isEqualTo(0.55); + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); }); } @@ -136,4 +141,92 @@ public void embeddingOverrideConnectionProperties() { }); } + @Test + public void optionsTest() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.openai.api-key=API_KEY", + "spring.ai.openai.base-url=TEST_BASE_URL", + + "spring.ai.openai.chat.options.model=MODEL_XYZ", + "spring.ai.openai.chat.options.frequencyPenalty=-1.5", + "spring.ai.openai.chat.options.logitBias.myTokenId=-5", + "spring.ai.openai.chat.options.maxTokens=123", + "spring.ai.openai.chat.options.n=10", + "spring.ai.openai.chat.options.presencePenalty=0", + "spring.ai.openai.chat.options.responseFormat.type=json", + "spring.ai.openai.chat.options.seed=66", + "spring.ai.openai.chat.options.stop=boza,koza", + "spring.ai.openai.chat.options.temperature=0.55", + "spring.ai.openai.chat.options.topP=0.56", + + "spring.ai.openai.chat.options.toolChoice.functionName=toolChoiceFunctionName", + + "spring.ai.openai.chat.options.tools[0].function.name=myFunction1", + "spring.ai.openai.chat.options.tools[0].function.description=function description", + "spring.ai.openai.chat.options.tools[0].function.jsonSchema=" + """ + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "lat": { + "type": "number", + "description": "The city latitude" + }, + "lon": { + "type": "number", + "description": "The city longitude" + }, + "unit": { + "type": "string", + "enum": ["c", "f"] + } + }, + "required": ["location", "lat", "lon", "unit"] + } + """, + "spring.ai.openai.chat.options.user=userXYZ" + ) + // @formatter:on + .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(OpenAiChatProperties.class); + var connectionProperties = context.getBean(OpenAiConnectionProperties.class); + var embeddingProperties = context.getBean(OpenAiEmbeddingProperties.class); + + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + + assertThat(embeddingProperties.getModel()).isEqualTo("text-embedding-ada-002"); + + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getFrequencyPenalty()).isEqualTo(-1.5f); + assertThat(chatProperties.getOptions().getLogitBias().get("myTokenId")).isEqualTo(-5); + assertThat(chatProperties.getOptions().getMaxTokens()).isEqualTo(123); + assertThat(chatProperties.getOptions().getN()).isEqualTo(10); + assertThat(chatProperties.getOptions().getPresencePenalty()).isEqualTo(0); + assertThat(chatProperties.getOptions().getResponseFormat()).isEqualTo(new ResponseFormat("json")); + assertThat(chatProperties.getOptions().getSeed()).isEqualTo(66); + assertThat(chatProperties.getOptions().getStop()).contains("boza", "koza"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + assertThat(chatProperties.getOptions().getTopP()).isEqualTo(0.56f); + + assertThat(chatProperties.getOptions().getToolChoice()) + .isEqualTo(new ToolChoice("function", Map.of("name", "toolChoiceFunctionName"))); + assertThat(chatProperties.getOptions().getUser()).isEqualTo("userXYZ"); + + assertThat(chatProperties.getOptions().getTools()).hasSize(1); + var tool = chatProperties.getOptions().getTools().get(0); + assertThat(tool.type()).isEqualTo(Type.FUNCTION); + var function = tool.function(); + assertThat(function.name()).isEqualTo("myFunction1"); + assertThat(function.description()).isEqualTo("function description"); + assertThat(function.parameters()).isNotEmpty(); + }); + } + }