Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.azure.openai;

import java.util.ArrayList;
Expand Down Expand Up @@ -92,6 +93,7 @@
* @author Thomas Vitale
* @author luocongqiu
* @author timostark
* @author Soby Chacko
* @see ChatModel
* @see com.azure.ai.openai.OpenAIClient
*/
Expand Down Expand Up @@ -454,6 +456,18 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
mergedAzureOptions.setModel(fromAzureOptions.getModel() != null ? fromAzureOptions.getModel()
: toSpringAiOptions.getDeploymentName());

mergedAzureOptions
.setSeed(fromAzureOptions.getSeed() != null ? fromAzureOptions.getSeed() : toSpringAiOptions.getSeed());

mergedAzureOptions.setLogprobs((fromAzureOptions.isLogprobs() != null && fromAzureOptions.isLogprobs())
|| (toSpringAiOptions.isLogprobs() != null && toSpringAiOptions.isLogprobs()));

mergedAzureOptions.setTopLogprobs(fromAzureOptions.getTopLogprobs() != null ? fromAzureOptions.getTopLogprobs()
: toSpringAiOptions.getTopLogProbs());

mergedAzureOptions.setEnhancements(fromAzureOptions.getEnhancements() != null
? fromAzureOptions.getEnhancements() : toSpringAiOptions.getEnhancements());

return mergedAzureOptions;
}

Expand Down Expand Up @@ -518,6 +532,22 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
mergedAzureOptions.setResponseFormat(toAzureResponseFormat(fromSpringAiOptions.getResponseFormat()));
}

if (fromSpringAiOptions.getSeed() != null) {
mergedAzureOptions.setSeed(fromSpringAiOptions.getSeed());
}

if (fromSpringAiOptions.isLogprobs() != null) {
mergedAzureOptions.setLogprobs(fromSpringAiOptions.isLogprobs());
}

if (fromSpringAiOptions.getTopLogProbs() != null) {
mergedAzureOptions.setTopLogprobs(fromSpringAiOptions.getTopLogProbs());
}

if (fromSpringAiOptions.getEnhancements() != null) {
mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements());
}

return mergedAzureOptions;
}

Expand Down Expand Up @@ -564,6 +594,19 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
if (fromOptions.getResponseFormat() != null) {
copyOptions.setResponseFormat(fromOptions.getResponseFormat());
}
if (fromOptions.getSeed() != null) {
copyOptions.setSeed(fromOptions.getSeed());
}

copyOptions.setLogprobs(fromOptions.isLogprobs());

if (fromOptions.getTopLogprobs() != null) {
copyOptions.setTopLogprobs(fromOptions.getTopLogprobs());
}

if (fromOptions.getEnhancements() != null) {
copyOptions.setEnhancements(fromOptions.getEnhancements());
}

return copyOptions;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.azure.openai;

import java.util.ArrayList;
Expand All @@ -21,6 +22,7 @@
import java.util.Map;
import java.util.Set;

import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
Expand All @@ -39,6 +41,7 @@
*
* @author Christian Tzolov
* @author Thomas Vitale
* @author Soby Chacko
*/
@JsonInclude(Include.NON_NULL)
public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
Expand Down Expand Up @@ -161,6 +164,37 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
@JsonIgnore
private Set<String> functions = new HashSet<>();

/**
* Seed value for deterministic sampling such that the same seed and parameters return
* the same result.
*/
@JsonProperty(value = "seed")
private Long seed;

/**
* Whether to return log probabilities of the output tokens or not. If true, returns
* the log probabilities of each output token returned in the `content` of `message`.
* This option is currently not available on the `gpt-4-vision-preview` model.
*/
@JsonProperty(value = "log_probs")
private Boolean logprobs;

/*
* An integer between 0 and 5 specifying the number of most likely tokens to return at
* each token position, each with an associated log probability. `logprobs` must be
* set to `true` if this parameter is used.
*/
@JsonProperty(value = "top_log_probs")
private Integer topLogProbs;

/*
* If provided, the configuration options for available Azure OpenAI chat
* enhancements.
*/
@NestedConfigurationProperty
@JsonIgnore
private AzureChatEnhancementConfiguration enhancements;

public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -250,6 +284,30 @@ public Builder withResponseFormat(AzureOpenAiResponseFormat responseFormat) {
return this;
}

public Builder withSeed(Long seed) {
Assert.notNull(seed, "seed must not be null");
this.options.seed = seed;
return this;
}

public Builder withLogprobs(Boolean logprobs) {
Assert.notNull(logprobs, "logprobs must not be null");
this.options.logprobs = logprobs;
return this;
}

public Builder withTopLogprobs(Integer topLogprobs) {
Assert.notNull(topLogprobs, "topLogprobs must not be null");
this.options.topLogProbs = topLogprobs;
return this;
}

public Builder withEnhancements(AzureChatEnhancementConfiguration enhancements) {
Assert.notNull(enhancements, "enhancements must not be null");
this.options.enhancements = enhancements;
return this;
}

public AzureOpenAiChatOptions build() {
return this.options;
}
Expand Down Expand Up @@ -395,6 +453,38 @@ public Integer getTopK() {
return null;
}

public Long getSeed() {
return this.seed;
}

public void setSeed(Long seed) {
this.seed = seed;
}

public Boolean isLogprobs() {
return this.logprobs;
}

public void setLogprobs(Boolean logprobs) {
this.logprobs = logprobs;
}

public Integer getTopLogProbs() {
return this.topLogProbs;
}

public void setTopLogProbs(Integer topLogProbs) {
this.topLogProbs = topLogProbs;
}

public AzureChatEnhancementConfiguration getEnhancements() {
return this.enhancements;
}

public void setEnhancements(AzureChatEnhancementConfiguration enhancements) {
this.enhancements = enhancements;
}

@Override
public AzureOpenAiChatOptions copy() {
return fromOptions(this);
Expand All @@ -413,6 +503,10 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
.withUser(fromOptions.getUser())
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withFunctions(fromOptions.getFunctions())
.withSeed(fromOptions.getSeed())
.withLogprobs(fromOptions.isLogprobs())
.withTopLogprobs(fromOptions.getTopLogProbs())
.withEnhancements(fromOptions.getEnhancements())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,9 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.azure.openai;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration;
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat;
import org.junit.jupiter.api.Test;
Expand All @@ -34,6 +37,7 @@

/**
* @author Christian Tzolov
* @author Soby Chacko
*/
public class AzureChatCompletionsOptionsTests {

Expand All @@ -42,6 +46,9 @@ public void createRequestWithChatOptions() {

OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);

AzureChatEnhancementConfiguration mockAzureChatEnhancementConfiguration = Mockito
.mock(AzureChatEnhancementConfiguration.class);

var defaultOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName("DEFAULT_MODEL")
.withTemperature(66.6)
Expand All @@ -53,6 +60,10 @@ public void createRequestWithChatOptions() {
.withStop(List.of("foo", "bar"))
.withTopP(0.69)
.withUser("user")
.withSeed(123L)
.withLogprobs(true)
.withTopLogprobs(5)
.withEnhancements(mockAzureChatEnhancementConfiguration)
.withResponseFormat(AzureOpenAiResponseFormat.TEXT)
.build();

Expand All @@ -72,8 +83,15 @@ public void createRequestWithChatOptions() {
assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar"));
assertThat(requestOptions.getTopP()).isEqualTo(0.69);
assertThat(requestOptions.getUser()).isEqualTo("user");
assertThat(requestOptions.getSeed()).isEqualTo(123L);
assertThat(requestOptions.isLogprobs()).isTrue();
assertThat(requestOptions.getTopLogprobs()).isEqualTo(5);
assertThat(requestOptions.getEnhancements()).isEqualTo(mockAzureChatEnhancementConfiguration);
assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsTextResponseFormat.class);

AzureChatEnhancementConfiguration anotherMockAzureChatEnhancementConfiguration = Mockito
.mock(AzureChatEnhancementConfiguration.class);

var runtimeOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName("PROMPT_MODEL")
.withTemperature(99.9)
Expand All @@ -85,6 +103,10 @@ public void createRequestWithChatOptions() {
.withStop(List.of("foo", "bar"))
.withTopP(0.111)
.withUser("user2")
.withSeed(1234L)
.withLogprobs(true)
.withTopLogprobs(4)
.withEnhancements(anotherMockAzureChatEnhancementConfiguration)
.withResponseFormat(AzureOpenAiResponseFormat.JSON)
.build();

Expand All @@ -102,6 +124,10 @@ public void createRequestWithChatOptions() {
assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar"));
assertThat(requestOptions.getTopP()).isEqualTo(0.111);
assertThat(requestOptions.getUser()).isEqualTo("user2");
assertThat(requestOptions.getSeed()).isEqualTo(1234L);
assertThat(requestOptions.isLogprobs()).isTrue();
assertThat(requestOptions.getTopLogprobs()).isEqualTo(4);
assertThat(requestOptions.getEnhancements()).isEqualTo(anotherMockAzureChatEnhancementConfiguration);
assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsJsonResponseFormat.class);
}

Expand Down