Skip to content

Commit e7bd8d5

Browse files
garethjevansilayaperumalg
authored andcommitted
feat: allow stream usage to be set for azure openai requests
1 parent 863cf38 commit e7bd8d5

File tree

3 files changed

+51
-8
lines changed

3 files changed

+51
-8
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import com.azure.ai.openai.OpenAIAsyncClient;
3030
import com.azure.ai.openai.OpenAIClient;
3131
import com.azure.ai.openai.OpenAIClientBuilder;
32+
import com.azure.ai.openai.implementation.accesshelpers.ChatCompletionsOptionsAccessHelper;
3233
import com.azure.ai.openai.models.ChatChoice;
3334
import com.azure.ai.openai.models.ChatCompletions;
3435
import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall;
@@ -206,7 +207,7 @@ public ChatResponse call(Prompt prompt) {
206207
this.observationRegistry)
207208
.observe(() -> {
208209
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
209-
options.setStream(false);
210+
ChatCompletionsOptionsAccessHelper.setStream(options, false);
210211

211212
ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
212213
ChatResponse chatResponse = toChatResponse(chatCompletions);
@@ -230,7 +231,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
230231

231232
return Flux.deferContextual(contextView -> {
232233
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
233-
options.setStream(true);
234+
ChatCompletionsOptionsAccessHelper.setStream(options, true);
234235

235236
Flux<ChatCompletions> chatCompletionsStream = this.openAIAsyncClient
236237
.getChatCompletionsStream(options.getModel(), options);
@@ -252,10 +253,14 @@ public Flux<ChatResponse> stream(Prompt prompt) {
252253
final Flux<ChatCompletions> accessibleChatCompletionsFlux = chatCompletionsStream
253254
// Note: the first chat completions can be ignored when using Azure OpenAI
254255
// service which is a known service bug.
255-
.filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices()))
256+
// The last element, when using stream_options will contain the usage data
257+
.filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices())
258+
|| chatCompletions.getUsage() != null)
256259
.map(chatCompletions -> {
257-
final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls();
258-
isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty());
260+
if (!chatCompletions.getChoices().isEmpty()) {
261+
final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls();
262+
isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty());
263+
}
259264
return chatCompletions;
260265
})
261266
.windowUntil(chatCompletions -> {
@@ -493,7 +498,13 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
493498
}
494499

495500
ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(fromAzureOptions.getMessages());
496-
mergedAzureOptions.setStream(fromAzureOptions.isStream());
501+
502+
ChatCompletionsOptionsAccessHelper.setStream(mergedAzureOptions,
503+
fromAzureOptions.isStream() != null ? fromAzureOptions.isStream() : false);
504+
505+
ChatCompletionsOptionsAccessHelper.setStreamOptions(mergedAzureOptions,
506+
fromAzureOptions.getStreamOptions() != null ? fromAzureOptions.getStreamOptions()
507+
: toSpringAiOptions.getStreamOptions());
497508

498509
mergedAzureOptions.setMaxTokens((fromAzureOptions.getMaxTokens() != null) ? fromAzureOptions.getMaxTokens()
499510
: toSpringAiOptions.getMaxTokens());
@@ -629,6 +640,15 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
629640
mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements());
630641
}
631642

643+
if (fromSpringAiOptions.getStreamOptions() != null) {
644+
ChatCompletionsOptionsAccessHelper.setStreamOptions(mergedAzureOptions,
645+
fromSpringAiOptions.getStreamOptions());
646+
}
647+
648+
if (fromSpringAiOptions.getEnhancements() != null) {
649+
mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements());
650+
}
651+
632652
return mergedAzureOptions;
633653
}
634654

@@ -640,8 +660,13 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
640660
private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
641661

642662
ChatCompletionsOptions copyOptions = new ChatCompletionsOptions(fromOptions.getMessages());
643-
copyOptions.setStream(fromOptions.isStream());
644663

664+
if (fromOptions.isStream() != null) {
665+
ChatCompletionsOptionsAccessHelper.setStream(copyOptions, fromOptions.isStream());
666+
}
667+
if (fromOptions.getStreamOptions() != null) {
668+
ChatCompletionsOptionsAccessHelper.setStreamOptions(copyOptions, fromOptions.getStreamOptions());
669+
}
645670
if (fromOptions.getMaxTokens() != null) {
646671
copyOptions.setMaxTokens(fromOptions.getMaxTokens());
647672
}

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.Set;
2424

2525
import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
26+
import com.azure.ai.openai.models.ChatCompletionStreamOptions;
2627
import com.fasterxml.jackson.annotation.JsonIgnore;
2728
import com.fasterxml.jackson.annotation.JsonInclude;
2829
import com.fasterxml.jackson.annotation.JsonInclude.Include;
@@ -193,6 +194,9 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions {
193194
@JsonIgnore
194195
private AzureChatEnhancementConfiguration enhancements;
195196

197+
@JsonProperty("stream_options")
198+
private ChatCompletionStreamOptions streamOptions;
199+
196200
@JsonIgnore
197201
private Map<String, Object> toolContext;
198202

@@ -219,6 +223,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
219223
.withTopLogprobs(fromOptions.getTopLogProbs())
220224
.withEnhancements(fromOptions.getEnhancements())
221225
.withToolContext(fromOptions.getToolContext())
226+
.withStreamOptions(fromOptions.getStreamOptions())
222227
.build();
223228
}
224229

@@ -412,6 +417,14 @@ public void setToolContext(Map<String, Object> toolContext) {
412417
this.toolContext = toolContext;
413418
}
414419

420+
public ChatCompletionStreamOptions getStreamOptions() {
421+
return this.streamOptions;
422+
}
423+
424+
public void setStreamOptions(ChatCompletionStreamOptions streamOptions) {
425+
this.streamOptions = streamOptions;
426+
}
427+
415428
@Override
416429
public AzureOpenAiChatOptions copy() {
417430
return fromOptions(this);
@@ -536,6 +549,11 @@ public Builder withToolContext(Map<String, Object> toolContext) {
536549
return this;
537550
}
538551

552+
public Builder withStreamOptions(ChatCompletionStreamOptions streamOptions) {
553+
this.options.streamOptions = streamOptions;
554+
return this;
555+
}
556+
539557
public AzureOpenAiChatOptions build() {
540558
return this.options;
541559
}

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@
174174
<!-- production dependencies -->
175175
<spring-boot.version>3.3.6</spring-boot.version>
176176
<ST4.version>4.3.4</ST4.version>
177-
<azure-open-ai-client.version>1.0.0-beta.12</azure-open-ai-client.version>
177+
<azure-open-ai-client.version>1.0.0-beta.13</azure-open-ai-client.version>
178178
<jtokkit.version>1.1.0</jtokkit.version>
179179
<victools.version>4.31.1</victools.version>
180180

0 commit comments

Comments
 (0)