diff --git a/README.md b/README.md
index 50e4facc713..1c53af2bdc9 100644
--- a/README.md
+++ b/README.md
@@ -84,7 +84,7 @@ You can find more details in the [Reference Documentation](https://docs.spring.i
Spring AI supports many AI models. For an overview see here. Specific models currently supported are
* OpenAI
* Azure OpenAI
-* Amazon Bedrock (Anthropic, Llama2, Cohere, Titan, Jurassic2)
+* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2)
* HuggingFace
* Google VertexAI (PaLM2, Gemini)
* Mistral AI
diff --git a/models/spring-ai-bedrock/README.md b/models/spring-ai-bedrock/README.md
index 94311055fba..19e48518a60 100644
--- a/models/spring-ai-bedrock/README.md
+++ b/models/spring-ai-bedrock/README.md
@@ -4,7 +4,7 @@
- [Anthropic2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-anthropic.html)
- [Cohere Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-cohere.html)
- [Cohere Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/bedrock-cohere-embedding.html)
-- [Llama2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-llama2.html)
+- [Llama Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-llama.html)
- [Titan Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-titan.html)
- [Titan Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/bedrock-titan-embedding.html)
- [Jurassic2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-jurassic2.html)
diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java
index 2437b35eebb..55a2d80af50 100644
--- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java
+++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java
@@ -24,6 +24,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.Region;
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest;
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse;
@@ -32,6 +33,7 @@
/**
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
// @formatter:off
@@ -92,6 +94,20 @@ public AnthropicChatBedrockApi(String modelId, AwsCredentialsProvider credential
super(modelId, credentialsProvider, region, objectMapper, timeout);
}
+ /**
+ * Create a new AnthropicChatBedrockApi instance using the provided credentials provider, region and object mapper.
+ *
+ * @param modelId The model id to use. See the {@link AnthropicChatModel} for the supported models.
+ * @param credentialsProvider The credentials provider to connect to AWS.
+ * @param region The AWS region to use.
+ * @param objectMapper The object mapper to use for JSON serialization and deserialization.
+ * @param timeout The timeout to use.
+ */
+ public AnthropicChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
+ ObjectMapper objectMapper, Duration timeout) {
+ super(modelId, credentialsProvider, region, objectMapper, timeout);
+ }
+
// https://github.com/build-on-aws/amazon-bedrock-java-examples/blob/main/example_code/bedrock-runtime/src/main/java/aws/community/examples/InvokeBedrockStreamingAsync.java
// https://docs.anthropic.com/claude/reference/complete_post
diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java
index e76bfcbeff3..0148b498373 100644
--- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java
+++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java
@@ -26,6 +26,7 @@
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.Region;
import java.time.Duration;
import java.util.List;
@@ -39,6 +40,7 @@
*
* @author Ben Middleton
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 1.0.0
*/
// @formatter:off
@@ -96,6 +98,20 @@ public Anthropic3ChatBedrockApi(String modelId, AwsCredentialsProvider credentia
super(modelId, credentialsProvider, region, objectMapper, timeout);
}
+ /**
+ * Create a new AnthropicChatBedrockApi instance using the provided credentials provider, region and object mapper.
+ *
+ * @param modelId The model id to use. See the {@link AnthropicChatModel} for the supported models.
+ * @param credentialsProvider The credentials provider to connect to AWS.
+ * @param region The AWS region to use.
+ * @param objectMapper The object mapper to use for JSON serialization and deserialization.
+ * @param timeout The timeout to use.
+ */
+ public Anthropic3ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
+ ObjectMapper objectMapper, Duration timeout) {
+ super(modelId, credentialsProvider, region, objectMapper, timeout);
+ }
+
// https://github.com/build-on-aws/amazon-bedrock-java-examples/blob/main/example_code/bedrock-runtime/src/main/java/aws/community/examples/InvokeBedrockStreamingAsync.java
// https://docs.anthropic.com/claude/reference/complete_post
@@ -441,7 +457,11 @@ public enum AnthropicChatModel {
/**
* anthropic.claude-3-haiku-20240307-v1:0
*/
- CLAUDE_V3_HAIKU("anthropic.claude-3-haiku-20240307-v1:0");
+ CLAUDE_V3_HAIKU("anthropic.claude-3-haiku-20240307-v1:0"),
+ /**
+ * anthropic.claude-3-opus-20240229-v1:0
+ */
+ CLAUDE_V3_OPUS("anthropic.claude-3-opus-20240229-v1:0");
private final String id;
diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java
index 8ed33139b7d..7db24b3b8c6 100644
--- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java
+++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java
@@ -25,9 +25,10 @@
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi;
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi;
-import org.springframework.ai.bedrock.llama2.BedrockLlama2ChatOptions;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
+import org.springframework.ai.bedrock.llama.BedrockLlamaChatOptions;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi;
import org.springframework.ai.bedrock.titan.BedrockTitanChatOptions;
+import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingOptions;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
import org.springframework.aot.hint.MemberCategory;
@@ -43,6 +44,7 @@
* @author Josh Long
* @author Christian Tzolov
* @author Mark Pollack
+ * @author Wei Jiang
*/
public class BedrockRuntimeHints implements RuntimeHintsRegistrar {
@@ -63,15 +65,17 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
for (var tr : findJsonAnnotatedClassesInPackage(BedrockCohereEmbeddingOptions.class))
hints.reflection().registerType(tr, mcs);
- for (var tr : findJsonAnnotatedClassesInPackage(Llama2ChatBedrockApi.class))
+ for (var tr : findJsonAnnotatedClassesInPackage(LlamaChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
- for (var tr : findJsonAnnotatedClassesInPackage(BedrockLlama2ChatOptions.class))
+ for (var tr : findJsonAnnotatedClassesInPackage(BedrockLlamaChatOptions.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClassesInPackage(TitanChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClassesInPackage(BedrockTitanChatOptions.class))
hints.reflection().registerType(tr, mcs);
+ for (var tr : findJsonAnnotatedClassesInPackage(BedrockTitanEmbeddingOptions.class))
+ hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClassesInPackage(TitanEmbeddingBedrockApi.class))
hints.reflection().registerType(tr, mcs);
diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java
index 74f4249f04f..17e0eed79b7 100644
--- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java
+++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java
@@ -61,6 +61,7 @@
* @see Model Parameters
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
public abstract class AbstractBedrockApi {
@@ -69,7 +70,7 @@ public abstract class AbstractBedrockApi {
private final String modelId;
private final ObjectMapper objectMapper;
- private final String region;
+ private final Region region;
private final BedrockRuntimeClient client;
private final BedrockRuntimeAsyncClient clientStreaming;
@@ -93,7 +94,7 @@ public AbstractBedrockApi(String modelId, String region, Duration timeout) {
this(modelId, ProfileCredentialsProvider.builder().build(), region, ModelOptionsUtils.OBJECT_MAPPER, timeout);
}
- /**
+ /**
* Create a new AbstractBedrockApi instance using the provided credentials provider, region and object mapper.
*
* @param modelId The model id to use.
@@ -105,6 +106,7 @@ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProv
ObjectMapper objectMapper) {
this(modelId, credentialsProvider, region, objectMapper, Duration.ofMinutes(5));
}
+
/**
* Create a new AbstractBedrockApi instance using the provided credentials provider, region and object mapper.
*
@@ -118,10 +120,26 @@ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProv
*/
public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region,
ObjectMapper objectMapper, Duration timeout) {
+ this(modelId, credentialsProvider, Region.of(region), objectMapper, timeout);
+ }
+
+ /**
+ * Create a new AbstractBedrockApi instance using the provided credentials provider, region and object mapper.
+ *
+ * @param modelId The model id to use.
+ * @param credentialsProvider The credentials provider to connect to AWS.
+ * @param region The AWS region to use.
+ * @param objectMapper The object mapper to use for JSON serialization and deserialization.
+ * @param timeout Configure the amount of time to allow the client to complete the execution of an API call.
+ * This timeout covers the entire client execution except for marshalling. This includes request handler execution,
+ * all HTTP requests including retries, unmarshalling, etc. This value should always be positive, if present.
+ */
+ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
+ ObjectMapper objectMapper, Duration timeout) {
Assert.hasText(modelId, "Model id must not be empty");
Assert.notNull(credentialsProvider, "Credentials provider must not be null");
- Assert.hasText(region, "Region must not be empty");
+ Assert.notNull(region, "Region must not be empty");
Assert.notNull(objectMapper, "Object mapper must not be null");
Assert.notNull(timeout, "Timeout must not be null");
@@ -131,13 +149,13 @@ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProv
this.client = BedrockRuntimeClient.builder()
- .region(Region.of(this.region))
+ .region(this.region)
.credentialsProvider(credentialsProvider)
.overrideConfiguration(c -> c.apiCallTimeout(timeout))
.build();
this.clientStreaming = BedrockRuntimeAsyncClient.builder()
- .region(Region.of(this.region))
+ .region(this.region)
.credentialsProvider(credentialsProvider)
.overrideConfiguration(c -> c.apiCallTimeout(timeout))
.build();
@@ -153,7 +171,7 @@ public String getModelId() {
/**
* @return The AWS region.
*/
- public String getRegion() {
+ public Region getRegion() {
return this.region;
}
diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java
index b3b02b6993f..5b133a9976c 100644
--- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java
+++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java
@@ -25,6 +25,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.Region;
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest;
@@ -36,6 +37,7 @@
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere.html
*
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
public class CohereChatBedrockApi extends
@@ -91,6 +93,20 @@ public CohereChatBedrockApi(String modelId, AwsCredentialsProvider credentialsPr
super(modelId, credentialsProvider, region, objectMapper, timeout);
}
+ /**
+ * Create a new CohereChatBedrockApi instance using the provided credentials provider, region and object mapper.
+ *
+ * @param modelId The model id to use. See the {@link CohereChatModel} for the supported models.
+ * @param credentialsProvider The credentials provider to connect to AWS.
+ * @param region The AWS region to use.
+ * @param objectMapper The object mapper to use for JSON serialization and deserialization.
+ * @param timeout The timeout to use.
+ */
+ public CohereChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
+ ObjectMapper objectMapper, Duration timeout) {
+ super(modelId, credentialsProvider, region, objectMapper, timeout);
+ }
+
/**
* CohereChatRequest encapsulates the request parameters for the Cohere command model.
*
diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java
index 7d0fa442cde..13752cc4486 100644
--- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java
+++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java
@@ -24,6 +24,7 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.Region;
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingRequest;
@@ -34,6 +35,7 @@
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere.html#model-parameters-embed
*
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
public class CohereEmbeddingBedrockApi extends
@@ -91,6 +93,21 @@ public CohereEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credenti
super(modelId, credentialsProvider, region, objectMapper, timeout);
}
+ /**
+ * Create a new CohereEmbeddingBedrockApi instance using the provided credentials provider, region and object
+ * mapper.
+ *
+ * @param modelId The model id to use. See the {@link CohereEmbeddingModel} for the supported models.
+ * @param credentialsProvider The credentials provider to connect to AWS.
+ * @param region The AWS region to use.
+ * @param objectMapper The object mapper to use for JSON serialization and deserialization.
+ * @param timeout The timeout to use.
+ */
+ public CohereEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
+ ObjectMapper objectMapper, Duration timeout) {
+ super(modelId, credentialsProvider, region, objectMapper, timeout);
+ }
+
/**
* The Cohere Embed model request.
*
diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java
index fa505176350..0ec58c8bd2c 100644
--- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java
+++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java
@@ -29,12 +29,14 @@
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatResponse;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.Region;
/**
* Java client for the Bedrock Jurassic2 chat model.
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html
*
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
public class Ai21Jurassic2ChatBedrockApi extends
@@ -92,6 +94,20 @@ public Ai21Jurassic2ChatBedrockApi(String modelId, AwsCredentialsProvider creden
super(modelId, credentialsProvider, region, objectMapper, timeout);
}
+ /**
+ * Create a new Ai21Jurassic2ChatBedrockApi instance.
+ *
+ * @param modelId The model id to use. See the {@link Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatModel} for the supported models.
+ * @param credentialsProvider The credentials provider to connect to AWS.
+ * @param region The AWS region to use.
+ * @param objectMapper The object mapper to use for JSON serialization and deserialization.
+ * @param timeout The timeout to use.
+ */
+ public Ai21Jurassic2ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
+ ObjectMapper objectMapper, Duration timeout) {
+ super(modelId, credentialsProvider, region, objectMapper, timeout);
+ }
+
/**
* AI21 Jurassic2 chat request parameters.
*
diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatClient.java
similarity index 67%
rename from models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java
rename to models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatClient.java
index a12fe0b6eb9..c1be58e5a1d 100644
--- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java
+++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatClient.java
@@ -13,16 +13,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.springframework.ai.bedrock.llama2;
+package org.springframework.ai.bedrock.llama;
import java.util.List;
import reactor.core.publisher.Flux;
import org.springframework.ai.bedrock.MessageToPromptConverter;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatRequest;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatResponse;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.ChatResponse;
@@ -35,26 +35,27 @@
import org.springframework.util.Assert;
/**
- * Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama2 chat
+ * Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama chat
* generative.
*
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
-public class BedrockLlama2ChatClient implements ChatClient, StreamingChatClient {
+public class BedrockLlamaChatClient implements ChatClient, StreamingChatClient {
- private final Llama2ChatBedrockApi chatApi;
+ private final LlamaChatBedrockApi chatApi;
- private final BedrockLlama2ChatOptions defaultOptions;
+ private final BedrockLlamaChatOptions defaultOptions;
- public BedrockLlama2ChatClient(Llama2ChatBedrockApi chatApi) {
+ public BedrockLlamaChatClient(LlamaChatBedrockApi chatApi) {
this(chatApi,
- BedrockLlama2ChatOptions.builder().withTemperature(0.8f).withTopP(0.9f).withMaxGenLen(100).build());
+ BedrockLlamaChatOptions.builder().withTemperature(0.8f).withTopP(0.9f).withMaxGenLen(100).build());
}
- public BedrockLlama2ChatClient(Llama2ChatBedrockApi chatApi, BedrockLlama2ChatOptions options) {
- Assert.notNull(chatApi, "Llama2ChatBedrockApi must not be null");
- Assert.notNull(options, "BedrockLlama2ChatOptions must not be null");
+ public BedrockLlamaChatClient(LlamaChatBedrockApi chatApi, BedrockLlamaChatOptions options) {
+ Assert.notNull(chatApi, "LlamaChatBedrockApi must not be null");
+ Assert.notNull(options, "BedrockLlamaChatOptions must not be null");
this.chatApi = chatApi;
this.defaultOptions = options;
@@ -65,7 +66,7 @@ public ChatResponse call(Prompt prompt) {
var request = createRequest(prompt);
- Llama2ChatResponse response = this.chatApi.chatCompletion(request);
+ LlamaChatResponse response = this.chatApi.chatCompletion(request);
return new ChatResponse(List.of(new Generation(response.generation()).withGenerationMetadata(
ChatGenerationMetadata.from(response.stopReason().name(), extractUsage(response)))));
@@ -76,7 +77,7 @@ public Flux stream(Prompt prompt) {
var request = createRequest(prompt);
- Flux fluxResponse = this.chatApi.chatCompletionStream(request);
+ Flux fluxResponse = this.chatApi.chatCompletionStream(request);
return fluxResponse.map(response -> {
String stopReason = response.stopReason() != null ? response.stopReason().name() : null;
@@ -85,7 +86,7 @@ public Flux stream(Prompt prompt) {
});
}
- private Usage extractUsage(Llama2ChatResponse response) {
+ private Usage extractUsage(LlamaChatResponse response) {
return new Usage() {
@Override
@@ -103,22 +104,22 @@ public Long getGenerationTokens() {
/**
* Accessible for testing.
*/
- Llama2ChatRequest createRequest(Prompt prompt) {
+ LlamaChatRequest createRequest(Prompt prompt) {
final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions());
- Llama2ChatRequest request = Llama2ChatRequest.builder(promptValue).build();
+ LlamaChatRequest request = LlamaChatRequest.builder(promptValue).build();
if (this.defaultOptions != null) {
- request = ModelOptionsUtils.merge(request, this.defaultOptions, Llama2ChatRequest.class);
+ request = ModelOptionsUtils.merge(request, this.defaultOptions, LlamaChatRequest.class);
}
if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
- BedrockLlama2ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
- ChatOptions.class, BedrockLlama2ChatOptions.class);
+ BedrockLlamaChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
+ ChatOptions.class, BedrockLlamaChatOptions.class);
- request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, Llama2ChatRequest.class);
+ request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, LlamaChatRequest.class);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java
similarity index 91%
rename from models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatOptions.java
rename to models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java
index a944b09904a..3502fd4c441 100644
--- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatOptions.java
+++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java
@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.springframework.ai.bedrock.llama2;
+package org.springframework.ai.bedrock.llama;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
@@ -26,7 +26,7 @@
* @author Christian Tzolov
*/
@JsonInclude(Include.NON_NULL)
-public class BedrockLlama2ChatOptions implements ChatOptions {
+public class BedrockLlamaChatOptions implements ChatOptions {
/**
* The temperature value controls the randomness of the generated text. Use a lower
@@ -51,7 +51,7 @@ public static Builder builder() {
public static class Builder {
- private BedrockLlama2ChatOptions options = new BedrockLlama2ChatOptions();
+ private BedrockLlamaChatOptions options = new BedrockLlamaChatOptions();
public Builder withTemperature(Float temperature) {
this.options.setTemperature(temperature);
@@ -68,7 +68,7 @@ public Builder withMaxGenLen(Integer maxGenLen) {
return this;
}
- public BedrockLlama2ChatOptions build() {
+ public BedrockLlamaChatOptions build() {
return this.options;
}
diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java
similarity index 61%
rename from models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApi.java
rename to models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java
index af10d69bdfb..25d71aedeea 100644
--- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApi.java
+++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java
@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.springframework.ai.bedrock.llama2.api;
+package org.springframework.ai.bedrock.llama.api;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
@@ -21,76 +21,92 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.Region;
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatRequest;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatResponse;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse;
import java.time.Duration;
// @formatter:off
/**
- * Java client for the Bedrock Llama2 chat model.
+ * Java client for the Bedrock Llama chat model.
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
*
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
-public class Llama2ChatBedrockApi extends
- AbstractBedrockApi {
+public class LlamaChatBedrockApi extends
+ AbstractBedrockApi {
/**
- * Create a new Llama2ChatBedrockApi instance using the default credentials provider chain, the default object
+ * Create a new LlamaChatBedrockApi instance using the default credentials provider chain, the default object
* mapper, default temperature and topP values.
*
- * @param modelId The model id to use. See the {@link Llama2ChatModel} for the supported models.
+ * @param modelId The model id to use. See the {@link LlamaChatModel} for the supported models.
* @param region The AWS region to use.
*/
- public Llama2ChatBedrockApi(String modelId, String region) {
+ public LlamaChatBedrockApi(String modelId, String region) {
super(modelId, region);
}
/**
- * Create a new Llama2ChatBedrockApi instance using the provided credentials provider, region and object mapper.
+ * Create a new LlamaChatBedrockApi instance using the provided credentials provider, region and object mapper.
*
- * @param modelId The model id to use. See the {@link Llama2ChatModel} for the supported models.
+ * @param modelId The model id to use. See the {@link LlamaChatModel} for the supported models.
* @param credentialsProvider The credentials provider to connect to AWS.
* @param region The AWS region to use.
* @param objectMapper The object mapper to use for JSON serialization and deserialization.
*/
- public Llama2ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region,
+ public LlamaChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region,
ObjectMapper objectMapper) {
super(modelId, credentialsProvider, region, objectMapper);
}
/**
- * Create a new Llama2ChatBedrockApi instance using the default credentials provider chain, the default object
+ * Create a new LlamaChatBedrockApi instance using the default credentials provider chain, the default object
* mapper, default temperature and topP values.
*
- * @param modelId The model id to use. See the {@link Llama2ChatModel} for the supported models.
+ * @param modelId The model id to use. See the {@link LlamaChatModel} for the supported models.
* @param region The AWS region to use.
* @param timeout The timeout to use.
*/
- public Llama2ChatBedrockApi(String modelId, String region, Duration timeout) {
+ public LlamaChatBedrockApi(String modelId, String region, Duration timeout) {
super(modelId, region, timeout);
}
/**
- * Create a new Llama2ChatBedrockApi instance using the provided credentials provider, region and object mapper.
+ * Create a new LlamaChatBedrockApi instance using the provided credentials provider, region and object mapper.
*
- * @param modelId The model id to use. See the {@link Llama2ChatModel} for the supported models.
+ * @param modelId The model id to use. See the {@link LlamaChatModel} for the supported models.
* @param credentialsProvider The credentials provider to connect to AWS.
* @param region The AWS region to use.
* @param objectMapper The object mapper to use for JSON serialization and deserialization.
* @param timeout The timeout to use.
*/
- public Llama2ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region,
+ public LlamaChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region,
ObjectMapper objectMapper, Duration timeout) {
super(modelId, credentialsProvider, region, objectMapper, timeout);
}
/**
- * Llama2ChatRequest encapsulates the request parameters for the Meta Llama2 chat model.
+ * Create a new LlamaChatBedrockApi instance using the provided credentials provider, region and object mapper.
+ *
+ * @param modelId The model id to use. See the {@link LlamaChatModel} for the supported models.
+ * @param credentialsProvider The credentials provider to connect to AWS.
+ * @param region The AWS region to use.
+ * @param objectMapper The object mapper to use for JSON serialization and deserialization.
+ * @param timeout The timeout to use.
+ */
+ public LlamaChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
+ ObjectMapper objectMapper, Duration timeout) {
+ super(modelId, credentialsProvider, region, objectMapper, timeout);
+ }
+
+ /**
+ * LlamaChatRequest encapsulates the request parameters for the Meta Llama chat model.
*
* @param prompt The prompt to use for the chat.
* @param temperature The temperature value controls the randomness of the generated text. Use a lower value to
@@ -100,16 +116,16 @@ public Llama2ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsPr
* @param maxGenLen The maximum length of the generated text.
*/
@JsonInclude(Include.NON_NULL)
- public record Llama2ChatRequest(
+ public record LlamaChatRequest(
@JsonProperty("prompt") String prompt,
@JsonProperty("temperature") Float temperature,
@JsonProperty("top_p") Float topP,
@JsonProperty("max_gen_len") Integer maxGenLen) {
/**
- * Create a new Llama2ChatRequest builder.
+ * Create a new LlamaChatRequest builder.
* @param prompt compulsory prompt parameter.
- * @return a new Llama2ChatRequest builder.
+ * @return a new LlamaChatRequest builder.
*/
public static Builder builder(String prompt) {
return new Builder(prompt);
@@ -140,8 +156,8 @@ public Builder withMaxGenLen(Integer maxGenLen) {
return this;
}
- public Llama2ChatRequest build() {
- return new Llama2ChatRequest(
+ public LlamaChatRequest build() {
+ return new LlamaChatRequest(
prompt,
temperature,
topP,
@@ -152,7 +168,7 @@ public Llama2ChatRequest build() {
}
/**
- * Llama2ChatResponse encapsulates the response parameters for the Meta Llama2 chat model.
+ * LlamaChatResponse encapsulates the response parameters for the Meta Llama chat model.
*
* @param generation The generated text.
* @param promptTokenCount The number of tokens in the prompt.
@@ -163,7 +179,7 @@ public Llama2ChatRequest build() {
* increasing the value of max_gen_len and trying again.
*/
@JsonInclude(Include.NON_NULL)
- public record Llama2ChatResponse(
+ public record LlamaChatResponse(
@JsonProperty("generation") String generation,
@JsonProperty("prompt_token_count") Integer promptTokenCount,
@JsonProperty("generation_token_count") Integer generationTokenCount,
@@ -186,9 +202,9 @@ public enum StopReason {
}
/**
- * Llama2 models version.
+ * Llama models version.
*/
- public enum Llama2ChatModel {
+ public enum LlamaChatModel {
/**
* meta.llama2-13b-chat-v1
@@ -198,7 +214,17 @@ public enum Llama2ChatModel {
/**
* meta.llama2-70b-chat-v1
*/
- LLAMA2_70B_CHAT_V1("meta.llama2-70b-chat-v1");
+ LLAMA2_70B_CHAT_V1("meta.llama2-70b-chat-v1"),
+
+ /**
+ * meta.llama3-8b-instruct-v1:0
+ */
+ LLAMA3_8B_INSTRUCT_V1("meta.llama3-8b-instruct-v1:0"),
+
+ /**
+ * meta.llama3-70b-instruct-v1:0
+ */
+ LLAMA3_70B_INSTRUCT_V1("meta.llama3-70b-instruct-v1:0");
private final String id;
@@ -209,19 +235,19 @@ public String id() {
return id;
}
- Llama2ChatModel(String value) {
+ LlamaChatModel(String value) {
this.id = value;
}
}
@Override
- public Llama2ChatResponse chatCompletion(Llama2ChatRequest request) {
- return this.internalInvocation(request, Llama2ChatResponse.class);
+ public LlamaChatResponse chatCompletion(LlamaChatRequest request) {
+ return this.internalInvocation(request, LlamaChatResponse.class);
}
@Override
- public Flux chatCompletionStream(Llama2ChatRequest request) {
- return this.internalInvocationStream(request, Llama2ChatResponse.class);
+ public Flux chatCompletionStream(LlamaChatRequest request) {
+ return this.internalInvocationStream(request, LlamaChatResponse.class);
}
}
// @formatter:on
\ No newline at end of file
diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java
index d48135f80ec..1d64f92eff8 100644
--- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java
+++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java
@@ -28,6 +28,7 @@
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.AbstractEmbeddingClient;
import org.springframework.ai.embedding.Embedding;
+import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.util.Assert;
@@ -40,6 +41,7 @@
* Note: Titan Embedding does not support batch embedding.
*
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
public class BedrockTitanEmbeddingClient extends AbstractEmbeddingClient {
@@ -87,9 +89,7 @@ public EmbeddingResponse call(EmbeddingRequest request) {
List> embeddingList = new ArrayList<>();
for (String inputContent : request.getInstructions()) {
- var apiRequest = (this.inputType == InputType.IMAGE)
- ? new TitanEmbeddingRequest.Builder().withInputImage(inputContent).build()
- : new TitanEmbeddingRequest.Builder().withInputText(inputContent).build();
+ var apiRequest = createTitanEmbeddingRequest(inputContent, request.getOptions());
TitanEmbeddingResponse response = this.embeddingApi.embedding(apiRequest);
embeddingList.add(response.embedding());
}
@@ -100,6 +100,18 @@ public EmbeddingResponse call(EmbeddingRequest request) {
return new EmbeddingResponse(embeddings);
}
+ private TitanEmbeddingRequest createTitanEmbeddingRequest(String inputContent, EmbeddingOptions requestOptions) {
+ InputType inputType = this.inputType;
+
+ if (requestOptions != null
+ && requestOptions instanceof BedrockTitanEmbeddingOptions bedrockTitanEmbeddingOptions) {
+ inputType = bedrockTitanEmbeddingOptions.getInputType();
+ }
+
+ return (inputType == InputType.IMAGE) ? new TitanEmbeddingRequest.Builder().withInputImage(inputContent).build()
+ : new TitanEmbeddingRequest.Builder().withInputText(inputContent).build();
+ }
+
@Override
public int dimensions() {
if (this.inputType == InputType.IMAGE) {
diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java
new file mode 100644
index 00000000000..fd1c609bf91
--- /dev/null
+++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java
@@ -0,0 +1,65 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.bedrock.titan;
+
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonInclude.Include;
+
+import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingClient.InputType;
+import org.springframework.ai.embedding.EmbeddingOptions;
+import org.springframework.util.Assert;
+
+/**
+ * @author Wei Jiang
+ */
+@JsonInclude(Include.NON_NULL)
+public class BedrockTitanEmbeddingOptions implements EmbeddingOptions {
+
+ /**
+ * Titan Embedding API input types. Could be either text or image (encoded in base64).
+ */
+ private InputType inputType;
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static class Builder {
+
+ private BedrockTitanEmbeddingOptions options = new BedrockTitanEmbeddingOptions();
+
+ public Builder withInputType(InputType inputType) {
+ Assert.notNull(inputType, "input type can not be null.");
+
+ this.options.setInputType(inputType);
+ return this;
+ }
+
+ public BedrockTitanEmbeddingOptions build() {
+ return this.options;
+ }
+
+ }
+
+ public InputType getInputType() {
+ return this.inputType;
+ }
+
+ public void setInputType(InputType inputType) {
+ this.inputType = inputType;
+ }
+
+}
diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java
index 498b34bf3d8..f4a219a8d99 100644
--- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java
+++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java
@@ -24,6 +24,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.Region;
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatRequest;
@@ -38,6 +39,7 @@
* https://docs.aws.amazon.com/bedrock/latest/userguide/titan-text-models.html
*
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
// @formatter:off
@@ -92,6 +94,20 @@ public TitanChatBedrockApi(String modelId, AwsCredentialsProvider credentialsPro
super(modelId, credentialsProvider, region, objectMapper, timeout);
}
+ /**
+ * Create a new TitanChatBedrockApi instance using the provided credentials provider, region and object mapper.
+ *
+ * @param modelId The model id to use. See the {@link TitanChatModel} for the supported models.
+ * @param credentialsProvider The credentials provider to connect to AWS.
+ * @param region The AWS region to use.
+ * @param objectMapper The object mapper to use for JSON serialization and deserialization.
+ * @param timeout The timeout to use.
+ */
+ public TitanChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
+ ObjectMapper objectMapper, Duration timeout) {
+ super(modelId, credentialsProvider, region, objectMapper, timeout);
+ }
+
/**
* TitanChatRequest encapsulates the request parameters for the Titan chat model.
*
diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java
index 9c1dcb3b267..5901799c32f 100644
--- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java
+++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java
@@ -23,6 +23,7 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.Region;
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingRequest;
@@ -34,6 +35,7 @@
* https://docs.aws.amazon.com/bedrock/latest/userguide/titan-multiemb-models.html
*
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
// @formatter:off
@@ -65,6 +67,20 @@ public TitanEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credentia
super(modelId, credentialsProvider, region, objectMapper, timeout);
}
+ /**
+ * Create a new TitanEmbeddingBedrockApi instance.
+ *
+ * @param modelId The model id to use. See the {@link TitanEmbeddingModel} for the supported models.
+ * @param credentialsProvider The credentials provider to connect to AWS.
+ * @param region The AWS region to use.
+ * @param objectMapper The object mapper to use for JSON serialization and deserialization.
+ * @param timeout The timeout to use.
+ */
+ public TitanEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
+ ObjectMapper objectMapper, Duration timeout) {
+ super(modelId, credentialsProvider, region, objectMapper, timeout);
+ }
+
/**
* Titan Embedding request parameters.
*
diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatClientIT.java
index c2050f18ca0..4fa59402054 100644
--- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatClientIT.java
+++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatClientIT.java
@@ -161,9 +161,9 @@ void beanOutputParserRecords() {
String format = outputParser.getFormat();
String template = """
Generate the filmography of 5 movies for Tom Hanks.
- Remove non JSON tex blocks from the output.
{format}
Provide your answer in the JSON format with the feature names as the keys.
+ Remove Markdown code blocks from the output.
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java
index a4b51a70750..92060ef5bf0 100644
--- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java
+++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java
@@ -20,7 +20,7 @@
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi;
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
import org.springframework.aot.hint.RuntimeHints;
@@ -43,7 +43,7 @@ void registerHints() {
bedrockRuntimeHints.registerHints(runtimeHints, null);
List classList = Arrays.asList(Ai21Jurassic2ChatBedrockApi.class, CohereChatBedrockApi.class,
- CohereEmbeddingBedrockApi.class, Llama2ChatBedrockApi.class, TitanChatBedrockApi.class,
+ CohereEmbeddingBedrockApi.class, LlamaChatBedrockApi.class, TitanChatBedrockApi.class,
TitanEmbeddingBedrockApi.class, AnthropicChatBedrockApi.class);
for (Class aClass : classList) {
diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatClientIT.java
similarity index 88%
rename from models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java
rename to models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatClientIT.java
index 7cbf3f23566..554afcf9fa0 100644
--- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java
+++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatClientIT.java
@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.springframework.ai.bedrock.llama2;
+package org.springframework.ai.bedrock.llama;
import java.time.Duration;
import java.util.Arrays;
@@ -22,27 +22,25 @@
import java.util.stream.Collectors;
import com.fasterxml.jackson.databind.ObjectMapper;
-import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import reactor.core.publisher.Flux;
-
-import org.springframework.ai.chat.ChatResponse;
-import org.springframework.ai.chat.messages.AssistantMessage;
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
import software.amazon.awssdk.regions.Region;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatModel;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel;
+import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
-import org.springframework.ai.parser.BeanOutputParser;
-import org.springframework.ai.parser.ListOutputParser;
-import org.springframework.ai.parser.MapOutputParser;
+import org.springframework.ai.chat.messages.AssistantMessage;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
-import org.springframework.ai.chat.messages.Message;
-import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.parser.BeanOutputParser;
+import org.springframework.ai.parser.ListOutputParser;
+import org.springframework.ai.parser.MapOutputParser;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.SpringBootConfiguration;
@@ -56,10 +54,10 @@
@SpringBootTest
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
-class BedrockLlama2ChatClientIT {
+class BedrockLlamaChatClientIT {
@Autowired
- private BedrockLlama2ChatClient client;
+ private BedrockLlamaChatClient client;
@Value("classpath:/prompts/system-message.st")
private Resource systemResource;
@@ -67,8 +65,8 @@ class BedrockLlama2ChatClientIT {
@Test
void multipleStreamAttempts() {
+ Flux joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a Toy joke?")));
Flux joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?")));
- Flux joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
String joke1 = joke1Stream.collectList()
.block()
@@ -105,7 +103,6 @@ void roleTest() {
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
}
- @Disabled("TODO: Fix the parser instructions to return the correct format")
@Test
void outputParser() {
DefaultConversionService conversionService = new DefaultConversionService();
@@ -147,7 +144,6 @@ void mapOutputParser() {
record ActorsFilmsRecord(String actor, List movies) {
}
- @Disabled("TODO: Fix the parser instructions to return the correct format")
@Test
void beanOutputParserRecords() {
@@ -169,7 +165,6 @@ void beanOutputParserRecords() {
assertThat(actorsFilms.movies()).hasSize(5);
}
- @Disabled("TODO: Fix the parser instructions to return the correct format")
@Test
void beanStreamOutputParserRecords() {
@@ -204,16 +199,16 @@ void beanStreamOutputParserRecords() {
public static class TestConfiguration {
@Bean
- public Llama2ChatBedrockApi llama2Api() {
- return new Llama2ChatBedrockApi(Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(),
+ public LlamaChatBedrockApi llamaApi() {
+ return new LlamaChatBedrockApi(LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(),
EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(),
Duration.ofMinutes(2));
}
@Bean
- public BedrockLlama2ChatClient llama2ChatClient(Llama2ChatBedrockApi llama2Api) {
- return new BedrockLlama2ChatClient(llama2Api,
- BedrockLlama2ChatOptions.builder().withTemperature(0.5f).withMaxGenLen(100).withTopP(0.9f).build());
+ public BedrockLlamaChatClient llamaChatClient(LlamaChatBedrockApi llamaApi) {
+ return new BedrockLlamaChatClient(llamaApi,
+ BedrockLlamaChatOptions.builder().withTemperature(0.5f).withMaxGenLen(100).withTopP(0.9f).build());
}
}
diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2CreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java
similarity index 67%
rename from models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2CreateRequestTests.java
rename to models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java
index 1a3329016f3..77018af9e6e 100644
--- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2CreateRequestTests.java
+++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java
@@ -13,15 +13,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.springframework.ai.bedrock.llama2;
+package org.springframework.ai.bedrock.llama;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
import software.amazon.awssdk.regions.Region;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatModel;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import java.time.Duration;
@@ -30,18 +32,21 @@
/**
* @author Christian Tzolov
+ * @author Wei Jiang
*/
-public class BedrockLlama2CreateRequestTests {
+@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
+@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
+public class BedrockLlamaCreateRequestTests {
- private Llama2ChatBedrockApi api = new Llama2ChatBedrockApi(Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(),
+ private LlamaChatBedrockApi api = new LlamaChatBedrockApi(LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(),
EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(),
Duration.ofMinutes(2));
@Test
public void createRequestWithChatOptions() {
- var client = new BedrockLlama2ChatClient(api,
- BedrockLlama2ChatOptions.builder().withTemperature(66.6f).withMaxGenLen(666).withTopP(0.66f).build());
+ var client = new BedrockLlamaChatClient(api,
+ BedrockLlamaChatOptions.builder().withTemperature(66.6f).withMaxGenLen(666).withTopP(0.66f).build());
var request = client.createRequest(new Prompt("Test message content"));
@@ -51,7 +56,7 @@ public void createRequestWithChatOptions() {
assertThat(request.maxGenLen()).isEqualTo(666);
request = client.createRequest(new Prompt("Test message content",
- BedrockLlama2ChatOptions.builder().withTemperature(99.9f).withMaxGenLen(999).withTopP(0.99f).build()));
+ BedrockLlamaChatOptions.builder().withTemperature(99.9f).withMaxGenLen(999).withTopP(0.99f).build()));
assertThat(request.prompt()).isNotEmpty();
assertThat(request.temperature()).isEqualTo(99.9f);
diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java
similarity index 61%
rename from models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApiIT.java
rename to models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java
index dc97d8e7b8f..5b4587358fb 100644
--- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApiIT.java
+++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java
@@ -13,47 +13,51 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.springframework.ai.bedrock.llama2.api;
+package org.springframework.ai.bedrock.llama.api;
import java.time.Duration;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+
import reactor.core.publisher.Flux;
+import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
import software.amazon.awssdk.regions.Region;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatModel;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatRequest;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatResponse;
-
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author Christian Tzolov
+ * @author Wei Jiang
*/
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
-public class Llama2ChatBedrockApiIT {
+public class LlamaChatBedrockApiIT {
- private Llama2ChatBedrockApi llama2ChatApi = new Llama2ChatBedrockApi(Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(),
- Region.US_EAST_1.id(), Duration.ofMinutes(2));
+ private LlamaChatBedrockApi llamaChatApi = new LlamaChatBedrockApi(LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(),
+ EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(),
+ Duration.ofMinutes(2));
@Test
public void chatCompletion() {
- Llama2ChatRequest request = Llama2ChatRequest.builder("Hello, my name is")
+ LlamaChatRequest request = LlamaChatRequest.builder("Hello, my name is")
.withTemperature(0.9f)
.withTopP(0.9f)
.withMaxGenLen(20)
.build();
- Llama2ChatResponse response = llama2ChatApi.chatCompletion(request);
+ LlamaChatResponse response = llamaChatApi.chatCompletion(request);
System.out.println(response.generation());
assertThat(response).isNotNull();
assertThat(response.generation()).isNotEmpty();
- assertThat(response.promptTokenCount()).isEqualTo(6);
assertThat(response.generationTokenCount()).isGreaterThan(10);
assertThat(response.generationTokenCount()).isLessThanOrEqualTo(20);
assertThat(response.stopReason()).isNotNull();
@@ -63,15 +67,15 @@ public void chatCompletion() {
@Test
public void chatCompletionStream() {
- Llama2ChatRequest request = new Llama2ChatRequest("Hello, my name is", 0.9f, 0.9f, 20);
- Flux responseStream = llama2ChatApi.chatCompletionStream(request);
- List responses = responseStream.collectList().block();
+ LlamaChatRequest request = new LlamaChatRequest("Hello, my name is", 0.9f, 0.9f, 20);
+ Flux responseStream = llamaChatApi.chatCompletionStream(request);
+ List responses = responseStream.collectList().block();
assertThat(responses).isNotNull();
assertThat(responses).hasSizeGreaterThan(10);
assertThat(responses.get(0).generation()).isNotEmpty();
- Llama2ChatResponse lastResponse = responses.get(responses.size() - 1);
+ LlamaChatResponse lastResponse = responses.get(responses.size() - 1);
assertThat(lastResponse.stopReason()).isNotNull();
assertThat(lastResponse.amazonBedrockInvocationMetrics()).isNotNull();
}
diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClientIT.java
index dead759015c..6400a15f010 100644
--- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClientIT.java
+++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClientIT.java
@@ -22,10 +22,14 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+
+import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
import software.amazon.awssdk.regions.Region;
+import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingClient.InputType;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingModel;
+import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
@@ -33,6 +37,8 @@
import org.springframework.context.annotation.Bean;
import org.springframework.core.io.DefaultResourceLoader;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest
@@ -46,7 +52,8 @@ class BedrockTitanEmbeddingClientIT {
@Test
void singleEmbedding() {
assertThat(embeddingClient).isNotNull();
- EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World"));
+ EmbeddingResponse embeddingResponse = embeddingClient.call(new EmbeddingRequest(List.of("Hello World"),
+ BedrockTitanEmbeddingOptions.builder().withInputType(InputType.TEXT).build()));
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
@@ -59,7 +66,8 @@ void imageEmbedding() throws IOException {
.getContentAsByteArray();
EmbeddingResponse embeddingResponse = embeddingClient
- .embedForResponse(List.of(Base64.getEncoder().encodeToString(image)));
+ .call(new EmbeddingRequest(List.of(Base64.getEncoder().encodeToString(image)),
+ BedrockTitanEmbeddingOptions.builder().withInputType(InputType.IMAGE).build()));
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
@@ -70,7 +78,8 @@ public static class TestConfiguration {
@Bean
public TitanEmbeddingBedrockApi titanEmbeddingApi() {
- return new TitanEmbeddingBedrockApi(TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id(), Region.US_EAST_1.id(),
+ return new TitanEmbeddingBedrockApi(TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id(),
+ EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(),
Duration.ofMinutes(2));
}
diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java
index 91aad8f5490..f5c1f4fd5bf 100644
--- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java
+++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java
@@ -177,7 +177,7 @@ public Flux stream(Prompt prompt) {
private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) {
List choices = chunk.choices()
.stream()
- .map(cc -> new Choice(cc.index(), cc.delta(), cc.finishReason()))
+ .map(cc -> new Choice(cc.index(), cc.delta(), cc.finishReason(), cc.logprobs()))
.toList();
return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, null);
@@ -285,7 +285,13 @@ protected List doGetUserMessages(ChatCompletionRequest re
@SuppressWarnings("null")
@Override
protected ChatCompletionMessage doGetToolResponseMessage(ResponseEntity chatCompletion) {
- return chatCompletion.getBody().choices().iterator().next().message();
+ ChatCompletionMessage msg = chatCompletion.getBody().choices().iterator().next().message();
+ if (msg.role() == null) {
+ // add missing role
+ msg = new ChatCompletionMessage(msg.content(), ChatCompletionMessage.Role.ASSISTANT, msg.name(),
+ msg.toolCalls());
+ }
+ return msg;
}
@Override
diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java
index 19ba2ac521f..5732c871a66 100644
--- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java
+++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java
@@ -535,14 +535,12 @@ public enum ChatCompletionFinishReason {
*/
@JsonProperty("model_length") MODEL_LENGTH,
/**
- * The model called a tool.
+ *
*/
- @JsonProperty("tool_call") TOOL_CALL,
-
- // anticipation of future changes. Based on:
- // https://github.com/mistralai/client-python/blob/main/src/mistralai/models/chat_completion.py
@JsonProperty("error") ERROR,
-
+ /**
+ * The model requested a tool call.
+ */
@JsonProperty("tool_calls") TOOL_CALLS
// @formatter:on
@@ -577,17 +575,65 @@ public record ChatCompletion(
* @param index The index of the choice in the list of choices.
* @param message A chat completion message generated by the model.
* @param finishReason The reason the model stopped generating tokens.
+ * @param logprobs Log probability information for the choice.
*/
@JsonInclude(Include.NON_NULL)
public record Choice(
// @formatter:off
@JsonProperty("index") Integer index,
@JsonProperty("message") ChatCompletionMessage message,
- @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) {
+ @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
+ @JsonProperty("logprobs") LogProbs logprobs) {
// @formatter:on
}
}
+ /**
+ *
+ * Log probability information for the choice. anticipation of future changes.
+ *
+ * @param content A list of message content tokens with log probability information.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record LogProbs(@JsonProperty("content") List content) {
+
+ /**
+ * Message content tokens with log probability information.
+ *
+ * @param token The token.
+ * @param logprob The log probability of the token.
+ * @param probBytes A list of integers representing the UTF-8 bytes representation
+ * of the token. Useful in instances where characters are represented by multiple
+ * tokens and their byte representations must be combined to generate the correct
+ * text representation. Can be null if there is no bytes representation for the
+ * token.
+ * @param topLogprobs List of the most likely tokens and their log probability, at
+ * this token position. In rare cases, there may be fewer than the number of
+ * requested top_logprobs returned.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record Content(@JsonProperty("token") String token, @JsonProperty("logprob") Float logprob,
+ @JsonProperty("bytes") List probBytes,
+ @JsonProperty("top_logprobs") List topLogprobs) {
+
+ /**
+ * The most likely tokens and their log probability, at this token position.
+ *
+ * @param token The token.
+ * @param logprob The log probability of the token.
+ * @param probBytes A list of integers representing the UTF-8 bytes
+ * representation of the token. Useful in instances where characters are
+ * represented by multiple tokens and their byte representations must be
+ * combined to generate the correct text representation. Can be null if there
+ * is no bytes representation for the token.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record TopLogProbs(@JsonProperty("token") String token, @JsonProperty("logprob") Float logprob,
+ @JsonProperty("bytes") List probBytes) {
+ }
+ }
+ }
+
/**
* Represents a streamed chunk of a chat completion response returned by model, based
* on the provided input.
@@ -616,13 +662,15 @@ public record ChatCompletionChunk(
* @param index The index of the choice in the list of choices.
* @param delta A chat completion delta generated by streamed model responses.
* @param finishReason The reason the model stopped generating tokens.
+ * @param logprobs Log probability information for the choice.
*/
@JsonInclude(Include.NON_NULL)
public record ChunkChoice(
// @formatter:off
@JsonProperty("index") Integer index,
@JsonProperty("delta") ChatCompletionMessage delta,
- @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) {
+ @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
+ @JsonProperty("logprobs") LogProbs logprobs) {
// @formatter:on
}
}
diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java
index 50cd223536c..774bd072934 100644
--- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java
+++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java
@@ -27,6 +27,7 @@
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ChatCompletionFunction;
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.Role;
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ToolCall;
+import org.springframework.ai.mistralai.api.MistralAiApi.LogProbs;
import org.springframework.util.CollectionUtils;
/**
@@ -83,8 +84,10 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
.toList();
var role = current.delta().role() != null ? current.delta().role() : Role.ASSISTANT;
- current = new ChunkChoice(current.index(), new ChatCompletionMessage(current.delta().content(),
- role, current.delta().name(), toolCallsWithID), current.finishReason());
+ current = new ChunkChoice(
+ current.index(), new ChatCompletionMessage(current.delta().content(), role,
+ current.delta().name(), toolCallsWithID),
+ current.finishReason(), current.logprobs());
}
}
return current;
@@ -95,8 +98,9 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
Integer index = (current.index() != null ? current.index() : previous.index());
ChatCompletionMessage message = merge(previous.delta(), current.delta());
+ LogProbs logprobs = (current.logprobs() != null ? current.logprobs() : previous.logprobs());
- return new ChunkChoice(index, message, finishReason);
+ return new ChunkChoice(index, message, finishReason, logprobs);
}
private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) {
@@ -190,8 +194,7 @@ public boolean isStreamingToolFunctionCallFinish(ChatCompletionChunk chatComplet
}
var choice = choices.get(0);
- return choice.finishReason() == ChatCompletionFinishReason.TOOL_CALL
- || choice.finishReason() == ChatCompletionFinishReason.TOOL_CALLS;
+ return choice.finishReason() == ChatCompletionFinishReason.TOOL_CALLS;
}
}
diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java
index 549d788eade..1ca349d21a3 100644
--- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java
+++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java
@@ -109,7 +109,7 @@ public void beforeEach() {
public void mistralAiChatTransientError() {
var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
- ChatCompletionFinishReason.STOP);
+ ChatCompletionFinishReason.STOP, null);
ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789l, "model",
List.of(choice), new MistralAiApi.Usage(10, 10, 10));
@@ -137,7 +137,7 @@ public void mistralAiChatNonTransientError() {
public void mistralAiChatStreamTransientError() {
var choice = new ChatCompletionChunk.ChunkChoice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
- ChatCompletionFinishReason.STOP);
+ ChatCompletionFinishReason.STOP, null);
ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789l,
"model", List.of(choice));
diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechClient.java
index 86465687588..fc4575440d0 100644
--- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechClient.java
+++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechClient.java
@@ -133,7 +133,7 @@ private OpenAiAudioApi.SpeechRequest createRequestBody(SpeechPrompt request) {
if (request.getOptions() != null) {
if (request.getOptions() instanceof OpenAiAudioSpeechOptions runtimeOptions) {
- options = this.merge(options, runtimeOptions);
+ options = this.merge(runtimeOptions, options);
}
else {
throw new IllegalArgumentException("Prompt options are not of type SpeechOptions: "
diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java
index b99c7cca6de..e021571e0ea 100644
--- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java
+++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java
@@ -178,7 +178,7 @@ OpenAiAudioApi.TranscriptionRequest createRequestBody(AudioTranscriptionPrompt r
if (request.getOptions() != null) {
if (request.getOptions() instanceof OpenAiAudioTranscriptionOptions runtimeOptions) {
- options = this.merge(options, runtimeOptions);
+ options = this.merge(runtimeOptions, options);
}
else {
throw new IllegalArgumentException("Prompt options are not of type TranscriptionOptions: "
diff --git a/pom.xml b/pom.xml
index 1aa98332298..30e19770d96 100644
--- a/pom.xml
+++ b/pom.xml
@@ -74,6 +74,8 @@
vector-stores/spring-ai-elasticsearch-storespring-ai-spring-boot-starters/spring-ai-starter-watsonx-aispring-ai-spring-boot-starters/spring-ai-starter-elasticsearch-store
+ vector-stores/spring-ai-opensearch-store
+ spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store
@@ -151,6 +153,12 @@
11.6.14.5.11.7.1
+ 2.10.1
+ 5.3.1
+
+
+ 1.19.7
+ 2.0.10.0.4
diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml
index d1c857f9da9..43a481c157d 100644
--- a/spring-ai-bom/pom.xml
+++ b/spring-ai-bom/pom.xml
@@ -210,6 +210,12 @@
${project.version}
+
+ org.springframework.ai
+ spring-ai-opensearch-store
+ ${project.version}
+
+
org.springframework.ai
diff --git a/spring-ai-core/src/main/java/org/springframework/ai/aot/SpringAiCoreRuntimeHints.java b/spring-ai-core/src/main/java/org/springframework/ai/aot/SpringAiCoreRuntimeHints.java
index ecf2a3fd94e..2d00f1ea1d3 100644
--- a/spring-ai-core/src/main/java/org/springframework/ai/aot/SpringAiCoreRuntimeHints.java
+++ b/spring-ai-core/src/main/java/org/springframework/ai/aot/SpringAiCoreRuntimeHints.java
@@ -35,8 +35,8 @@ public class SpringAiCoreRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
- var chatTypes = Set.of(AbstractMessage.class, AssistantMessage.class, ChatMessage.class, FunctionMessage.class,
- Message.class, MessageType.class, UserMessage.class, SystemMessage.class, FunctionCallbackContext.class,
+ var chatTypes = Set.of(AbstractMessage.class, AssistantMessage.class, FunctionMessage.class, Message.class,
+ MessageType.class, UserMessage.class, SystemMessage.class, FunctionCallbackContext.class,
FunctionCallback.class, FunctionCallbackWrapper.class);
for (var c : chatTypes) {
hints.reflection().registerType(c);
diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatClient.java
index 6925c16eedb..cff4f8674b8 100644
--- a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatClient.java
+++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatClient.java
@@ -16,6 +16,10 @@
package org.springframework.ai.chat;
import org.springframework.ai.chat.prompt.Prompt;
+
+import java.util.Arrays;
+
+import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.model.ModelClient;
@@ -28,6 +32,12 @@ default String call(String message) {
return (generation != null) ? generation.getOutput().getContent() : "";
}
+ default String call(Message... messages) {
+ Prompt prompt = new Prompt(Arrays.asList(messages));
+ Generation generation = call(prompt).getResult();
+ return (generation != null) ? generation.getOutput().getContent() : "";
+ }
+
@Override
ChatResponse call(Prompt prompt);
diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java
index a6e4a0b7629..69634b1920d 100644
--- a/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java
+++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java
@@ -15,8 +15,11 @@
*/
package org.springframework.ai.chat;
+import java.util.Arrays;
+
import reactor.core.publisher.Flux;
+import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.StreamingModelClient;
@@ -30,6 +33,13 @@ default Flux stream(String message) {
: response.getResult().getOutput().getContent());
}
+ default Flux stream(Message... messages) {
+ Prompt prompt = new Prompt(Arrays.asList(messages));
+ return stream(prompt).map(response -> (response.getResult() == null || response.getResult().getOutput() == null
+ || response.getResult().getOutput().getContent() == null) ? ""
+ : response.getResult().getOutput().getContent());
+ }
+
@Override
Flux stream(Prompt prompt);
diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ChatMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ChatMessage.java
deleted file mode 100644
index ea4803a5943..00000000000
--- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ChatMessage.java
+++ /dev/null
@@ -1,42 +0,0 @@
-/*
- * Copyright 2023 - 2024 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.ai.chat.messages;
-
-import java.util.Map;
-
-/**
- * Represents a chat message in a chat application.
- *
- */
-public class ChatMessage extends AbstractMessage {
-
- public ChatMessage(String role, String content) {
- super(MessageType.valueOf(role), content);
- }
-
- public ChatMessage(String role, String content, Map properties) {
- super(MessageType.valueOf(role), content, properties);
- }
-
- public ChatMessage(MessageType messageType, String content) {
- super(messageType, content);
- }
-
- public ChatMessage(MessageType messageType, String content, Map properties) {
- super(messageType, content, properties);
- }
-
-}
diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java
index d981a67edec..4c9229516b6 100644
--- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java
+++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java
@@ -15,6 +15,7 @@
*/
package org.springframework.ai.chat.messages;
+import java.util.Arrays;
import java.util.List;
import org.springframework.core.io.Resource;
@@ -38,6 +39,10 @@ public UserMessage(String textContent, List mediaList) {
super(MessageType.USER, textContent, mediaList);
}
+ public UserMessage(String textContent, Media... media) {
+ this(textContent, Arrays.asList(media));
+ }
+
@Override
public String toString() {
return "UserMessage{" + "content='" + getContent() + '\'' + ", properties=" + properties + ", messageType="
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-llama2-chat-api.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-llama-chat-api.jpg
similarity index 100%
rename from spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-llama2-chat-api.jpg
rename to spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-llama-chat-api.jpg
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/orbis-sensualium-pictus2.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/orbis-sensualium-pictus2.jpg
new file mode 100644
index 00000000000..688c39d41b9
Binary files /dev/null and b/spring-ai-docs/src/main/antora/modules/ROOT/images/orbis-sensualium-pictus2.jpg differ
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-message-api.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-message-api.jpg
new file mode 100644
index 00000000000..cbdae70ab2a
Binary files /dev/null and b/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-message-api.jpg differ
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc
index 457f9ed77a4..5c05e9d6277 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc
@@ -11,7 +11,7 @@
*** xref:api/bedrock-chat.adoc[Amazon Bedrock]
**** xref:api/chat/bedrock/bedrock-anthropic3.adoc[Anthropic3]
**** xref:api/chat/bedrock/bedrock-anthropic.adoc[Anthropic2]
-**** xref:api/chat/bedrock/bedrock-llama2.adoc[Llama2]
+**** xref:api/chat/bedrock/bedrock-llama.adoc[Llama]
**** xref:api/chat/bedrock/bedrock-cohere.adoc[Cohere]
**** xref:api/chat/bedrock/bedrock-titan.adoc[Titan]
**** xref:api/chat/bedrock/bedrock-jurassic2.adoc[Jurassic2]
@@ -48,6 +48,7 @@
*** xref:api/vectordbs/azure.adoc[]
*** xref:api/vectordbs/apache-cassandra.adoc[]
*** xref:api/vectordbs/chroma.adoc[]
+*** xref:api/vectordbs/elasticsearch.adoc[]
*** xref:api/vectordbs/gemfire.adoc[GemFire]
*** xref:api/vectordbs/milvus.adoc[]
*** xref:api/vectordbs/mongodb.adoc[]
@@ -61,6 +62,7 @@
** xref:api/functions.adoc[Function Calling]
+** xref:api/multimodality.adoc[Multimodality]
** xref:api/prompt.adoc[]
** xref:api/output-parser.adoc[]
** xref:api/etl-pipeline.adoc[]
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc
index 4da2b06f7d0..3bdb84e102f 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc
@@ -63,6 +63,18 @@ AWS credentials are resolved in the following order:
6. Credentials delivered through the Amazon EC2 container service if the `AWS_CONTAINER_CREDENTIALS_RELATIVE_URI` environment variable is set and the security manager has permission to access the variable.
7. Instance profile credentials delivered through the Amazon EC2 metadata service or set the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables.
+AWS region is resolved in the following order:
+
+1. Spring-AI Bedrock `spring.ai.bedrock.aws.region` property.
+2. Java System Properties - `aws.region`.
+3. Environment Variables - `AWS_REGION`.
+4. Credential profiles file at the default location (`~/.aws/credentials`) shared by all AWS SDKs and the AWS CLI.
+5. Instance profile region delivered through the Amazon EC2 metadata service.
+
+In addition to the standard Spring-AI Bedrock credentials and region properties configuration, Spring-AI provides support for custom `AwsCredentialsProvider` and `AwsRegionProvider` beans.
+
+NOTE: For example, using Spring-AI and https://spring.io/projects/spring-cloud-aws[Spring Cloud for Amazon Web Services] at the same time. Spring-AI is compatible with Spring Cloud for Amazon Web Services credential configuration.
+
=== Enable selected Bedrock model
NOTE: By default, all models are disabled. You have to enable the chosen Bedrock models explicitly using the `spring.ai.bedrock...enabled=true` property.
@@ -73,7 +85,7 @@ Here are the supported `` and `` combinations:
|====
| Model | Chat | Chat Streaming | Embedding
-| llama2 | Yes | Yes | No
+| llama | Yes | Yes | No
| jurassic2 | Yes | No | No
| cohere | Yes | Yes | Yes
| anthropic 2 | Yes | Yes | No
@@ -82,7 +94,7 @@ Here are the supported `` and `` combinations:
| titan | Yes | Yes | Yes (however, no batch support)
|====
-For example, to enable the Bedrock Llama2 Chat client, you need to set `spring.ai.bedrock.llama2.chat.enabled=true`.
+For example, to enable the Bedrock Llama Chat client, you need to set `spring.ai.bedrock.llama.chat.enabled=true`.
Next, you can use the `spring.ai.bedrock...*` properties to configure each model as provided.
@@ -90,7 +102,7 @@ For more information, refer to the documentation below for each supported model.
* xref:api/chat/bedrock/bedrock-anthropic.adoc[Spring AI Bedrock Anthropic 2 Chat]: `spring.ai.bedrock.anthropic.chat.enabled=true`
* xref:api/chat/bedrock/bedrock-anthropic3.adoc[Spring AI Bedrock Anthropic 3 Chat]: `spring.ai.bedrock.anthropic.chat.enabled=true`
-* xref:api/chat/bedrock/bedrock-llama2.adoc[Spring AI Bedrock Llama2 Chat]: `spring.ai.bedrock.llama2.chat.enabled=true`
+* xref:api/chat/bedrock/bedrock-llama.adoc[Spring AI Bedrock Llama Chat]: `spring.ai.bedrock.llama.chat.enabled=true`
* xref:api/chat/bedrock/bedrock-cohere.adoc[Spring AI Bedrock Cohere Chat]: `spring.ai.bedrock.cohere.chat.enabled=true`
* xref:api/embeddings/bedrock-cohere-embedding.adoc[Spring AI Bedrock Cohere Embeddings]: `spring.ai.bedrock.cohere.embedding.enabled=true`
* xref:api/chat/bedrock/bedrock-titan.adoc[Spring AI Bedrock Titan Chat]: `spring.ai.bedrock.titan.chat.enabled=true`
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc
index c0d5f516d82..af1ac613ae2 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc
@@ -264,7 +264,7 @@ Flux response = chatClient.stream(
The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java[Anthropic3ChatBedrockApi] provides is lightweight Java client on top of AWS Bedrock link:https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html[Anthropic Claude models].
-Client supports the `anthropic.claude-3-sonnet-20240229-v1:0`,`anthropic.claude-3-haiku-20240307-v1:0` and the legacy `anthropic.claude-v2`, `anthropic.claude-v2:1` and `anthropic.claude-instant-v1` models for both synchronous (e.g. `chatCompletion()`) and streaming (e.g. `chatCompletionStream()`) responses.
+Client supports the `anthropic.claude-3-opus-20240229-v1:0`,`anthropic.claude-3-sonnet-20240229-v1:0`,`anthropic.claude-3-haiku-20240307-v1:0` and the legacy `anthropic.claude-v2`, `anthropic.claude-v2:1` and `anthropic.claude-instant-v1` models for both synchronous (e.g. `chatCompletion()`) and streaming (e.g. `chatCompletionStream()`) responses.
Here is a simple snippet how to use the api programmatically:
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama2.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc
similarity index 55%
rename from spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama2.adoc
rename to spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc
index 7f891e39432..1d04b2b8ffb 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama2.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc
@@ -1,13 +1,13 @@
-= Llama2 Chat
+= Llama Chat
-https://ai.meta.com/llama/[Meta's Llama 2 Chat] is part of the Llama 2 collection of large language models.
+https://ai.meta.com/llama/[Meta's Llama Chat] is part of the Llama collection of large language models.
It excels in dialogue-based applications with a parameter scale ranging from 7 billion to 70 billion.
Leveraging public datasets and over 1 million human annotations, Llama Chat offers context-aware dialogues.
-Trained on 2 trillion tokens from public data sources, Llama-2-Chat provides extensive knowledge for insightful conversations.
+Trained on 2 trillion tokens from public data sources, Llama-Chat provides extensive knowledge for insightful conversations.
Rigorous testing, including over 1,000 hours of red-teaming and annotation, ensures both performance and safety, making it a reliable choice for AI-driven dialogues.
-The https://aws.amazon.com/bedrock/llama-2/[AWS Llama 2 Model Page] and https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html[Amazon Bedrock User Guide] contains detailed information on how to use the AWS hosted model.
+The https://aws.amazon.com/bedrock/llama/[AWS Llama Model Page] and https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html[Amazon Bedrock User Guide] contains detailed information on how to use the AWS hosted model.
== Prerequisites
@@ -43,15 +43,15 @@ dependencies {
TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file.
-=== Enable Llama2 Chat Support
+=== Enable Llama Chat Support
-By default the Bedrock Llama2 model is disabled.
-To enable it set the `spring.ai.bedrock.llama2.chat.enabled` property to `true`.
+By default the Bedrock Llama model is disabled.
+To enable it set the `spring.ai.bedrock.llama.chat.enabled` property to `true`.
Exporting environment variable is one way to set this configuration property:
[source,shell]
----
-export SPRING_AI_BEDROCK_LLAMA2_CHAT_ENABLED=true
+export SPRING_AI_BEDROCK_LLAMA_CHAT_ENABLED=true
----
=== Chat Properties
@@ -69,29 +69,29 @@ The prefix `spring.ai.bedrock.aws` is the property prefix to configure the conne
|====
-The prefix `spring.ai.bedrock.llama2.chat` is the property prefix that configures the chat client implementation for Llama2.
+The prefix `spring.ai.bedrock.llama.chat` is the property prefix that configures the chat client implementation for Llama.
[cols="2,5,1"]
|====
| Property | Description | Default
-| spring.ai.bedrock.llama2.chat.enabled | Enable or disable support for Llama2 | false
-| spring.ai.bedrock.llama2.chat.model | The model id to use (See Below) | meta.llama2-70b-chat-v1
-| spring.ai.bedrock.llama2.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied, while a value closer to 0.0 will typically result in less surprising responses from the model. This value specifies default to be used by the backend while making the call to the model. | 0.7
-| spring.ai.bedrock.llama2.chat.options.top-p | The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and nucleus sampling. Nucleus sampling considers the smallest set of tokens whose probability sum is at least topP. | AWS Bedrock default
-| spring.ai.bedrock.llama2.chat.options.max-gen-len | Specify the maximum number of tokens to use in the generated response. The model truncates the response once the generated text exceeds maxGenLen. | 300
+| spring.ai.bedrock.llama.chat.enabled | Enable or disable support for Llama | false
+| spring.ai.bedrock.llama.chat.model | The model id to use (See Below) | meta.llama3-70b-instruct-v1:0
+| spring.ai.bedrock.llama.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied, while a value closer to 0.0 will typically result in less surprising responses from the model. This value specifies default to be used by the backend while making the call to the model. | 0.7
+| spring.ai.bedrock.llama.chat.options.top-p | The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and nucleus sampling. Nucleus sampling considers the smallest set of tokens whose probability sum is at least topP. | AWS Bedrock default
+| spring.ai.bedrock.llama.chat.options.max-gen-len | Specify the maximum number of tokens to use in the generated response. The model truncates the response once the generated text exceeds maxGenLen. | 300
|====
-Look at https://github.com/spring-projects/spring-ai/blob/4ba9a3cd689b9fd3a3805f540debe398a079c6ef/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApi.java#L164[Llama2ChatBedrockApi#Llama2ChatModel] for other model IDs. The other value supported is `meta.llama2-13b-chat-v1`.
+Look at https://github.com/spring-projects/spring-ai/blob/4ba9a3cd689b9fd3a3805f540debe398a079c6ef/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java#L164[LlamaChatBedrockApi#LlamaChatModel] for other model IDs. The other value supported is `meta.llama2-13b-chat-v1`.
Model ID values can also be found in the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html[AWS Bedrock documentation for base model IDs].
-TIP: All properties prefixed with `spring.ai.bedrock.llama2.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call.
+TIP: All properties prefixed with `spring.ai.bedrock.llama.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call.
== Runtime Options [[chat-options]]
-The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatOptions.java[BedrockLlama2ChatOptions.java] provides model configurations, such as temperature, topK, topP, etc.
+The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java[BedrockLlChatOptions.java] provides model configurations, such as temperature, topK, topP, etc.
-On start-up, the default options can be configured with the `BedrockLlama2ChatClient(api, options)` constructor or the `spring.ai.bedrock.llama2.chat.options.*` properties.
+On start-up, the default options can be configured with the `BedrockLlamaChatClient(api, options)` constructor or the `spring.ai.bedrock.llama.chat.options.*` properties.
At run-time you can override the default options by adding new, request specific, options to the `Prompt` call.
For example to override the default temperature for a specific request:
@@ -101,13 +101,13 @@ For example to override the default temperature for a specific request:
ChatResponse response = chatClient.call(
new Prompt(
"Generate the names of 5 famous pirates.",
- BedrockLlama2ChatOptions.builder()
+ BedrockLlamaChatOptions.builder()
.withTemperature(0.4)
.build()
));
----
-TIP: In addition to the model specific https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatOptions.java[BedrockLlama2ChatOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()].
+TIP: In addition to the model specific https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java[BedrockLlamaChatOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()].
== Sample Controller
@@ -122,13 +122,13 @@ spring.ai.bedrock.aws.timeout=1000ms
spring.ai.bedrock.aws.access-key=${AWS_ACCESS_KEY_ID}
spring.ai.bedrock.aws.secret-key=${AWS_SECRET_ACCESS_KEY}
-spring.ai.bedrock.llama2.chat.enabled=true
-spring.ai.bedrock.llama2.chat.options.temperature=0.8
+spring.ai.bedrock.llama.chat.enabled=true
+spring.ai.bedrock.llama.chat.options.temperature=0.8
----
TIP: replace the `regions`, `access-key` and `secret-key` with your AWS credentials.
-This will create a `BedrockLlama2ChatClient` implementation that you can inject into your class.
+This will create a `BedrockLlamaChatClient` implementation that you can inject into your class.
Here is an example of a simple `@Controller` class that uses the chat client for text generations.
[source,java]
@@ -136,10 +136,10 @@ Here is an example of a simple `@Controller` class that uses the chat client for
@RestController
public class ChatController {
- private final BedrockLlama2ChatClient chatClient;
+ private final BedrockLlamaChatClient chatClient;
@Autowired
- public ChatController(BedrockLlama2ChatClient chatClient) {
+ public ChatController(BedrockLlamaChatClient chatClient) {
this.chatClient = chatClient;
}
@@ -158,7 +158,7 @@ public class ChatController {
== Manual Configuration
-The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java[BedrockLlama2ChatClient] implements the `ChatClient` and `StreamingChatClient` and uses the <> to connect to the Bedrock Anthropic service.
+The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatClient.java[BedrockLlamaChatClient] implements the `ChatClient` and `StreamingChatClient` and uses the <> to connect to the Bedrock Anthropic service.
Add the `spring-ai-bedrock` dependency to your project's Maven `pom.xml` file:
@@ -181,18 +181,18 @@ dependencies {
TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file.
-Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java[BedrockLlama2ChatClient] and use it for text generations:
+Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatClient.java[BedrockLlamaChatClient] and use it for text generations:
[source,java]
----
-Llama2ChatBedrockApi api = new Llama2ChatBedrockApi(Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(),
+LlamaChatBedrockApi api = new LlamaChatBedrockApi(LlamaChatModel.LLAMA2_70B_CHAT_V1.id(),
EnvironmentVariableCredentialsProvider.create(),
Region.US_EAST_1.id(),
new ObjectMapper(),
Duration.ofMillis(1000L));
-BedrockLlama2ChatClient chatClient = new BedrockLlama2ChatClient(api,
- BedrockLlama2ChatOptions.builder()
+BedrockLlamaChatClient chatClient = new BedrockLlamaChatClient(api,
+ BedrockLlamaChatOptions.builder()
.withTemperature(0.5f)
.withMaxGenLen(100)
.withTopP(0.9f).build());
@@ -205,38 +205,38 @@ Flux response = chatClient.stream(
new Prompt("Generate the names of 5 famous pirates."));
----
-== Low-level Llama2ChatBedrockApi Client [[low-level-api]]
+== Low-level LlamaChatBedrockApi Client [[low-level-api]]
-https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApi.java[Llama2ChatBedrockApi] provides is lightweight Java client on top of AWS Bedrock https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html[Meta Llama 2 and Llama 2 Chat models].
+https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java[LlamaChatBedrockApi] provides is lightweight Java client on top of AWS Bedrock https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html[Meta Llama 2 and Llama 2 Chat models].
-Following class diagram illustrates the Llama2ChatBedrockApi interface and building blocks:
+Following class diagram illustrates the LlamaChatBedrockApi interface and building blocks:
-image::bedrock/bedrock-llama2-chat-api.jpg[Llama2ChatBedrockApi Class Diagram]
+image::bedrock/bedrock-llama-chat-api.jpg[LlamaChatBedrockApi Class Diagram]
-The Llama2ChatBedrockApi supports the `meta.llama2-13b-chat-v1` and `meta.llama2-70b-chat-v1` models for both synchronous (e.g. `chatCompletion()`) and streaming (e.g. `chatCompletionStream()`) responses.
+The LlamaChatBedrockApi supports the `meta.llama3-8b-instruct-v1:0`,`meta.llama3-70b-instruct-v1:0`,`meta.llama2-13b-chat-v1` and `meta.llama2-70b-chat-v1` models for both synchronous (e.g. `chatCompletion()`) and streaming (e.g. `chatCompletionStream()`) responses.
Here is a simple snippet how to use the api programmatically:
[source,java]
----
-Llama2ChatBedrockApi llama2ChatApi = new Llama2ChatBedrockApi(
- Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(),
+LlamaChatBedrockApi llamaChatApi = new LlamaChatBedrockApi(
+ LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(),
Region.US_EAST_1.id(),
Duration.ofMillis(1000L));
-Llama2ChatRequest request = Llama2ChatRequest.builder("Hello, my name is")
+LlamaChatRequest request = LlamaChatRequest.builder("Hello, my name is")
.withTemperature(0.9f)
.withTopP(0.9f)
.withMaxGenLen(20)
.build();
-Llama2ChatResponse response = llama2ChatApi.chatCompletion(request);
+LlamaChatResponse response = llamaChatApi.chatCompletion(request);
// Streaming response
-Flux responseStream = llama2ChatApi.chatCompletionStream(request);
-List responses = responseStream.collectList().block();
+Flux responseStream = llamaChatApi.chatCompletionStream(request);
+List responses = responseStream.collectList().block();
----
-Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApi.java[Llama2ChatBedrockApi.java]'s JavaDoc for further information.
+Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java[LlamaChatBedrockApi.java]'s JavaDoc for further information.
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc
index d3cf341bf1f..7f74dc4a320 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc
@@ -9,7 +9,7 @@ The OpenAI API does not call the function directly; instead, the model generates
Spring AI provides flexible and user-friendly ways to register and call custom functions.
In general, the custom functions need to provide a function `name`, `description`, and the function call `signature` (as JSON schema) to let the model know what arguments the function expects. The `description` helps the model to understand when to call the function.
-As a developer, you need to implement a functions that takes the function call arguments sent from the AI model, and respond with the result back to the model. Your function can in turn invoke other 3rd party services to provide the results.
+As a developer, you need to implement a function that takes the function call arguments sent from the AI model, and respond with the result back to the model. Your function can in turn invoke other 3rd party services to provide the results.
Spring AI makes this as easy as defining a `@Bean` definition that returns a `java.util.Function` and supplying the bean name as an option when invoking the `ChatClient`.
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc
index e1e36a4597d..4bb9acf2566 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc
@@ -176,9 +176,9 @@ where fruits are being displayed, possibly for convenience or aesthetic purposes
== Sample Controller
-https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-openai-spring-boot-starter` to your pom (or gradle) dependencies.
+https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-ollama-spring-boot-starter` to your pom (or gradle) dependencies.
-Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the OpenAi Chat client:
+Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the Ollama Chat client:
[source,application.properties]
----
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc
index 7bd008b400f..902d0d59b78 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc
@@ -183,7 +183,7 @@ image::spring-ai-chat-completions-clients.jpg[align="center", width="800px"]
* xref:api/chat/vertexai-gemini-chat.adoc[Google Vertex AI Gemini Chat Completion] (streaming, multi-modality & function-calling support)
* xref:api/bedrock.adoc[Amazon Bedrock]
** xref:api/chat/bedrock/bedrock-cohere.adoc[Cohere Chat Completion]
-** xref:api/chat/bedrock/bedrock-llama2.adoc[Llama2 Chat Completion]
+** xref:api/chat/bedrock/bedrock-llama.adoc[Llama Chat Completion]
** xref:api/chat/bedrock/bedrock-titan.adoc[Titan Chat Completion]
** xref:api/chat/bedrock/bedrock-anthropic.adoc[Anthropic Chat Completion]
** xref:api/chat/bedrock/bedrock-jurassic2.adoc[Jurassic2 Chat Completion]
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc
index 6d5d675a06c..b7bc8a74eb7 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc
@@ -81,6 +81,22 @@ The prefix `spring.ai.bedrock.titan.embedding` (defined in `BedrockTitanEmbeddin
Supported values are: `amazon.titan-embed-image-v1` and `amazon.titan-embed-text-v1`.
Model ID values can also be found in the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html[AWS Bedrock documentation for base model IDs].
+== Runtime Options [[embedding-options]]
+
+The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java[BedrockTitanEmbeddingOptions.java] provides model configurations, such as `input-type`.
+On start-up, the default options can be configured with the `BedrockTitanEmbeddingClient(api).withInputType(type)` method or the `spring.ai.bedrock.titan.embedding.input-type` properties.
+
+At run-time you can override the default options by adding new, request specific, options to the `EmbeddingRequest` call.
+For example to override the default temperature for a specific request:
+
+[source,java]
+----
+EmbeddingResponse embeddingResponse = embeddingClient.call(
+ new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"),
+ BedrockTitanEmbeddingOptions.builder()
+ .withInputType(InputType.TEXT)
+ .build()));
+----
== Sample Controller
@@ -154,7 +170,7 @@ Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/sp
var titanEmbeddingApi = new TitanEmbeddingBedrockApi(
TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id(), Region.US_EAST_1.id());
-var embeddingClient new BedrockTitanEmbeddingClient(titanEmbeddingApi);
+var embeddingClient = new BedrockTitanEmbeddingClient(titanEmbeddingApi);
EmbeddingResponse embeddingResponse = embeddingClient
.embedForResponse(List.of("Hello World")); // NOTE titan does not support batch embedding.
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc
index f1d9eb04a83..75a0d45f3b5 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc
@@ -187,14 +187,17 @@ Next, create an `OpenAiEmbeddingClient` instance and use it to compute the simil
----
var openAiApi = new OpenAiApi(System.getenv("OPENAI_API_KEY"));
-var embeddingClient = new OpenAiEmbeddingClient(openAiApi)
- .withDefaultOptions(OpenAiChatOptions.build()
- .withModel("text-embedding-ada-002")
- .withUser("user-6")
- .build());
+var embeddingClient = new OpenAiEmbeddingClient(
+ openAiApi,
+ MetadataMode.EMBED,
+ OpenAiEmbeddingOptions.builder()
+ .withModel("text-embedding-ada-002")
+ .withUser("user-6")
+ .build(),
+ RetryUtils.DEFAULT_RETRY_TEMPLATE);
EmbeddingResponse embeddingResponse = embeddingClient
- .embedForResponse(List.of("Hello World", "World is big and salvation is near"));
+ .embedForResponse(List.of("Hello World", "World is big and salvation is near"));
----
The `OpenAiEmbeddingOptions` provides the configuration information for the embedding requests.
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc
index 87b237a4cec..d0656f16084 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc
@@ -170,10 +170,10 @@ Example:
public class MyTikaDocumentReader {
@Value("classpath:/word-sample.docx") // This is the word document to load
- private Resource resource;
+ private Resource resource;
- List loadText() {
- TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(resourceUri);
+ List loadText() {
+ TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(resource);
return tikaDocumentReader.get();
}
}
@@ -229,4 +229,4 @@ See xref:api/vectordbs.adoc[Vector DB Documentation] for a full listing.
The following class diagram illustrates the ETL interfaces and implementations.
// image::etl-class-diagram.jpg[align="center", width="800px"]
-image::etl-class-diagram.jpg[align="center"]
\ No newline at end of file
+image::etl-class-diagram.jpg[align="center"]
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/multimodality.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/multimodality.adoc
new file mode 100644
index 00000000000..37a43ca9182
--- /dev/null
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/multimodality.adoc
@@ -0,0 +1,64 @@
+[[Multimodality]]
+= Multimodality API
+
+Humans process knowledge, simultaneously across multiple modes of data inputs.
+The way we learn, our experiences are all multimodal.
+We don't have just vision, just audio and just text.
+
+These foundational principles of learning were articulated by the father of modern education link:https://en.wikipedia.org/wiki/John_Amos_Comenius[John Amos Comenius], in his work, "Orbis Sensualium Pictus", dating back to 1658.
+
+image::orbis-sensualium-pictus2.jpg[Orbis Sensualium Pictus, align="center"]
+
+> "All things that are naturally connected ought to be taught in combination"
+
+Contrary to those principles, in the past, our approach to Machine Learning was often focused on specialized models tailored to process a single modality.
+For instance, we developed audio models for tasks like text-to-speech or speech-to-text, and computer vision models for tasks such as object detection and classification.
+
+However, a new wave of multimodal large language models starts to emerge.
+Examples include OpenAI's GPT-4 Vision, Google's Vertex AI Gemini Pro Vision, Anthropic's Claude3, and open source offerings LLaVA and balklava are able to accept multiple inputs, including text images, audio and video and generate text responses by integrating these inputs.
+
+The multimodal large language model (LLM) features enable the models to process and generate text in conjunction with other modalities such as images, audio, or video.
+
+== Spring AI Multimodality
+
+Multimodality refers to a model’s ability to simultaneously understand and process information from various sources, including text, images, audio, and other data formats.
+
+The Spring AI Message API provides all necessary abstractions to support multimodal LLMs.
+
+image::spring-ai-message-api.jpg[Spring AI Message API, width=600, align="center"]
+
+The Message’s `content` field is used as primarily text inputs, while the, optional, `media` field allows adding one or more additional content of different modalities such as images, audio and video.
+The `MimeType` specifies the modality type.
+Depending on the used LLMs the Media's data field can be either encoded raw media content or an URI to the content.
+
+NOTE: The media field is currently applicable only for user input messages (e.g., `UserMessage`). It does not hold significance for system messages. The `AssistantMessage`, which includes the LLM response, provides text content only. To generate non-text media outputs, you should utilize one of dedicated, single modality models.*
+
+
+For example we can take the following picture (*multimodal.test.png*) as an input and ask the LLM to explain what it sees in the picture.
+
+image::multimodal.test.png[Multimodal Test Image, 200, 200, align="left"]
+
+From most of the multimodal LLMs, the Spring AI code would look something like this:
+
+[source,java]
+----
+byte[] imageData = new ClassPathResource("/multimodal.test.png").getContentAsByteArray();
+
+var userMessage = new UserMessage(
+ "Explain what do you see in this picture?", // content
+ List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); // media
+
+ChatResponse response = chatClient.call(new Prompt(List.of(userMessage)));
+----
+
+and produce a response like:
+
+> This is an image of a fruit bowl with a simple design. The bowl is made of metal with curved wire edges that create an open structure, allowing the fruit to be visible from all angles. Inside the bowl, there are two yellow bananas resting on top of what appears to be a red apple. The bananas are slightly overripe, as indicated by the brown spots on their peels. The bowl has a metal ring at the top, likely to serve as a handle for carrying. The bowl is placed on a flat surface with a neutral-colored background that provides a clear view of the fruit inside.
+
+Latest version of Spring AI provides multimodal support for the following Chat Clients:
+
+* xref:api/chat/openai-chat.adoc#_multimodal[Open AI - (GPT-4-Vision model)]
+* xref:api/chat/openai-chat.adoc#_multimodal[Ollama - (LlaVa and Baklava models)]
+* xref:api/chat/vertexai-gemini-chat.adoc#_multimodal[Vertex AI Gemini - (gemini-pro-vision model)]
+* xref:api/chat/anthropic-chat.adoc#_multimodal[Anthropic Claude 3]
+* xref:api/chat/bedrock/bedrock-anthropic3.adoc#_multimodal[AWS Bedrock Anthropic Claude 3]
\ No newline at end of file
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/apache-cassandra.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/apache-cassandra.adoc
index a264c08c1a9..51c6110cb0f 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/apache-cassandra.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/apache-cassandra.adoc
@@ -4,9 +4,9 @@ This section walks you through setting up `CassandraVectorStore` to store docume
== What is Apache Cassandra ?
-link:https://cassandra.apache.org[Apache Cassandra] is a true open source distributed database reknown for scalability and high availability without compromising performance.
+link:https://cassandra.apache.org[Apache Cassandra®] is a true open source distributed database reknown for linear scalability, proven fault-tolerance and low latency, making it the perfect platform for mission-critical transactional data.
-Linear scalability, proven fault-tolerance and low latency on commodity hardware makes it the perfect platform for mission-critical data. Its Vector Similarity Search (VSS) is based on the JVector library that ensures best-in-class performance and relevancy.
+Its Vector Similarity Search (VSS) is based on the JVector library that ensures best-in-class performance and relevancy.
A vector search in Apache Cassandra is done as simply as:
```
@@ -15,9 +15,13 @@ SELECT content FROM table ORDER BY content_vector ANN OF query_embedding ;
More docs on this can be read https://cassandra.apache.org/doc/latest/cassandra/getting-started/vector-search-quickstart.html[here].
-The Spring AI Cassandra Vector Store is designed to work for both brand new RAG applications as well as being able to be retrofitted on top of existing data and tables. This vector store may also equally be used for non-RAG non_AI use-cases, e.g. semantic searcing in an existing database. The Vector Store will automatically create, or enhance, the schema as needed according to its configuration. If you don't want the schema modifications, configure the store with `disallowSchemaChanges`.
+This Spring AI Vector Store is designed to work for both brand new RAG applications as well as being able to be retrofitted on top of existing data and tables.
-== What is JVector Vector Search ?
+The store can also be used for non-RAG use-cases in an existing database, e.g. semantic searches, geo-proximity searches, etc.
+
+The store will automatically create, or enhance, the schema as needed according to its configuration. If you don't want the schema modifications, configure the store with `disallowSchemaChanges`.
+
+== What is JVector ?
link:https://github.com/jbellis/jvector[JVector] is a pure Java embedded vector search engine.
@@ -70,13 +74,6 @@ Add these dependencies to your project:
TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file.
-* If for example you want to use the OpenAI modules, remember to provide your OpenAI API Key. Set it as an environment variable like so:
-
-[source,bash]
-----
-export SPRING_AI_OPENAI_API_KEY='Your_OpenAI_API_Key'
-----
-
== Usage
@@ -93,21 +90,14 @@ public VectorStore vectorStore(EmbeddingClient embeddingClient) {
}
----
-NOTE: It is more convenient and preferred to create the `CassandraVectorStore` as a Bean.
-But if you decide you can create it manually.
-
[NOTE]
====
-The default configuration connects to Cassandra at localhost:9042 and will automatically create the default schema at `springframework_ai_vector.springframework_ai_vector_store`.
-
-Please see `CassandraVectorStoreConfig.Builder` for all the configuration options.
+The default configuration connects to Cassandra at `localhost:9042` and will automatically create a default schema in keyspace `springframework`, table `ai_vector_store`.
====
[NOTE]
====
-The Cassandra Java Driver is easiest configured via the `application.conf` file on the classpath.
-
-More info can be found link: https://github.com/apache/cassandra-java-driver/tree/4.x/manual/core/configuration[here].
+The Cassandra Java Driver is easiest configured via an `application.conf` file on the classpath. More info https://github.com/apache/cassandra-java-driver/tree/4.x/manual/core/configuration[here].
====
Then in your main code, create some documents:
@@ -148,7 +138,7 @@ List results = vectorStore.similaritySearch(
=== Metadata filtering
-You can leverage the generic, portable link:https://docs.spring.io/spring-ai/reference/api/vectordbs.html#_metadata_filters[metadata filters] with the CassandraVectorStore as well. Metadata fields must be configured in `CassandraVectorStoreConfig`.
+You can leverage the generic, portable link:https://docs.spring.io/spring-ai/reference/api/vectordbs.html#_metadata_filters[metadata filters] with the CassandraVectorStore as well. Metadata columns must be configured in `CassandraVectorStoreConfig`.
For example, you can use either the text expression language:
@@ -173,7 +163,9 @@ vectorStore.similaritySearch(
The portable filter expressions get automatically converted into link:https://cassandra.apache.org/doc/latest/cassandra/developing/cql/index.html[CQL queries].
-Metadata fields to be searchable need to be either primary key columns or SAI indexed. To do this configure the metadata field with the `SchemaColumnTags.INDEXED`.
+For metadata columns to be searchable they must be either primary keys or SAI indexed. To make non-primary-key columns indexed configure the metadata column with the `SchemaColumnTags.INDEXED`.
+
+
== Advanced Example: Vector Store ontop full Wikipedia dataset
@@ -187,7 +179,8 @@ Create the schema in the Cassandra database first:
[source,bash]
----
-wget https://raw.githubusercontent.com/datastax-labs/colbert-wikipedia-data/main/schema.cql -O colbert-wikipedia-schema.cql
+wget https://s.apache.org/colbert-wikipedia-schema-cql -O colbert-wikipedia-schema.cql
+
cqlsh -f colbert-wikipedia-schema.cql
----
@@ -212,14 +205,14 @@ public CassandraVectorStore store(EmbeddingClient embeddingClient) {
.withTableName("articles")
.withPartitionKeys(partitionColumns)
.withClusteringKeys(clusteringColumns)
- .withContentFieldName("body")
- .withEmbeddingFieldName("all_minilm_l6_v2_embedding")
+ .withContentColumnName("body")
+ .withEmbeddingColumndName("all_minilm_l6_v2_embedding")
.withIndexName("all_minilm_l6_v2_ann")
.disallowSchemaChanges()
- .addMetadataFields(extraColumns)
+ .addMetadataColumns(extraColumns)
.withPrimaryKeyTranslator((List
+
+ org.springframework.ai
+ spring-ai-opensearch-store
+ ${project.parent.version}
+ true
+
+
@@ -354,6 +361,13 @@
test
+
+ org.opensearch
+ opensearch-testcontainers
+ ${testcontainers.opensearch.version}
+ test
+
+
org.skyscreamerjsonassert
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java
index 9cfd634d13b..46ee804056d 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java
@@ -19,6 +19,9 @@
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
+import software.amazon.awssdk.regions.Region;
+import software.amazon.awssdk.regions.providers.AwsRegionProvider;
+import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
@@ -28,6 +31,7 @@
/**
* @author Christian Tzolov
+ * @author Wei Jiang
*/
@Configuration
@EnableConfigurationProperties({ BedrockAwsConnectionProperties.class })
@@ -45,4 +49,38 @@ public AwsCredentialsProvider credentialsProvider(BedrockAwsConnectionProperties
return DefaultCredentialsProvider.create();
}
+ @Bean
+ @ConditionalOnMissingBean
+ public AwsRegionProvider regionProvider(BedrockAwsConnectionProperties properties) {
+
+ if (StringUtils.hasText(properties.getRegion())) {
+ return new StaticRegionProvider(properties.getRegion());
+ }
+
+ return DefaultAwsRegionProviderChain.builder().build();
+ }
+
+ /**
+ * @author Wei Jiang
+ */
+ static class StaticRegionProvider implements AwsRegionProvider {
+
+ private final Region region;
+
+ public StaticRegionProvider(String region) {
+ try {
+ this.region = Region.of(region);
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("The region '" + region + "' is not a valid region!", e);
+ }
+ }
+
+ @Override
+ public Region getRegion() {
+ return this.region;
+ }
+
+ }
+
}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java
index 9a74161b050..7b3d4d2d545 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java
@@ -21,6 +21,7 @@
import org.springframework.ai.bedrock.anthropic.BedrockAnthropicChatClient;
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi;
import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
@@ -28,6 +29,7 @@
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.providers.AwsRegionProvider;
/**
* {@link AutoConfiguration Auto-configuration} for Bedrock Anthropic Chat Client.
@@ -35,6 +37,7 @@
* Leverages the Spring Cloud AWS to resolve the {@link AwsCredentialsProvider}.
*
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
@AutoConfiguration
@@ -46,16 +49,18 @@ public class BedrockAnthropicChatAutoConfiguration {
@Bean
@ConditionalOnMissingBean
+ @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class })
public AnthropicChatBedrockApi anthropicApi(AwsCredentialsProvider credentialsProvider,
- BedrockAnthropicChatProperties properties, BedrockAwsConnectionProperties awsProperties) {
- return new AnthropicChatBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(),
+ AwsRegionProvider regionProvider, BedrockAnthropicChatProperties properties,
+ BedrockAwsConnectionProperties awsProperties) {
+ return new AnthropicChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(),
new ObjectMapper(), awsProperties.getTimeout());
}
@Bean
+ @ConditionalOnBean(AnthropicChatBedrockApi.class)
public BedrockAnthropicChatClient anthropicChatClient(AnthropicChatBedrockApi anthropicApi,
BedrockAnthropicChatProperties properties) {
-
return new BedrockAnthropicChatClient(anthropicApi, properties.getOptions());
}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java
index 31f18597ca7..60e5cdce69c 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java
@@ -21,6 +21,7 @@
import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatClient;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi;
import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
@@ -28,6 +29,7 @@
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.providers.AwsRegionProvider;
/**
* {@link AutoConfiguration Auto-configuration} for Bedrock Anthropic Chat Client.
@@ -35,6 +37,7 @@
* Leverages the Spring Cloud AWS to resolve the {@link AwsCredentialsProvider}.
*
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
@AutoConfiguration
@@ -46,14 +49,17 @@ public class BedrockAnthropic3ChatAutoConfiguration {
@Bean
@ConditionalOnMissingBean
- public Anthropic3ChatBedrockApi anthropicApi(AwsCredentialsProvider credentialsProvider,
- BedrockAnthropic3ChatProperties properties, BedrockAwsConnectionProperties awsProperties) {
- return new Anthropic3ChatBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(),
+ @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class })
+ public Anthropic3ChatBedrockApi anthropic3Api(AwsCredentialsProvider credentialsProvider,
+ AwsRegionProvider regionProvider, BedrockAnthropic3ChatProperties properties,
+ BedrockAwsConnectionProperties awsProperties) {
+ return new Anthropic3ChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(),
new ObjectMapper(), awsProperties.getTimeout());
}
@Bean
- public BedrockAnthropic3ChatClient anthropicChatClient(Anthropic3ChatBedrockApi anthropicApi,
+ @ConditionalOnBean(Anthropic3ChatBedrockApi.class)
+ public BedrockAnthropic3ChatClient anthropic3ChatClient(Anthropic3ChatBedrockApi anthropicApi,
BedrockAnthropic3ChatProperties properties) {
return new BedrockAnthropic3ChatClient(anthropicApi, properties.getOptions());
}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java
index 3ee34f90fa1..66b6d5e5e77 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java
@@ -21,6 +21,7 @@
import org.springframework.ai.bedrock.cohere.BedrockCohereChatClient;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
@@ -28,11 +29,13 @@
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.providers.AwsRegionProvider;
/**
* {@link AutoConfiguration Auto-configuration} for Bedrock Cohere Chat Client.
*
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
@AutoConfiguration
@@ -44,13 +47,16 @@ public class BedrockCohereChatAutoConfiguration {
@Bean
@ConditionalOnMissingBean
+ @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class })
public CohereChatBedrockApi cohereChatApi(AwsCredentialsProvider credentialsProvider,
- BedrockCohereChatProperties properties, BedrockAwsConnectionProperties awsProperties) {
- return new CohereChatBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(),
+ AwsRegionProvider regionProvider, BedrockCohereChatProperties properties,
+ BedrockAwsConnectionProperties awsProperties) {
+ return new CohereChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(),
new ObjectMapper(), awsProperties.getTimeout());
}
@Bean
+ @ConditionalOnBean(CohereChatBedrockApi.class)
public BedrockCohereChatClient cohereChatClient(CohereChatBedrockApi cohereChatApi,
BedrockCohereChatProperties properties) {
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java
index 95ecba88858..76e3ea6033d 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java
@@ -17,12 +17,14 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.providers.AwsRegionProvider;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties;
import org.springframework.ai.bedrock.cohere.BedrockCohereEmbeddingClient;
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi;
import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
@@ -34,6 +36,7 @@
* {@link AutoConfiguration Auto-configuration} for Bedrock Cohere Embedding Client.
*
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
@AutoConfiguration
@@ -45,14 +48,17 @@ public class BedrockCohereEmbeddingAutoConfiguration {
@Bean
@ConditionalOnMissingBean
+ @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class })
public CohereEmbeddingBedrockApi cohereEmbeddingApi(AwsCredentialsProvider credentialsProvider,
- BedrockCohereEmbeddingProperties properties, BedrockAwsConnectionProperties awsProperties) {
- return new CohereEmbeddingBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(),
+ AwsRegionProvider regionProvider, BedrockCohereEmbeddingProperties properties,
+ BedrockAwsConnectionProperties awsProperties) {
+ return new CohereEmbeddingBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(),
new ObjectMapper(), awsProperties.getTimeout());
}
@Bean
@ConditionalOnMissingBean
+ @ConditionalOnBean(CohereEmbeddingBedrockApi.class)
public BedrockCohereEmbeddingClient cohereEmbeddingClient(CohereEmbeddingBedrockApi cohereEmbeddingApi,
BedrockCohereEmbeddingProperties properties) {
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java
index f7c50657238..e8266a0e417 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java
@@ -22,6 +22,7 @@
import org.springframework.ai.bedrock.jurassic2.BedrockAi21Jurassic2ChatClient;
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi;
import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
@@ -29,11 +30,13 @@
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.providers.AwsRegionProvider;
/**
* {@link AutoConfiguration Auto-configuration} for Bedrock Jurassic2 Chat Client.
*
* @author Ahmed Yousri
+ * @author Wei Jiang
* @since 1.0.0
*/
@AutoConfiguration
@@ -46,13 +49,16 @@ public class BedrockAi21Jurassic2ChatAutoConfiguration {
@Bean
@ConditionalOnMissingBean
+ @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class })
public Ai21Jurassic2ChatBedrockApi ai21Jurassic2ChatBedrockApi(AwsCredentialsProvider credentialsProvider,
- BedrockAi21Jurassic2ChatProperties properties, BedrockAwsConnectionProperties awsProperties) {
- return new Ai21Jurassic2ChatBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(),
+ AwsRegionProvider regionProvider, BedrockAi21Jurassic2ChatProperties properties,
+ BedrockAwsConnectionProperties awsProperties) {
+ return new Ai21Jurassic2ChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(),
new ObjectMapper(), awsProperties.getTimeout());
}
@Bean
+ @ConditionalOnBean(Ai21Jurassic2ChatBedrockApi.class)
public BedrockAi21Jurassic2ChatClient jurassic2ChatClient(Ai21Jurassic2ChatBedrockApi ai21Jurassic2ChatBedrockApi,
BedrockAi21Jurassic2ChatProperties properties) {
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java
similarity index 57%
rename from spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfiguration.java
rename to spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java
index 314e3671be2..9293acc84f2 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java
@@ -13,16 +13,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.springframework.ai.autoconfigure.bedrock.llama2;
+package org.springframework.ai.autoconfigure.bedrock.llama;
import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.providers.AwsRegionProvider;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties;
-import org.springframework.ai.bedrock.llama2.BedrockLlama2ChatClient;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
+import org.springframework.ai.bedrock.llama.BedrockLlamaChatClient;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi;
import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
@@ -31,33 +33,35 @@
import org.springframework.context.annotation.Import;
/**
- * {@link AutoConfiguration Auto-configuration} for Bedrock Llama2 Chat Client.
+ * {@link AutoConfiguration Auto-configuration} for Bedrock Llama Chat Client.
*
* Leverages the Spring Cloud AWS to resolve the {@link AwsCredentialsProvider}.
*
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
@AutoConfiguration
-@ConditionalOnClass(Llama2ChatBedrockApi.class)
-@EnableConfigurationProperties({ BedrockLlama2ChatProperties.class, BedrockAwsConnectionProperties.class })
-@ConditionalOnProperty(prefix = BedrockLlama2ChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true")
+@ConditionalOnClass(LlamaChatBedrockApi.class)
+@EnableConfigurationProperties({ BedrockLlamaChatProperties.class, BedrockAwsConnectionProperties.class })
+@ConditionalOnProperty(prefix = BedrockLlamaChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true")
@Import(BedrockAwsConnectionConfiguration.class)
-public class BedrockLlama2ChatAutoConfiguration {
+public class BedrockLlamaChatAutoConfiguration {
@Bean
@ConditionalOnMissingBean
- public Llama2ChatBedrockApi llama2Api(AwsCredentialsProvider credentialsProvider,
- BedrockLlama2ChatProperties properties, BedrockAwsConnectionProperties awsProperties) {
- return new Llama2ChatBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(),
+ @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class })
+ public LlamaChatBedrockApi llamaApi(AwsCredentialsProvider credentialsProvider, AwsRegionProvider regionProvider,
+ BedrockLlamaChatProperties properties, BedrockAwsConnectionProperties awsProperties) {
+ return new LlamaChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(),
new ObjectMapper(), awsProperties.getTimeout());
}
@Bean
- public BedrockLlama2ChatClient llama2ChatClient(Llama2ChatBedrockApi llama2Api,
- BedrockLlama2ChatProperties properties) {
+ @ConditionalOnBean(LlamaChatBedrockApi.class)
+ public BedrockLlamaChatClient llamaChatClient(LlamaChatBedrockApi llamaApi, BedrockLlamaChatProperties properties) {
- return new BedrockLlama2ChatClient(llama2Api, properties.getOptions());
+ return new BedrockLlamaChatClient(llamaApi, properties.getOptions());
}
}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java
similarity index 63%
rename from spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatProperties.java
rename to spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java
index a81f3a4af0c..f93b65aca32 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatProperties.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java
@@ -13,36 +13,36 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.springframework.ai.autoconfigure.bedrock.llama2;
+package org.springframework.ai.autoconfigure.bedrock.llama;
-import org.springframework.ai.bedrock.llama2.BedrockLlama2ChatOptions;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatModel;
+import org.springframework.ai.bedrock.llama.BedrockLlamaChatOptions;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
/**
- * Configuration properties for Bedrock Llama2.
+ * Configuration properties for Bedrock Llama.
*
* @author Christian Tzolov
* @since 0.8.0
*/
-@ConfigurationProperties(BedrockLlama2ChatProperties.CONFIG_PREFIX)
-public class BedrockLlama2ChatProperties {
+@ConfigurationProperties(BedrockLlamaChatProperties.CONFIG_PREFIX)
+public class BedrockLlamaChatProperties {
- public static final String CONFIG_PREFIX = "spring.ai.bedrock.llama2.chat";
+ public static final String CONFIG_PREFIX = "spring.ai.bedrock.llama.chat";
/**
- * Enable Bedrock Llama2 chat client. Disabled by default.
+ * Enable Bedrock Llama chat client. Disabled by default.
*/
private boolean enabled = false;
/**
- * The generative id to use. See the {@link Llama2ChatModel} for the supported models.
+ * The generative id to use. See the {@link LlamaChatModel} for the supported models.
*/
- private String model = Llama2ChatModel.LLAMA2_70B_CHAT_V1.id();
+ private String model = LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id();
@NestedConfigurationProperty
- private BedrockLlama2ChatOptions options = BedrockLlama2ChatOptions.builder()
+ private BedrockLlamaChatOptions options = BedrockLlamaChatOptions.builder()
.withTemperature(0.7f)
.withMaxGenLen(300)
.build();
@@ -63,11 +63,11 @@ public void setModel(String model) {
this.model = model;
}
- public BedrockLlama2ChatOptions getOptions() {
+ public BedrockLlamaChatOptions getOptions() {
return this.options;
}
- public void setOptions(BedrockLlama2ChatOptions options) {
+ public void setOptions(BedrockLlamaChatOptions options) {
this.options = options;
}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java
index e24ec3696ab..67995b9e39c 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java
@@ -21,6 +21,7 @@
import org.springframework.ai.bedrock.titan.BedrockTitanChatClient;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
@@ -28,11 +29,13 @@
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.providers.AwsRegionProvider;
/**
* {@link AutoConfiguration Auto-configuration} for Bedrock Titan Chat Client.
*
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
@AutoConfiguration
@@ -44,14 +47,16 @@ public class BedrockTitanChatAutoConfiguration {
@Bean
@ConditionalOnMissingBean
+ @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class })
public TitanChatBedrockApi titanChatBedrockApi(AwsCredentialsProvider credentialsProvider,
- BedrockTitanChatProperties properties, BedrockAwsConnectionProperties awsProperties) {
-
- return new TitanChatBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(),
+ AwsRegionProvider regionProvider, BedrockTitanChatProperties properties,
+ BedrockAwsConnectionProperties awsProperties) {
+ return new TitanChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(),
new ObjectMapper(), awsProperties.getTimeout());
}
@Bean
+ @ConditionalOnBean(TitanChatBedrockApi.class)
public BedrockTitanChatClient titanChatClient(TitanChatBedrockApi titanChatApi,
BedrockTitanChatProperties properties) {
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java
index f36c6e427fd..5ea79d4513a 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java
@@ -17,12 +17,14 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.providers.AwsRegionProvider;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties;
import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingClient;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
@@ -34,6 +36,7 @@
* {@link AutoConfiguration Auto-configuration} for Bedrock Titan Embedding Client.
*
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
@AutoConfiguration
@@ -45,14 +48,17 @@ public class BedrockTitanEmbeddingAutoConfiguration {
@Bean
@ConditionalOnMissingBean
+ @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class })
public TitanEmbeddingBedrockApi titanEmbeddingBedrockApi(AwsCredentialsProvider credentialsProvider,
- BedrockTitanEmbeddingProperties properties, BedrockAwsConnectionProperties awsProperties) {
- return new TitanEmbeddingBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(),
+ AwsRegionProvider regionProvider, BedrockTitanEmbeddingProperties properties,
+ BedrockAwsConnectionProperties awsProperties) {
+ return new TitanEmbeddingBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(),
new ObjectMapper(), awsProperties.getTimeout());
}
@Bean
@ConditionalOnMissingBean
+ @ConditionalOnBean(TitanEmbeddingBedrockApi.class)
public BedrockTitanEmbeddingClient titanEmbeddingClient(TitanEmbeddingBedrockApi titanEmbeddingApi,
BedrockTitanEmbeddingProperties properties) {
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraConnectionDetails.java
deleted file mode 100644
index b67f90f6ac9..00000000000
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraConnectionDetails.java
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Copyright 2024 - 2024 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.ai.autoconfigure.vectorstore.cassandra;
-
-import java.net.InetSocketAddress;
-import java.util.List;
-
-import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails;
-
-/**
- * @author Mick Semb Wever
- * @since 1.0.0
- */
-public interface CassandraConnectionDetails extends ConnectionDetails {
-
- boolean hasCassandraContactPoints();
-
- List getCassandraContactPoints();
-
- boolean hasCassandraLocalDatacenter();
-
- String getCassandraLocalDatacenter();
-
-}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java
index 1bb93f410a3..23a077ee340 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java
@@ -15,16 +15,17 @@
*/
package org.springframework.ai.autoconfigure.vectorstore.cassandra;
-import java.net.InetSocketAddress;
-import java.util.Arrays;
-import java.util.List;
+import java.time.Duration;
-import com.google.common.base.Preconditions;
+import com.datastax.oss.driver.api.core.CqlSession;
+import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.vectorstore.CassandraVectorStore;
import org.springframework.ai.vectorstore.CassandraVectorStoreConfig;
import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration;
+import org.springframework.boot.autoconfigure.cassandra.DriverConfigLoaderBuilderCustomizer;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
@@ -34,37 +35,24 @@
* @author Mick Semb Wever
* @since 1.0.0
*/
-@AutoConfiguration
-@ConditionalOnClass({ CassandraVectorStore.class, EmbeddingClient.class })
+@AutoConfiguration(after = CassandraAutoConfiguration.class)
+@ConditionalOnClass({ CassandraVectorStore.class, CqlSession.class })
@EnableConfigurationProperties(CassandraVectorStoreProperties.class)
public class CassandraVectorStoreAutoConfiguration {
- @Bean
- @ConditionalOnMissingBean(CassandraConnectionDetails.class)
- public PropertiesCassandraConnectionDetails cassandraConnectionDetails(CassandraVectorStoreProperties properties) {
- return new PropertiesCassandraConnectionDetails(properties);
- }
-
@Bean
@ConditionalOnMissingBean
public CassandraVectorStore vectorStore(EmbeddingClient embeddingClient, CassandraVectorStoreProperties properties,
- CassandraConnectionDetails cassandraConnectionDetails) {
+ CqlSession cqlSession) {
- var builder = CassandraVectorStoreConfig.builder();
- if (cassandraConnectionDetails.hasCassandraContactPoints()) {
- for (InetSocketAddress contactPoint : cassandraConnectionDetails.getCassandraContactPoints()) {
- builder = builder.addContactPoint(contactPoint);
- }
- }
- if (cassandraConnectionDetails.hasCassandraLocalDatacenter()) {
- builder = builder.withLocalDatacenter(cassandraConnectionDetails.getCassandraLocalDatacenter());
- }
+ var builder = CassandraVectorStoreConfig.builder().withCqlSession(cqlSession);
builder = builder.withKeyspaceName(properties.getKeyspace())
.withTableName(properties.getTable())
- .withContentColumnName(properties.getContentFieldName())
- .withEmbeddingColumnName(properties.getEmbeddingFieldName())
- .withIndexName(properties.getIndexName());
+ .withContentColumnName(properties.getContentColumnName())
+ .withEmbeddingColumnName(properties.getEmbeddingColumnName())
+ .withIndexName(properties.getIndexName())
+ .withFixedThreadPoolExecutorSize(properties.getFixedThreadPoolExecutorSize());
if (properties.getDisallowSchemaCreation()) {
builder = builder.disallowSchemaChanges();
@@ -73,46 +61,20 @@ public CassandraVectorStore vectorStore(EmbeddingClient embeddingClient, Cassand
return new CassandraVectorStore(builder.build(), embeddingClient);
}
- private static class PropertiesCassandraConnectionDetails implements CassandraConnectionDetails {
-
- private final CassandraVectorStoreProperties properties;
-
- public PropertiesCassandraConnectionDetails(CassandraVectorStoreProperties properties) {
- this.properties = properties;
- }
-
- private String[] getCassandraContactPointHosts() {
- return this.properties.getCassandraContactPointHosts().split("(,| )");
- }
-
- @Override
- public List getCassandraContactPoints() {
-
- Preconditions.checkState(hasCassandraContactPoints(), "cassandraContactPointHosts has not been set");
- final int port = this.properties.getCassandraContactPointPort();
-
- return Arrays.asList(getCassandraContactPointHosts())
- .stream()
- .map((host) -> InetSocketAddress.createUnresolved(host, port))
- .toList();
- }
-
- @Override
- public String getCassandraLocalDatacenter() {
- Preconditions.checkState(hasCassandraLocalDatacenter(), "cassandraLocalDatacenter has not been set");
- return this.properties.getCassandraLocalDatacenter();
- }
-
- @Override
- public boolean hasCassandraContactPoints() {
- return null != this.properties.getCassandraContactPointHosts();
- }
-
- @Override
- public boolean hasCassandraLocalDatacenter() {
- return null != this.properties.getCassandraLocalDatacenter();
- }
-
+ @Bean
+ public DriverConfigLoaderBuilderCustomizer driverConfigLoaderBuilderCustomizer() {
+ // this replaces spring-ai-cassandra-*.jar!application.conf
+ // as spring-boot autoconfigure will not resolve the default driver configs
+ return (builder) -> builder.startProfile(CassandraVectorStore.DRIVER_PROFILE_UPDATES)
+ .withString(DefaultDriverOption.REQUEST_CONSISTENCY, "LOCAL_QUORUM")
+ .withDuration(DefaultDriverOption.REQUEST_TIMEOUT, Duration.ofSeconds(1))
+ .withBoolean(DefaultDriverOption.REQUEST_DEFAULT_IDEMPOTENCE, true)
+ .endProfile()
+ .startProfile(CassandraVectorStore.DRIVER_PROFILE_SEARCH)
+ .withString(DefaultDriverOption.REQUEST_CONSISTENCY, "LOCAL_ONE")
+ .withDuration(DefaultDriverOption.REQUEST_TIMEOUT, Duration.ofSeconds(10))
+ .withBoolean(DefaultDriverOption.REQUEST_DEFAULT_IDEMPOTENCE, true)
+ .endProfile();
}
}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java
index 73b014ab6b0..27af7605e38 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java
@@ -15,6 +15,8 @@
*/
package org.springframework.ai.autoconfigure.vectorstore.cassandra;
+import com.google.api.client.util.Preconditions;
+
import org.springframework.ai.vectorstore.CassandraVectorStoreConfig;
import org.springframework.boot.context.properties.ConfigurationProperties;
@@ -27,17 +29,11 @@ public class CassandraVectorStoreProperties {
public static final String CONFIG_PREFIX = "spring.ai.vectorstore.cassandra";
- private String cassandraContactPointHosts = null;
-
- private int cassandraContactPointPort = 9042;
-
- private String cassandraLocalDatacenter = null;
-
private String keyspace = CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME;
private String table = CassandraVectorStoreConfig.DEFAULT_TABLE_NAME;
- private String indexName = CassandraVectorStoreConfig.DEFAULT_INDEX_NAME;
+ private String indexName = null;
private String contentColumnName = CassandraVectorStoreConfig.DEFAULT_CONTENT_COLUMN_NAME;
@@ -45,30 +41,7 @@ public class CassandraVectorStoreProperties {
private boolean disallowSchemaChanges = false;
- public String getCassandraContactPointHosts() {
- return this.cassandraContactPointHosts;
- }
-
- /** comma or space separated */
- public void setCassandraContactPointHosts(String cassandraContactPointHosts) {
- this.cassandraContactPointHosts = cassandraContactPointHosts;
- }
-
- public int getCassandraContactPointPort() {
- return this.cassandraContactPointPort;
- }
-
- public void setCassandraContactPointPort(int cassandraContactPointPort) {
- this.cassandraContactPointPort = cassandraContactPointPort;
- }
-
- public String getCassandraLocalDatacenter() {
- return this.cassandraLocalDatacenter;
- }
-
- public void setCassandraLocalDatacenter(String cassandraLocalDatacenter) {
- this.cassandraLocalDatacenter = cassandraLocalDatacenter;
- }
+ private int fixedThreadPoolExecutorSize = CassandraVectorStoreConfig.DEFAULT_ADD_CONCURRENCY;
public String getKeyspace() {
return this.keyspace;
@@ -94,20 +67,20 @@ public void setIndexName(String indexName) {
this.indexName = indexName;
}
- public String getContentFieldName() {
+ public String getContentColumnName() {
return this.contentColumnName;
}
- public void setContentFieldName(String contentFieldName) {
- this.contentColumnName = contentFieldName;
+ public void setContentColumnName(String contentColumnName) {
+ this.contentColumnName = contentColumnName;
}
- public String getEmbeddingFieldName() {
+ public String getEmbeddingColumnName() {
return this.embeddingColumnName;
}
- public void setEmbeddingFieldName(String embeddingFieldName) {
- this.embeddingColumnName = embeddingFieldName;
+ public void setEmbeddingColumnName(String embeddingColumnName) {
+ this.embeddingColumnName = embeddingColumnName;
}
public Boolean getDisallowSchemaCreation() {
@@ -118,4 +91,13 @@ public void setDisallowSchemaCreation(boolean disallowSchemaCreation) {
this.disallowSchemaChanges = disallowSchemaCreation;
}
+ public int getFixedThreadPoolExecutorSize() {
+ return this.fixedThreadPoolExecutorSize;
+ }
+
+ public void setFixedThreadPoolExecutorSize(int fixedThreadPoolExecutorSize) {
+ Preconditions.checkArgument(0 < fixedThreadPoolExecutorSize);
+ this.fixedThreadPoolExecutorSize = fixedThreadPoolExecutorSize;
+ }
+
}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java
new file mode 100644
index 00000000000..e5c24d4ddf0
--- /dev/null
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.autoconfigure.vectorstore.opensearch;
+
+import org.apache.hc.client5.http.auth.AuthScope;
+import org.apache.hc.client5.http.auth.UsernamePasswordCredentials;
+import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider;
+import org.apache.hc.core5.http.HttpHost;
+import org.opensearch.client.opensearch.OpenSearchClient;
+import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder;
+import org.springframework.ai.embedding.EmbeddingClient;
+import org.springframework.ai.vectorstore.OpenSearchVectorStore;
+import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
+import org.springframework.boot.context.properties.EnableConfigurationProperties;
+import org.springframework.context.annotation.Bean;
+
+import java.net.URISyntaxException;
+import java.util.Optional;
+
+@AutoConfiguration
+@ConditionalOnClass({OpenSearchVectorStore.class, EmbeddingClient.class, OpenSearchClient.class})
+@EnableConfigurationProperties(OpenSearchVectorStoreProperties.class)
+class OpenSearchVectorStoreAutoConfiguration {
+
+ @Bean
+ @ConditionalOnMissingBean
+ OpenSearchVectorStore vectorStore(OpenSearchVectorStoreProperties properties, OpenSearchClient openSearchClient,
+ EmbeddingClient embeddingClient) {
+ return new OpenSearchVectorStore(
+ Optional.ofNullable(properties.getIndexName()).orElse(OpenSearchVectorStore.DEFAULT_INDEX_NAME),
+ openSearchClient, embeddingClient, Optional.ofNullable(properties.getMappingJson())
+ .orElse(OpenSearchVectorStore.DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION_1536));
+ }
+
+ @Bean
+ @ConditionalOnMissingBean
+ OpenSearchClient openSearchClient(OpenSearchVectorStoreProperties properties) {
+ HttpHost[] httpHosts = properties.getUris().stream().map(s -> createHttpHost(s)).toArray(HttpHost[]::new);
+ ApacheHttpClient5TransportBuilder transportBuilder = ApacheHttpClient5TransportBuilder.builder(httpHosts);
+
+ Optional.ofNullable(properties.getUsername())
+ .map(username -> createBasicCredentialsProvider(httpHosts[0], username, properties.getPassword()))
+ .ifPresent(basicCredentialsProvider -> transportBuilder.setHttpClientConfigCallback(
+ httpAsyncClientBuilder -> httpAsyncClientBuilder.setDefaultCredentialsProvider(
+ basicCredentialsProvider)));
+
+ return new OpenSearchClient(transportBuilder.build());
+ }
+
+ private BasicCredentialsProvider createBasicCredentialsProvider(HttpHost httpHost, String username,
+ String password) {
+ BasicCredentialsProvider basicCredentialsProvider = new BasicCredentialsProvider();
+ basicCredentialsProvider.setCredentials(new AuthScope(httpHost),
+ new UsernamePasswordCredentials(username, password.toCharArray()));
+ return basicCredentialsProvider;
+ }
+
+ private HttpHost createHttpHost(String s) {
+ try {
+ return HttpHost.create(s);
+ } catch (URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreProperties.java
new file mode 100644
index 00000000000..900cdbd233b
--- /dev/null
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreProperties.java
@@ -0,0 +1,80 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.autoconfigure.vectorstore.opensearch;
+
+import org.springframework.boot.context.properties.ConfigurationProperties;
+
+import java.util.List;
+
+@ConfigurationProperties(prefix = OpenSearchVectorStoreProperties.CONFIG_PREFIX)
+public class OpenSearchVectorStoreProperties {
+
+ public static final String CONFIG_PREFIX = "spring.ai.vectorstore.opensearch";
+
+ /**
+ * Comma-separated list of the OpenSearch instances to use.
+ */
+ private List uris;
+
+ private String indexName;
+
+ private String username;
+
+ private String password;
+
+ private String mappingJson;
+
+ public List getUris() {
+ return uris;
+ }
+
+ public void setUris(List uris) {
+ this.uris = uris;
+ }
+
+ public String getIndexName() {
+ return this.indexName;
+ }
+
+ public void setIndexName(String indexName) {
+ this.indexName = indexName;
+ }
+
+ public String getUsername() {
+ return username;
+ }
+
+ public void setUsername(String username) {
+ this.username = username;
+ }
+
+ public String getPassword() {
+ return password;
+ }
+
+ public void setPassword(String password) {
+ this.password = password;
+ }
+
+ public String getMappingJson() {
+ return mappingJson;
+ }
+
+ public void setMappingJson(String mappingJson) {
+ this.mappingJson = mappingJson;
+ }
+
+}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports
index 59f104b6825..1d1532873d9 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports
+++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports
@@ -6,7 +6,7 @@ org.springframework.ai.autoconfigure.huggingface.HuggingfaceChatAutoConfiguratio
org.springframework.ai.autoconfigure.vertexai.palm2.VertexAiPalm2AutoConfiguration
org.springframework.ai.autoconfigure.vertexai.gemini.VertexAiGeminiAutoConfiguration
org.springframework.ai.autoconfigure.bedrock.jurrasic2.BedrockAi21Jurassic2ChatAutoConfiguration
-org.springframework.ai.autoconfigure.bedrock.llama2.BedrockLlama2ChatAutoConfiguration
+org.springframework.ai.autoconfigure.bedrock.llama.BedrockLlamaChatAutoConfiguration
org.springframework.ai.autoconfigure.bedrock.cohere.BedrockCohereChatAutoConfiguration
org.springframework.ai.autoconfigure.bedrock.cohere.BedrockCohereEmbeddingAutoConfiguration
org.springframework.ai.autoconfigure.bedrock.anthropic.BedrockAnthropicChatAutoConfiguration
@@ -32,3 +32,4 @@ org.springframework.ai.autoconfigure.anthropic.AnthropicAutoConfiguration
org.springframework.ai.autoconfigure.watsonxai.WatsonxAiAutoConfiguration
org.springframework.ai.autoconfigure.vectorstore.elasticsearch.ElasticsearchVectorStoreAutoConfiguration
org.springframework.ai.autoconfigure.vectorstore.cassandra.CassandraVectorStoreAutoConfiguration
+org.springframework.ai.autoconfigure.vectorstore.opensearch.OpenSearchVectorStoreAutoConfiguration
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java
new file mode 100644
index 00000000000..bea58ce80e3
--- /dev/null
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java
@@ -0,0 +1,139 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.autoconfigure.bedrock;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.AutoConfigurations;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
+import org.springframework.boot.context.properties.EnableConfigurationProperties;
+import org.springframework.boot.test.context.runner.ApplicationContextRunner;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Import;
+
+import software.amazon.awssdk.auth.credentials.AwsCredentials;
+import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.Region;
+import software.amazon.awssdk.regions.providers.AwsRegionProvider;
+
+/**
+ * @author Wei Jiang
+ * @since 0.8.1
+ */
+@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
+@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
+public class BedrockAwsConnectionConfigurationIT {
+
+ @Test
+ public void autoConfigureAWSCredentialAndRegionProvider() {
+ new ApplicationContextRunner()
+ .withPropertyValues("spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"),
+ "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"),
+ "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id())
+ .withConfiguration(AutoConfigurations.of(TestAutoConfiguration.class))
+ .run((context) -> {
+ var awsCredentialsProvider = context.getBean(AwsCredentialsProvider.class);
+ var awsRegionProvider = context.getBean(AwsRegionProvider.class);
+
+ assertThat(awsCredentialsProvider).isNotNull();
+ assertThat(awsRegionProvider).isNotNull();
+
+ var credentials = awsCredentialsProvider.resolveCredentials();
+ assertThat(credentials).isNotNull();
+ assertThat(credentials.accessKeyId()).isEqualTo(System.getenv("AWS_ACCESS_KEY_ID"));
+ assertThat(credentials.secretAccessKey()).isEqualTo(System.getenv("AWS_SECRET_ACCESS_KEY"));
+
+ assertThat(awsRegionProvider.getRegion()).isEqualTo(Region.US_EAST_1);
+ });
+ }
+
+ @Test
+ public void autoConfigureWithCustomAWSCredentialAndRegionProvider() {
+ new ApplicationContextRunner()
+ .withPropertyValues("spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"),
+ "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"),
+ "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id())
+ .withConfiguration(AutoConfigurations.of(TestAutoConfiguration.class,
+ CustomAwsCredentialsProviderAndAwsRegionProviderAutoConfiguration.class))
+ .run((context) -> {
+ var awsCredentialsProvider = context.getBean(AwsCredentialsProvider.class);
+ var awsRegionProvider = context.getBean(AwsRegionProvider.class);
+
+ assertThat(awsCredentialsProvider).isNotNull();
+ assertThat(awsRegionProvider).isNotNull();
+
+ var credentials = awsCredentialsProvider.resolveCredentials();
+ assertThat(credentials).isNotNull();
+ assertThat(credentials.accessKeyId()).isEqualTo("CUSTOM_ACCESS_KEY");
+ assertThat(credentials.secretAccessKey()).isEqualTo("CUSTOM_SECRET_ACCESS_KEY");
+
+ assertThat(awsRegionProvider.getRegion()).isEqualTo(Region.AWS_GLOBAL);
+ });
+ }
+
+ @EnableConfigurationProperties({ BedrockAwsConnectionProperties.class })
+ @Import(BedrockAwsConnectionConfiguration.class)
+ static class TestAutoConfiguration {
+
+ }
+
+ @AutoConfiguration
+ static class CustomAwsCredentialsProviderAndAwsRegionProviderAutoConfiguration {
+
+ @Bean
+ @ConditionalOnMissingBean
+ public AwsCredentialsProvider credentialsProvider() {
+ return new AwsCredentialsProvider() {
+
+ @Override
+ public AwsCredentials resolveCredentials() {
+ return new AwsCredentials() {
+
+ @Override
+ public String accessKeyId() {
+ return "CUSTOM_ACCESS_KEY";
+ }
+
+ @Override
+ public String secretAccessKey() {
+ return "CUSTOM_SECRET_ACCESS_KEY";
+ }
+
+ };
+ }
+
+ };
+ }
+
+ @Bean
+ @ConditionalOnMissingBean
+ public AwsRegionProvider regionProvider() {
+ return new AwsRegionProvider() {
+
+ @Override
+ public Region getRegion() {
+ return Region.AWS_GLOBAL;
+ }
+
+ };
+ }
+
+ }
+
+}
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java
similarity index 61%
rename from spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java
rename to spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java
index 3c55877bcaf..6ca69b65599 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java
@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.springframework.ai.autoconfigure.bedrock.llama2;
+package org.springframework.ai.autoconfigure.bedrock.llama;
import java.util.List;
import java.util.Map;
@@ -27,8 +27,8 @@
import software.amazon.awssdk.regions.Region;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties;
-import org.springframework.ai.bedrock.llama2.BedrockLlama2ChatClient;
-import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatModel;
+import org.springframework.ai.bedrock.llama.BedrockLlamaChatClient;
+import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
@@ -41,21 +41,22 @@
/**
* @author Christian Tzolov
+ * @author Wei Jiang
* @since 0.8.0
*/
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
-public class BedrockLlama2ChatAutoConfigurationIT {
+public class BedrockLlamaChatAutoConfigurationIT {
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
- .withPropertyValues("spring.ai.bedrock.llama2.chat.enabled=true",
+ .withPropertyValues("spring.ai.bedrock.llama.chat.enabled=true",
"spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"),
"spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"),
"spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(),
- "spring.ai.bedrock.llama2.chat.model=" + Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(),
- "spring.ai.bedrock.llama2.chat.options.temperature=0.5",
- "spring.ai.bedrock.llama2.chat.options.maxGenLen=500")
- .withConfiguration(AutoConfigurations.of(BedrockLlama2ChatAutoConfiguration.class));
+ "spring.ai.bedrock.llama.chat.model=" + LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(),
+ "spring.ai.bedrock.llama.chat.options.temperature=0.5",
+ "spring.ai.bedrock.llama.chat.options.maxGenLen=500")
+ .withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class));
private final Message systemMessage = new SystemPromptTemplate("""
You are a helpful AI assistant. Your name is {name}.
@@ -70,8 +71,8 @@ public class BedrockLlama2ChatAutoConfigurationIT {
@Test
public void chatCompletion() {
contextRunner.run(context -> {
- BedrockLlama2ChatClient llama2ChatClient = context.getBean(BedrockLlama2ChatClient.class);
- ChatResponse response = llama2ChatClient.call(new Prompt(List.of(userMessage, systemMessage)));
+ BedrockLlamaChatClient llamaChatClient = context.getBean(BedrockLlamaChatClient.class);
+ ChatResponse response = llamaChatClient.call(new Prompt(List.of(userMessage, systemMessage)));
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
});
}
@@ -80,9 +81,9 @@ public void chatCompletion() {
public void chatCompletionStreaming() {
contextRunner.run(context -> {
- BedrockLlama2ChatClient llama2ChatClient = context.getBean(BedrockLlama2ChatClient.class);
+ BedrockLlamaChatClient llamaChatClient = context.getBean(BedrockLlamaChatClient.class);
- Flux response = llama2ChatClient.stream(new Prompt(List.of(userMessage, systemMessage)));
+ Flux response = llamaChatClient.stream(new Prompt(List.of(userMessage, systemMessage)));
List responses = response.collectList().block();
assertThat(responses.size()).isGreaterThan(2);
@@ -102,23 +103,23 @@ public void chatCompletionStreaming() {
public void propertiesTest() {
new ApplicationContextRunner()
- .withPropertyValues("spring.ai.bedrock.llama2.chat.enabled=true",
+ .withPropertyValues("spring.ai.bedrock.llama.chat.enabled=true",
"spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY",
- "spring.ai.bedrock.llama2.chat.model=MODEL_XYZ",
+ "spring.ai.bedrock.llama.chat.model=MODEL_XYZ",
"spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(),
- "spring.ai.bedrock.llama2.chat.options.temperature=0.55",
- "spring.ai.bedrock.llama2.chat.options.maxGenLen=123")
- .withConfiguration(AutoConfigurations.of(BedrockLlama2ChatAutoConfiguration.class))
+ "spring.ai.bedrock.llama.chat.options.temperature=0.55",
+ "spring.ai.bedrock.llama.chat.options.maxGenLen=123")
+ .withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class))
.run(context -> {
- var llama2ChatProperties = context.getBean(BedrockLlama2ChatProperties.class);
+ var llamaChatProperties = context.getBean(BedrockLlamaChatProperties.class);
var awsProperties = context.getBean(BedrockAwsConnectionProperties.class);
- assertThat(llama2ChatProperties.isEnabled()).isTrue();
+ assertThat(llamaChatProperties.isEnabled()).isTrue();
assertThat(awsProperties.getRegion()).isEqualTo(Region.EU_CENTRAL_1.id());
- assertThat(llama2ChatProperties.getOptions().getTemperature()).isEqualTo(0.55f);
- assertThat(llama2ChatProperties.getOptions().getMaxGenLen()).isEqualTo(123);
- assertThat(llama2ChatProperties.getModel()).isEqualTo("MODEL_XYZ");
+ assertThat(llamaChatProperties.getOptions().getTemperature()).isEqualTo(0.55f);
+ assertThat(llamaChatProperties.getOptions().getMaxGenLen()).isEqualTo(123);
+ assertThat(llamaChatProperties.getModel()).isEqualTo("MODEL_XYZ");
assertThat(awsProperties.getAccessKey()).isEqualTo("ACCESS_KEY");
assertThat(awsProperties.getSecretKey()).isEqualTo("SECRET_KEY");
@@ -129,27 +130,26 @@ public void propertiesTest() {
public void chatCompletionDisabled() {
// It is disabled by default
- new ApplicationContextRunner()
- .withConfiguration(AutoConfigurations.of(BedrockLlama2ChatAutoConfiguration.class))
+ new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class))
.run(context -> {
- assertThat(context.getBeansOfType(BedrockLlama2ChatProperties.class)).isEmpty();
- assertThat(context.getBeansOfType(BedrockLlama2ChatClient.class)).isEmpty();
+ assertThat(context.getBeansOfType(BedrockLlamaChatProperties.class)).isEmpty();
+ assertThat(context.getBeansOfType(BedrockLlamaChatClient.class)).isEmpty();
});
// Explicitly enable the chat auto-configuration.
- new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.llama2.chat.enabled=true")
- .withConfiguration(AutoConfigurations.of(BedrockLlama2ChatAutoConfiguration.class))
+ new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.llama.chat.enabled=true")
+ .withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class))
.run(context -> {
- assertThat(context.getBeansOfType(BedrockLlama2ChatProperties.class)).isNotEmpty();
- assertThat(context.getBeansOfType(BedrockLlama2ChatClient.class)).isNotEmpty();
+ assertThat(context.getBeansOfType(BedrockLlamaChatProperties.class)).isNotEmpty();
+ assertThat(context.getBeansOfType(BedrockLlamaChatClient.class)).isNotEmpty();
});
// Explicitly disable the chat auto-configuration.
- new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.llama2.chat.enabled=false")
- .withConfiguration(AutoConfigurations.of(BedrockLlama2ChatAutoConfiguration.class))
+ new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.llama.chat.enabled=false")
+ .withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class))
.run(context -> {
- assertThat(context.getBeansOfType(BedrockLlama2ChatProperties.class)).isEmpty();
- assertThat(context.getBeansOfType(BedrockLlama2ChatClient.class)).isEmpty();
+ assertThat(context.getBeansOfType(BedrockLlamaChatProperties.class)).isEmpty();
+ assertThat(context.getBeansOfType(BedrockLlamaChatClient.class)).isEmpty();
});
}
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfigurationIT.java
index 15530052f96..119bde94478 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfigurationIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfigurationIT.java
@@ -31,6 +31,7 @@
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.boot.autoconfigure.AutoConfigurations;
+import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@@ -55,18 +56,18 @@ class CassandraVectorStoreAutoConfigurationIT {
ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad")));
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
- .withConfiguration(AutoConfigurations.of(CassandraVectorStoreAutoConfiguration.class))
+ .withConfiguration(
+ AutoConfigurations.of(CassandraVectorStoreAutoConfiguration.class, CassandraAutoConfiguration.class))
.withUserConfiguration(Config.class)
.withPropertyValues("spring.ai.vectorstore.cassandra.keyspace=test_autoconfigure")
- .withPropertyValues("spring.ai.vectorstore.cassandra.contentFieldName=doc_chunk");
+ .withPropertyValues("spring.ai.vectorstore.cassandra.contentColumnName=doc_chunk");
@Test
void addAndSearch() {
- contextRunner
- .withPropertyValues("spring.ai.vectorstore.cassandra.cassandraContactPointHosts=" + getContactPointHost())
- .withPropertyValues("spring.ai.vectorstore.cassandra.cassandraContactPointPort=" + getContactPointPort())
- .withPropertyValues("spring.ai.vectorstore.cassandra.cassandraLocalDatacenter="
- + cassandraContainer.getLocalDatacenter())
+ contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost())
+ .withPropertyValues("spring.cassandra.port=" + getContactPointPort())
+ .withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter())
+ .withPropertyValues("spring.ai.vectorstore.cassandra.fixedThreadPoolExecutorSize=8")
.run(context -> {
VectorStore vectorStore = context.getBean(VectorStore.class);
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java
index c5a3e4d0f7c..c0d5ad04aeb 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java
@@ -30,39 +30,34 @@ class CassandraVectorStorePropertiesTests {
@Test
void defaultValues() {
var props = new CassandraVectorStoreProperties();
- assertThat(props.getCassandraContactPointHosts()).isNull();
- assertThat(props.getCassandraContactPointPort()).isEqualTo(9042);
- assertThat(props.getCassandraLocalDatacenter()).isNull();
assertThat(props.getKeyspace()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME);
assertThat(props.getTable()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_TABLE_NAME);
- assertThat(props.getContentFieldName()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_CONTENT_COLUMN_NAME);
- assertThat(props.getEmbeddingFieldName()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_EMBEDDING_COLUMN_NAME);
- assertThat(props.getIndexName()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_INDEX_NAME);
+ assertThat(props.getContentColumnName()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_CONTENT_COLUMN_NAME);
+ assertThat(props.getEmbeddingColumnName()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_EMBEDDING_COLUMN_NAME);
+ assertThat(props.getIndexName()).isNull();
assertThat(props.getDisallowSchemaCreation()).isFalse();
+ assertThat(props.getFixedThreadPoolExecutorSize())
+ .isEqualTo(CassandraVectorStoreConfig.DEFAULT_ADD_CONCURRENCY);
}
@Test
void customValues() {
var props = new CassandraVectorStoreProperties();
- props.setCassandraContactPointHosts("127.0.0.1,127.0.0.2");
- props.setCassandraContactPointPort(9043);
- props.setCassandraLocalDatacenter("dc1");
props.setKeyspace("my_keyspace");
props.setTable("my_table");
- props.setContentFieldName("my_content");
- props.setEmbeddingFieldName("my_vector");
+ props.setContentColumnName("my_content");
+ props.setEmbeddingColumnName("my_vector");
props.setIndexName("my_sai");
props.setDisallowSchemaCreation(true);
+ props.setFixedThreadPoolExecutorSize(10);
- assertThat(props.getCassandraContactPointHosts()).isEqualTo("127.0.0.1,127.0.0.2");
- assertThat(props.getCassandraContactPointPort()).isEqualTo(9043);
- assertThat(props.getCassandraLocalDatacenter()).isEqualTo("dc1");
assertThat(props.getKeyspace()).isEqualTo("my_keyspace");
assertThat(props.getTable()).isEqualTo("my_table");
- assertThat(props.getContentFieldName()).isEqualTo("my_content");
- assertThat(props.getEmbeddingFieldName()).isEqualTo("my_vector");
+ assertThat(props.getContentColumnName()).isEqualTo("my_content");
+ assertThat(props.getEmbeddingColumnName()).isEqualTo("my_vector");
assertThat(props.getIndexName()).isEqualTo("my_sai");
assertThat(props.getDisallowSchemaCreation()).isTrue();
+ assertThat(props.getFixedThreadPoolExecutorSize()).isEqualTo(10);
}
}
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java
new file mode 100644
index 00000000000..bfa3dab48f2
--- /dev/null
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java
@@ -0,0 +1,125 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.autoconfigure.vectorstore.opensearch;
+
+import org.awaitility.Awaitility;
+import org.junit.jupiter.api.Test;
+import org.opensearch.testcontainers.OpensearchContainer;
+import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.embedding.EmbeddingClient;
+import org.springframework.ai.transformers.TransformersEmbeddingClient;
+import org.springframework.ai.vectorstore.OpenSearchVectorStore;
+import org.springframework.ai.vectorstore.SearchRequest;
+import org.springframework.boot.autoconfigure.AutoConfigurations;
+import org.springframework.boot.test.context.runner.ApplicationContextRunner;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.core.io.DefaultResourceLoader;
+import org.testcontainers.junit.jupiter.Container;
+import org.testcontainers.junit.jupiter.Testcontainers;
+import org.testcontainers.utility.DockerImageName;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.hamcrest.Matchers.hasSize;
+
+@Testcontainers
+class OpenSearchVectorStoreAutoConfigurationIT {
+
+ @Container
+ private static final OpensearchContainer> opensearchContainer =
+ new OpensearchContainer<>(DockerImageName.parse("opensearchproject/opensearch:2.12.0"));
+
+ private static final String DOCUMENT_INDEX = "auto-spring-ai-document-index";
+
+ private List documents = List.of(
+ new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")),
+ new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()),
+ new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2")));
+
+ private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withConfiguration(
+ AutoConfigurations.of(OpenSearchVectorStoreAutoConfiguration.class, SpringAiRetryAutoConfiguration.class))
+ .withUserConfiguration(Config.class)
+ .withPropertyValues(
+ OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".uris=" + opensearchContainer.getHttpHostAddress(),
+ OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".indexName=" + DOCUMENT_INDEX,
+ OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".mappingJson=" + """
+ {
+ "properties":{
+ "embedding":{
+ "type":"knn_vector",
+ "dimension":384
+ }
+ }
+ }
+ """);
+
+ @Test
+ public void addAndSearchTest() {
+
+ this.contextRunner.run(context -> {
+ OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class);
+
+ vectorStore.add(documents);
+
+ Awaitility.await().until(() -> vectorStore.similaritySearch(
+ SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)),
+ hasSize(1));
+
+ List results = vectorStore.similaritySearch(
+ SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0));
+
+ assertThat(results).hasSize(1);
+ Document resultDoc = results.get(0);
+ assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());
+ assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock");
+ assertThat(resultDoc.getMetadata()).hasSize(2);
+ assertThat(resultDoc.getMetadata()).containsKey("meta2");
+ assertThat(resultDoc.getMetadata()).containsKey("distance");
+
+ // Remove all documents from the store
+ vectorStore.delete(documents.stream().map(Document::getId).toList());
+
+ Awaitility.await().until(() -> vectorStore.similaritySearch(
+ SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)), hasSize(0));
+ });
+ }
+
+ private String getText(String uri) {
+ var resource = new DefaultResourceLoader().getResource(uri);
+ try {
+ return resource.getContentAsString(StandardCharsets.UTF_8);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Configuration(proxyBeanMethods = false)
+ static class Config {
+
+ @Bean
+ public EmbeddingClient embeddingClient() {
+ return new TransformersEmbeddingClient();
+ }
+
+ }
+
+}
diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store/pom.xml
new file mode 100644
index 00000000000..c97eb81ad68
--- /dev/null
+++ b/spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store/pom.xml
@@ -0,0 +1,42 @@
+
+
+ 4.0.0
+
+ org.springframework.ai
+ spring-ai
+ 1.0.0-SNAPSHOT
+ ../../pom.xml
+
+ spring-ai-opensearch-store-spring-boot-starter
+ jar
+ Spring AI Starter - OpenSearch Store
+ Spring AI OpenSearch Store Auto Configuration
+ https://github.com/spring-projects/spring-ai
+
+
+ https://github.com/spring-projects/spring-ai
+ git://github.com/spring-projects/spring-ai.git
+ git@github.com:spring-projects/spring-ai.git
+
+
+
+
+
+ org.springframework.boot
+ spring-boot-starter
+
+
+
+ org.springframework.ai
+ spring-ai-spring-boot-autoconfigure
+ ${project.parent.version}
+
+
+
+ org.springframework.ai
+ spring-ai-opensearch-store
+ ${project.parent.version}
+
+
+
+
diff --git a/vector-stores/spring-ai-azure/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java b/vector-stores/spring-ai-azure/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java
index ec29ce0bfcf..165b190b06d 100644
--- a/vector-stores/spring-ai-azure/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java
+++ b/vector-stores/spring-ai-azure/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java
@@ -320,7 +320,7 @@ private List toFloatList(List doubleList) {
}
/**
- * Internal data structure for retrieving and and storing documents.
+ * Internal data structure for retrieving and storing documents.
*/
private record AzureSearchDocument(String id, String content, List embedding, String metadata) {
}
diff --git a/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java b/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java
index 4a8b681a5f0..3efd440341b 100644
--- a/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java
+++ b/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java
@@ -79,7 +79,7 @@ private static void doOperand(ExpressionType type, StringBuilder context) {
// TODO SAI supports collections
// reach out to mck@apache.org if you'd like these implemented
// case CONTAINS -> context.append(" CONTAINS ");
- // case CONTAINS_KEY -> context.append(" CONTAINS KEY ");
+ // case CONTAINS_KEY -> context.append(" CONTAINS_KEY ");
default -> throw new UnsupportedOperationException(
String.format("Expression type %s not yet implemented. Patches welcome.", type));
}
diff --git a/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java b/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java
index 6f651c944df..c7422cc5b7e 100644
--- a/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java
+++ b/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java
@@ -17,16 +17,20 @@
import java.util.ArrayList;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
import com.datastax.oss.driver.api.core.cql.BoundStatement;
import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder;
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
import com.datastax.oss.driver.api.core.cql.Row;
+import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata;
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
@@ -40,6 +44,7 @@
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingClient;
+import org.springframework.ai.vectorstore.CassandraVectorStoreConfig;
import org.springframework.ai.vectorstore.CassandraVectorStoreConfig.SchemaColumn;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.beans.factory.InitializingBean;
@@ -53,13 +58,13 @@
* fields in the documents to be stored alongside the vector and content data.
*
* This class requires a CassandraVectorStoreConfig configuration object for
- * initialization, which includes settings like connection details, index name, field
+ * initialization, which includes settings like connection details, index name, column
* names, etc. It also requires an EmbeddingClient to convert documents into embeddings
* before storing them.
*
* A schema matching the configuration is automatically created if it doesn't exist.
* Missing columns and indexes in existing tables will also be automatically created.
- * Disable this with the disallowSchemaCreation.
+ * Disable this with the CassandraVectorStoreConfig#disallowSchemaChanges().
*
* This class is designed to work with brand new tables that it creates for you, or on top
* of existing Cassandra tables. The latter is appropriate when wanting to keep data in
@@ -69,9 +74,20 @@
* Instances of this class are not dynamic against server-side schema changes. If you
* change the schema server-side you need a new CassandraVectorStore instance.
*
+ * When adding documents with the method {@link #add(List)} it first calls
+ * embeddingClient to create the embeddings. This is slow. Configure
+ * {@link CassandraVectorStoreConfig.Builder#withFixedThreadPoolExecutorSize(int)}
+ * accordingly to improve performance so embeddings are created and the documents are
+ * added concurrently. The default concurrency is 16
+ * ({@link CassandraVectorStoreConfig#DEFAULT_ADD_CONCURRENCY}). Remote transformers
+ * probably want higher concurrency, and local transformers may need lower concurrency.
+ * This concurrency limit does not need to be higher than the max parallel calls made to
+ * the {@link #add(List)} method multiplied by the list size. This setting can
+ * also serve as a protecting throttle against your embedding model.
+ *
* @author Mick Semb Wever
* @see VectorStore
- * @see CassandraVectorStoreConfig
+ * @see org.springframework.ai.vectorstore.CassandraVectorStoreConfig
* @see EmbeddingClient
* @since 1.0.0
*/
@@ -87,10 +103,14 @@ public enum Similarity {
}
- private static final String QUERY_FORMAT = "select %s,%s,%s%s from %s.%s ? order by %s ann of ? limit ?";
-
public static final String SIMILARITY_FIELD_NAME = "similarity_score";
+ public static final String DRIVER_PROFILE_UPDATES = "spring-ai-updates";
+
+ public static final String DRIVER_PROFILE_SEARCH = "spring-ai-search";
+
+ private static final String QUERY_FORMAT = "select %s,%s,%s%s from %s.%s ? order by %s ann of ? limit ?";
+
private static final Logger logger = LoggerFactory.getLogger(CassandraVectorStore.class);
private final CassandraVectorStoreConfig conf;
@@ -99,7 +119,7 @@ public enum Similarity {
private final FilterExpressionConverter filterExpressionConverter;
- private final Map, PreparedStatement> addStmts = new HashMap<>();
+ private final ConcurrentMap, PreparedStatement> addStmts = new ConcurrentHashMap<>();
private final PreparedStatement deleteStmt;
@@ -133,30 +153,39 @@ public CassandraVectorStore(CassandraVectorStoreConfig conf, EmbeddingClient emb
@Override
public void add(List documents) {
- CompletableFuture[] futures = new CompletableFuture[documents.size()];
- short i = 0;
- for (Document d : documents) {
- List