diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 14faaf6f82d..d45d4db18cd 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -59,6 +59,17 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions { * or 100 should result in a ban or exclusive selection of the relevant token. */ private @JsonProperty("logit_bias") Map logitBias; + /** + * Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities + * of each output token returned in the 'content' of 'message'. This option is currently not available + * on the 'gpt-4-vision-preview' model. + */ + private @JsonProperty("logprobs") Boolean logprobs; + /** + * An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, + * each with an associated log probability. 'logprobs' must be set to 'true' if this parameter is used. + */ + private @JsonProperty("top_logprobs") Integer topLogprobs; /** * The maximum number of tokens to generate in the chat completion. The total length of input * tokens and generated tokens is limited by the model's context length. @@ -177,6 +188,16 @@ public Builder withLogitBias(Map logitBias) { return this; } + public Builder withLogprobs(Boolean logprobs) { + this.options.logprobs = logprobs; + return this; + } + + public Builder withTopLogprobs(Integer topLogprobs) { + this.options.topLogprobs = topLogprobs; + return this; + } + public Builder withMaxTokens(Integer maxTokens) { this.options.maxTokens = maxTokens; return this; @@ -279,6 +300,22 @@ public void setLogitBias(Map logitBias) { this.logitBias = logitBias; } + public Boolean getLogprobs() { + return this.logprobs; + } + + public void setLogprobs(Boolean logprobs) { + this.logprobs = logprobs; + } + + public Integer getTopLogprobs() { + return this.topLogprobs; + } + + public void setTopLogprobs(Integer topLogprobs) { + this.topLogprobs = topLogprobs; + } + public Integer getMaxTokens() { return this.maxTokens; } @@ -395,6 +432,8 @@ public int hashCode() { result = prime * result + ((model == null) ? 0 : model.hashCode()); result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode()); result = prime * result + ((logitBias == null) ? 0 : logitBias.hashCode()); + result = prime * result + ((logprobs == null) ? 0 : logprobs.hashCode()); + result = prime * result + ((topLogprobs == null) ? 0 : topLogprobs.hashCode()); result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); result = prime * result + ((n == null) ? 0 : n.hashCode()); result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode()); @@ -436,6 +475,18 @@ else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) } else if (!this.logitBias.equals(other.logitBias)) return false; + if (this.logprobs == null) { + if (other.logprobs != null) + return false; + } + else if (!this.logprobs.equals(other.logprobs)) + return false; + if (this.topLogprobs == null) { + if (other.topLogprobs != null) + return false; + } + else if (!this.topLogprobs.equals(other.topLogprobs)) + return false; if (this.maxTokens == null) { if (other.maxTokens != null) return false; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java index d255853ec55..808c5f3513f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java @@ -118,7 +118,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest> apiRequest = (this.defaultOptions != null) ? new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>(request.getInstructions(), this.defaultOptions.getModel(), this.defaultOptions.getEncodingFormat(), - this.defaultOptions.getUser()) + this.defaultOptions.getDimensions(), this.defaultOptions.getUser()) : new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>(request.getInstructions(), OpenAiApi.DEFAULT_EMBEDDING_MODEL); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java index 59a85aacc32..3f91e12f537 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java @@ -37,6 +37,10 @@ public class OpenAiEmbeddingOptions implements EmbeddingOptions { * The format to return the embeddings in. Can be either float or base64. */ private @JsonProperty("encoding_format") String encodingFormat; + /** + * The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models. + */ + private @JsonProperty("dimensions") Integer dimensions; /** * A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. */ @@ -65,6 +69,11 @@ public Builder withEncodingFormat(String encodingFormat) { return this; } + public Builder withDimensions(Integer dimensions) { + this.options.dimensions = dimensions; + return this; + } + public Builder withUser(String user) { this.options.setUser(user); return this; @@ -92,6 +101,14 @@ public void setEncodingFormat(String encodingFormat) { this.encodingFormat = encodingFormat; } + public Integer getDimensions() { + return dimensions; + } + + public void setDimensions(Integer dimensions) { + this.dimensions = dimensions; + } + public String getUser() { return user; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 5edfdc9e7c9..5bc83eadedf 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -265,6 +265,11 @@ public Function(String description, String name, String jsonSchema) { * Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will * vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 * or 100 should result in a ban or exclusive selection of the relevant token. + * @param logprobs Whether to return log probabilities of the output tokens or not. If true, returns the log + * probabilities of each output token returned in the 'content' of 'message'. This option is currently not available + * on the 'gpt-4-vision-preview' model. + * @param topLogprobs An integer between 0 and 5 specifying the number of most likely tokens to return at each token + * position, each with an associated log probability. 'logprobs' must be set to 'true' if this parameter is used. * @param maxTokens The maximum number of tokens to generate in the chat completion. The total length of input * tokens and generated tokens is limited by the model's context length. * @param n How many chat completion choices to generate for each input message. Note that you will be charged based @@ -302,6 +307,8 @@ public record ChatCompletionRequest ( @JsonProperty("model") String model, @JsonProperty("frequency_penalty") Float frequencyPenalty, @JsonProperty("logit_bias") Map logitBias, + @JsonProperty("logprobs") Boolean logprobs, + @JsonProperty("top_logprobs") Integer topLogprobs, @JsonProperty("max_tokens") Integer maxTokens, @JsonProperty("n") Integer n, @JsonProperty("presence_penalty") Float presencePenalty, @@ -323,7 +330,7 @@ public record ChatCompletionRequest ( * @param temperature What sampling temperature to use, between 0 and 1. */ public ChatCompletionRequest(List messages, String model, Float temperature) { - this(messages, model, null, null, null, null, null, + this(messages, model, null, null, null, null, null, null, null, null, null, null, false, temperature, null, null, null, null); } @@ -338,7 +345,7 @@ public ChatCompletionRequest(List messages, String model, * as they become available, with the stream terminated by a data: [DONE] message. */ public ChatCompletionRequest(List messages, String model, Float temperature, boolean stream) { - this(messages, model, null, null, null, null, null, + this(messages, model, null, null, null, null, null, null, null, null, null, null, stream, temperature, null, null, null, null); } @@ -354,7 +361,7 @@ public ChatCompletionRequest(List messages, String model, */ public ChatCompletionRequest(List messages, String model, List tools, String toolChoice) { - this(messages, model, null, null, null, null, null, + this(messages, model, null, null, null, null, null, null, null, null, null, null, false, 0.8f, null, tools, toolChoice, null); } @@ -368,7 +375,7 @@ public ChatCompletionRequest(List messages, String model, * as they become available, with the stream terminated by a data: [DONE] message. */ public ChatCompletionRequest(List messages, Boolean stream) { - this(messages, null, null, null, null, null, null, + this(messages, null, null, null, null, null, null, null, null, null, null, null, stream, null, null, null, null, null); } @@ -869,6 +876,7 @@ public Embedding(Integer index, List embedding) { * dimensions or less. * @param model ID of the model to use. * @param encodingFormat The format to return the embeddings in. Can be either float or base64. + * @param dimensions The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models. * @param user A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. */ @JsonInclude(Include.NON_NULL) @@ -876,6 +884,7 @@ public record EmbeddingRequest( @JsonProperty("input") T input, @JsonProperty("model") String model, @JsonProperty("encoding_format") String encodingFormat, + @JsonProperty("dimensions") Integer dimensions, @JsonProperty("user") String user) { /** @@ -884,7 +893,7 @@ public record EmbeddingRequest( * @param model ID of the model to use. */ public EmbeddingRequest(T input, String model) { - this(input, model, "float", null); + this(input, model, "float", null, null); } /**