Skip to content

Commit 56a41e6

Browse files
committed
GH-889: Align AzureOpenAiChatOptions with Azure ChatCompletionsOptions
Add missing options from Azure ChatCompletionsOptions to Spring AI AzureOpenAiChatOptions. The following fields have been added: - seed - logprobs - topLogprobs - enhancements This change ensures better alignment between the two option sets, improving compatibility and feature parity. Resolves #889
1 parent c67442d commit 56a41e6

File tree

3 files changed

+166
-3
lines changed

3 files changed

+166
-3
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023 - 2024 the original author or authors.
2+
* Copyright 2023-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
1617
package org.springframework.ai.azure.openai;
1718

1819
import java.util.ArrayList;
@@ -92,6 +93,7 @@
9293
* @author Thomas Vitale
9394
* @author luocongqiu
9495
* @author timostark
96+
* @author Soby Chacko
9597
* @see ChatModel
9698
* @see com.azure.ai.openai.OpenAIClient
9799
*/
@@ -454,6 +456,18 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
454456
mergedAzureOptions.setModel(fromAzureOptions.getModel() != null ? fromAzureOptions.getModel()
455457
: toSpringAiOptions.getDeploymentName());
456458

459+
mergedAzureOptions
460+
.setSeed(fromAzureOptions.getSeed() != null ? fromAzureOptions.getSeed() : toSpringAiOptions.getSeed());
461+
462+
mergedAzureOptions.setLogprobs((fromAzureOptions.isLogprobs() != null && fromAzureOptions.isLogprobs())
463+
|| (toSpringAiOptions.isLogprobs() != null && toSpringAiOptions.isLogprobs()));
464+
465+
mergedAzureOptions.setTopLogprobs(fromAzureOptions.getTopLogprobs() != null ? fromAzureOptions.getTopLogprobs()
466+
: toSpringAiOptions.getTopLogProbs());
467+
468+
mergedAzureOptions.setEnhancements(fromAzureOptions.getEnhancements() != null
469+
? fromAzureOptions.getEnhancements() : toSpringAiOptions.getEnhancements());
470+
457471
return mergedAzureOptions;
458472
}
459473

@@ -518,6 +532,22 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
518532
mergedAzureOptions.setResponseFormat(toAzureResponseFormat(fromSpringAiOptions.getResponseFormat()));
519533
}
520534

535+
if (fromSpringAiOptions.getSeed() != null) {
536+
mergedAzureOptions.setSeed(fromSpringAiOptions.getSeed());
537+
}
538+
539+
if (fromSpringAiOptions.isLogprobs() != null) {
540+
mergedAzureOptions.setLogprobs(fromSpringAiOptions.isLogprobs());
541+
}
542+
543+
if (fromSpringAiOptions.getTopLogProbs() != null) {
544+
mergedAzureOptions.setTopLogprobs(fromSpringAiOptions.getTopLogProbs());
545+
}
546+
547+
if (fromSpringAiOptions.getEnhancements() != null) {
548+
mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements());
549+
}
550+
521551
return mergedAzureOptions;
522552
}
523553

@@ -564,6 +594,19 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
564594
if (fromOptions.getResponseFormat() != null) {
565595
copyOptions.setResponseFormat(fromOptions.getResponseFormat());
566596
}
597+
if (fromOptions.getSeed() != null) {
598+
copyOptions.setSeed(fromOptions.getSeed());
599+
}
600+
601+
copyOptions.setLogprobs(fromOptions.isLogprobs());
602+
603+
if (fromOptions.getTopLogprobs() != null) {
604+
copyOptions.setTopLogprobs(fromOptions.getTopLogprobs());
605+
}
606+
607+
if (fromOptions.getEnhancements() != null) {
608+
copyOptions.setEnhancements(fromOptions.getEnhancements());
609+
}
567610

568611
return copyOptions;
569612
}

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023 - 2024 the original author or authors.
2+
* Copyright 2023-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
1617
package org.springframework.ai.azure.openai;
1718

1819
import java.util.ArrayList;
@@ -21,6 +22,7 @@
2122
import java.util.Map;
2223
import java.util.Set;
2324

25+
import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
2426
import com.fasterxml.jackson.annotation.JsonIgnore;
2527
import com.fasterxml.jackson.annotation.JsonInclude;
2628
import com.fasterxml.jackson.annotation.JsonInclude.Include;
@@ -39,6 +41,7 @@
3941
*
4042
* @author Christian Tzolov
4143
* @author Thomas Vitale
44+
* @author Soby Chacko
4245
*/
4346
@JsonInclude(Include.NON_NULL)
4447
public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
@@ -161,6 +164,37 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
161164
@JsonIgnore
162165
private Set<String> functions = new HashSet<>();
163166

167+
/**
168+
* Seed value for deterministic sampling such that the same seed and parameters return
169+
* the same result.
170+
*/
171+
@JsonProperty(value = "seed")
172+
private Long seed;
173+
174+
/**
175+
* Whether to return log probabilities of the output tokens or not. If true, returns
176+
* the log probabilities of each output token returned in the `content` of `message`.
177+
* This option is currently not available on the `gpt-4-vision-preview` model.
178+
*/
179+
@JsonProperty(value = "log_probs")
180+
private Boolean logprobs;
181+
182+
/*
183+
* An integer between 0 and 5 specifying the number of most likely tokens to return at
184+
* each token position, each with an associated log probability. `logprobs` must be
185+
* set to `true` if this parameter is used.
186+
*/
187+
@JsonProperty(value = "top_log_probs")
188+
private Integer topLogProbs;
189+
190+
/*
191+
* If provided, the configuration options for available Azure OpenAI chat
192+
* enhancements.
193+
*/
194+
@NestedConfigurationProperty
195+
@JsonIgnore
196+
private AzureChatEnhancementConfiguration enhancements;
197+
164198
public static Builder builder() {
165199
return new Builder();
166200
}
@@ -250,6 +284,30 @@ public Builder withResponseFormat(AzureOpenAiResponseFormat responseFormat) {
250284
return this;
251285
}
252286

287+
public Builder withSeed(Long seed) {
288+
Assert.notNull(seed, "seed must not be null");
289+
this.options.seed = seed;
290+
return this;
291+
}
292+
293+
public Builder withLogprobs(Boolean logprobs) {
294+
Assert.notNull(logprobs, "logprobs must not be null");
295+
this.options.logprobs = logprobs;
296+
return this;
297+
}
298+
299+
public Builder withTopLogprobs(Integer topLogprobs) {
300+
Assert.notNull(topLogprobs, "topLogprobs must not be null");
301+
this.options.topLogProbs = topLogprobs;
302+
return this;
303+
}
304+
305+
public Builder withEnhancements(AzureChatEnhancementConfiguration enhancements) {
306+
Assert.notNull(enhancements, "enhancements must not be null");
307+
this.options.enhancements = enhancements;
308+
return this;
309+
}
310+
253311
public AzureOpenAiChatOptions build() {
254312
return this.options;
255313
}
@@ -395,6 +453,38 @@ public Integer getTopK() {
395453
return null;
396454
}
397455

456+
public Long getSeed() {
457+
return this.seed;
458+
}
459+
460+
public void setSeed(Long seed) {
461+
this.seed = seed;
462+
}
463+
464+
public Boolean isLogprobs() {
465+
return this.logprobs;
466+
}
467+
468+
public void setLogprobs(Boolean logprobs) {
469+
this.logprobs = logprobs;
470+
}
471+
472+
public Integer getTopLogProbs() {
473+
return this.topLogProbs;
474+
}
475+
476+
public void setTopLogProbs(Integer topLogProbs) {
477+
this.topLogProbs = topLogProbs;
478+
}
479+
480+
public AzureChatEnhancementConfiguration getEnhancements() {
481+
return this.enhancements;
482+
}
483+
484+
public void setEnhancements(AzureChatEnhancementConfiguration enhancements) {
485+
this.enhancements = enhancements;
486+
}
487+
398488
@Override
399489
public AzureOpenAiChatOptions copy() {
400490
return fromOptions(this);
@@ -413,6 +503,10 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
413503
.withUser(fromOptions.getUser())
414504
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
415505
.withFunctions(fromOptions.getFunctions())
506+
.withSeed(fromOptions.getSeed())
507+
.withLogprobs(fromOptions.isLogprobs())
508+
.withTopLogprobs(fromOptions.getTopLogProbs())
509+
.withEnhancements(fromOptions.getEnhancements())
416510
.build();
417511
}
418512

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023 - 2024 the original author or authors.
2+
* Copyright 2023-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -13,9 +13,12 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
1617
package org.springframework.ai.azure.openai;
1718

1819
import com.azure.ai.openai.OpenAIClient;
20+
import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
21+
import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration;
1922
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
2023
import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat;
2124
import org.junit.jupiter.api.Test;
@@ -34,6 +37,7 @@
3437

3538
/**
3639
* @author Christian Tzolov
40+
* @author Soby Chacko
3741
*/
3842
public class AzureChatCompletionsOptionsTests {
3943

@@ -42,6 +46,9 @@ public void createRequestWithChatOptions() {
4246

4347
OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);
4448

49+
AzureChatEnhancementConfiguration mockAzureChatEnhancementConfiguration = Mockito
50+
.mock(AzureChatEnhancementConfiguration.class);
51+
4552
var defaultOptions = AzureOpenAiChatOptions.builder()
4653
.withDeploymentName("DEFAULT_MODEL")
4754
.withTemperature(66.6)
@@ -53,6 +60,10 @@ public void createRequestWithChatOptions() {
5360
.withStop(List.of("foo", "bar"))
5461
.withTopP(0.69)
5562
.withUser("user")
63+
.withSeed(123L)
64+
.withLogprobs(true)
65+
.withTopLogprobs(5)
66+
.withEnhancements(mockAzureChatEnhancementConfiguration)
5667
.withResponseFormat(AzureOpenAiResponseFormat.TEXT)
5768
.build();
5869

@@ -72,8 +83,15 @@ public void createRequestWithChatOptions() {
7283
assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar"));
7384
assertThat(requestOptions.getTopP()).isEqualTo(0.69);
7485
assertThat(requestOptions.getUser()).isEqualTo("user");
86+
assertThat(requestOptions.getSeed()).isEqualTo(123L);
87+
assertThat(requestOptions.isLogprobs()).isTrue();
88+
assertThat(requestOptions.getTopLogprobs()).isEqualTo(5);
89+
assertThat(requestOptions.getEnhancements()).isEqualTo(mockAzureChatEnhancementConfiguration);
7590
assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsTextResponseFormat.class);
7691

92+
AzureChatEnhancementConfiguration anotherMockAzureChatEnhancementConfiguration = Mockito
93+
.mock(AzureChatEnhancementConfiguration.class);
94+
7795
var runtimeOptions = AzureOpenAiChatOptions.builder()
7896
.withDeploymentName("PROMPT_MODEL")
7997
.withTemperature(99.9)
@@ -85,6 +103,10 @@ public void createRequestWithChatOptions() {
85103
.withStop(List.of("foo", "bar"))
86104
.withTopP(0.111)
87105
.withUser("user2")
106+
.withSeed(1234L)
107+
.withLogprobs(true)
108+
.withTopLogprobs(4)
109+
.withEnhancements(anotherMockAzureChatEnhancementConfiguration)
88110
.withResponseFormat(AzureOpenAiResponseFormat.JSON)
89111
.build();
90112

@@ -102,6 +124,10 @@ public void createRequestWithChatOptions() {
102124
assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar"));
103125
assertThat(requestOptions.getTopP()).isEqualTo(0.111);
104126
assertThat(requestOptions.getUser()).isEqualTo("user2");
127+
assertThat(requestOptions.getSeed()).isEqualTo(1234L);
128+
assertThat(requestOptions.isLogprobs()).isTrue();
129+
assertThat(requestOptions.getTopLogprobs()).isEqualTo(4);
130+
assertThat(requestOptions.getEnhancements()).isEqualTo(anotherMockAzureChatEnhancementConfiguration);
105131
assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsJsonResponseFormat.class);
106132
}
107133

0 commit comments

Comments
 (0)