Skip to content

Commit e44ef40

Browse files
authored
Merge branch 'main' into chat-options-builder
2 parents 1d9a32d + 85580f8 commit e44ef40

File tree

39 files changed

+452
-106
lines changed

39 files changed

+452
-106
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import com.azure.ai.openai.OpenAIAsyncClient;
3030
import com.azure.ai.openai.OpenAIClient;
3131
import com.azure.ai.openai.OpenAIClientBuilder;
32+
import com.azure.ai.openai.implementation.accesshelpers.ChatCompletionsOptionsAccessHelper;
3233
import com.azure.ai.openai.models.ChatChoice;
3334
import com.azure.ai.openai.models.ChatCompletions;
3435
import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall;
@@ -206,7 +207,7 @@ public ChatResponse call(Prompt prompt) {
206207
this.observationRegistry)
207208
.observe(() -> {
208209
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
209-
options.setStream(false);
210+
ChatCompletionsOptionsAccessHelper.setStream(options, false);
210211

211212
ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
212213
ChatResponse chatResponse = toChatResponse(chatCompletions);
@@ -230,7 +231,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
230231

231232
return Flux.deferContextual(contextView -> {
232233
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
233-
options.setStream(true);
234+
ChatCompletionsOptionsAccessHelper.setStream(options, true);
234235

235236
Flux<ChatCompletions> chatCompletionsStream = this.openAIAsyncClient
236237
.getChatCompletionsStream(options.getModel(), options);
@@ -252,10 +253,14 @@ public Flux<ChatResponse> stream(Prompt prompt) {
252253
final Flux<ChatCompletions> accessibleChatCompletionsFlux = chatCompletionsStream
253254
// Note: the first chat completions can be ignored when using Azure OpenAI
254255
// service which is a known service bug.
255-
.filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices()))
256+
// The last element, when using stream_options will contain the usage data
257+
.filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices())
258+
|| chatCompletions.getUsage() != null)
256259
.map(chatCompletions -> {
257-
final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls();
258-
isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty());
260+
if (!chatCompletions.getChoices().isEmpty()) {
261+
final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls();
262+
isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty());
263+
}
259264
return chatCompletions;
260265
})
261266
.windowUntil(chatCompletions -> {
@@ -493,7 +498,13 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
493498
}
494499

495500
ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(fromAzureOptions.getMessages());
496-
mergedAzureOptions.setStream(fromAzureOptions.isStream());
501+
502+
ChatCompletionsOptionsAccessHelper.setStream(mergedAzureOptions,
503+
fromAzureOptions.isStream() != null ? fromAzureOptions.isStream() : false);
504+
505+
ChatCompletionsOptionsAccessHelper.setStreamOptions(mergedAzureOptions,
506+
fromAzureOptions.getStreamOptions() != null ? fromAzureOptions.getStreamOptions()
507+
: toSpringAiOptions.getStreamOptions());
497508

498509
mergedAzureOptions.setMaxTokens((fromAzureOptions.getMaxTokens() != null) ? fromAzureOptions.getMaxTokens()
499510
: toSpringAiOptions.getMaxTokens());
@@ -629,6 +640,15 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
629640
mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements());
630641
}
631642

643+
if (fromSpringAiOptions.getStreamOptions() != null) {
644+
ChatCompletionsOptionsAccessHelper.setStreamOptions(mergedAzureOptions,
645+
fromSpringAiOptions.getStreamOptions());
646+
}
647+
648+
if (fromSpringAiOptions.getEnhancements() != null) {
649+
mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements());
650+
}
651+
632652
return mergedAzureOptions;
633653
}
634654

@@ -640,8 +660,13 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
640660
private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
641661

642662
ChatCompletionsOptions copyOptions = new ChatCompletionsOptions(fromOptions.getMessages());
643-
copyOptions.setStream(fromOptions.isStream());
644663

664+
if (fromOptions.isStream() != null) {
665+
ChatCompletionsOptionsAccessHelper.setStream(copyOptions, fromOptions.isStream());
666+
}
667+
if (fromOptions.getStreamOptions() != null) {
668+
ChatCompletionsOptionsAccessHelper.setStreamOptions(copyOptions, fromOptions.getStreamOptions());
669+
}
645670
if (fromOptions.getMaxTokens() != null) {
646671
copyOptions.setMaxTokens(fromOptions.getMaxTokens());
647672
}

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.Set;
2424

2525
import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
26+
import com.azure.ai.openai.models.ChatCompletionStreamOptions;
2627
import com.fasterxml.jackson.annotation.JsonIgnore;
2728
import com.fasterxml.jackson.annotation.JsonInclude;
2829
import com.fasterxml.jackson.annotation.JsonInclude.Include;
@@ -193,6 +194,9 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions {
193194
@JsonIgnore
194195
private AzureChatEnhancementConfiguration enhancements;
195196

197+
@JsonProperty("stream_options")
198+
private ChatCompletionStreamOptions streamOptions;
199+
196200
@JsonIgnore
197201
private Map<String, Object> toolContext;
198202

@@ -219,6 +223,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
219223
.withTopLogprobs(fromOptions.getTopLogProbs())
220224
.withEnhancements(fromOptions.getEnhancements())
221225
.withToolContext(fromOptions.getToolContext())
226+
.withStreamOptions(fromOptions.getStreamOptions())
222227
.build();
223228
}
224229

@@ -412,6 +417,14 @@ public void setToolContext(Map<String, Object> toolContext) {
412417
this.toolContext = toolContext;
413418
}
414419

420+
public ChatCompletionStreamOptions getStreamOptions() {
421+
return this.streamOptions;
422+
}
423+
424+
public void setStreamOptions(ChatCompletionStreamOptions streamOptions) {
425+
this.streamOptions = streamOptions;
426+
}
427+
415428
@Override
416429
public AzureOpenAiChatOptions copy() {
417430
return fromOptions(this);
@@ -536,6 +549,11 @@ public Builder withToolContext(Map<String, Object> toolContext) {
536549
return this;
537550
}
538551

552+
public Builder withStreamOptions(ChatCompletionStreamOptions streamOptions) {
553+
this.options.streamOptions = streamOptions;
554+
return this;
555+
}
556+
539557
public AzureOpenAiChatOptions build() {
540558
return this.options;
541559
}

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import java.util.stream.Collectors;
2525

2626
import org.junit.jupiter.api.Test;
27-
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2827
import org.junit.jupiter.params.ParameterizedTest;
2928
import org.junit.jupiter.params.provider.ValueSource;
3029
import org.slf4j.Logger;
@@ -51,8 +50,7 @@
5150
import static org.assertj.core.api.Assertions.assertThat;
5251

5352
@SpringBootTest(classes = BedrockConverseTestConfiguration.class)
54-
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
55-
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
53+
@RequiresAwsCredentials
5654
class BedrockConverseChatClientIT {
5755

5856
private static final Logger logger = LoggerFactory.getLogger(BedrockConverseChatClientIT.class);

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import org.junit.jupiter.api.Disabled;
2727
import org.junit.jupiter.api.Test;
28-
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2928
import org.junit.jupiter.params.ParameterizedTest;
3029
import org.junit.jupiter.params.provider.ValueSource;
3130
import org.slf4j.Logger;
@@ -60,8 +59,7 @@
6059
import static org.assertj.core.api.Assertions.assertThat;
6160

6261
@SpringBootTest(classes = BedrockConverseTestConfiguration.class)
63-
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
64-
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
62+
@RequiresAwsCredentials
6563
class BedrockProxyChatModelIT {
6664

6765
private static final Logger logger = LoggerFactory.getLogger(BedrockProxyChatModelIT.class);
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.bedrock;
18+
19+
import java.lang.annotation.ElementType;
20+
import java.lang.annotation.Retention;
21+
import java.lang.annotation.RetentionPolicy;
22+
import java.lang.annotation.Target;
23+
24+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
25+
26+
@Target({ ElementType.TYPE, ElementType.METHOD })
27+
@Retention(RetentionPolicy.RUNTIME)
28+
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
29+
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
30+
@EnabledIfEnvironmentVariable(named = "AWS_SESSION_TOKEN", matches = ".*")
31+
public @interface RequiresAwsCredentials {
32+
33+
// You can add custom properties here if needed
34+
35+
}

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModelIT.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
import com.fasterxml.jackson.databind.ObjectMapper;
2626
import org.junit.jupiter.api.Disabled;
2727
import org.junit.jupiter.api.Test;
28-
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2928
import org.slf4j.Logger;
3029
import org.slf4j.LoggerFactory;
3130
import reactor.core.publisher.Flux;
3231
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
3332
import software.amazon.awssdk.regions.Region;
3433

34+
import org.springframework.ai.bedrock.RequiresAwsCredentials;
3535
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi;
3636
import org.springframework.ai.chat.messages.AssistantMessage;
3737
import org.springframework.ai.chat.messages.Message;
@@ -55,8 +55,7 @@
5555
import static org.assertj.core.api.Assertions.assertThat;
5656

5757
@SpringBootTest
58-
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
59-
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
58+
@RequiresAwsCredentials
6059
class BedrockAnthropicChatModelIT {
6160

6261
private static final Logger logger = LoggerFactory.getLogger(BedrockAnthropicChatModelIT.class);

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222

2323
import com.fasterxml.jackson.databind.ObjectMapper;
2424
import org.junit.jupiter.api.Test;
25-
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2625
import org.slf4j.Logger;
2726
import org.slf4j.LoggerFactory;
2827
import reactor.core.publisher.Flux;
2928
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
3029
import software.amazon.awssdk.regions.Region;
3130

31+
import org.springframework.ai.bedrock.RequiresAwsCredentials;
3232
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatModel;
3333
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest;
3434
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse;
@@ -38,8 +38,7 @@
3838
/**
3939
* @author Christian Tzolov
4040
*/
41-
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
42-
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
41+
@RequiresAwsCredentials
4342
public class AnthropicChatBedrockApiIT {
4443

4544
private final Logger logger = LoggerFactory.getLogger(AnthropicChatBedrockApiIT.class);

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525

2626
import com.fasterxml.jackson.databind.ObjectMapper;
2727
import org.junit.jupiter.api.Test;
28-
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2928
import org.slf4j.Logger;
3029
import org.slf4j.LoggerFactory;
3130
import reactor.core.publisher.Flux;
3231
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
3332
import software.amazon.awssdk.regions.Region;
3433

34+
import org.springframework.ai.bedrock.RequiresAwsCredentials;
3535
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi;
3636
import org.springframework.ai.chat.messages.AssistantMessage;
3737
import org.springframework.ai.chat.messages.Message;
@@ -58,8 +58,7 @@
5858
import static org.assertj.core.api.Assertions.assertThat;
5959

6060
@SpringBootTest
61-
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
62-
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
61+
@RequiresAwsCredentials
6362
class BedrockAnthropic3ChatModelIT {
6463

6564
private static final Logger logger = LoggerFactory.getLogger(BedrockAnthropic3ChatModelIT.class);

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222

2323
import com.fasterxml.jackson.databind.ObjectMapper;
2424
import org.junit.jupiter.api.Test;
25-
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2625
import org.slf4j.Logger;
2726
import org.slf4j.LoggerFactory;
2827
import reactor.core.publisher.Flux;
2928
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
3029
import software.amazon.awssdk.regions.Region;
3130

31+
import org.springframework.ai.bedrock.RequiresAwsCredentials;
3232
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel;
3333
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatRequest;
3434
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse;
@@ -42,8 +42,7 @@
4242
/**
4343
* @author Ben Middleton
4444
*/
45-
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
46-
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
45+
@RequiresAwsCredentials
4746
public class Anthropic3ChatBedrockApiIT {
4847

4948
private final Logger logger = LoggerFactory.getLogger(Anthropic3ChatBedrockApiIT.class);

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
import com.fasterxml.jackson.databind.ObjectMapper;
2626
import org.junit.jupiter.api.Disabled;
2727
import org.junit.jupiter.api.Test;
28-
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2928
import reactor.core.publisher.Flux;
3029
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
3130
import software.amazon.awssdk.regions.Region;
3231

32+
import org.springframework.ai.bedrock.RequiresAwsCredentials;
3333
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
3434
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatModel;
3535
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -54,8 +54,7 @@
5454
import static org.assertj.core.api.Assertions.assertThat;
5555

5656
@SpringBootTest
57-
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
58-
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
57+
@RequiresAwsCredentials
5958
@Disabled("COHERE_COMMAND_V14 is not supported anymore")
6059
class BedrockCohereChatModelIT {
6160

0 commit comments

Comments
 (0)