Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,17 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
* or 100 should result in a ban or exclusive selection of the relevant token.
*/
private @JsonProperty("logit_bias") Map<String, Integer> logitBias;
/**
* Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities
* of each output token returned in the 'content' of 'message'. This option is currently not available
* on the 'gpt-4-vision-preview' model.
*/
private @JsonProperty("logprobs") Boolean logprobs;
/**
* An integer between 0 and 5 specifying the number of most likely tokens to return at each token position,
* each with an associated log probability. 'logprobs' must be set to 'true' if this parameter is used.
*/
private @JsonProperty("top_logprobs") Integer topLogprobs;
/**
* The maximum number of tokens to generate in the chat completion. The total length of input
* tokens and generated tokens is limited by the model's context length.
Expand Down Expand Up @@ -177,6 +188,16 @@ public Builder withLogitBias(Map<String, Integer> logitBias) {
return this;
}

public Builder withLogprobs(Boolean logprobs) {
this.options.logprobs = logprobs;
return this;
}

public Builder withTopLogprobs(Integer topLogprobs) {
this.options.topLogprobs = topLogprobs;
return this;
}

public Builder withMaxTokens(Integer maxTokens) {
this.options.maxTokens = maxTokens;
return this;
Expand Down Expand Up @@ -279,6 +300,22 @@ public void setLogitBias(Map<String, Integer> logitBias) {
this.logitBias = logitBias;
}

public Boolean getLogprobs() {
return this.logprobs;
}

public void setLogprobs(Boolean logprobs) {
this.logprobs = logprobs;
}

public Integer getTopLogprobs() {
return this.topLogprobs;
}

public void setTopLogprobs(Integer topLogprobs) {
this.topLogprobs = topLogprobs;
}

public Integer getMaxTokens() {
return this.maxTokens;
}
Expand Down Expand Up @@ -395,6 +432,8 @@ public int hashCode() {
result = prime * result + ((model == null) ? 0 : model.hashCode());
result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode());
result = prime * result + ((logitBias == null) ? 0 : logitBias.hashCode());
result = prime * result + ((logprobs == null) ? 0 : logprobs.hashCode());
result = prime * result + ((topLogprobs == null) ? 0 : topLogprobs.hashCode());
result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode());
result = prime * result + ((n == null) ? 0 : n.hashCode());
result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode());
Expand Down Expand Up @@ -436,6 +475,18 @@ else if (!this.frequencyPenalty.equals(other.frequencyPenalty))
}
else if (!this.logitBias.equals(other.logitBias))
return false;
if (this.logprobs == null) {
if (other.logprobs != null)
return false;
}
else if (!this.logprobs.equals(other.logprobs))
return false;
if (this.topLogprobs == null) {
if (other.topLogprobs != null)
return false;
}
else if (!this.topLogprobs.equals(other.topLogprobs))
return false;
if (this.maxTokens == null) {
if (other.maxTokens != null)
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public EmbeddingResponse call(EmbeddingRequest request) {
org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<List<String>> apiRequest = (this.defaultOptions != null)
? new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>(request.getInstructions(),
this.defaultOptions.getModel(), this.defaultOptions.getEncodingFormat(),
this.defaultOptions.getUser())
this.defaultOptions.getDimensions(), this.defaultOptions.getUser())
: new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>(request.getInstructions(),
OpenAiApi.DEFAULT_EMBEDDING_MODEL);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ public class OpenAiEmbeddingOptions implements EmbeddingOptions {
* The format to return the embeddings in. Can be either float or base64.
*/
private @JsonProperty("encoding_format") String encodingFormat;
/**
* The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
*/
private @JsonProperty("dimensions") Integer dimensions;
/**
* A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
*/
Expand Down Expand Up @@ -65,6 +69,11 @@ public Builder withEncodingFormat(String encodingFormat) {
return this;
}

public Builder withDimensions(Integer dimensions) {
this.options.dimensions = dimensions;
return this;
}

public Builder withUser(String user) {
this.options.setUser(user);
return this;
Expand Down Expand Up @@ -92,6 +101,14 @@ public void setEncodingFormat(String encodingFormat) {
this.encodingFormat = encodingFormat;
}

public Integer getDimensions() {
return dimensions;
}

public void setDimensions(Integer dimensions) {
this.dimensions = dimensions;
}

public String getUser() {
return user;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ public Function(String description, String name, String jsonSchema) {
* Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will
* vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100
* or 100 should result in a ban or exclusive selection of the relevant token.
* @param logprobs Whether to return log probabilities of the output tokens or not. If true, returns the log
* probabilities of each output token returned in the 'content' of 'message'. This option is currently not available
* on the 'gpt-4-vision-preview' model.
* @param topLogprobs An integer between 0 and 5 specifying the number of most likely tokens to return at each token
* position, each with an associated log probability. 'logprobs' must be set to 'true' if this parameter is used.
* @param maxTokens The maximum number of tokens to generate in the chat completion. The total length of input
* tokens and generated tokens is limited by the model's context length.
* @param n How many chat completion choices to generate for each input message. Note that you will be charged based
Expand Down Expand Up @@ -302,6 +307,8 @@ public record ChatCompletionRequest (
@JsonProperty("model") String model,
@JsonProperty("frequency_penalty") Float frequencyPenalty,
@JsonProperty("logit_bias") Map<String, Integer> logitBias,
@JsonProperty("logprobs") Boolean logprobs,
@JsonProperty("top_logprobs") Integer topLogprobs,
@JsonProperty("max_tokens") Integer maxTokens,
@JsonProperty("n") Integer n,
@JsonProperty("presence_penalty") Float presencePenalty,
Expand All @@ -323,7 +330,7 @@ public record ChatCompletionRequest (
* @param temperature What sampling temperature to use, between 0 and 1.
*/
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Float temperature) {
this(messages, model, null, null, null, null, null,
this(messages, model, null, null, null, null, null, null, null,
null, null, null, false, temperature, null,
null, null, null);
}
Expand All @@ -338,7 +345,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
* as they become available, with the stream terminated by a data: [DONE] message.
*/
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Float temperature, boolean stream) {
this(messages, model, null, null, null, null, null,
this(messages, model, null, null, null, null, null, null, null,
null, null, null, stream, temperature, null,
null, null, null);
}
Expand All @@ -354,7 +361,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
*/
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
List<FunctionTool> tools, String toolChoice) {
this(messages, model, null, null, null, null, null,
this(messages, model, null, null, null, null, null, null, null,
null, null, null, false, 0.8f, null,
tools, toolChoice, null);
}
Expand All @@ -368,7 +375,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
* as they become available, with the stream terminated by a data: [DONE] message.
*/
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
this(messages, null, null, null, null, null, null,
this(messages, null, null, null, null, null, null, null, null,
null, null, null, stream, null, null,
null, null, null);
}
Expand Down Expand Up @@ -869,13 +876,15 @@ public Embedding(Integer index, List<Double> embedding) {
* dimensions or less.
* @param model ID of the model to use.
* @param encodingFormat The format to return the embeddings in. Can be either float or base64.
* @param dimensions The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
* @param user A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
*/
@JsonInclude(Include.NON_NULL)
public record EmbeddingRequest<T>(
@JsonProperty("input") T input,
@JsonProperty("model") String model,
@JsonProperty("encoding_format") String encodingFormat,
@JsonProperty("dimensions") Integer dimensions,
@JsonProperty("user") String user) {

/**
Expand All @@ -884,7 +893,7 @@ public record EmbeddingRequest<T>(
* @param model ID of the model to use.
*/
public EmbeddingRequest(T input, String model) {
this(input, model, "float", null);
this(input, model, "float", null, null);
}

/**
Expand Down