From 21d4032f40526ddcd7b50b0c16b6d584c0a3752c Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sat, 27 Jan 2024 21:48:00 +0100 Subject: [PATCH 1/6] Add OpenAi Chat Options (2) - Add OpenAiChatOptions that implements ChatOptions and exposes all OpenAi request options, except messages and stream. - Add OpenAiChatOptions field (as defaultOptions) to OpenAiChatClient. Implement star-up/runtime options merging on chat request creation - Add OpenAiChatOptions options field to OpenAiChatProperties. Later is set as OpenAiChatClient#defaultOptions. Use the spring.ai.openai.chat.options.* prefix to set the options. - Add tests for properties and options merging. Part of #228 --- models/spring-ai-openai/pom.xml | 6 + .../ai/openai/OpenAiChatClient.java | 86 ++-- .../ai/openai/api/OpenAiApi.java | 26 +- .../ai/openai/api/OpenAiChatOptions.java | 454 ++++++++++++++++++ .../ai/openai/ChatCompletionRequestTests.java | 56 +++ .../ai/openai/OpenAiTestConfiguration.java | 1 - .../transformer/MetadataTransformerIT.java | 1 - .../ai/model/ModelOptionsUtils.java | 53 +- .../openai/OpenAiAutoConfiguration.java | 5 +- .../openai/OpenAiChatProperties.java | 26 +- .../openai/OpenAiPropertiesTests.java | 109 ++++- 11 files changed, 754 insertions(+), 69 deletions(-) create mode 100644 models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiChatOptions.java create mode 100644 models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java diff --git a/models/spring-ai-openai/pom.xml b/models/spring-ai-openai/pom.xml index ad238b9a55f..8e0279ff969 100644 --- a/models/spring-ai-openai/pom.xml +++ b/models/spring-ai-openai/pom.xml @@ -34,6 +34,12 @@ 2.0.4 + + + org.springframework.boot + spring-boot + + io.rest-assured json-path diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java index 3e18671e0ca..6b629806c3b 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java @@ -28,14 +28,16 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.StreamingChatClient; -import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException; +import org.springframework.ai.openai.api.OpenAiChatOptions; import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata; import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; import org.springframework.http.ResponseEntity; @@ -57,11 +59,15 @@ */ public class OpenAiChatClient implements ChatClient, StreamingChatClient { - private Double temperature = 0.7; + private final Logger logger = LoggerFactory.getLogger(getClass()); - private String model = "gpt-3.5-turbo"; + private static final List REQUEST_JSON_FIELD_NAMES = ModelOptionsUtils + .getJsonPropertyValues(ChatCompletionRequest.class); - private final Logger logger = LoggerFactory.getLogger(getClass()); + private OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder() + .withModel("gpt-3.5-turbo") + .withTemperature(0.7f) + .build(); public final RetryTemplate retryTemplate = RetryTemplate.builder() .maxAttempts(10) @@ -76,40 +82,23 @@ public OpenAiChatClient(OpenAiApi openAiApi) { this.openAiApi = openAiApi; } - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - public Double getTemperature() { - return this.temperature; - } - - public void setTemperature(Double temperature) { - this.temperature = temperature; + public OpenAiChatClient withDefaultOptions(OpenAiChatOptions options) { + this.defaultOptions = options; + return this; } @Override public ChatResponse call(Prompt prompt) { return this.retryTemplate.execute(ctx -> { - List messages = prompt.getInstructions(); - List chatCompletionMessages = messages.stream() - .map(m -> new ChatCompletionMessage(m.getContent(), - ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))) - .toList(); + ChatCompletionRequest request = createRequest(prompt, false); - ResponseEntity completionEntity = this.openAiApi - .chatCompletionEntity(new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, this.model, - this.temperature.floatValue())); + ResponseEntity completionEntity = this.openAiApi.chatCompletionEntity(request); var chatCompletion = completionEntity.getBody(); if (chatCompletion == null) { - logger.warn("No chat completion returned for request: {}", chatCompletionMessages); + logger.warn("No chat completion returned for request: {}", prompt); return new ChatResponse(List.of()); } @@ -128,16 +117,9 @@ public ChatResponse call(Prompt prompt) { @Override public Flux stream(Prompt prompt) { return this.retryTemplate.execute(ctx -> { - List messages = prompt.getInstructions(); - - List chatCompletionMessages = messages.stream() - .map(m -> new ChatCompletionMessage(m.getContent(), - ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))) - .toList(); + ChatCompletionRequest request = createRequest(prompt, true); - Flux completionChunks = this.openAiApi - .chatCompletionStream(new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, this.model, - this.temperature.floatValue(), true)); + Flux completionChunks = this.openAiApi.chatCompletionStream(request); // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. @@ -161,4 +143,36 @@ public Flux stream(Prompt prompt) { }); } + /** + * Accessible for testing. + */ + ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + + List chatCompletionMessages = prompt.getInstructions() + .stream() + .map(m -> new ChatCompletionMessage(m.getContent(), + ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))) + .toList(); + + ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream); + + if (this.defaultOptions != null) { + request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class, + REQUEST_JSON_FIELD_NAMES); + } + + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof OpenAiChatOptions runtimeOptions) { + request = ModelOptionsUtils.merge(runtimeOptions, request, ChatCompletionRequest.class, + REQUEST_JSON_FIELD_NAMES); + } + else { + throw new IllegalArgumentException("Prompt options are not of type ChatCompletionRequest:" + + prompt.getOptions().getClass().getSimpleName()); + } + } + + return request; + } + } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 989bd302fea..130647152b3 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -31,6 +31,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.boot.context.properties.bind.ConstructorBinding; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; @@ -183,6 +184,7 @@ public record FunctionTool( * Create a tool of type 'function' and the given function definition. * @param function function definition. */ + @ConstructorBinding public FunctionTool(Function function) { this(Type.FUNCTION, function); } @@ -219,13 +221,14 @@ public record Function( * @param name tool function name. * @param jsonSchema tool function schema as json. */ + @ConstructorBinding public Function(String description, String name, String jsonSchema) { this(description, name, parseJson(jsonSchema)); } } } - /** +/** * Creates a model response for the given chat conversation. * * @param messages A list of messages comprising the conversation so far. @@ -269,17 +272,17 @@ public Function(String description, String name, String jsonSchema) { * */ @JsonInclude(Include.NON_NULL) - public record ChatCompletionRequest( + public record ChatCompletionRequest ( @JsonProperty("messages") List messages, @JsonProperty("model") String model, @JsonProperty("frequency_penalty") Float frequencyPenalty, - @JsonProperty("logit_bias") Map logitBias, + @JsonProperty("logit_bias") Map logitBias, @JsonProperty("max_tokens") Integer maxTokens, @JsonProperty("n") Integer n, @JsonProperty("presence_penalty") Float presencePenalty, @JsonProperty("response_format") ResponseFormat responseFormat, @JsonProperty("seed") Integer seed, - @JsonProperty("stop") String stop, + @JsonProperty("stop") List stop, @JsonProperty("stream") Boolean stream, @JsonProperty("temperature") Float temperature, @JsonProperty("top_p") Float topP, @@ -331,6 +334,20 @@ public ChatCompletionRequest(List messages, String model, tools, toolChoice, null); } + /** + * Shortcut constructor for a chat completion request with the given messages, model, tools and tool choice. + * Streaming is set to false, temperature to 0.8 and all other parameters are null. + * + * @param messages A list of messages comprising the conversation so far. + * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events + * as they become available, with the stream terminated by a data: [DONE] message. + */ + public ChatCompletionRequest(List messages, Boolean stream) { + this(messages, null, null, null, null, null, null, + null, null, null, stream, null, null, + null, null, null); + } + /** * Specifies a tool the model should use. Use to force the model to call a specific function. * @@ -346,6 +363,7 @@ public record ToolChoice( * Create a tool choice of type 'function' and name 'functionName'. * @param functionName Function name of the tool. */ + @ConstructorBinding public ToolChoice(String functionName) { this("function", Map.of("name", functionName)); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiChatOptions.java new file mode 100644 index 00000000000..34858ac3a37 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiChatOptions.java @@ -0,0 +1,454 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.api; + +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.chat.ChatOptions; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoice; +import org.springframework.ai.openai.api.OpenAiApi.FunctionTool; + +/** + * @author Christian Tzolov + * @since 0.8.0 + */ +@JsonInclude(Include.NON_NULL) +public class OpenAiChatOptions implements ChatOptions { + + // @formatter:off + /** + * ID of the model to use. + */ + @JsonProperty("model") String model; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + */ + @JsonProperty("frequency_penalty") Float frequencyPenalty = 0.0f; + /** + * Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object + * that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. + * Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will + * vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 + * or 100 should result in a ban or exclusive selection of the relevant token. + */ + @JsonProperty("logit_bias") Map logitBias; + /** + * The maximum number of tokens to generate in the chat completion. The total length of input + * tokens and generated tokens is limited by the model's context length. + */ + @JsonProperty("max_tokens") Integer maxTokens; + /** + * How many chat completion choices to generate for each input message. Note that you will be charged based + * on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs. + */ + @JsonProperty("n") Integer n = 1; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + * appear in the text so far, increasing the model's likelihood to talk about new topics. + */ + @JsonProperty("presence_penalty") Float presencePenalty; + /** + * An object specifying the format that the model must output. Setting to { "type": + * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. + */ + @JsonProperty("response_format") ResponseFormat responseFormat; + /** + * This feature is in Beta. If specified, our system will make a best effort to sample + * deterministically, such that repeated requests with the same seed and parameters should return the same result. + * Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor + * changes in the backend. + */ + @JsonProperty("seed") Integer seed; + /** + * Up to 4 sequences where the API will stop generating further tokens. + */ + @JsonProperty("stop") List stop; + /** + * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output + * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend + * altering this or top_p but not both. + */ + @JsonProperty("temperature") Float temperature = 0.8f; + /** + * An alternative to sampling with temperature, called nucleus sampling, where the model considers the + * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + * probability mass are considered. We generally recommend altering this or temperature but not both. + */ + @JsonProperty("top_p") Float topP; + /** + * A list of tools the model may call. Currently, only functions are supported as a tool. Use this to + * provide a list of functions the model may generate JSON inputs for. + */ + @JsonProperty("tools") List tools; + /** + * Controls which (if any) function is called by the model. none means the model will not call a + * function and instead generates a message. auto means the model can pick between generating a message or calling a + * function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces + * the model to call that function. none is the default when no functions are present. auto is the default if + * functions are present. + */ + @JsonProperty("tool_choice") ToolChoice toolChoice; + /** + * A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + */ + @JsonProperty("user") String user; + // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + protected OpenAiChatOptions options; + + public Builder() { + this.options = new OpenAiChatOptions(); + } + + public Builder(OpenAiChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withFrequencyPenalty(Float frequencyPenalty) { + options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withLogitBias(Map logitBias) { + options.logitBias = logitBias; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + options.maxTokens = maxTokens; + return this; + } + + public Builder withN(Integer n) { + options.n = n; + return this; + } + + public Builder withPresencePenalty(Float presencePenalty) { + options.presencePenalty = presencePenalty; + return this; + } + + public Builder withResponseFormat(ResponseFormat responseFormat) { + options.responseFormat = responseFormat; + return this; + } + + public Builder withSeed(Integer seed) { + options.seed = seed; + return this; + } + + public Builder withStop(List stop) { + options.stop = stop; + return this; + } + + public Builder withTemperature(Float temperature) { + options.temperature = temperature; + return this; + } + + public Builder withTopP(Float topP) { + options.topP = topP; + return this; + } + + public Builder withTools(List tools) { + options.tools = tools; + return this; + } + + public Builder withToolChoice(ToolChoice toolChoice) { + options.toolChoice = toolChoice; + return this; + } + + public Builder withUser(String user) { + options.user = user; + return this; + } + + public OpenAiChatOptions build() { + return options; + } + + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public Float getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(Float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public Map getLogitBias() { + return logitBias; + } + + public void setLogitBias(Map logitBias) { + this.logitBias = logitBias; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Integer getN() { + return n; + } + + public void setN(Integer n) { + this.n = n; + } + + public Float getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(Float presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public ResponseFormat getResponseFormat() { + return responseFormat; + } + + public void setResponseFormat(ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + public Integer getSeed() { + return seed; + } + + public void setSeed(Integer seed) { + this.seed = seed; + } + + public List getStop() { + return stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + public Float getTemperature() { + return temperature; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + public Float getTopP() { + return topP; + } + + public void setTopP(Float topP) { + this.topP = topP; + } + + public List getTools() { + return tools; + } + + public void setTools(List tools) { + this.tools = tools; + } + + public ToolChoice getToolChoice() { + return toolChoice; + } + + public void setToolChoice(ToolChoice toolChoice) { + this.toolChoice = toolChoice; + } + + public String getUser() { + return user; + } + + public void setUser(String user) { + this.user = user; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((model == null) ? 0 : model.hashCode()); + result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode()); + result = prime * result + ((logitBias == null) ? 0 : logitBias.hashCode()); + result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); + result = prime * result + ((n == null) ? 0 : n.hashCode()); + result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode()); + result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); + result = prime * result + ((seed == null) ? 0 : seed.hashCode()); + result = prime * result + ((stop == null) ? 0 : stop.hashCode()); + result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); + result = prime * result + ((topP == null) ? 0 : topP.hashCode()); + result = prime * result + ((tools == null) ? 0 : tools.hashCode()); + result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); + result = prime * result + ((user == null) ? 0 : user.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + OpenAiChatOptions other = (OpenAiChatOptions) obj; + if (model == null) { + if (other.model != null) + return false; + } + else if (!model.equals(other.model)) + return false; + if (frequencyPenalty == null) { + if (other.frequencyPenalty != null) + return false; + } + else if (!frequencyPenalty.equals(other.frequencyPenalty)) + return false; + if (logitBias == null) { + if (other.logitBias != null) + return false; + } + else if (!logitBias.equals(other.logitBias)) + return false; + if (maxTokens == null) { + if (other.maxTokens != null) + return false; + } + else if (!maxTokens.equals(other.maxTokens)) + return false; + if (n == null) { + if (other.n != null) + return false; + } + else if (!n.equals(other.n)) + return false; + if (presencePenalty == null) { + if (other.presencePenalty != null) + return false; + } + else if (!presencePenalty.equals(other.presencePenalty)) + return false; + if (responseFormat == null) { + if (other.responseFormat != null) + return false; + } + else if (!responseFormat.equals(other.responseFormat)) + return false; + if (seed == null) { + if (other.seed != null) + return false; + } + else if (!seed.equals(other.seed)) + return false; + if (stop == null) { + if (other.stop != null) + return false; + } + else if (!stop.equals(other.stop)) + return false; + if (temperature == null) { + if (other.temperature != null) + return false; + } + else if (!temperature.equals(other.temperature)) + return false; + if (topP == null) { + if (other.topP != null) + return false; + } + else if (!topP.equals(other.topP)) + return false; + if (tools == null) { + if (other.tools != null) + return false; + } + else if (!tools.equals(other.tools)) + return false; + if (toolChoice == null) { + if (other.toolChoice != null) + return false; + } + else if (!toolChoice.equals(other.toolChoice)) + return false; + if (user == null) { + if (other.user != null) + return false; + } + else if (!user.equals(other.user)) + return false; + return true; + } + + @Override + @JsonIgnore + public Integer getTopK() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'getTopK'"); + } + + @Override + @JsonIgnore + public void setTopK(Integer topK) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'setTopK'"); + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java new file mode 100644 index 00000000000..980ae89986e --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java @@ -0,0 +1,56 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiChatOptions; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + */ +public class ChatCompletionRequestTests { + + @Test + public void createRequestWithChatOptions() { + + var client = new OpenAiChatClient(new OpenAiApi("TEST")) + .withDefaultOptions(OpenAiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build()); + + var request = client.createRequest(new Prompt("Test message content"), false); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isFalse(); + + assertThat(request.model()).isEqualTo("DEFAULT_MODEL"); + assertThat(request.temperature()).isEqualTo(66.6f); + + request = client.createRequest(new Prompt("Test message content", + OpenAiChatOptions.builder().withModel("PROMPT_MODEL").withTemperature(99.9f).build()), true); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isTrue(); + + assertThat(request.model()).isEqualTo("PROMPT_MODEL"); + assertThat(request.temperature()).isEqualTo(99.9f); + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index b14b0ba9391..d43a25dde1d 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -32,7 +32,6 @@ private String getApiKey() { @Bean public OpenAiChatClient openAiChatClient(OpenAiApi api) { OpenAiChatClient openAiChatClient = new OpenAiChatClient(api); - openAiChatClient.setTemperature(0.3); return openAiChatClient; } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java index 120b32a7be2..4f655320615 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java @@ -166,7 +166,6 @@ public OpenAiApi openAiApi() throws IOException { @Bean public OpenAiChatClient openAiChatClient(OpenAiApi openAiApi) { OpenAiChatClient openAiChatClient = new OpenAiChatClient(openAiApi); - openAiChatClient.setTemperature(0.3); return openAiChatClient; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java index 85c47554ef7..267a1ccd960 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -16,15 +16,26 @@ package org.springframework.ai.model; +import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.util.CollectionUtils; + +import java.lang.reflect.Field; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.stream.Collectors; -public abstract class ModelOptionsUtils { +/** + * Utility class for manipulating {@link ModelOptions} objects. + * + * @author Christian Tzolov + */ +public final class ModelOptionsUtils { private final static ObjectMapper OBJECT_MAPPER = new ObjectMapper(); @@ -39,9 +50,10 @@ private ModelOptionsUtils() { * @param source the source object to merge. * @param target the target object to merge into. * @param clazz the class to return. + * @param acceptedFieldNames the list of field names accepted for the target object. * @return the merged object represented by the given class. */ - public static T merge(Object source, Object target, Class clazz) { + public static T merge(Object source, Object target, Class clazz, List acceptedFieldNames) { Map sourceMap = objectToMap(source); Map targetMap = objectToMap(target); @@ -50,9 +62,29 @@ public static T merge(Object source, Object target, Class clazz) { .filter(e -> e.getValue() != null) .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()))); + if (!CollectionUtils.isEmpty(acceptedFieldNames)) { + targetMap = targetMap.entrySet() + .stream() + .filter(e -> acceptedFieldNames.contains(e.getKey())) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); + } + return mapToClass(targetMap, clazz); } + /** + * Merges the source object into the target object and returns an object represented + * by the given class. The source null values are ignored. + * @param they type of the class to return. + * @param source the source object to merge. + * @param target the target object to merge into. + * @param clazz the class to return. + * @return the merged object represented by the given class. + */ + public static T merge(Object source, Object target, Class clazz) { + return merge(source, target, clazz, null); + } + /** * Converts the given object to a Map. * @param source the object to convert to a Map. @@ -89,4 +121,21 @@ public static T mapToClass(Map source, Class clazz) { } } + /** + * Returns the list of values of the {@link JsonProperty} annotations. + * @param clazz the class that contains fields annotated with {@link JsonProperty}. + * @return the list of values of the {@link JsonProperty} annotations. + */ + public static List getJsonPropertyValues(Class clazz) { + List values = new ArrayList<>(); + Field[] fields = clazz.getDeclaredFields(); + for (Field field : fields) { + JsonProperty jsonPropertyAnnotation = field.getAnnotation(JsonProperty.class); + if (jsonPropertyAnnotation != null) { + values.add(jsonPropertyAnnotation.value()); + } + } + return values; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index 669856a81bc..1e7edfba235 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -54,9 +54,8 @@ public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProper var openAiApi = new OpenAiApi(baseUrl, apiKey, RestClient.builder()); - OpenAiChatClient openAiChatClient = new OpenAiChatClient(openAiApi); - openAiChatClient.setTemperature(chatProperties.getTemperature()); - openAiChatClient.setModel(chatProperties.getModel()); + OpenAiChatClient openAiChatClient = new OpenAiChatClient(openAiApi) + .withDefaultOptions(chatProperties.getOptions()); return openAiChatClient; } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java index 9f3f387f153..b6e8608f104 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java @@ -16,7 +16,9 @@ package org.springframework.ai.autoconfigure.openai; +import org.springframework.ai.openai.api.OpenAiChatOptions; import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; @ConfigurationProperties(OpenAiChatProperties.CONFIG_PREFIX) public class OpenAiChatProperties extends OpenAiParentProperties { @@ -25,24 +27,20 @@ public class OpenAiChatProperties extends OpenAiParentProperties { public static final String DEFAULT_CHAT_MODEL = "gpt-3.5-turbo"; - private Double temperature = 0.7; + private static final Double DEFAULT_TEMPERATURE = 0.7; - private String model = DEFAULT_CHAT_MODEL; + @NestedConfigurationProperty + private OpenAiChatOptions options = OpenAiChatOptions.builder() + .withModel(DEFAULT_CHAT_MODEL) + .withTemperature(DEFAULT_TEMPERATURE.floatValue()) + .build(); - public String getModel() { - return model; + public OpenAiChatOptions getOptions() { + return options; } - public void setModel(String model) { - this.model = model; - } - - public Double getTemperature() { - return temperature; - } - - public void setTemperature(Double temperature) { - this.temperature = temperature; + public void setOptions(OpenAiChatOptions options) { + this.options = options; } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java index bf8fc8208f4..4ed60cf6a05 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java @@ -16,8 +16,13 @@ package org.springframework.ai.autoconfigure.openai; +import java.util.Map; + import org.junit.jupiter.api.Test; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoice; +import org.springframework.ai.openai.api.OpenAiApi.FunctionTool.Type; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -39,8 +44,8 @@ public void chatProperties() { // @formatter:off "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.api-key=abc123", - "spring.ai.openai.chat.model=MODEL_XYZ", - "spring.ai.openai.chat.temperature=0.55") + "spring.ai.openai.chat.options.model=MODEL_XYZ", + "spring.ai.openai.chat.options.temperature=0.55") // @formatter:on .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)) .run(context -> { @@ -53,8 +58,8 @@ public void chatProperties() { assertThat(chatProperties.getApiKey()).isNull(); assertThat(chatProperties.getBaseUrl()).isNull(); - assertThat(chatProperties.getModel()).isEqualTo("MODEL_XYZ"); - assertThat(chatProperties.getTemperature()).isEqualTo(0.55); + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); }); } @@ -67,8 +72,8 @@ public void chatOverrideConnectionProperties() { "spring.ai.openai.api-key=abc123", "spring.ai.openai.chat.base-url=TEST_BASE_URL2", "spring.ai.openai.chat.api-key=456", - "spring.ai.openai.chat.model=MODEL_XYZ", - "spring.ai.openai.chat.temperature=0.55") + "spring.ai.openai.chat.options.model=MODEL_XYZ", + "spring.ai.openai.chat.options.temperature=0.55") // @formatter:on .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)) .run(context -> { @@ -81,8 +86,8 @@ public void chatOverrideConnectionProperties() { assertThat(chatProperties.getApiKey()).isEqualTo("456"); assertThat(chatProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); - assertThat(chatProperties.getModel()).isEqualTo("MODEL_XYZ"); - assertThat(chatProperties.getTemperature()).isEqualTo(0.55); + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); }); } @@ -136,4 +141,92 @@ public void embeddingOverrideConnectionProperties() { }); } + @Test + public void optionsTest() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.openai.api-key=API_KEY", + "spring.ai.openai.base-url=TEST_BASE_URL", + + "spring.ai.openai.chat.options.model=MODEL_XYZ", + "spring.ai.openai.chat.options.frequencyPenalty=-1.5", + "spring.ai.openai.chat.options.logitBias.myTokenId=-5", + "spring.ai.openai.chat.options.maxTokens=123", + "spring.ai.openai.chat.options.n=10", + "spring.ai.openai.chat.options.presencePenalty=0", + "spring.ai.openai.chat.options.responseFormat.type=json", + "spring.ai.openai.chat.options.seed=66", + "spring.ai.openai.chat.options.stop=boza,koza", + "spring.ai.openai.chat.options.temperature=0.55", + "spring.ai.openai.chat.options.topP=0.56", + + "spring.ai.openai.chat.options.toolChoice.functionName=toolChoiceFunctionName", + + "spring.ai.openai.chat.options.tools[0].function.name=myFunction1", + "spring.ai.openai.chat.options.tools[0].function.description=function description", + "spring.ai.openai.chat.options.tools[0].function.jsonSchema=" + """ + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "lat": { + "type": "number", + "description": "The city latitude" + }, + "lon": { + "type": "number", + "description": "The city longitude" + }, + "unit": { + "type": "string", + "enum": ["c", "f"] + } + }, + "required": ["location", "lat", "lon", "unit"] + } + """, + "spring.ai.openai.chat.options.user=userXYZ" + ) + // @formatter:on + .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(OpenAiChatProperties.class); + var connectionProperties = context.getBean(OpenAiConnectionProperties.class); + var embeddingProperties = context.getBean(OpenAiEmbeddingProperties.class); + + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + + assertThat(embeddingProperties.getModel()).isEqualTo("text-embedding-ada-002"); + + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getFrequencyPenalty()).isEqualTo(-1.5f); + assertThat(chatProperties.getOptions().getLogitBias().get("myTokenId")).isEqualTo(-5); + assertThat(chatProperties.getOptions().getMaxTokens()).isEqualTo(123); + assertThat(chatProperties.getOptions().getN()).isEqualTo(10); + assertThat(chatProperties.getOptions().getPresencePenalty()).isEqualTo(0); + assertThat(chatProperties.getOptions().getResponseFormat()).isEqualTo(new ResponseFormat("json")); + assertThat(chatProperties.getOptions().getSeed()).isEqualTo(66); + assertThat(chatProperties.getOptions().getStop()).contains("boza", "koza"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + assertThat(chatProperties.getOptions().getTopP()).isEqualTo(0.56f); + + assertThat(chatProperties.getOptions().getToolChoice()) + .isEqualTo(new ToolChoice("function", Map.of("name", "toolChoiceFunctionName"))); + assertThat(chatProperties.getOptions().getUser()).isEqualTo("userXYZ"); + + assertThat(chatProperties.getOptions().getTools()).hasSize(1); + var tool = chatProperties.getOptions().getTools().get(0); + assertThat(tool.type()).isEqualTo(Type.FUNCTION); + var function = tool.function(); + assertThat(function.name()).isEqualTo("myFunction1"); + assertThat(function.description()).isEqualTo("function description"); + assertThat(function.parameters()).isNotEmpty(); + }); + } + } From 443c5c2194125a65eb39da7472fff28238e9b833 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sun, 28 Jan 2024 16:58:55 +0100 Subject: [PATCH 2/6] move the OpenAiChatOptions.java out of the api package --- .../java/org/springframework/ai/openai/OpenAiChatClient.java | 1 - .../ai/openai/{api => }/OpenAiChatOptions.java | 4 +++- .../springframework/ai/openai/ChatCompletionRequestTests.java | 1 - .../ai/autoconfigure/openai/OpenAiChatProperties.java | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) rename models/spring-ai-openai/src/main/java/org/springframework/ai/openai/{api => }/OpenAiChatOptions.java (98%) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java index 6b629806c3b..0cbe81101f3 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java @@ -37,7 +37,6 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException; -import org.springframework.ai.openai.api.OpenAiChatOptions; import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata; import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; import org.springframework.http.ResponseEntity; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java similarity index 98% rename from models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiChatOptions.java rename to models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 34858ac3a37..e99fdd2fed7 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.openai.api; +package org.springframework.ai.openai; import java.util.List; import java.util.Map; @@ -27,6 +27,8 @@ import org.springframework.ai.chat.ChatOptions; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoice; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.openai.api.OpenAiApi.FunctionTool; /** diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java index 980ae89986e..93b6b63db3c 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java @@ -20,7 +20,6 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.openai.api.OpenAiApi; -import org.springframework.ai.openai.api.OpenAiChatOptions; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java index b6e8608f104..aea4c73d3f3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java @@ -16,7 +16,7 @@ package org.springframework.ai.autoconfigure.openai; -import org.springframework.ai.openai.api.OpenAiChatOptions; +import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; From 39fe1c1cfa1d7a156387f7e0ae57d6ffa17e70ea Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sun, 28 Jan 2024 19:01:50 +0100 Subject: [PATCH 3/6] Add OpenAiEmbeddingOptions - Add OpenAiEmbeddingOptions class implementing the EmbeddingOptions interface. - Add OpenAiEmbeddingClient#defaultOptions - Add request merging with default and propmt options. - Add OopenAiEmbeddingProperties#options field of type OpenAiEmbeddingOptions --- .../ai/openai/OpenAiChatOptions.java | 134 +++++++++--------- .../ai/openai/OpenAiEmbeddingClient.java | 43 ++++-- .../ai/openai/OpenAiEmbeddingOptions.java | 104 ++++++++++++++ .../openai/OpenAiAutoConfiguration.java | 2 +- .../openai/OpenAiEmbeddingProperties.java | 15 +- .../openai/OpenAiPropertiesTests.java | 39 ++++- 6 files changed, 242 insertions(+), 95 deletions(-) create mode 100644 models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index e99fdd2fed7..a7b0c7e0d5e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -27,8 +27,6 @@ import org.springframework.ai.chat.ChatOptions; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoice; -import org.springframework.ai.openai.api.OpenAiApi; -import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.openai.api.OpenAiApi.FunctionTool; /** @@ -42,12 +40,12 @@ public class OpenAiChatOptions implements ChatOptions { /** * ID of the model to use. */ - @JsonProperty("model") String model; + private @JsonProperty("model") String model; /** * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. */ - @JsonProperty("frequency_penalty") Float frequencyPenalty = 0.0f; + private @JsonProperty("frequency_penalty") Float frequencyPenalty = 0.0f; /** * Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object * that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. @@ -55,55 +53,55 @@ public class OpenAiChatOptions implements ChatOptions { * vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 * or 100 should result in a ban or exclusive selection of the relevant token. */ - @JsonProperty("logit_bias") Map logitBias; + private @JsonProperty("logit_bias") Map logitBias; /** * The maximum number of tokens to generate in the chat completion. The total length of input * tokens and generated tokens is limited by the model's context length. */ - @JsonProperty("max_tokens") Integer maxTokens; + private @JsonProperty("max_tokens") Integer maxTokens; /** * How many chat completion choices to generate for each input message. Note that you will be charged based * on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs. */ - @JsonProperty("n") Integer n = 1; + private @JsonProperty("n") Integer n = 1; /** * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they * appear in the text so far, increasing the model's likelihood to talk about new topics. */ - @JsonProperty("presence_penalty") Float presencePenalty; + private @JsonProperty("presence_penalty") Float presencePenalty; /** * An object specifying the format that the model must output. Setting to { "type": * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. */ - @JsonProperty("response_format") ResponseFormat responseFormat; + private @JsonProperty("response_format") ResponseFormat responseFormat; /** * This feature is in Beta. If specified, our system will make a best effort to sample * deterministically, such that repeated requests with the same seed and parameters should return the same result. * Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor * changes in the backend. */ - @JsonProperty("seed") Integer seed; + private @JsonProperty("seed") Integer seed; /** * Up to 4 sequences where the API will stop generating further tokens. */ - @JsonProperty("stop") List stop; + private @JsonProperty("stop") List stop; /** * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend * altering this or top_p but not both. */ - @JsonProperty("temperature") Float temperature = 0.8f; + private @JsonProperty("temperature") Float temperature = 0.8f; /** * An alternative to sampling with temperature, called nucleus sampling, where the model considers the * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% * probability mass are considered. We generally recommend altering this or temperature but not both. */ - @JsonProperty("top_p") Float topP; + private @JsonProperty("top_p") Float topP; /** * A list of tools the model may call. Currently, only functions are supported as a tool. Use this to * provide a list of functions the model may generate JSON inputs for. */ - @JsonProperty("tools") List tools; + private @JsonProperty("tools") List tools; /** * Controls which (if any) function is called by the model. none means the model will not call a * function and instead generates a message. auto means the model can pick between generating a message or calling a @@ -111,11 +109,11 @@ public class OpenAiChatOptions implements ChatOptions { * the model to call that function. none is the default when no functions are present. auto is the default if * functions are present. */ - @JsonProperty("tool_choice") ToolChoice toolChoice; + private @JsonProperty("tool_choice") ToolChoice toolChoice; /** * A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. */ - @JsonProperty("user") String user; + private @JsonProperty("user") String user; // @formatter:on public static Builder builder() { @@ -140,78 +138,78 @@ public Builder withModel(String model) { } public Builder withFrequencyPenalty(Float frequencyPenalty) { - options.frequencyPenalty = frequencyPenalty; + this.options.frequencyPenalty = frequencyPenalty; return this; } public Builder withLogitBias(Map logitBias) { - options.logitBias = logitBias; + this.options.logitBias = logitBias; return this; } public Builder withMaxTokens(Integer maxTokens) { - options.maxTokens = maxTokens; + this.options.maxTokens = maxTokens; return this; } public Builder withN(Integer n) { - options.n = n; + this.options.n = n; return this; } public Builder withPresencePenalty(Float presencePenalty) { - options.presencePenalty = presencePenalty; + this.options.presencePenalty = presencePenalty; return this; } public Builder withResponseFormat(ResponseFormat responseFormat) { - options.responseFormat = responseFormat; + this.options.responseFormat = responseFormat; return this; } public Builder withSeed(Integer seed) { - options.seed = seed; + this.options.seed = seed; return this; } public Builder withStop(List stop) { - options.stop = stop; + this.options.stop = stop; return this; } public Builder withTemperature(Float temperature) { - options.temperature = temperature; + this.options.temperature = temperature; return this; } public Builder withTopP(Float topP) { - options.topP = topP; + this.options.topP = topP; return this; } public Builder withTools(List tools) { - options.tools = tools; + this.options.tools = tools; return this; } public Builder withToolChoice(ToolChoice toolChoice) { - options.toolChoice = toolChoice; + this.options.toolChoice = toolChoice; return this; } public Builder withUser(String user) { - options.user = user; + this.options.user = user; return this; } public OpenAiChatOptions build() { - return options; + return this.options; } } public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -219,7 +217,7 @@ public void setModel(String model) { } public Float getFrequencyPenalty() { - return frequencyPenalty; + return this.frequencyPenalty; } public void setFrequencyPenalty(Float frequencyPenalty) { @@ -227,7 +225,7 @@ public void setFrequencyPenalty(Float frequencyPenalty) { } public Map getLogitBias() { - return logitBias; + return this.logitBias; } public void setLogitBias(Map logitBias) { @@ -235,7 +233,7 @@ public void setLogitBias(Map logitBias) { } public Integer getMaxTokens() { - return maxTokens; + return this.maxTokens; } public void setMaxTokens(Integer maxTokens) { @@ -243,7 +241,7 @@ public void setMaxTokens(Integer maxTokens) { } public Integer getN() { - return n; + return this.n; } public void setN(Integer n) { @@ -251,7 +249,7 @@ public void setN(Integer n) { } public Float getPresencePenalty() { - return presencePenalty; + return this.presencePenalty; } public void setPresencePenalty(Float presencePenalty) { @@ -259,7 +257,7 @@ public void setPresencePenalty(Float presencePenalty) { } public ResponseFormat getResponseFormat() { - return responseFormat; + return this.responseFormat; } public void setResponseFormat(ResponseFormat responseFormat) { @@ -267,7 +265,7 @@ public void setResponseFormat(ResponseFormat responseFormat) { } public Integer getSeed() { - return seed; + return this.seed; } public void setSeed(Integer seed) { @@ -275,7 +273,7 @@ public void setSeed(Integer seed) { } public List getStop() { - return stop; + return this.stop; } public void setStop(List stop) { @@ -283,7 +281,7 @@ public void setStop(List stop) { } public Float getTemperature() { - return temperature; + return this.temperature; } public void setTemperature(Float temperature) { @@ -291,7 +289,7 @@ public void setTemperature(Float temperature) { } public Float getTopP() { - return topP; + return this.topP; } public void setTopP(Float topP) { @@ -299,7 +297,7 @@ public void setTopP(Float topP) { } public List getTools() { - return tools; + return this.tools; } public void setTools(List tools) { @@ -307,7 +305,7 @@ public void setTools(List tools) { } public ToolChoice getToolChoice() { - return toolChoice; + return this.toolChoice; } public void setToolChoice(ToolChoice toolChoice) { @@ -315,7 +313,7 @@ public void setToolChoice(ToolChoice toolChoice) { } public String getUser() { - return user; + return this.user; } public void setUser(String user) { @@ -352,89 +350,89 @@ public boolean equals(Object obj) { if (getClass() != obj.getClass()) return false; OpenAiChatOptions other = (OpenAiChatOptions) obj; - if (model == null) { + if (this.model == null) { if (other.model != null) return false; } else if (!model.equals(other.model)) return false; - if (frequencyPenalty == null) { + if (this.frequencyPenalty == null) { if (other.frequencyPenalty != null) return false; } - else if (!frequencyPenalty.equals(other.frequencyPenalty)) + else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) return false; - if (logitBias == null) { + if (this.logitBias == null) { if (other.logitBias != null) return false; } - else if (!logitBias.equals(other.logitBias)) + else if (!this.logitBias.equals(other.logitBias)) return false; - if (maxTokens == null) { + if (this.maxTokens == null) { if (other.maxTokens != null) return false; } - else if (!maxTokens.equals(other.maxTokens)) + else if (!this.maxTokens.equals(other.maxTokens)) return false; - if (n == null) { + if (this.n == null) { if (other.n != null) return false; } - else if (!n.equals(other.n)) + else if (!this.n.equals(other.n)) return false; - if (presencePenalty == null) { + if (this.presencePenalty == null) { if (other.presencePenalty != null) return false; } - else if (!presencePenalty.equals(other.presencePenalty)) + else if (!this.presencePenalty.equals(other.presencePenalty)) return false; - if (responseFormat == null) { + if (this.responseFormat == null) { if (other.responseFormat != null) return false; } - else if (!responseFormat.equals(other.responseFormat)) + else if (!this.responseFormat.equals(other.responseFormat)) return false; - if (seed == null) { + if (this.seed == null) { if (other.seed != null) return false; } - else if (!seed.equals(other.seed)) + else if (!this.seed.equals(other.seed)) return false; - if (stop == null) { + if (this.stop == null) { if (other.stop != null) return false; } else if (!stop.equals(other.stop)) return false; - if (temperature == null) { + if (this.temperature == null) { if (other.temperature != null) return false; } - else if (!temperature.equals(other.temperature)) + else if (!this.temperature.equals(other.temperature)) return false; - if (topP == null) { + if (this.topP == null) { if (other.topP != null) return false; } else if (!topP.equals(other.topP)) return false; - if (tools == null) { + if (this.tools == null) { if (other.tools != null) return false; } else if (!tools.equals(other.tools)) return false; - if (toolChoice == null) { + if (this.toolChoice == null) { if (other.toolChoice != null) return false; } else if (!toolChoice.equals(other.toolChoice)) return false; - if (user == null) { + if (this.user == null) { if (other.user != null) return false; } - else if (!user.equals(other.user)) + else if (!this.user.equals(other.user)) return false; return true; } @@ -442,14 +440,12 @@ else if (!user.equals(other.user)) @Override @JsonIgnore public Integer getTopK() { - // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'getTopK'"); } @Override @JsonIgnore public void setTopK(Integer topK) { - // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'setTopK'"); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java index eff4690e58e..4611957770f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java @@ -29,6 +29,7 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList; import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException; @@ -45,9 +46,16 @@ public class OpenAiEmbeddingClient extends AbstractEmbeddingClient { private static final Logger logger = LoggerFactory.getLogger(OpenAiEmbeddingClient.class); + private static final List REQUEST_JSON_FIELD_NAMES = ModelOptionsUtils + .getJsonPropertyValues(org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest.class); + public static final String DEFAULT_OPENAI_EMBEDDING_MODEL = "text-embedding-ada-002"; - public final RetryTemplate retryTemplate = RetryTemplate.builder() + private OpenAiEmbeddingOptions defaultOptions = OpenAiEmbeddingOptions.builder() + .withModel(DEFAULT_OPENAI_EMBEDDING_MODEL) + .build(); + + private final RetryTemplate retryTemplate = RetryTemplate.builder() .maxAttempts(10) .retryOn(OpenAiApiException.class) .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000)) @@ -55,27 +63,24 @@ public class OpenAiEmbeddingClient extends AbstractEmbeddingClient { private final OpenAiApi openAiApi; - private final String embeddingModelName; - private final MetadataMode metadataMode; public OpenAiEmbeddingClient(OpenAiApi openAiApi) { - this(openAiApi, DEFAULT_OPENAI_EMBEDDING_MODEL); - } - - public OpenAiEmbeddingClient(OpenAiApi openAiApi, String embeddingModel) { - this(openAiApi, embeddingModel, MetadataMode.EMBED); + this(openAiApi, MetadataMode.EMBED); } - public OpenAiEmbeddingClient(OpenAiApi openAiApi, String model, MetadataMode metadataMode) { + public OpenAiEmbeddingClient(OpenAiApi openAiApi, MetadataMode metadataMode) { Assert.notNull(openAiApi, "OpenAiService must not be null"); - Assert.notNull(model, "Model must not be null"); Assert.notNull(metadataMode, "metadataMode must not be null"); this.openAiApi = openAiApi; - this.embeddingModelName = model; this.metadataMode = metadataMode; } + public OpenAiEmbeddingClient withDefaultOptions(OpenAiEmbeddingOptions options) { + this.defaultOptions = options; + return this; + } + @Override public List embed(Document document) { Assert.notNull(document, "Document must not be null"); @@ -87,7 +92,17 @@ public EmbeddingResponse call(EmbeddingRequest request) { return this.retryTemplate.execute(ctx -> { org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest> apiRequest = new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>( - request.getInstructions(), this.embeddingModelName); + request.getInstructions(), DEFAULT_OPENAI_EMBEDDING_MODEL); + + if (this.defaultOptions != null) { + apiRequest = ModelOptionsUtils.merge(apiRequest, this.defaultOptions, + org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest.class, REQUEST_JSON_FIELD_NAMES); + } + + if (request.getOptions() != null) { + apiRequest = ModelOptionsUtils.merge(request.getOptions(), apiRequest, + org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest.class, REQUEST_JSON_FIELD_NAMES); + } EmbeddingList apiEmbeddingResponse = this.openAiApi.embeddings(apiRequest).getBody(); @@ -96,7 +111,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { return new EmbeddingResponse(List.of()); } - var metadata = generateMetadata(apiEmbeddingResponse.model(), apiEmbeddingResponse.usage()); + var metadata = generateResponseMetadata(apiEmbeddingResponse.model(), apiEmbeddingResponse.usage()); List embeddings = apiEmbeddingResponse.data() .stream() @@ -108,7 +123,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { }); } - private EmbeddingResponseMetadata generateMetadata(String model, Usage usage) { + private EmbeddingResponseMetadata generateResponseMetadata(String model, Usage usage) { EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(); metadata.put("model", model); metadata.put("prompt-tokens", usage.promptTokens()); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java new file mode 100644 index 00000000000..e37d08161c9 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java @@ -0,0 +1,104 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.embedding.EmbeddingOptions; + +/** + * @author Christian Tzolov + * @since 0.8.0 + */ +@JsonInclude(Include.NON_NULL) +public class OpenAiEmbeddingOptions extends EmbeddingOptions { + + // @formatter:off + /** + * ID of the model to use. + */ + private @JsonProperty("model") String model; + /** + * The format to return the embeddings in. Can be either float or base64. + */ + private @JsonProperty("encoding_format") String encodingFormat; + /** + * A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + */ + private @JsonProperty("user") String user; + // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + protected OpenAiEmbeddingOptions options; + + public Builder() { + this.options = new OpenAiEmbeddingOptions(); + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withEncodingFormat(String encodingFormat) { + this.options.setEncodingFormat(encodingFormat); + return this; + } + + public Builder withUser(String user) { + this.options.setUser(user); + return this; + } + + public OpenAiEmbeddingOptions build() { + return this.options; + } + + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public String getEncodingFormat() { + return encodingFormat; + } + + public void setEncodingFormat(String encodingFormat) { + this.encodingFormat = encodingFormat; + } + + public String getUser() { + return user; + } + + public void setUser(String user) { + this.user = user; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index 1e7edfba235..7c28e39ea69 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -75,7 +75,7 @@ public EmbeddingClient openAiEmbeddingClient(OpenAiConnectionProperties commonPr var openAiApi = new OpenAiApi(baseUrl, apiKey, RestClient.builder()); - return new OpenAiEmbeddingClient(openAiApi, embeddingProperties.getModel()); + return new OpenAiEmbeddingClient(openAiApi).withDefaultOptions(embeddingProperties.getOptions()); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java index 7e125548d16..b265643d3bc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java @@ -16,7 +16,9 @@ package org.springframework.ai.autoconfigure.openai; +import org.springframework.ai.openai.OpenAiEmbeddingOptions; import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; @ConfigurationProperties(OpenAiEmbeddingProperties.CONFIG_PREFIX) public class OpenAiEmbeddingProperties extends OpenAiParentProperties { @@ -25,14 +27,17 @@ public class OpenAiEmbeddingProperties extends OpenAiParentProperties { public static final String DEFAULT_EMBEDDING_MODEL = "text-embedding-ada-002"; - private String model = DEFAULT_EMBEDDING_MODEL; + @NestedConfigurationProperty + private OpenAiEmbeddingOptions options = OpenAiEmbeddingOptions.builder() + .withModel(DEFAULT_EMBEDDING_MODEL) + .build(); - public String getModel() { - return model; + public OpenAiEmbeddingOptions getOptions() { + return options; } - public void setModel(String model) { - this.model = model; + public void setOptions(OpenAiEmbeddingOptions options) { + this.options = options; } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java index 4ed60cf6a05..7fb7b938eea 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java @@ -98,7 +98,7 @@ public void embeddingProperties() { // @formatter:off "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.api-key=abc123", - "spring.ai.openai.embedding.model=MODEL_XYZ") + "spring.ai.openai.embedding.options.model=MODEL_XYZ") // @formatter:on .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)) .run(context -> { @@ -111,7 +111,7 @@ public void embeddingProperties() { assertThat(embeddingProperties.getApiKey()).isNull(); assertThat(embeddingProperties.getBaseUrl()).isNull(); - assertThat(embeddingProperties.getModel()).isEqualTo("MODEL_XYZ"); + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); }); } @@ -124,7 +124,7 @@ public void embeddingOverrideConnectionProperties() { "spring.ai.openai.api-key=abc123", "spring.ai.openai.embedding.base-url=TEST_BASE_URL2", "spring.ai.openai.embedding.api-key=456", - "spring.ai.openai.embedding.model=MODEL_XYZ") + "spring.ai.openai.embedding.options.model=MODEL_XYZ") // @formatter:on .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)) .run(context -> { @@ -137,12 +137,12 @@ public void embeddingOverrideConnectionProperties() { assertThat(embeddingProperties.getApiKey()).isEqualTo("456"); assertThat(embeddingProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); - assertThat(embeddingProperties.getModel()).isEqualTo("MODEL_XYZ"); + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); }); } @Test - public void optionsTest() { + public void chatOptionsTest() { new ApplicationContextRunner().withPropertyValues( // @formatter:off @@ -201,7 +201,7 @@ public void optionsTest() { assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); - assertThat(embeddingProperties.getModel()).isEqualTo("text-embedding-ada-002"); + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("text-embedding-ada-002"); assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); assertThat(chatProperties.getOptions().getFrequencyPenalty()).isEqualTo(-1.5f); @@ -229,4 +229,31 @@ public void optionsTest() { }); } + @Test + public void embeddingOptionsTest() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.openai.api-key=API_KEY", + "spring.ai.openai.base-url=TEST_BASE_URL", + + "spring.ai.openai.embedding.options.model=MODEL_XYZ", + "spring.ai.openai.embedding.options.encodingFormat=MyEncodingFormat", + "spring.ai.openai.embedding.options.user=userXYZ" + ) + // @formatter:on + .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)) + .run(context -> { + var connectionProperties = context.getBean(OpenAiConnectionProperties.class); + var embeddingProperties = context.getBean(OpenAiEmbeddingProperties.class); + + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(embeddingProperties.getOptions().getEncodingFormat()).isEqualTo("MyEncodingFormat"); + assertThat(embeddingProperties.getOptions().getUser()).isEqualTo("userXYZ"); + }); + } + } From 09cc4fdd25130b0daf502f0abbb2d3a1e264c9d3 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sun, 28 Jan 2024 19:11:50 +0100 Subject: [PATCH 4/6] Rebase --- .../org/springframework/ai/openai/OpenAiEmbeddingOptions.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java index e37d08161c9..39a73a707a6 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java @@ -27,7 +27,7 @@ * @since 0.8.0 */ @JsonInclude(Include.NON_NULL) -public class OpenAiEmbeddingOptions extends EmbeddingOptions { +public class OpenAiEmbeddingOptions implements EmbeddingOptions { // @formatter:off /** From 2c52a501bd9e88e7e6a55f3f5c5955e62fa44d3c Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sun, 28 Jan 2024 22:14:43 +0100 Subject: [PATCH 5/6] Update OpenAI client docs --- .../ROOT/pages/api/clients/openai.adoc | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/openai.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/openai.adoc index 90e538fd959..bcb42937131 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/openai.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/openai.adoc @@ -74,6 +74,13 @@ public class ChatController { } ---- +== OpenAiChatOptions + +The http://OpenAiChatOptions.java[OpenAiChatOptions.java] allows you to configure OpenAI options, such as the model to use, the temperature, the frequency penalty, etc. +You can assign the default options, at startup, to the `OpenAiChatClient` using the `withDefaultOptions()` method. You can also override the default options at runtime by passing in the `OpenAiChatOptions` object to the `Prompt` constructor. + +The default options can be configured using the `spring.ai.openai.chat.options` properties as well. + == OpenAI Properties The prefix `spring.ai.openai` is used as the property prefix that lets you connect to OpenAI. @@ -94,8 +101,20 @@ The prefix `spring.ai.openai.chat` is the property prefix that lets you configur | spring.ai.openai.chat.base-url | Optional overrides the spring.ai.openai.base-url to provide chat specific url | - | spring.ai.openai.chat.api-key | Optional overrides the spring.ai.openai.api-key to provide chat specific api-key | - -| spring.ai.openai.chat.model | This is the OpenAI Chat model to use | `gpt-35-turbo` (the `gpt-3.5-turbo`, `gpt-4`, and `gpt-4-32k` point to the latest model versions) -| spring.ai.openai.chat.temperature | The sampling temperature to use that controls the apparent creativity of generated completions. Higher values will make output more random while lower values will make results more focused and deterministic. It is not recommended to modify temperature and top_p for the same completions request as the interaction of these two settings is difficult to predict. | 0.7 +| spring.ai.openai.chat.options.model | This is the OpenAI Chat model to use | `gpt-35-turbo` (the `gpt-3.5-turbo`, `gpt-4`, and `gpt-4-32k` point to the latest model versions) +| spring.ai.openai.chat.options.temperature | The sampling temperature to use that controls the apparent creativity of generated completions. Higher values will make output more random while lower values will make results more focused and deterministic. It is not recommended to modify temperature and top_p for the same completions request as the interaction of these two settings is difficult to predict. | 0.8 +| spring.ai.openai.chat.options.frequencyPenalty | Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. | 0.0f +| spring.ai.openai.chat.options.logitBias | Modify the likelihood of specified tokens appearing in the completion. | - +| spring.ai.openai.chat.options.maxTokens | The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length. | - +| spring.ai.openai.chat.options.n | How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs. | 1 +| spring.ai.openai.chat.options.presencePenalty | Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. | - +| spring.ai.openai.chat.options.responseFormat | An object specifying the format that the model must output. Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.| - +| spring.ai.openai.chat.options.seed | This feature is in Beta. If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. | - +| spring.ai.openai.chat.options.stop | Up to 4 sequences where the API will stop generating further tokens. | - +| spring.ai.openai.chat.options.topP | An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. | - +| spring.ai.openai.chat.options.tools | A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. | - +| spring.ai.openai.chat.options.toolChoice | Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces the model to call that function. none is the default when no functions are present. auto is the default if functions are present. | - +| spring.ai.openai.chat.options.user | A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. | - |==== The prefix `spring.ai.openai.embedding` is property prefix that configures the `EmbeddingClient` implementation for OpenAI. @@ -105,7 +124,9 @@ The prefix `spring.ai.openai.embedding` is property prefix that configures the ` | Property | Description | Default | spring.ai.openai.embedding.base-url | Optional overrides the spring.ai.openai.base-url to provide chat specific url | - | spring.ai.openai.embedding.api-key | Optional overrides the spring.ai.openai.api-key to provide chat specific api-key | - -| spring.ai.openai.embedding.model | The model to use | text-embedding-ada-002 +| spring.ai.openai.embedding.options.model | The model to use | text-embedding-ada-002 +| spring.ai.openai.embedding.options.encodingFormat | The format to return the embeddings in. Can be either float or base64. | - +| spring.ai.openai.embedding.options.user | A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. | - |==== NOTE: You can override the common `spring.ai.openai.base-url` and `spring.ai.openai.api-key` for the `ChatClient` and `EmbeddingClient` implementations. @@ -113,4 +134,4 @@ The `spring.ai.openai.chat.base-url` and `spring.ai.openai.chat.api-key` propert Similarly, the `spring.ai.openai.embedding.base-url` and `spring.ai.openai.embedding.api-key` properties if set take precedence over the common properties. This is useful if you want to use different OpenAI accounts for different models and different model endpoints. -Also by default, the `spring.ai.openai.chat.model` is set to `gpt-35-turbo` and the `spring.ai.openai.embedding.model` is set to `text-embedding-ada-002`. +Also by default, the `spring.ai.openai.chat.options.model` is set to `gpt-35-turbo` and the `spring.ai.openai.embedding.options.model` is set to `text-embedding-ada-002`. From a457cd53bf448460fe58c130db5b406f572c79f7 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Mon, 29 Jan 2024 19:43:23 +0100 Subject: [PATCH 6/6] Address review comments --- .../ai/openai/OpenAiChatClient.java | 9 ++---- .../ai/openai/OpenAiEmbeddingClient.java | 7 ++--- .../ai/embedding/EmbeddingOptions.java | 7 +++-- .../ai/model/ModelOptionsUtils.java | 30 ++++++++++++------- 4 files changed, 29 insertions(+), 24 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java index 0cbe81101f3..1b57f10717f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java @@ -60,9 +60,6 @@ public class OpenAiChatClient implements ChatClient, StreamingChatClient { private final Logger logger = LoggerFactory.getLogger(getClass()); - private static final List REQUEST_JSON_FIELD_NAMES = ModelOptionsUtils - .getJsonPropertyValues(ChatCompletionRequest.class); - private OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder() .withModel("gpt-3.5-turbo") .withTemperature(0.7f) @@ -156,14 +153,12 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream); if (this.defaultOptions != null) { - request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class, - REQUEST_JSON_FIELD_NAMES); + request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); } if (prompt.getOptions() != null) { if (prompt.getOptions() instanceof OpenAiChatOptions runtimeOptions) { - request = ModelOptionsUtils.merge(runtimeOptions, request, ChatCompletionRequest.class, - REQUEST_JSON_FIELD_NAMES); + request = ModelOptionsUtils.merge(runtimeOptions, request, ChatCompletionRequest.class); } else { throw new IllegalArgumentException("Prompt options are not of type ChatCompletionRequest:" diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java index 4611957770f..21a657fccba 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java @@ -46,9 +46,6 @@ public class OpenAiEmbeddingClient extends AbstractEmbeddingClient { private static final Logger logger = LoggerFactory.getLogger(OpenAiEmbeddingClient.class); - private static final List REQUEST_JSON_FIELD_NAMES = ModelOptionsUtils - .getJsonPropertyValues(org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest.class); - public static final String DEFAULT_OPENAI_EMBEDDING_MODEL = "text-embedding-ada-002"; private OpenAiEmbeddingOptions defaultOptions = OpenAiEmbeddingOptions.builder() @@ -96,12 +93,12 @@ public EmbeddingResponse call(EmbeddingRequest request) { if (this.defaultOptions != null) { apiRequest = ModelOptionsUtils.merge(apiRequest, this.defaultOptions, - org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest.class, REQUEST_JSON_FIELD_NAMES); + org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest.class); } if (request.getOptions() != null) { apiRequest = ModelOptionsUtils.merge(request.getOptions(), apiRequest, - org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest.class, REQUEST_JSON_FIELD_NAMES); + org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest.class); } EmbeddingList apiEmbeddingResponse = this.openAiApi.embeddings(apiRequest).getBody(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptions.java index c4137cb57ef..c4d5beccb77 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptions.java @@ -23,7 +23,10 @@ */ public interface EmbeddingOptions extends ModelOptions { - public static EmbeddingOptions EMPTY = new EmbeddingOptions() { - }; + public static class EmptyEmbeddingOptions implements EmbeddingOptions { + + } + + public static EmbeddingOptions EMPTY = new EmptyEmbeddingOptions(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java index 267a1ccd960..96ddc1e3fb4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -16,20 +16,22 @@ package org.springframework.ai.model; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - -import org.springframework.util.CollectionUtils; - import java.lang.reflect.Field; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; + +import org.springframework.util.CollectionUtils; + /** * Utility class for manipulating {@link ModelOptions} objects. * @@ -37,7 +39,8 @@ */ public final class ModelOptionsUtils { - private final static ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private final static ObjectMapper OBJECT_MAPPER = new ObjectMapper() + .disable(SerializationFeature.FAIL_ON_EMPTY_BEANS); private ModelOptionsUtils() { @@ -54,6 +57,11 @@ private ModelOptionsUtils() { * @return the merged object represented by the given class. */ public static T merge(Object source, Object target, Class clazz, List acceptedFieldNames) { + + List requestFieldNames = CollectionUtils.isEmpty(acceptedFieldNames) + ? REQUEST_FIELD_NAMES_PER_CLASS.computeIfAbsent(clazz, ModelOptionsUtils::getJsonPropertyValues) + : acceptedFieldNames; + Map sourceMap = objectToMap(source); Map targetMap = objectToMap(target); @@ -62,16 +70,18 @@ public static T merge(Object source, Object target, Class clazz, List e.getValue() != null) .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()))); - if (!CollectionUtils.isEmpty(acceptedFieldNames)) { + if (!CollectionUtils.isEmpty(requestFieldNames)) { targetMap = targetMap.entrySet() .stream() - .filter(e -> acceptedFieldNames.contains(e.getKey())) + .filter(e -> requestFieldNames.contains(e.getKey())) .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); } return mapToClass(targetMap, clazz); } + private static ConcurrentHashMap, List> REQUEST_FIELD_NAMES_PER_CLASS = new ConcurrentHashMap, List>(); + /** * Merges the source object into the target object and returns an object represented * by the given class. The source null values are ignored.