diff --git a/README.md b/README.md index 50e4facc713..1c53af2bdc9 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ You can find more details in the [Reference Documentation](https://docs.spring.i Spring AI supports many AI models. For an overview see here. Specific models currently supported are * OpenAI * Azure OpenAI -* Amazon Bedrock (Anthropic, Llama2, Cohere, Titan, Jurassic2) +* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2) * HuggingFace * Google VertexAI (PaLM2, Gemini) * Mistral AI diff --git a/models/spring-ai-bedrock/README.md b/models/spring-ai-bedrock/README.md index 94311055fba..19e48518a60 100644 --- a/models/spring-ai-bedrock/README.md +++ b/models/spring-ai-bedrock/README.md @@ -4,7 +4,7 @@ - [Anthropic2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-anthropic.html) - [Cohere Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-cohere.html) - [Cohere Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/bedrock-cohere-embedding.html) -- [Llama2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-llama2.html) +- [Llama Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-llama.html) - [Titan Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-titan.html) - [Titan Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/bedrock-titan-embedding.html) - [Jurassic2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-jurassic2.html) diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java index 2437b35eebb..55a2d80af50 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java @@ -24,6 +24,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import reactor.core.publisher.Flux; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse; @@ -32,6 +33,7 @@ /** * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ // @formatter:off @@ -92,6 +94,20 @@ public AnthropicChatBedrockApi(String modelId, AwsCredentialsProvider credential super(modelId, credentialsProvider, region, objectMapper, timeout); } + /** + * Create a new AnthropicChatBedrockApi instance using the provided credentials provider, region and object mapper. + * + * @param modelId The model id to use. See the {@link AnthropicChatModel} for the supported models. + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + * @param timeout The timeout to use. + */ + public AnthropicChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, + ObjectMapper objectMapper, Duration timeout) { + super(modelId, credentialsProvider, region, objectMapper, timeout); + } + // https://github.com/build-on-aws/amazon-bedrock-java-examples/blob/main/example_code/bedrock-runtime/src/main/java/aws/community/examples/InvokeBedrockStreamingAsync.java // https://docs.anthropic.com/claude/reference/complete_post diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java index e76bfcbeff3..0148b498373 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java @@ -26,6 +26,7 @@ import org.springframework.util.Assert; import reactor.core.publisher.Flux; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; import java.time.Duration; import java.util.List; @@ -39,6 +40,7 @@ * * @author Ben Middleton * @author Christian Tzolov + * @author Wei Jiang * @since 1.0.0 */ // @formatter:off @@ -96,6 +98,20 @@ public Anthropic3ChatBedrockApi(String modelId, AwsCredentialsProvider credentia super(modelId, credentialsProvider, region, objectMapper, timeout); } + /** + * Create a new AnthropicChatBedrockApi instance using the provided credentials provider, region and object mapper. + * + * @param modelId The model id to use. See the {@link AnthropicChatModel} for the supported models. + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + * @param timeout The timeout to use. + */ + public Anthropic3ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, + ObjectMapper objectMapper, Duration timeout) { + super(modelId, credentialsProvider, region, objectMapper, timeout); + } + // https://github.com/build-on-aws/amazon-bedrock-java-examples/blob/main/example_code/bedrock-runtime/src/main/java/aws/community/examples/InvokeBedrockStreamingAsync.java // https://docs.anthropic.com/claude/reference/complete_post @@ -441,7 +457,11 @@ public enum AnthropicChatModel { /** * anthropic.claude-3-haiku-20240307-v1:0 */ - CLAUDE_V3_HAIKU("anthropic.claude-3-haiku-20240307-v1:0"); + CLAUDE_V3_HAIKU("anthropic.claude-3-haiku-20240307-v1:0"), + /** + * anthropic.claude-3-opus-20240229-v1:0 + */ + CLAUDE_V3_OPUS("anthropic.claude-3-opus-20240229-v1:0"); private final String id; diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java index 8ed33139b7d..7db24b3b8c6 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java @@ -25,9 +25,10 @@ import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; -import org.springframework.ai.bedrock.llama2.BedrockLlama2ChatOptions; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi; +import org.springframework.ai.bedrock.llama.BedrockLlamaChatOptions; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; import org.springframework.ai.bedrock.titan.BedrockTitanChatOptions; +import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingOptions; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi; import org.springframework.aot.hint.MemberCategory; @@ -43,6 +44,7 @@ * @author Josh Long * @author Christian Tzolov * @author Mark Pollack + * @author Wei Jiang */ public class BedrockRuntimeHints implements RuntimeHintsRegistrar { @@ -63,15 +65,17 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) { for (var tr : findJsonAnnotatedClassesInPackage(BedrockCohereEmbeddingOptions.class)) hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(Llama2ChatBedrockApi.class)) + for (var tr : findJsonAnnotatedClassesInPackage(LlamaChatBedrockApi.class)) hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(BedrockLlama2ChatOptions.class)) + for (var tr : findJsonAnnotatedClassesInPackage(BedrockLlamaChatOptions.class)) hints.reflection().registerType(tr, mcs); for (var tr : findJsonAnnotatedClassesInPackage(TitanChatBedrockApi.class)) hints.reflection().registerType(tr, mcs); for (var tr : findJsonAnnotatedClassesInPackage(BedrockTitanChatOptions.class)) hints.reflection().registerType(tr, mcs); + for (var tr : findJsonAnnotatedClassesInPackage(BedrockTitanEmbeddingOptions.class)) + hints.reflection().registerType(tr, mcs); for (var tr : findJsonAnnotatedClassesInPackage(TitanEmbeddingBedrockApi.class)) hints.reflection().registerType(tr, mcs); diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java index 74f4249f04f..17e0eed79b7 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java @@ -61,6 +61,7 @@ * @see Model Parameters * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public abstract class AbstractBedrockApi { @@ -69,7 +70,7 @@ public abstract class AbstractBedrockApi { private final String modelId; private final ObjectMapper objectMapper; - private final String region; + private final Region region; private final BedrockRuntimeClient client; private final BedrockRuntimeAsyncClient clientStreaming; @@ -93,7 +94,7 @@ public AbstractBedrockApi(String modelId, String region, Duration timeout) { this(modelId, ProfileCredentialsProvider.builder().build(), region, ModelOptionsUtils.OBJECT_MAPPER, timeout); } - /** + /** * Create a new AbstractBedrockApi instance using the provided credentials provider, region and object mapper. * * @param modelId The model id to use. @@ -105,6 +106,7 @@ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProv ObjectMapper objectMapper) { this(modelId, credentialsProvider, region, objectMapper, Duration.ofMinutes(5)); } + /** * Create a new AbstractBedrockApi instance using the provided credentials provider, region and object mapper. * @@ -118,10 +120,26 @@ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProv */ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, ObjectMapper objectMapper, Duration timeout) { + this(modelId, credentialsProvider, Region.of(region), objectMapper, timeout); + } + + /** + * Create a new AbstractBedrockApi instance using the provided credentials provider, region and object mapper. + * + * @param modelId The model id to use. + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + * @param timeout Configure the amount of time to allow the client to complete the execution of an API call. + * This timeout covers the entire client execution except for marshalling. This includes request handler execution, + * all HTTP requests including retries, unmarshalling, etc. This value should always be positive, if present. + */ + public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, + ObjectMapper objectMapper, Duration timeout) { Assert.hasText(modelId, "Model id must not be empty"); Assert.notNull(credentialsProvider, "Credentials provider must not be null"); - Assert.hasText(region, "Region must not be empty"); + Assert.notNull(region, "Region must not be empty"); Assert.notNull(objectMapper, "Object mapper must not be null"); Assert.notNull(timeout, "Timeout must not be null"); @@ -131,13 +149,13 @@ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProv this.client = BedrockRuntimeClient.builder() - .region(Region.of(this.region)) + .region(this.region) .credentialsProvider(credentialsProvider) .overrideConfiguration(c -> c.apiCallTimeout(timeout)) .build(); this.clientStreaming = BedrockRuntimeAsyncClient.builder() - .region(Region.of(this.region)) + .region(this.region) .credentialsProvider(credentialsProvider) .overrideConfiguration(c -> c.apiCallTimeout(timeout)) .build(); @@ -153,7 +171,7 @@ public String getModelId() { /** * @return The AWS region. */ - public String getRegion() { + public Region getRegion() { return this.region; } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java index b3b02b6993f..5b133a9976c 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java @@ -25,6 +25,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import reactor.core.publisher.Flux; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.api.AbstractBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest; @@ -36,6 +37,7 @@ * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere.html * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class CohereChatBedrockApi extends @@ -91,6 +93,20 @@ public CohereChatBedrockApi(String modelId, AwsCredentialsProvider credentialsPr super(modelId, credentialsProvider, region, objectMapper, timeout); } + /** + * Create a new CohereChatBedrockApi instance using the provided credentials provider, region and object mapper. + * + * @param modelId The model id to use. See the {@link CohereChatModel} for the supported models. + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + * @param timeout The timeout to use. + */ + public CohereChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, + ObjectMapper objectMapper, Duration timeout) { + super(modelId, credentialsProvider, region, objectMapper, timeout); + } + /** * CohereChatRequest encapsulates the request parameters for the Cohere command model. * diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java index 7d0fa442cde..13752cc4486 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java @@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.ObjectMapper; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.api.AbstractBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingRequest; @@ -34,6 +35,7 @@ * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere.html#model-parameters-embed * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class CohereEmbeddingBedrockApi extends @@ -91,6 +93,21 @@ public CohereEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credenti super(modelId, credentialsProvider, region, objectMapper, timeout); } + /** + * Create a new CohereEmbeddingBedrockApi instance using the provided credentials provider, region and object + * mapper. + * + * @param modelId The model id to use. See the {@link CohereEmbeddingModel} for the supported models. + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + * @param timeout The timeout to use. + */ + public CohereEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, + ObjectMapper objectMapper, Duration timeout) { + super(modelId, credentialsProvider, region, objectMapper, timeout); + } + /** * The Cohere Embed model request. * diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java index fa505176350..0ec58c8bd2c 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java @@ -29,12 +29,14 @@ import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatResponse; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; /** * Java client for the Bedrock Jurassic2 chat model. * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class Ai21Jurassic2ChatBedrockApi extends @@ -92,6 +94,20 @@ public Ai21Jurassic2ChatBedrockApi(String modelId, AwsCredentialsProvider creden super(modelId, credentialsProvider, region, objectMapper, timeout); } + /** + * Create a new Ai21Jurassic2ChatBedrockApi instance. + * + * @param modelId The model id to use. See the {@link Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatModel} for the supported models. + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + * @param timeout The timeout to use. + */ + public Ai21Jurassic2ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, + ObjectMapper objectMapper, Duration timeout) { + super(modelId, credentialsProvider, region, objectMapper, timeout); + } + /** * AI21 Jurassic2 chat request parameters. * diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatClient.java similarity index 67% rename from models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java rename to models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatClient.java index a12fe0b6eb9..c1be58e5a1d 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatClient.java @@ -13,16 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.bedrock.llama2; +package org.springframework.ai.bedrock.llama; import java.util.List; import reactor.core.publisher.Flux; import org.springframework.ai.bedrock.MessageToPromptConverter; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatRequest; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatResponse; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse; import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.ChatResponse; @@ -35,26 +35,27 @@ import org.springframework.util.Assert; /** - * Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama2 chat + * Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama chat * generative. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ -public class BedrockLlama2ChatClient implements ChatClient, StreamingChatClient { +public class BedrockLlamaChatClient implements ChatClient, StreamingChatClient { - private final Llama2ChatBedrockApi chatApi; + private final LlamaChatBedrockApi chatApi; - private final BedrockLlama2ChatOptions defaultOptions; + private final BedrockLlamaChatOptions defaultOptions; - public BedrockLlama2ChatClient(Llama2ChatBedrockApi chatApi) { + public BedrockLlamaChatClient(LlamaChatBedrockApi chatApi) { this(chatApi, - BedrockLlama2ChatOptions.builder().withTemperature(0.8f).withTopP(0.9f).withMaxGenLen(100).build()); + BedrockLlamaChatOptions.builder().withTemperature(0.8f).withTopP(0.9f).withMaxGenLen(100).build()); } - public BedrockLlama2ChatClient(Llama2ChatBedrockApi chatApi, BedrockLlama2ChatOptions options) { - Assert.notNull(chatApi, "Llama2ChatBedrockApi must not be null"); - Assert.notNull(options, "BedrockLlama2ChatOptions must not be null"); + public BedrockLlamaChatClient(LlamaChatBedrockApi chatApi, BedrockLlamaChatOptions options) { + Assert.notNull(chatApi, "LlamaChatBedrockApi must not be null"); + Assert.notNull(options, "BedrockLlamaChatOptions must not be null"); this.chatApi = chatApi; this.defaultOptions = options; @@ -65,7 +66,7 @@ public ChatResponse call(Prompt prompt) { var request = createRequest(prompt); - Llama2ChatResponse response = this.chatApi.chatCompletion(request); + LlamaChatResponse response = this.chatApi.chatCompletion(request); return new ChatResponse(List.of(new Generation(response.generation()).withGenerationMetadata( ChatGenerationMetadata.from(response.stopReason().name(), extractUsage(response))))); @@ -76,7 +77,7 @@ public Flux stream(Prompt prompt) { var request = createRequest(prompt); - Flux fluxResponse = this.chatApi.chatCompletionStream(request); + Flux fluxResponse = this.chatApi.chatCompletionStream(request); return fluxResponse.map(response -> { String stopReason = response.stopReason() != null ? response.stopReason().name() : null; @@ -85,7 +86,7 @@ public Flux stream(Prompt prompt) { }); } - private Usage extractUsage(Llama2ChatResponse response) { + private Usage extractUsage(LlamaChatResponse response) { return new Usage() { @Override @@ -103,22 +104,22 @@ public Long getGenerationTokens() { /** * Accessible for testing. */ - Llama2ChatRequest createRequest(Prompt prompt) { + LlamaChatRequest createRequest(Prompt prompt) { final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); - Llama2ChatRequest request = Llama2ChatRequest.builder(promptValue).build(); + LlamaChatRequest request = LlamaChatRequest.builder(promptValue).build(); if (this.defaultOptions != null) { - request = ModelOptionsUtils.merge(request, this.defaultOptions, Llama2ChatRequest.class); + request = ModelOptionsUtils.merge(request, this.defaultOptions, LlamaChatRequest.class); } if (prompt.getOptions() != null) { if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { - BedrockLlama2ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, - ChatOptions.class, BedrockLlama2ChatOptions.class); + BedrockLlamaChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, + ChatOptions.class, BedrockLlamaChatOptions.class); - request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, Llama2ChatRequest.class); + request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, LlamaChatRequest.class); } else { throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java similarity index 91% rename from models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatOptions.java rename to models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java index a944b09904a..3502fd4c441 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.bedrock.llama2; +package org.springframework.ai.bedrock.llama; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -26,7 +26,7 @@ * @author Christian Tzolov */ @JsonInclude(Include.NON_NULL) -public class BedrockLlama2ChatOptions implements ChatOptions { +public class BedrockLlamaChatOptions implements ChatOptions { /** * The temperature value controls the randomness of the generated text. Use a lower @@ -51,7 +51,7 @@ public static Builder builder() { public static class Builder { - private BedrockLlama2ChatOptions options = new BedrockLlama2ChatOptions(); + private BedrockLlamaChatOptions options = new BedrockLlamaChatOptions(); public Builder withTemperature(Float temperature) { this.options.setTemperature(temperature); @@ -68,7 +68,7 @@ public Builder withMaxGenLen(Integer maxGenLen) { return this; } - public BedrockLlama2ChatOptions build() { + public BedrockLlamaChatOptions build() { return this.options; } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java similarity index 61% rename from models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApi.java rename to models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java index af10d69bdfb..25d71aedeea 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.bedrock.llama2.api; +package org.springframework.ai.bedrock.llama.api; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -21,76 +21,92 @@ import com.fasterxml.jackson.databind.ObjectMapper; import reactor.core.publisher.Flux; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.api.AbstractBedrockApi; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatRequest; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatResponse; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse; import java.time.Duration; // @formatter:off /** - * Java client for the Bedrock Llama2 chat model. + * Java client for the Bedrock Llama chat model. * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ -public class Llama2ChatBedrockApi extends - AbstractBedrockApi { +public class LlamaChatBedrockApi extends + AbstractBedrockApi { /** - * Create a new Llama2ChatBedrockApi instance using the default credentials provider chain, the default object + * Create a new LlamaChatBedrockApi instance using the default credentials provider chain, the default object * mapper, default temperature and topP values. * - * @param modelId The model id to use. See the {@link Llama2ChatModel} for the supported models. + * @param modelId The model id to use. See the {@link LlamaChatModel} for the supported models. * @param region The AWS region to use. */ - public Llama2ChatBedrockApi(String modelId, String region) { + public LlamaChatBedrockApi(String modelId, String region) { super(modelId, region); } /** - * Create a new Llama2ChatBedrockApi instance using the provided credentials provider, region and object mapper. + * Create a new LlamaChatBedrockApi instance using the provided credentials provider, region and object mapper. * - * @param modelId The model id to use. See the {@link Llama2ChatModel} for the supported models. + * @param modelId The model id to use. See the {@link LlamaChatModel} for the supported models. * @param credentialsProvider The credentials provider to connect to AWS. * @param region The AWS region to use. * @param objectMapper The object mapper to use for JSON serialization and deserialization. */ - public Llama2ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, + public LlamaChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, ObjectMapper objectMapper) { super(modelId, credentialsProvider, region, objectMapper); } /** - * Create a new Llama2ChatBedrockApi instance using the default credentials provider chain, the default object + * Create a new LlamaChatBedrockApi instance using the default credentials provider chain, the default object * mapper, default temperature and topP values. * - * @param modelId The model id to use. See the {@link Llama2ChatModel} for the supported models. + * @param modelId The model id to use. See the {@link LlamaChatModel} for the supported models. * @param region The AWS region to use. * @param timeout The timeout to use. */ - public Llama2ChatBedrockApi(String modelId, String region, Duration timeout) { + public LlamaChatBedrockApi(String modelId, String region, Duration timeout) { super(modelId, region, timeout); } /** - * Create a new Llama2ChatBedrockApi instance using the provided credentials provider, region and object mapper. + * Create a new LlamaChatBedrockApi instance using the provided credentials provider, region and object mapper. * - * @param modelId The model id to use. See the {@link Llama2ChatModel} for the supported models. + * @param modelId The model id to use. See the {@link LlamaChatModel} for the supported models. * @param credentialsProvider The credentials provider to connect to AWS. * @param region The AWS region to use. * @param objectMapper The object mapper to use for JSON serialization and deserialization. * @param timeout The timeout to use. */ - public Llama2ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, + public LlamaChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, ObjectMapper objectMapper, Duration timeout) { super(modelId, credentialsProvider, region, objectMapper, timeout); } /** - * Llama2ChatRequest encapsulates the request parameters for the Meta Llama2 chat model. + * Create a new LlamaChatBedrockApi instance using the provided credentials provider, region and object mapper. + * + * @param modelId The model id to use. See the {@link LlamaChatModel} for the supported models. + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + * @param timeout The timeout to use. + */ + public LlamaChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, + ObjectMapper objectMapper, Duration timeout) { + super(modelId, credentialsProvider, region, objectMapper, timeout); + } + + /** + * LlamaChatRequest encapsulates the request parameters for the Meta Llama chat model. * * @param prompt The prompt to use for the chat. * @param temperature The temperature value controls the randomness of the generated text. Use a lower value to @@ -100,16 +116,16 @@ public Llama2ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsPr * @param maxGenLen The maximum length of the generated text. */ @JsonInclude(Include.NON_NULL) - public record Llama2ChatRequest( + public record LlamaChatRequest( @JsonProperty("prompt") String prompt, @JsonProperty("temperature") Float temperature, @JsonProperty("top_p") Float topP, @JsonProperty("max_gen_len") Integer maxGenLen) { /** - * Create a new Llama2ChatRequest builder. + * Create a new LlamaChatRequest builder. * @param prompt compulsory prompt parameter. - * @return a new Llama2ChatRequest builder. + * @return a new LlamaChatRequest builder. */ public static Builder builder(String prompt) { return new Builder(prompt); @@ -140,8 +156,8 @@ public Builder withMaxGenLen(Integer maxGenLen) { return this; } - public Llama2ChatRequest build() { - return new Llama2ChatRequest( + public LlamaChatRequest build() { + return new LlamaChatRequest( prompt, temperature, topP, @@ -152,7 +168,7 @@ public Llama2ChatRequest build() { } /** - * Llama2ChatResponse encapsulates the response parameters for the Meta Llama2 chat model. + * LlamaChatResponse encapsulates the response parameters for the Meta Llama chat model. * * @param generation The generated text. * @param promptTokenCount The number of tokens in the prompt. @@ -163,7 +179,7 @@ public Llama2ChatRequest build() { * increasing the value of max_gen_len and trying again. */ @JsonInclude(Include.NON_NULL) - public record Llama2ChatResponse( + public record LlamaChatResponse( @JsonProperty("generation") String generation, @JsonProperty("prompt_token_count") Integer promptTokenCount, @JsonProperty("generation_token_count") Integer generationTokenCount, @@ -186,9 +202,9 @@ public enum StopReason { } /** - * Llama2 models version. + * Llama models version. */ - public enum Llama2ChatModel { + public enum LlamaChatModel { /** * meta.llama2-13b-chat-v1 @@ -198,7 +214,17 @@ public enum Llama2ChatModel { /** * meta.llama2-70b-chat-v1 */ - LLAMA2_70B_CHAT_V1("meta.llama2-70b-chat-v1"); + LLAMA2_70B_CHAT_V1("meta.llama2-70b-chat-v1"), + + /** + * meta.llama3-8b-instruct-v1:0 + */ + LLAMA3_8B_INSTRUCT_V1("meta.llama3-8b-instruct-v1:0"), + + /** + * meta.llama3-70b-instruct-v1:0 + */ + LLAMA3_70B_INSTRUCT_V1("meta.llama3-70b-instruct-v1:0"); private final String id; @@ -209,19 +235,19 @@ public String id() { return id; } - Llama2ChatModel(String value) { + LlamaChatModel(String value) { this.id = value; } } @Override - public Llama2ChatResponse chatCompletion(Llama2ChatRequest request) { - return this.internalInvocation(request, Llama2ChatResponse.class); + public LlamaChatResponse chatCompletion(LlamaChatRequest request) { + return this.internalInvocation(request, LlamaChatResponse.class); } @Override - public Flux chatCompletionStream(Llama2ChatRequest request) { - return this.internalInvocationStream(request, Llama2ChatResponse.class); + public Flux chatCompletionStream(LlamaChatRequest request) { + return this.internalInvocationStream(request, LlamaChatResponse.class); } } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java index d48135f80ec..1d64f92eff8 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java @@ -28,6 +28,7 @@ import org.springframework.ai.document.Document; import org.springframework.ai.embedding.AbstractEmbeddingClient; import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.util.Assert; @@ -40,6 +41,7 @@ * Note: Titan Embedding does not support batch embedding. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class BedrockTitanEmbeddingClient extends AbstractEmbeddingClient { @@ -87,9 +89,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { List> embeddingList = new ArrayList<>(); for (String inputContent : request.getInstructions()) { - var apiRequest = (this.inputType == InputType.IMAGE) - ? new TitanEmbeddingRequest.Builder().withInputImage(inputContent).build() - : new TitanEmbeddingRequest.Builder().withInputText(inputContent).build(); + var apiRequest = createTitanEmbeddingRequest(inputContent, request.getOptions()); TitanEmbeddingResponse response = this.embeddingApi.embedding(apiRequest); embeddingList.add(response.embedding()); } @@ -100,6 +100,18 @@ public EmbeddingResponse call(EmbeddingRequest request) { return new EmbeddingResponse(embeddings); } + private TitanEmbeddingRequest createTitanEmbeddingRequest(String inputContent, EmbeddingOptions requestOptions) { + InputType inputType = this.inputType; + + if (requestOptions != null + && requestOptions instanceof BedrockTitanEmbeddingOptions bedrockTitanEmbeddingOptions) { + inputType = bedrockTitanEmbeddingOptions.getInputType(); + } + + return (inputType == InputType.IMAGE) ? new TitanEmbeddingRequest.Builder().withInputImage(inputContent).build() + : new TitanEmbeddingRequest.Builder().withInputText(inputContent).build(); + } + @Override public int dimensions() { if (this.inputType == InputType.IMAGE) { diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java new file mode 100644 index 00000000000..fd1c609bf91 --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java @@ -0,0 +1,65 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.titan; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; + +import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingClient.InputType; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.util.Assert; + +/** + * @author Wei Jiang + */ +@JsonInclude(Include.NON_NULL) +public class BedrockTitanEmbeddingOptions implements EmbeddingOptions { + + /** + * Titan Embedding API input types. Could be either text or image (encoded in base64). + */ + private InputType inputType; + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private BedrockTitanEmbeddingOptions options = new BedrockTitanEmbeddingOptions(); + + public Builder withInputType(InputType inputType) { + Assert.notNull(inputType, "input type can not be null."); + + this.options.setInputType(inputType); + return this; + } + + public BedrockTitanEmbeddingOptions build() { + return this.options; + } + + } + + public InputType getInputType() { + return this.inputType; + } + + public void setInputType(InputType inputType) { + this.inputType = inputType; + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java index 498b34bf3d8..f4a219a8d99 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java @@ -24,6 +24,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import reactor.core.publisher.Flux; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.api.AbstractBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatRequest; @@ -38,6 +39,7 @@ * https://docs.aws.amazon.com/bedrock/latest/userguide/titan-text-models.html * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ // @formatter:off @@ -92,6 +94,20 @@ public TitanChatBedrockApi(String modelId, AwsCredentialsProvider credentialsPro super(modelId, credentialsProvider, region, objectMapper, timeout); } + /** + * Create a new TitanChatBedrockApi instance using the provided credentials provider, region and object mapper. + * + * @param modelId The model id to use. See the {@link TitanChatModel} for the supported models. + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + * @param timeout The timeout to use. + */ + public TitanChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, + ObjectMapper objectMapper, Duration timeout) { + super(modelId, credentialsProvider, region, objectMapper, timeout); + } + /** * TitanChatRequest encapsulates the request parameters for the Titan chat model. * diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java index 9c1dcb3b267..5901799c32f 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.ObjectMapper; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.api.AbstractBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingRequest; @@ -34,6 +35,7 @@ * https://docs.aws.amazon.com/bedrock/latest/userguide/titan-multiemb-models.html * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ // @formatter:off @@ -65,6 +67,20 @@ public TitanEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credentia super(modelId, credentialsProvider, region, objectMapper, timeout); } + /** + * Create a new TitanEmbeddingBedrockApi instance. + * + * @param modelId The model id to use. See the {@link TitanEmbeddingModel} for the supported models. + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + * @param timeout The timeout to use. + */ + public TitanEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, + ObjectMapper objectMapper, Duration timeout) { + super(modelId, credentialsProvider, region, objectMapper, timeout); + } + /** * Titan Embedding request parameters. * diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatClientIT.java index c2050f18ca0..4fa59402054 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatClientIT.java @@ -161,9 +161,9 @@ void beanOutputParserRecords() { String format = outputParser.getFormat(); String template = """ Generate the filmography of 5 movies for Tom Hanks. - Remove non JSON tex blocks from the output. {format} Provide your answer in the JSON format with the feature names as the keys. + Remove Markdown code blocks from the output. """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java index a4b51a70750..92060ef5bf0 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java @@ -20,7 +20,7 @@ import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi; import org.springframework.aot.hint.RuntimeHints; @@ -43,7 +43,7 @@ void registerHints() { bedrockRuntimeHints.registerHints(runtimeHints, null); List classList = Arrays.asList(Ai21Jurassic2ChatBedrockApi.class, CohereChatBedrockApi.class, - CohereEmbeddingBedrockApi.class, Llama2ChatBedrockApi.class, TitanChatBedrockApi.class, + CohereEmbeddingBedrockApi.class, LlamaChatBedrockApi.class, TitanChatBedrockApi.class, TitanEmbeddingBedrockApi.class, AnthropicChatBedrockApi.class); for (Class aClass : classList) { diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatClientIT.java similarity index 88% rename from models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java rename to models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatClientIT.java index 7cbf3f23566..554afcf9fa0 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatClientIT.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.bedrock.llama2; +package org.springframework.ai.bedrock.llama; import java.time.Duration; import java.util.Arrays; @@ -22,27 +22,25 @@ import java.util.stream.Collectors; import com.fasterxml.jackson.databind.ObjectMapper; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import reactor.core.publisher.Flux; - -import org.springframework.ai.chat.ChatResponse; -import org.springframework.ai.chat.messages.AssistantMessage; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatModel; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; +import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; -import org.springframework.ai.parser.BeanOutputParser; -import org.springframework.ai.parser.ListOutputParser; -import org.springframework.ai.parser.MapOutputParser; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.parser.BeanOutputParser; +import org.springframework.ai.parser.ListOutputParser; +import org.springframework.ai.parser.MapOutputParser; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -56,10 +54,10 @@ @SpringBootTest @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") -class BedrockLlama2ChatClientIT { +class BedrockLlamaChatClientIT { @Autowired - private BedrockLlama2ChatClient client; + private BedrockLlamaChatClient client; @Value("classpath:/prompts/system-message.st") private Resource systemResource; @@ -67,8 +65,8 @@ class BedrockLlama2ChatClientIT { @Test void multipleStreamAttempts() { + Flux joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a Toy joke?"))); Flux joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?"))); - Flux joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); String joke1 = joke1Stream.collectList() .block() @@ -105,7 +103,6 @@ void roleTest() { assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } - @Disabled("TODO: Fix the parser instructions to return the correct format") @Test void outputParser() { DefaultConversionService conversionService = new DefaultConversionService(); @@ -147,7 +144,6 @@ void mapOutputParser() { record ActorsFilmsRecord(String actor, List movies) { } - @Disabled("TODO: Fix the parser instructions to return the correct format") @Test void beanOutputParserRecords() { @@ -169,7 +165,6 @@ void beanOutputParserRecords() { assertThat(actorsFilms.movies()).hasSize(5); } - @Disabled("TODO: Fix the parser instructions to return the correct format") @Test void beanStreamOutputParserRecords() { @@ -204,16 +199,16 @@ void beanStreamOutputParserRecords() { public static class TestConfiguration { @Bean - public Llama2ChatBedrockApi llama2Api() { - return new Llama2ChatBedrockApi(Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(), + public LlamaChatBedrockApi llamaApi() { + return new LlamaChatBedrockApi(LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(), EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), Duration.ofMinutes(2)); } @Bean - public BedrockLlama2ChatClient llama2ChatClient(Llama2ChatBedrockApi llama2Api) { - return new BedrockLlama2ChatClient(llama2Api, - BedrockLlama2ChatOptions.builder().withTemperature(0.5f).withMaxGenLen(100).withTopP(0.9f).build()); + public BedrockLlamaChatClient llamaChatClient(LlamaChatBedrockApi llamaApi) { + return new BedrockLlamaChatClient(llamaApi, + BedrockLlamaChatOptions.builder().withTemperature(0.5f).withMaxGenLen(100).withTopP(0.9f).build()); } } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2CreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java similarity index 67% rename from models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2CreateRequestTests.java rename to models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java index 1a3329016f3..77018af9e6e 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2CreateRequestTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.bedrock.llama2; +package org.springframework.ai.bedrock.llama; import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatModel; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; import org.springframework.ai.chat.prompt.Prompt; import java.time.Duration; @@ -30,18 +32,21 @@ /** * @author Christian Tzolov + * @author Wei Jiang */ -public class BedrockLlama2CreateRequestTests { +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class BedrockLlamaCreateRequestTests { - private Llama2ChatBedrockApi api = new Llama2ChatBedrockApi(Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(), + private LlamaChatBedrockApi api = new LlamaChatBedrockApi(LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(), EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), Duration.ofMinutes(2)); @Test public void createRequestWithChatOptions() { - var client = new BedrockLlama2ChatClient(api, - BedrockLlama2ChatOptions.builder().withTemperature(66.6f).withMaxGenLen(666).withTopP(0.66f).build()); + var client = new BedrockLlamaChatClient(api, + BedrockLlamaChatOptions.builder().withTemperature(66.6f).withMaxGenLen(666).withTopP(0.66f).build()); var request = client.createRequest(new Prompt("Test message content")); @@ -51,7 +56,7 @@ public void createRequestWithChatOptions() { assertThat(request.maxGenLen()).isEqualTo(666); request = client.createRequest(new Prompt("Test message content", - BedrockLlama2ChatOptions.builder().withTemperature(99.9f).withMaxGenLen(999).withTopP(0.99f).build())); + BedrockLlamaChatOptions.builder().withTemperature(99.9f).withMaxGenLen(999).withTopP(0.99f).build())); assertThat(request.prompt()).isNotEmpty(); assertThat(request.temperature()).isEqualTo(99.9f); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java similarity index 61% rename from models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApiIT.java rename to models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java index dc97d8e7b8f..5b4587358fb 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java @@ -13,47 +13,51 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.bedrock.llama2.api; +package org.springframework.ai.bedrock.llama.api; import java.time.Duration; import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse; + +import com.fasterxml.jackson.databind.ObjectMapper; + import reactor.core.publisher.Flux; +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatModel; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatRequest; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatResponse; - import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov + * @author Wei Jiang */ @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") -public class Llama2ChatBedrockApiIT { +public class LlamaChatBedrockApiIT { - private Llama2ChatBedrockApi llama2ChatApi = new Llama2ChatBedrockApi(Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(), - Region.US_EAST_1.id(), Duration.ofMinutes(2)); + private LlamaChatBedrockApi llamaChatApi = new LlamaChatBedrockApi(LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(), + EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), + Duration.ofMinutes(2)); @Test public void chatCompletion() { - Llama2ChatRequest request = Llama2ChatRequest.builder("Hello, my name is") + LlamaChatRequest request = LlamaChatRequest.builder("Hello, my name is") .withTemperature(0.9f) .withTopP(0.9f) .withMaxGenLen(20) .build(); - Llama2ChatResponse response = llama2ChatApi.chatCompletion(request); + LlamaChatResponse response = llamaChatApi.chatCompletion(request); System.out.println(response.generation()); assertThat(response).isNotNull(); assertThat(response.generation()).isNotEmpty(); - assertThat(response.promptTokenCount()).isEqualTo(6); assertThat(response.generationTokenCount()).isGreaterThan(10); assertThat(response.generationTokenCount()).isLessThanOrEqualTo(20); assertThat(response.stopReason()).isNotNull(); @@ -63,15 +67,15 @@ public void chatCompletion() { @Test public void chatCompletionStream() { - Llama2ChatRequest request = new Llama2ChatRequest("Hello, my name is", 0.9f, 0.9f, 20); - Flux responseStream = llama2ChatApi.chatCompletionStream(request); - List responses = responseStream.collectList().block(); + LlamaChatRequest request = new LlamaChatRequest("Hello, my name is", 0.9f, 0.9f, 20); + Flux responseStream = llamaChatApi.chatCompletionStream(request); + List responses = responseStream.collectList().block(); assertThat(responses).isNotNull(); assertThat(responses).hasSizeGreaterThan(10); assertThat(responses.get(0).generation()).isNotEmpty(); - Llama2ChatResponse lastResponse = responses.get(responses.size() - 1); + LlamaChatResponse lastResponse = responses.get(responses.size() - 1); assertThat(lastResponse.stopReason()).isNotNull(); assertThat(lastResponse.amazonBedrockInvocationMetrics()).isNotNull(); } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClientIT.java index dead759015c..6400a15f010 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClientIT.java @@ -22,10 +22,14 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; +import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingClient.InputType; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingModel; +import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; @@ -33,6 +37,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; +import com.fasterxml.jackson.databind.ObjectMapper; + import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest @@ -46,7 +52,8 @@ class BedrockTitanEmbeddingClientIT { @Test void singleEmbedding() { assertThat(embeddingClient).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World")); + EmbeddingResponse embeddingResponse = embeddingClient.call(new EmbeddingRequest(List.of("Hello World"), + BedrockTitanEmbeddingOptions.builder().withInputType(InputType.TEXT).build())); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); assertThat(embeddingClient.dimensions()).isEqualTo(1024); @@ -59,7 +66,8 @@ void imageEmbedding() throws IOException { .getContentAsByteArray(); EmbeddingResponse embeddingResponse = embeddingClient - .embedForResponse(List.of(Base64.getEncoder().encodeToString(image))); + .call(new EmbeddingRequest(List.of(Base64.getEncoder().encodeToString(image)), + BedrockTitanEmbeddingOptions.builder().withInputType(InputType.IMAGE).build())); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); assertThat(embeddingClient.dimensions()).isEqualTo(1024); @@ -70,7 +78,8 @@ public static class TestConfiguration { @Bean public TitanEmbeddingBedrockApi titanEmbeddingApi() { - return new TitanEmbeddingBedrockApi(TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id(), Region.US_EAST_1.id(), + return new TitanEmbeddingBedrockApi(TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id(), + EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), Duration.ofMinutes(2)); } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java index 91aad8f5490..f5c1f4fd5bf 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java @@ -177,7 +177,7 @@ public Flux stream(Prompt prompt) { private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) { List choices = chunk.choices() .stream() - .map(cc -> new Choice(cc.index(), cc.delta(), cc.finishReason())) + .map(cc -> new Choice(cc.index(), cc.delta(), cc.finishReason(), cc.logprobs())) .toList(); return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, null); @@ -285,7 +285,13 @@ protected List doGetUserMessages(ChatCompletionRequest re @SuppressWarnings("null") @Override protected ChatCompletionMessage doGetToolResponseMessage(ResponseEntity chatCompletion) { - return chatCompletion.getBody().choices().iterator().next().message(); + ChatCompletionMessage msg = chatCompletion.getBody().choices().iterator().next().message(); + if (msg.role() == null) { + // add missing role + msg = new ChatCompletionMessage(msg.content(), ChatCompletionMessage.Role.ASSISTANT, msg.name(), + msg.toolCalls()); + } + return msg; } @Override diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java index 19ba2ac521f..5732c871a66 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java @@ -535,14 +535,12 @@ public enum ChatCompletionFinishReason { */ @JsonProperty("model_length") MODEL_LENGTH, /** - * The model called a tool. + * */ - @JsonProperty("tool_call") TOOL_CALL, - - // anticipation of future changes. Based on: - // https://github.com/mistralai/client-python/blob/main/src/mistralai/models/chat_completion.py @JsonProperty("error") ERROR, - + /** + * The model requested a tool call. + */ @JsonProperty("tool_calls") TOOL_CALLS // @formatter:on @@ -577,17 +575,65 @@ public record ChatCompletion( * @param index The index of the choice in the list of choices. * @param message A chat completion message generated by the model. * @param finishReason The reason the model stopped generating tokens. + * @param logprobs Log probability information for the choice. */ @JsonInclude(Include.NON_NULL) public record Choice( // @formatter:off @JsonProperty("index") Integer index, @JsonProperty("message") ChatCompletionMessage message, - @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) { + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("logprobs") LogProbs logprobs) { // @formatter:on } } + /** + * + * Log probability information for the choice. anticipation of future changes. + * + * @param content A list of message content tokens with log probability information. + */ + @JsonInclude(Include.NON_NULL) + public record LogProbs(@JsonProperty("content") List content) { + + /** + * Message content tokens with log probability information. + * + * @param token The token. + * @param logprob The log probability of the token. + * @param probBytes A list of integers representing the UTF-8 bytes representation + * of the token. Useful in instances where characters are represented by multiple + * tokens and their byte representations must be combined to generate the correct + * text representation. Can be null if there is no bytes representation for the + * token. + * @param topLogprobs List of the most likely tokens and their log probability, at + * this token position. In rare cases, there may be fewer than the number of + * requested top_logprobs returned. + */ + @JsonInclude(Include.NON_NULL) + public record Content(@JsonProperty("token") String token, @JsonProperty("logprob") Float logprob, + @JsonProperty("bytes") List probBytes, + @JsonProperty("top_logprobs") List topLogprobs) { + + /** + * The most likely tokens and their log probability, at this token position. + * + * @param token The token. + * @param logprob The log probability of the token. + * @param probBytes A list of integers representing the UTF-8 bytes + * representation of the token. Useful in instances where characters are + * represented by multiple tokens and their byte representations must be + * combined to generate the correct text representation. Can be null if there + * is no bytes representation for the token. + */ + @JsonInclude(Include.NON_NULL) + public record TopLogProbs(@JsonProperty("token") String token, @JsonProperty("logprob") Float logprob, + @JsonProperty("bytes") List probBytes) { + } + } + } + /** * Represents a streamed chunk of a chat completion response returned by model, based * on the provided input. @@ -616,13 +662,15 @@ public record ChatCompletionChunk( * @param index The index of the choice in the list of choices. * @param delta A chat completion delta generated by streamed model responses. * @param finishReason The reason the model stopped generating tokens. + * @param logprobs Log probability information for the choice. */ @JsonInclude(Include.NON_NULL) public record ChunkChoice( // @formatter:off @JsonProperty("index") Integer index, @JsonProperty("delta") ChatCompletionMessage delta, - @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) { + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("logprobs") LogProbs logprobs) { // @formatter:on } } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java index 50cd223536c..774bd072934 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java @@ -27,6 +27,7 @@ import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ChatCompletionFunction; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.Role; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ToolCall; +import org.springframework.ai.mistralai.api.MistralAiApi.LogProbs; import org.springframework.util.CollectionUtils; /** @@ -83,8 +84,10 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) { .toList(); var role = current.delta().role() != null ? current.delta().role() : Role.ASSISTANT; - current = new ChunkChoice(current.index(), new ChatCompletionMessage(current.delta().content(), - role, current.delta().name(), toolCallsWithID), current.finishReason()); + current = new ChunkChoice( + current.index(), new ChatCompletionMessage(current.delta().content(), role, + current.delta().name(), toolCallsWithID), + current.finishReason(), current.logprobs()); } } return current; @@ -95,8 +98,9 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) { Integer index = (current.index() != null ? current.index() : previous.index()); ChatCompletionMessage message = merge(previous.delta(), current.delta()); + LogProbs logprobs = (current.logprobs() != null ? current.logprobs() : previous.logprobs()); - return new ChunkChoice(index, message, finishReason); + return new ChunkChoice(index, message, finishReason, logprobs); } private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) { @@ -190,8 +194,7 @@ public boolean isStreamingToolFunctionCallFinish(ChatCompletionChunk chatComplet } var choice = choices.get(0); - return choice.finishReason() == ChatCompletionFinishReason.TOOL_CALL - || choice.finishReason() == ChatCompletionFinishReason.TOOL_CALLS; + return choice.finishReason() == ChatCompletionFinishReason.TOOL_CALLS; } } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java index 549d788eade..1ca349d21a3 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java @@ -109,7 +109,7 @@ public void beforeEach() { public void mistralAiChatTransientError() { var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT), - ChatCompletionFinishReason.STOP); + ChatCompletionFinishReason.STOP, null); ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789l, "model", List.of(choice), new MistralAiApi.Usage(10, 10, 10)); @@ -137,7 +137,7 @@ public void mistralAiChatNonTransientError() { public void mistralAiChatStreamTransientError() { var choice = new ChatCompletionChunk.ChunkChoice(0, new ChatCompletionMessage("Response", Role.ASSISTANT), - ChatCompletionFinishReason.STOP); + ChatCompletionFinishReason.STOP, null); ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789l, "model", List.of(choice)); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechClient.java index 86465687588..fc4575440d0 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechClient.java @@ -133,7 +133,7 @@ private OpenAiAudioApi.SpeechRequest createRequestBody(SpeechPrompt request) { if (request.getOptions() != null) { if (request.getOptions() instanceof OpenAiAudioSpeechOptions runtimeOptions) { - options = this.merge(options, runtimeOptions); + options = this.merge(runtimeOptions, options); } else { throw new IllegalArgumentException("Prompt options are not of type SpeechOptions: " diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java index b99c7cca6de..e021571e0ea 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java @@ -178,7 +178,7 @@ OpenAiAudioApi.TranscriptionRequest createRequestBody(AudioTranscriptionPrompt r if (request.getOptions() != null) { if (request.getOptions() instanceof OpenAiAudioTranscriptionOptions runtimeOptions) { - options = this.merge(options, runtimeOptions); + options = this.merge(runtimeOptions, options); } else { throw new IllegalArgumentException("Prompt options are not of type TranscriptionOptions: " diff --git a/pom.xml b/pom.xml index 1aa98332298..30e19770d96 100644 --- a/pom.xml +++ b/pom.xml @@ -74,6 +74,8 @@ vector-stores/spring-ai-elasticsearch-store spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai spring-ai-spring-boot-starters/spring-ai-starter-elasticsearch-store + vector-stores/spring-ai-opensearch-store + spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store @@ -151,6 +153,12 @@ 11.6.1 4.5.1 1.7.1 + 2.10.1 + 5.3.1 + + + 1.19.7 + 2.0.1 0.0.4 diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index d1c857f9da9..43a481c157d 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -210,6 +210,12 @@ ${project.version} + + org.springframework.ai + spring-ai-opensearch-store + ${project.version} + + org.springframework.ai diff --git a/spring-ai-core/src/main/java/org/springframework/ai/aot/SpringAiCoreRuntimeHints.java b/spring-ai-core/src/main/java/org/springframework/ai/aot/SpringAiCoreRuntimeHints.java index ecf2a3fd94e..2d00f1ea1d3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/aot/SpringAiCoreRuntimeHints.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/aot/SpringAiCoreRuntimeHints.java @@ -35,8 +35,8 @@ public class SpringAiCoreRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { - var chatTypes = Set.of(AbstractMessage.class, AssistantMessage.class, ChatMessage.class, FunctionMessage.class, - Message.class, MessageType.class, UserMessage.class, SystemMessage.class, FunctionCallbackContext.class, + var chatTypes = Set.of(AbstractMessage.class, AssistantMessage.class, FunctionMessage.class, Message.class, + MessageType.class, UserMessage.class, SystemMessage.class, FunctionCallbackContext.class, FunctionCallback.class, FunctionCallbackWrapper.class); for (var c : chatTypes) { hints.reflection().registerType(c); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatClient.java index 6925c16eedb..cff4f8674b8 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatClient.java @@ -16,6 +16,10 @@ package org.springframework.ai.chat; import org.springframework.ai.chat.prompt.Prompt; + +import java.util.Arrays; + +import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.model.ModelClient; @@ -28,6 +32,12 @@ default String call(String message) { return (generation != null) ? generation.getOutput().getContent() : ""; } + default String call(Message... messages) { + Prompt prompt = new Prompt(Arrays.asList(messages)); + Generation generation = call(prompt).getResult(); + return (generation != null) ? generation.getOutput().getContent() : ""; + } + @Override ChatResponse call(Prompt prompt); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java index a6e4a0b7629..69634b1920d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java @@ -15,8 +15,11 @@ */ package org.springframework.ai.chat; +import java.util.Arrays; + import reactor.core.publisher.Flux; +import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.StreamingModelClient; @@ -30,6 +33,13 @@ default Flux stream(String message) { : response.getResult().getOutput().getContent()); } + default Flux stream(Message... messages) { + Prompt prompt = new Prompt(Arrays.asList(messages)); + return stream(prompt).map(response -> (response.getResult() == null || response.getResult().getOutput() == null + || response.getResult().getOutput().getContent() == null) ? "" + : response.getResult().getOutput().getContent()); + } + @Override Flux stream(Prompt prompt); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ChatMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ChatMessage.java deleted file mode 100644 index ea4803a5943..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ChatMessage.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.chat.messages; - -import java.util.Map; - -/** - * Represents a chat message in a chat application. - * - */ -public class ChatMessage extends AbstractMessage { - - public ChatMessage(String role, String content) { - super(MessageType.valueOf(role), content); - } - - public ChatMessage(String role, String content, Map properties) { - super(MessageType.valueOf(role), content, properties); - } - - public ChatMessage(MessageType messageType, String content) { - super(messageType, content); - } - - public ChatMessage(MessageType messageType, String content, Map properties) { - super(messageType, content, properties); - } - -} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java index d981a67edec..4c9229516b6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.chat.messages; +import java.util.Arrays; import java.util.List; import org.springframework.core.io.Resource; @@ -38,6 +39,10 @@ public UserMessage(String textContent, List mediaList) { super(MessageType.USER, textContent, mediaList); } + public UserMessage(String textContent, Media... media) { + this(textContent, Arrays.asList(media)); + } + @Override public String toString() { return "UserMessage{" + "content='" + getContent() + '\'' + ", properties=" + properties + ", messageType=" diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-llama2-chat-api.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-llama-chat-api.jpg similarity index 100% rename from spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-llama2-chat-api.jpg rename to spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-llama-chat-api.jpg diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/orbis-sensualium-pictus2.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/orbis-sensualium-pictus2.jpg new file mode 100644 index 00000000000..688c39d41b9 Binary files /dev/null and b/spring-ai-docs/src/main/antora/modules/ROOT/images/orbis-sensualium-pictus2.jpg differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-message-api.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-message-api.jpg new file mode 100644 index 00000000000..cbdae70ab2a Binary files /dev/null and b/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-message-api.jpg differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 457f9ed77a4..5c05e9d6277 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -11,7 +11,7 @@ *** xref:api/bedrock-chat.adoc[Amazon Bedrock] **** xref:api/chat/bedrock/bedrock-anthropic3.adoc[Anthropic3] **** xref:api/chat/bedrock/bedrock-anthropic.adoc[Anthropic2] -**** xref:api/chat/bedrock/bedrock-llama2.adoc[Llama2] +**** xref:api/chat/bedrock/bedrock-llama.adoc[Llama] **** xref:api/chat/bedrock/bedrock-cohere.adoc[Cohere] **** xref:api/chat/bedrock/bedrock-titan.adoc[Titan] **** xref:api/chat/bedrock/bedrock-jurassic2.adoc[Jurassic2] @@ -48,6 +48,7 @@ *** xref:api/vectordbs/azure.adoc[] *** xref:api/vectordbs/apache-cassandra.adoc[] *** xref:api/vectordbs/chroma.adoc[] +*** xref:api/vectordbs/elasticsearch.adoc[] *** xref:api/vectordbs/gemfire.adoc[GemFire] *** xref:api/vectordbs/milvus.adoc[] *** xref:api/vectordbs/mongodb.adoc[] @@ -61,6 +62,7 @@ ** xref:api/functions.adoc[Function Calling] +** xref:api/multimodality.adoc[Multimodality] ** xref:api/prompt.adoc[] ** xref:api/output-parser.adoc[] ** xref:api/etl-pipeline.adoc[] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc index 4da2b06f7d0..3bdb84e102f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc @@ -63,6 +63,18 @@ AWS credentials are resolved in the following order: 6. Credentials delivered through the Amazon EC2 container service if the `AWS_CONTAINER_CREDENTIALS_RELATIVE_URI` environment variable is set and the security manager has permission to access the variable. 7. Instance profile credentials delivered through the Amazon EC2 metadata service or set the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables. +AWS region is resolved in the following order: + +1. Spring-AI Bedrock `spring.ai.bedrock.aws.region` property. +2. Java System Properties - `aws.region`. +3. Environment Variables - `AWS_REGION`. +4. Credential profiles file at the default location (`~/.aws/credentials`) shared by all AWS SDKs and the AWS CLI. +5. Instance profile region delivered through the Amazon EC2 metadata service. + +In addition to the standard Spring-AI Bedrock credentials and region properties configuration, Spring-AI provides support for custom `AwsCredentialsProvider` and `AwsRegionProvider` beans. + +NOTE: For example, using Spring-AI and https://spring.io/projects/spring-cloud-aws[Spring Cloud for Amazon Web Services] at the same time. Spring-AI is compatible with Spring Cloud for Amazon Web Services credential configuration. + === Enable selected Bedrock model NOTE: By default, all models are disabled. You have to enable the chosen Bedrock models explicitly using the `spring.ai.bedrock...enabled=true` property. @@ -73,7 +85,7 @@ Here are the supported `` and `` combinations: |==== | Model | Chat | Chat Streaming | Embedding -| llama2 | Yes | Yes | No +| llama | Yes | Yes | No | jurassic2 | Yes | No | No | cohere | Yes | Yes | Yes | anthropic 2 | Yes | Yes | No @@ -82,7 +94,7 @@ Here are the supported `` and `` combinations: | titan | Yes | Yes | Yes (however, no batch support) |==== -For example, to enable the Bedrock Llama2 Chat client, you need to set `spring.ai.bedrock.llama2.chat.enabled=true`. +For example, to enable the Bedrock Llama Chat client, you need to set `spring.ai.bedrock.llama.chat.enabled=true`. Next, you can use the `spring.ai.bedrock...*` properties to configure each model as provided. @@ -90,7 +102,7 @@ For more information, refer to the documentation below for each supported model. * xref:api/chat/bedrock/bedrock-anthropic.adoc[Spring AI Bedrock Anthropic 2 Chat]: `spring.ai.bedrock.anthropic.chat.enabled=true` * xref:api/chat/bedrock/bedrock-anthropic3.adoc[Spring AI Bedrock Anthropic 3 Chat]: `spring.ai.bedrock.anthropic.chat.enabled=true` -* xref:api/chat/bedrock/bedrock-llama2.adoc[Spring AI Bedrock Llama2 Chat]: `spring.ai.bedrock.llama2.chat.enabled=true` +* xref:api/chat/bedrock/bedrock-llama.adoc[Spring AI Bedrock Llama Chat]: `spring.ai.bedrock.llama.chat.enabled=true` * xref:api/chat/bedrock/bedrock-cohere.adoc[Spring AI Bedrock Cohere Chat]: `spring.ai.bedrock.cohere.chat.enabled=true` * xref:api/embeddings/bedrock-cohere-embedding.adoc[Spring AI Bedrock Cohere Embeddings]: `spring.ai.bedrock.cohere.embedding.enabled=true` * xref:api/chat/bedrock/bedrock-titan.adoc[Spring AI Bedrock Titan Chat]: `spring.ai.bedrock.titan.chat.enabled=true` diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc index c0d5f516d82..af1ac613ae2 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc @@ -264,7 +264,7 @@ Flux response = chatClient.stream( The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java[Anthropic3ChatBedrockApi] provides is lightweight Java client on top of AWS Bedrock link:https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html[Anthropic Claude models]. -Client supports the `anthropic.claude-3-sonnet-20240229-v1:0`,`anthropic.claude-3-haiku-20240307-v1:0` and the legacy `anthropic.claude-v2`, `anthropic.claude-v2:1` and `anthropic.claude-instant-v1` models for both synchronous (e.g. `chatCompletion()`) and streaming (e.g. `chatCompletionStream()`) responses. +Client supports the `anthropic.claude-3-opus-20240229-v1:0`,`anthropic.claude-3-sonnet-20240229-v1:0`,`anthropic.claude-3-haiku-20240307-v1:0` and the legacy `anthropic.claude-v2`, `anthropic.claude-v2:1` and `anthropic.claude-instant-v1` models for both synchronous (e.g. `chatCompletion()`) and streaming (e.g. `chatCompletionStream()`) responses. Here is a simple snippet how to use the api programmatically: diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama2.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc similarity index 55% rename from spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama2.adoc rename to spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc index 7f891e39432..1d04b2b8ffb 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama2.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc @@ -1,13 +1,13 @@ -= Llama2 Chat += Llama Chat -https://ai.meta.com/llama/[Meta's Llama 2 Chat] is part of the Llama 2 collection of large language models. +https://ai.meta.com/llama/[Meta's Llama Chat] is part of the Llama collection of large language models. It excels in dialogue-based applications with a parameter scale ranging from 7 billion to 70 billion. Leveraging public datasets and over 1 million human annotations, Llama Chat offers context-aware dialogues. -Trained on 2 trillion tokens from public data sources, Llama-2-Chat provides extensive knowledge for insightful conversations. +Trained on 2 trillion tokens from public data sources, Llama-Chat provides extensive knowledge for insightful conversations. Rigorous testing, including over 1,000 hours of red-teaming and annotation, ensures both performance and safety, making it a reliable choice for AI-driven dialogues. -The https://aws.amazon.com/bedrock/llama-2/[AWS Llama 2 Model Page] and https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html[Amazon Bedrock User Guide] contains detailed information on how to use the AWS hosted model. +The https://aws.amazon.com/bedrock/llama/[AWS Llama Model Page] and https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html[Amazon Bedrock User Guide] contains detailed information on how to use the AWS hosted model. == Prerequisites @@ -43,15 +43,15 @@ dependencies { TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. -=== Enable Llama2 Chat Support +=== Enable Llama Chat Support -By default the Bedrock Llama2 model is disabled. -To enable it set the `spring.ai.bedrock.llama2.chat.enabled` property to `true`. +By default the Bedrock Llama model is disabled. +To enable it set the `spring.ai.bedrock.llama.chat.enabled` property to `true`. Exporting environment variable is one way to set this configuration property: [source,shell] ---- -export SPRING_AI_BEDROCK_LLAMA2_CHAT_ENABLED=true +export SPRING_AI_BEDROCK_LLAMA_CHAT_ENABLED=true ---- === Chat Properties @@ -69,29 +69,29 @@ The prefix `spring.ai.bedrock.aws` is the property prefix to configure the conne |==== -The prefix `spring.ai.bedrock.llama2.chat` is the property prefix that configures the chat client implementation for Llama2. +The prefix `spring.ai.bedrock.llama.chat` is the property prefix that configures the chat client implementation for Llama. [cols="2,5,1"] |==== | Property | Description | Default -| spring.ai.bedrock.llama2.chat.enabled | Enable or disable support for Llama2 | false -| spring.ai.bedrock.llama2.chat.model | The model id to use (See Below) | meta.llama2-70b-chat-v1 -| spring.ai.bedrock.llama2.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied, while a value closer to 0.0 will typically result in less surprising responses from the model. This value specifies default to be used by the backend while making the call to the model. | 0.7 -| spring.ai.bedrock.llama2.chat.options.top-p | The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and nucleus sampling. Nucleus sampling considers the smallest set of tokens whose probability sum is at least topP. | AWS Bedrock default -| spring.ai.bedrock.llama2.chat.options.max-gen-len | Specify the maximum number of tokens to use in the generated response. The model truncates the response once the generated text exceeds maxGenLen. | 300 +| spring.ai.bedrock.llama.chat.enabled | Enable or disable support for Llama | false +| spring.ai.bedrock.llama.chat.model | The model id to use (See Below) | meta.llama3-70b-instruct-v1:0 +| spring.ai.bedrock.llama.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied, while a value closer to 0.0 will typically result in less surprising responses from the model. This value specifies default to be used by the backend while making the call to the model. | 0.7 +| spring.ai.bedrock.llama.chat.options.top-p | The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and nucleus sampling. Nucleus sampling considers the smallest set of tokens whose probability sum is at least topP. | AWS Bedrock default +| spring.ai.bedrock.llama.chat.options.max-gen-len | Specify the maximum number of tokens to use in the generated response. The model truncates the response once the generated text exceeds maxGenLen. | 300 |==== -Look at https://github.com/spring-projects/spring-ai/blob/4ba9a3cd689b9fd3a3805f540debe398a079c6ef/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApi.java#L164[Llama2ChatBedrockApi#Llama2ChatModel] for other model IDs. The other value supported is `meta.llama2-13b-chat-v1`. +Look at https://github.com/spring-projects/spring-ai/blob/4ba9a3cd689b9fd3a3805f540debe398a079c6ef/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java#L164[LlamaChatBedrockApi#LlamaChatModel] for other model IDs. The other value supported is `meta.llama2-13b-chat-v1`. Model ID values can also be found in the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html[AWS Bedrock documentation for base model IDs]. -TIP: All properties prefixed with `spring.ai.bedrock.llama2.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. +TIP: All properties prefixed with `spring.ai.bedrock.llama.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. == Runtime Options [[chat-options]] -The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatOptions.java[BedrockLlama2ChatOptions.java] provides model configurations, such as temperature, topK, topP, etc. +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java[BedrockLlChatOptions.java] provides model configurations, such as temperature, topK, topP, etc. -On start-up, the default options can be configured with the `BedrockLlama2ChatClient(api, options)` constructor or the `spring.ai.bedrock.llama2.chat.options.*` properties. +On start-up, the default options can be configured with the `BedrockLlamaChatClient(api, options)` constructor or the `spring.ai.bedrock.llama.chat.options.*` properties. At run-time you can override the default options by adding new, request specific, options to the `Prompt` call. For example to override the default temperature for a specific request: @@ -101,13 +101,13 @@ For example to override the default temperature for a specific request: ChatResponse response = chatClient.call( new Prompt( "Generate the names of 5 famous pirates.", - BedrockLlama2ChatOptions.builder() + BedrockLlamaChatOptions.builder() .withTemperature(0.4) .build() )); ---- -TIP: In addition to the model specific https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatOptions.java[BedrockLlama2ChatOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()]. +TIP: In addition to the model specific https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java[BedrockLlamaChatOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()]. == Sample Controller @@ -122,13 +122,13 @@ spring.ai.bedrock.aws.timeout=1000ms spring.ai.bedrock.aws.access-key=${AWS_ACCESS_KEY_ID} spring.ai.bedrock.aws.secret-key=${AWS_SECRET_ACCESS_KEY} -spring.ai.bedrock.llama2.chat.enabled=true -spring.ai.bedrock.llama2.chat.options.temperature=0.8 +spring.ai.bedrock.llama.chat.enabled=true +spring.ai.bedrock.llama.chat.options.temperature=0.8 ---- TIP: replace the `regions`, `access-key` and `secret-key` with your AWS credentials. -This will create a `BedrockLlama2ChatClient` implementation that you can inject into your class. +This will create a `BedrockLlamaChatClient` implementation that you can inject into your class. Here is an example of a simple `@Controller` class that uses the chat client for text generations. [source,java] @@ -136,10 +136,10 @@ Here is an example of a simple `@Controller` class that uses the chat client for @RestController public class ChatController { - private final BedrockLlama2ChatClient chatClient; + private final BedrockLlamaChatClient chatClient; @Autowired - public ChatController(BedrockLlama2ChatClient chatClient) { + public ChatController(BedrockLlamaChatClient chatClient) { this.chatClient = chatClient; } @@ -158,7 +158,7 @@ public class ChatController { == Manual Configuration -The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java[BedrockLlama2ChatClient] implements the `ChatClient` and `StreamingChatClient` and uses the <> to connect to the Bedrock Anthropic service. +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatClient.java[BedrockLlamaChatClient] implements the `ChatClient` and `StreamingChatClient` and uses the <> to connect to the Bedrock Anthropic service. Add the `spring-ai-bedrock` dependency to your project's Maven `pom.xml` file: @@ -181,18 +181,18 @@ dependencies { TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. -Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java[BedrockLlama2ChatClient] and use it for text generations: +Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatClient.java[BedrockLlamaChatClient] and use it for text generations: [source,java] ---- -Llama2ChatBedrockApi api = new Llama2ChatBedrockApi(Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(), +LlamaChatBedrockApi api = new LlamaChatBedrockApi(LlamaChatModel.LLAMA2_70B_CHAT_V1.id(), EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), Duration.ofMillis(1000L)); -BedrockLlama2ChatClient chatClient = new BedrockLlama2ChatClient(api, - BedrockLlama2ChatOptions.builder() +BedrockLlamaChatClient chatClient = new BedrockLlamaChatClient(api, + BedrockLlamaChatOptions.builder() .withTemperature(0.5f) .withMaxGenLen(100) .withTopP(0.9f).build()); @@ -205,38 +205,38 @@ Flux response = chatClient.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- -== Low-level Llama2ChatBedrockApi Client [[low-level-api]] +== Low-level LlamaChatBedrockApi Client [[low-level-api]] -https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApi.java[Llama2ChatBedrockApi] provides is lightweight Java client on top of AWS Bedrock https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html[Meta Llama 2 and Llama 2 Chat models]. +https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java[LlamaChatBedrockApi] provides is lightweight Java client on top of AWS Bedrock https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html[Meta Llama 2 and Llama 2 Chat models]. -Following class diagram illustrates the Llama2ChatBedrockApi interface and building blocks: +Following class diagram illustrates the LlamaChatBedrockApi interface and building blocks: -image::bedrock/bedrock-llama2-chat-api.jpg[Llama2ChatBedrockApi Class Diagram] +image::bedrock/bedrock-llama-chat-api.jpg[LlamaChatBedrockApi Class Diagram] -The Llama2ChatBedrockApi supports the `meta.llama2-13b-chat-v1` and `meta.llama2-70b-chat-v1` models for both synchronous (e.g. `chatCompletion()`) and streaming (e.g. `chatCompletionStream()`) responses. +The LlamaChatBedrockApi supports the `meta.llama3-8b-instruct-v1:0`,`meta.llama3-70b-instruct-v1:0`,`meta.llama2-13b-chat-v1` and `meta.llama2-70b-chat-v1` models for both synchronous (e.g. `chatCompletion()`) and streaming (e.g. `chatCompletionStream()`) responses. Here is a simple snippet how to use the api programmatically: [source,java] ---- -Llama2ChatBedrockApi llama2ChatApi = new Llama2ChatBedrockApi( - Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(), +LlamaChatBedrockApi llamaChatApi = new LlamaChatBedrockApi( + LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(), Region.US_EAST_1.id(), Duration.ofMillis(1000L)); -Llama2ChatRequest request = Llama2ChatRequest.builder("Hello, my name is") +LlamaChatRequest request = LlamaChatRequest.builder("Hello, my name is") .withTemperature(0.9f) .withTopP(0.9f) .withMaxGenLen(20) .build(); -Llama2ChatResponse response = llama2ChatApi.chatCompletion(request); +LlamaChatResponse response = llamaChatApi.chatCompletion(request); // Streaming response -Flux responseStream = llama2ChatApi.chatCompletionStream(request); -List responses = responseStream.collectList().block(); +Flux responseStream = llamaChatApi.chatCompletionStream(request); +List responses = responseStream.collectList().block(); ---- -Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApi.java[Llama2ChatBedrockApi.java]'s JavaDoc for further information. +Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java[LlamaChatBedrockApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc index d3cf341bf1f..7f74dc4a320 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc @@ -9,7 +9,7 @@ The OpenAI API does not call the function directly; instead, the model generates Spring AI provides flexible and user-friendly ways to register and call custom functions. In general, the custom functions need to provide a function `name`, `description`, and the function call `signature` (as JSON schema) to let the model know what arguments the function expects. The `description` helps the model to understand when to call the function. -As a developer, you need to implement a functions that takes the function call arguments sent from the AI model, and respond with the result back to the model. Your function can in turn invoke other 3rd party services to provide the results. +As a developer, you need to implement a function that takes the function call arguments sent from the AI model, and respond with the result back to the model. Your function can in turn invoke other 3rd party services to provide the results. Spring AI makes this as easy as defining a `@Bean` definition that returns a `java.util.Function` and supplying the bean name as an option when invoking the `ChatClient`. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc index e1e36a4597d..4bb9acf2566 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc @@ -176,9 +176,9 @@ where fruits are being displayed, possibly for convenience or aesthetic purposes == Sample Controller -https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-openai-spring-boot-starter` to your pom (or gradle) dependencies. +https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-ollama-spring-boot-starter` to your pom (or gradle) dependencies. -Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the OpenAi Chat client: +Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the Ollama Chat client: [source,application.properties] ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc index 7bd008b400f..902d0d59b78 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc @@ -183,7 +183,7 @@ image::spring-ai-chat-completions-clients.jpg[align="center", width="800px"] * xref:api/chat/vertexai-gemini-chat.adoc[Google Vertex AI Gemini Chat Completion] (streaming, multi-modality & function-calling support) * xref:api/bedrock.adoc[Amazon Bedrock] ** xref:api/chat/bedrock/bedrock-cohere.adoc[Cohere Chat Completion] -** xref:api/chat/bedrock/bedrock-llama2.adoc[Llama2 Chat Completion] +** xref:api/chat/bedrock/bedrock-llama.adoc[Llama Chat Completion] ** xref:api/chat/bedrock/bedrock-titan.adoc[Titan Chat Completion] ** xref:api/chat/bedrock/bedrock-anthropic.adoc[Anthropic Chat Completion] ** xref:api/chat/bedrock/bedrock-jurassic2.adoc[Jurassic2 Chat Completion] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc index 6d5d675a06c..b7bc8a74eb7 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc @@ -81,6 +81,22 @@ The prefix `spring.ai.bedrock.titan.embedding` (defined in `BedrockTitanEmbeddin Supported values are: `amazon.titan-embed-image-v1` and `amazon.titan-embed-text-v1`. Model ID values can also be found in the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html[AWS Bedrock documentation for base model IDs]. +== Runtime Options [[embedding-options]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java[BedrockTitanEmbeddingOptions.java] provides model configurations, such as `input-type`. +On start-up, the default options can be configured with the `BedrockTitanEmbeddingClient(api).withInputType(type)` method or the `spring.ai.bedrock.titan.embedding.input-type` properties. + +At run-time you can override the default options by adding new, request specific, options to the `EmbeddingRequest` call. +For example to override the default temperature for a specific request: + +[source,java] +---- +EmbeddingResponse embeddingResponse = embeddingClient.call( + new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"), + BedrockTitanEmbeddingOptions.builder() + .withInputType(InputType.TEXT) + .build())); +---- == Sample Controller @@ -154,7 +170,7 @@ Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/sp var titanEmbeddingApi = new TitanEmbeddingBedrockApi( TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id(), Region.US_EAST_1.id()); -var embeddingClient new BedrockTitanEmbeddingClient(titanEmbeddingApi); +var embeddingClient = new BedrockTitanEmbeddingClient(titanEmbeddingApi); EmbeddingResponse embeddingResponse = embeddingClient .embedForResponse(List.of("Hello World")); // NOTE titan does not support batch embedding. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc index f1d9eb04a83..75a0d45f3b5 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc @@ -187,14 +187,17 @@ Next, create an `OpenAiEmbeddingClient` instance and use it to compute the simil ---- var openAiApi = new OpenAiApi(System.getenv("OPENAI_API_KEY")); -var embeddingClient = new OpenAiEmbeddingClient(openAiApi) - .withDefaultOptions(OpenAiChatOptions.build() - .withModel("text-embedding-ada-002") - .withUser("user-6") - .build()); +var embeddingClient = new OpenAiEmbeddingClient( + openAiApi, + MetadataMode.EMBED, + OpenAiEmbeddingOptions.builder() + .withModel("text-embedding-ada-002") + .withUser("user-6") + .build(), + RetryUtils.DEFAULT_RETRY_TEMPLATE); EmbeddingResponse embeddingResponse = embeddingClient - .embedForResponse(List.of("Hello World", "World is big and salvation is near")); + .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- The `OpenAiEmbeddingOptions` provides the configuration information for the embedding requests. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc index 87b237a4cec..d0656f16084 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc @@ -170,10 +170,10 @@ Example: public class MyTikaDocumentReader { @Value("classpath:/word-sample.docx") // This is the word document to load - private Resource resource; + private Resource resource; - List loadText() { - TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(resourceUri); + List loadText() { + TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(resource); return tikaDocumentReader.get(); } } @@ -229,4 +229,4 @@ See xref:api/vectordbs.adoc[Vector DB Documentation] for a full listing. The following class diagram illustrates the ETL interfaces and implementations. // image::etl-class-diagram.jpg[align="center", width="800px"] -image::etl-class-diagram.jpg[align="center"] \ No newline at end of file +image::etl-class-diagram.jpg[align="center"] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/multimodality.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/multimodality.adoc new file mode 100644 index 00000000000..37a43ca9182 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/multimodality.adoc @@ -0,0 +1,64 @@ +[[Multimodality]] += Multimodality API + +Humans process knowledge, simultaneously across multiple modes of data inputs. +The way we learn, our experiences are all multimodal. +We don't have just vision, just audio and just text. + +These foundational principles of learning were articulated by the father of modern education link:https://en.wikipedia.org/wiki/John_Amos_Comenius[John Amos Comenius], in his work, "Orbis Sensualium Pictus", dating back to 1658. + +image::orbis-sensualium-pictus2.jpg[Orbis Sensualium Pictus, align="center"] + +> "All things that are naturally connected ought to be taught in combination" + +Contrary to those principles, in the past, our approach to Machine Learning was often focused on specialized models tailored to process a single modality. +For instance, we developed audio models for tasks like text-to-speech or speech-to-text, and computer vision models for tasks such as object detection and classification. + +However, a new wave of multimodal large language models starts to emerge. +Examples include OpenAI's GPT-4 Vision, Google's Vertex AI Gemini Pro Vision, Anthropic's Claude3, and open source offerings LLaVA and balklava are able to accept multiple inputs, including text images, audio and video and generate text responses by integrating these inputs. + +The multimodal large language model (LLM) features enable the models to process and generate text in conjunction with other modalities such as images, audio, or video. + +== Spring AI Multimodality + +Multimodality refers to a model’s ability to simultaneously understand and process information from various sources, including text, images, audio, and other data formats. + +The Spring AI Message API provides all necessary abstractions to support multimodal LLMs. + +image::spring-ai-message-api.jpg[Spring AI Message API, width=600, align="center"] + +The Message’s `content` field is used as primarily text inputs, while the, optional, `media` field allows adding one or more additional content of different modalities such as images, audio and video. +The `MimeType` specifies the modality type. +Depending on the used LLMs the Media's data field can be either encoded raw media content or an URI to the content. + +NOTE: The media field is currently applicable only for user input messages (e.g., `UserMessage`). It does not hold significance for system messages. The `AssistantMessage`, which includes the LLM response, provides text content only. To generate non-text media outputs, you should utilize one of dedicated, single modality models.* + + +For example we can take the following picture (*multimodal.test.png*) as an input and ask the LLM to explain what it sees in the picture. + +image::multimodal.test.png[Multimodal Test Image, 200, 200, align="left"] + +From most of the multimodal LLMs, the Spring AI code would look something like this: + +[source,java] +---- +byte[] imageData = new ClassPathResource("/multimodal.test.png").getContentAsByteArray(); + +var userMessage = new UserMessage( + "Explain what do you see in this picture?", // content + List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); // media + +ChatResponse response = chatClient.call(new Prompt(List.of(userMessage))); +---- + +and produce a response like: + +> This is an image of a fruit bowl with a simple design. The bowl is made of metal with curved wire edges that create an open structure, allowing the fruit to be visible from all angles. Inside the bowl, there are two yellow bananas resting on top of what appears to be a red apple. The bananas are slightly overripe, as indicated by the brown spots on their peels. The bowl has a metal ring at the top, likely to serve as a handle for carrying. The bowl is placed on a flat surface with a neutral-colored background that provides a clear view of the fruit inside. + +Latest version of Spring AI provides multimodal support for the following Chat Clients: + +* xref:api/chat/openai-chat.adoc#_multimodal[Open AI - (GPT-4-Vision model)] +* xref:api/chat/openai-chat.adoc#_multimodal[Ollama - (LlaVa and Baklava models)] +* xref:api/chat/vertexai-gemini-chat.adoc#_multimodal[Vertex AI Gemini - (gemini-pro-vision model)] +* xref:api/chat/anthropic-chat.adoc#_multimodal[Anthropic Claude 3] +* xref:api/chat/bedrock/bedrock-anthropic3.adoc#_multimodal[AWS Bedrock Anthropic Claude 3] \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/apache-cassandra.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/apache-cassandra.adoc index a264c08c1a9..51c6110cb0f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/apache-cassandra.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/apache-cassandra.adoc @@ -4,9 +4,9 @@ This section walks you through setting up `CassandraVectorStore` to store docume == What is Apache Cassandra ? -link:https://cassandra.apache.org[Apache Cassandra] is a true open source distributed database reknown for scalability and high availability without compromising performance. +link:https://cassandra.apache.org[Apache Cassandra®] is a true open source distributed database reknown for linear scalability, proven fault-tolerance and low latency, making it the perfect platform for mission-critical transactional data. -Linear scalability, proven fault-tolerance and low latency on commodity hardware makes it the perfect platform for mission-critical data. Its Vector Similarity Search (VSS) is based on the JVector library that ensures best-in-class performance and relevancy. +Its Vector Similarity Search (VSS) is based on the JVector library that ensures best-in-class performance and relevancy. A vector search in Apache Cassandra is done as simply as: ``` @@ -15,9 +15,13 @@ SELECT content FROM table ORDER BY content_vector ANN OF query_embedding ; More docs on this can be read https://cassandra.apache.org/doc/latest/cassandra/getting-started/vector-search-quickstart.html[here]. -The Spring AI Cassandra Vector Store is designed to work for both brand new RAG applications as well as being able to be retrofitted on top of existing data and tables. This vector store may also equally be used for non-RAG non_AI use-cases, e.g. semantic searcing in an existing database. The Vector Store will automatically create, or enhance, the schema as needed according to its configuration. If you don't want the schema modifications, configure the store with `disallowSchemaChanges`. +This Spring AI Vector Store is designed to work for both brand new RAG applications as well as being able to be retrofitted on top of existing data and tables. -== What is JVector Vector Search ? +The store can also be used for non-RAG use-cases in an existing database, e.g. semantic searches, geo-proximity searches, etc. + +The store will automatically create, or enhance, the schema as needed according to its configuration. If you don't want the schema modifications, configure the store with `disallowSchemaChanges`. + +== What is JVector ? link:https://github.com/jbellis/jvector[JVector] is a pure Java embedded vector search engine. @@ -70,13 +74,6 @@ Add these dependencies to your project: TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. -* If for example you want to use the OpenAI modules, remember to provide your OpenAI API Key. Set it as an environment variable like so: - -[source,bash] ----- -export SPRING_AI_OPENAI_API_KEY='Your_OpenAI_API_Key' ----- - == Usage @@ -93,21 +90,14 @@ public VectorStore vectorStore(EmbeddingClient embeddingClient) { } ---- -NOTE: It is more convenient and preferred to create the `CassandraVectorStore` as a Bean. -But if you decide you can create it manually. - [NOTE] ==== -The default configuration connects to Cassandra at localhost:9042 and will automatically create the default schema at `springframework_ai_vector.springframework_ai_vector_store`. - -Please see `CassandraVectorStoreConfig.Builder` for all the configuration options. +The default configuration connects to Cassandra at `localhost:9042` and will automatically create a default schema in keyspace `springframework`, table `ai_vector_store`. ==== [NOTE] ==== -The Cassandra Java Driver is easiest configured via the `application.conf` file on the classpath. - -More info can be found link: https://github.com/apache/cassandra-java-driver/tree/4.x/manual/core/configuration[here]. +The Cassandra Java Driver is easiest configured via an `application.conf` file on the classpath. More info https://github.com/apache/cassandra-java-driver/tree/4.x/manual/core/configuration[here]. ==== Then in your main code, create some documents: @@ -148,7 +138,7 @@ List results = vectorStore.similaritySearch( === Metadata filtering -You can leverage the generic, portable link:https://docs.spring.io/spring-ai/reference/api/vectordbs.html#_metadata_filters[metadata filters] with the CassandraVectorStore as well. Metadata fields must be configured in `CassandraVectorStoreConfig`. +You can leverage the generic, portable link:https://docs.spring.io/spring-ai/reference/api/vectordbs.html#_metadata_filters[metadata filters] with the CassandraVectorStore as well. Metadata columns must be configured in `CassandraVectorStoreConfig`. For example, you can use either the text expression language: @@ -173,7 +163,9 @@ vectorStore.similaritySearch( The portable filter expressions get automatically converted into link:https://cassandra.apache.org/doc/latest/cassandra/developing/cql/index.html[CQL queries]. -Metadata fields to be searchable need to be either primary key columns or SAI indexed. To do this configure the metadata field with the `SchemaColumnTags.INDEXED`. +For metadata columns to be searchable they must be either primary keys or SAI indexed. To make non-primary-key columns indexed configure the metadata column with the `SchemaColumnTags.INDEXED`. + + == Advanced Example: Vector Store ontop full Wikipedia dataset @@ -187,7 +179,8 @@ Create the schema in the Cassandra database first: [source,bash] ---- -wget https://raw.githubusercontent.com/datastax-labs/colbert-wikipedia-data/main/schema.cql -O colbert-wikipedia-schema.cql +wget https://s.apache.org/colbert-wikipedia-schema-cql -O colbert-wikipedia-schema.cql + cqlsh -f colbert-wikipedia-schema.cql ---- @@ -212,14 +205,14 @@ public CassandraVectorStore store(EmbeddingClient embeddingClient) { .withTableName("articles") .withPartitionKeys(partitionColumns) .withClusteringKeys(clusteringColumns) - .withContentFieldName("body") - .withEmbeddingFieldName("all_minilm_l6_v2_embedding") + .withContentColumnName("body") + .withEmbeddingColumndName("all_minilm_l6_v2_embedding") .withIndexName("all_minilm_l6_v2_ann") .disallowSchemaChanges() - .addMetadataFields(extraColumns) + .addMetadataColumns(extraColumns) .withPrimaryKeyTranslator((List primaryKeys) -> { - // the deliminator used to join fields together into the document's id - // is arbitary, here "§¶" is used + // the deliminator used to join fields together into the document's id is arbitary + // here "§¶" is used if (primaryKeys.isEmpty()) { return "test§¶0"; } @@ -243,8 +236,11 @@ public EmbeddingClient embeddingClient() { } ---- + +== Complete wikipedia dataset + And, if you would like to load the full wikipedia dataset. -First download the `simplewiki-sstable.tar` from this link https://drive.google.com/file/d/1CcMMsj8jTKRVGep4A7hmOSvaPepsaKYP/view?usp=share_link . This will take a while, the file is tens of GBs. +First download the `simplewiki-sstable.tar` from this link https://s.apache.org/simplewiki-sstable-tar . This will take a while, the file is tens of GBs. [source,bash] ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc new file mode 100644 index 00000000000..9bea3eb060e --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc @@ -0,0 +1,185 @@ += Elasticsearch + +This section walks you through setting up the Elasticsearch `VectorStore` to store document embeddings and perform similarity searches. + +link:https://www.elastic.co/elasticsearch[Elasticsearch] is an open source search and analytics engine based on the Apache Lucene library. + +== Prerequisites + +* A running Elasticsearch instance. The following options are available: +** link:https://hub.docker.com/_/elasticsearch/[Docker] +** link:https://www.elastic.co/guide/en/elasticsearch/reference/current/install-elasticsearch.html#elasticsearch-install-packages[Self-Managed Elasticsearch] +** link:https://www.elastic.co/cloud/elasticsearch-service/signup?page=docs&placement=docs-body[Elastic Cloud] + +== Dependencies + +Add the Elasticsearch Vector Store dependency to your project: + +[source,xml] +---- + + org.springframework.ai + spring-ai-elasticsearch-store + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-elasticsearch-store' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +== Configuration + +To connect to Elasticsearch and use the `ElasticsearchVectorStore`, you need to provide access details for your instance. +A simple configuration can either be provided via Spring Boot's `application.yml`, + +[source,yaml] +---- +spring: + elasticsearch: + uris: + username: + password: +# API key if needed, e.g. OpenAI + ai: + openai: + api: + key: +---- + +environment variables, + +[source,bash] +---- +export SPRING_ELASTICSEARCH_URIS= +export SPRING_ELASTICSEARCH_USERNAME= +export SPRING_ELASTICSEARCH_PASSWORD= +# API key if needed, e.g. OpenAI +export SPRING_AI_OPENAI_API_KEY= +---- + +or can be a mix of those. +For example, if you want to store your password as an environment variable but keep the rest in the plain `application.yml` file. + +NOTE: If you choose to create a shell script for ease in future work, be sure to run it prior to starting your application by "sourcing" the file, i.e. `source .sh`. + +Spring Boot's auto-configuration feature for the Elasticsearch RestClient will create a bean instance that will be used by the `ElasticsearchVectorStore`. + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the Elasticsearch Vector Store. +To enable it, add the following dependency to your project's Maven `pom.xml` file: + +[source,xml] +---- + + org.springframework.ai + spring-ai-elasticsearch-store-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-elasticsearch-store-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +Please have a look at the list of <> for the vector store to learn about the default values and configuration options. + +Here is an example of the needed bean: + +[source,java] +---- +@Bean +public EmbeddingClient embeddingCLient() { + // Can be any other EmbeddingClient implementation + return new OpenAiEmbeddingClient(new OpenAiApi(System.getenv("SPRING_AI_OPENAI_API_KEY"))); +} +---- + +In cases where the Spring Boot auto-configured Elasticsearch `RestClient` bean is not what you want or need, you can still define your own bean. +Please read the link:https://www.elastic.co/guide/en/elasticsearch/client/java-api-client/current/java-rest-low-usage-initialization.html[Elasticsearch Documentation] +for more in-depth information about the configuration of a custom RestClient. + +[source,java] +---- +@Bean +public RestClient restClienbt() { + RestClientBuilder builder = RestClient.builder(new HttpHost("", 9200, "http")); + Header[] defaultHeaders = new Header[] { new BasicHeader("Authorization", "Basic ") }; + builder.setDefaultHeaders(defaultHeaders); + return builder.build(); +} +---- + +Now you can auto-wire the `ElasticsearchVectorStore` as a vector store in your application. + +== Metadata Filtering + +You can leverage the generic, portable xref:api/vectordbs.adoc#metadata-filters[metadata filters] with Elasticsearcg as well. + +For example, you can use either the text expression language: + +[source,java] +---- +vectorStore.similaritySearch(SearchRequest.defaults() + .withQuery("The World") + .withTopK(TOP_K) + .withSimilarityThreshold(SIMILARITY_THRESHOLD) + .withFilterExpression("author in ['john', 'jill'] && 'article_type' == 'blog'")); +---- + +or programmatically using the `Filter.Expression` DSL: + +[source,java] +---- +FilterExpressionBuilder b = new FilterExpressionBuilder(); + +vectorStore.similaritySearch(SearchRequest.defaults() + .withQuery("The World") + .withTopK(TOP_K) + .withSimilarityThreshold(SIMILARITY_THRESHOLD) + .withFilterExpression(b.and( + b.in("john", "jill"), + b.eq("article_type", "blog")).build())); +---- + +NOTE: Those (portable) filter expressions get automatically converted into the proprietary Elasticsearch `WHERE` link:https://www.elastic.co/guide/en/elasticsearch/reference/current/sql-syntax-select.html#sql-syntax-where[filter expressions]. + +For example, this portable filter expression: + +[source,sql] +---- +author in ['john', 'jill'] && 'article_type' == 'blog' +---- + +is converted into the proprietary Elasticsearch filter format: + +[source,text] +---- +(metadata.author:john OR jill) AND metadata.article_type:blog +---- + +[[elasticsearchvector-properties]] +== ElasticsearchVectorStore Properties + +You can use the following properties in your Spring Boot configuration to customize the Elasticsearch vector store. + +|=== +|Property |Default Value + +|`spring.ai.vectorstore.elasticsearch.index-name` +|spring-ai-document-index +|=== + diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/qdrant.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/qdrant.adoc index 3b8caba0ab8..a5d4b871a16 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/qdrant.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/qdrant.adoc @@ -90,7 +90,7 @@ List documents = List.of( new Document("You walk forward facing the past and you turn back toward the future.", Map.of("meta2", "meta2"))); // Add the documents to Qdrant -vectorStore.add(List.of(document)); +vectorStore.add(documents); // Retrieve documents similar to a query List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc index 486e6e1e4ec..e2f04a0796f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc @@ -147,7 +147,7 @@ Each of the following sections in the documentation shows which dependencies you ** xref:api/chat/vertexai-gemini-chat.adoc[Google Vertex AI Gemini Chat Completion] (streaming, multi-modality & function-calling support) ** xref:api/bedrock.adoc[Amazon Bedrock] *** xref:api/chat/bedrock/bedrock-cohere.adoc[Cohere Chat Completion] -*** xref:api/chat/bedrock/bedrock-llama2.adoc[Llama2 Chat Completion] +*** xref:api/chat/bedrock/bedrock-llama.adoc[Llama Chat Completion] *** xref:api/chat/bedrock/bedrock-titan.adoc[Titan Chat Completion] *** xref:api/chat/bedrock/bedrock-anthropic.adoc[Anthropic Chat Completion] *** xref:api/chat/bedrock/bedrock-jurassic2.adoc[Jurassic2 Chat Completion] diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index d48275f3d83..a01b11affc9 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -267,6 +267,13 @@ true + + org.springframework.ai + spring-ai-opensearch-store + ${project.parent.version} + true + + @@ -354,6 +361,13 @@ test + + org.opensearch + opensearch-testcontainers + ${testcontainers.opensearch.version} + test + + org.skyscreamer jsonassert diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java index 9cfd634d13b..46ee804056d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java @@ -19,6 +19,9 @@ import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; +import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.context.properties.EnableConfigurationProperties; @@ -28,6 +31,7 @@ /** * @author Christian Tzolov + * @author Wei Jiang */ @Configuration @EnableConfigurationProperties({ BedrockAwsConnectionProperties.class }) @@ -45,4 +49,38 @@ public AwsCredentialsProvider credentialsProvider(BedrockAwsConnectionProperties return DefaultCredentialsProvider.create(); } + @Bean + @ConditionalOnMissingBean + public AwsRegionProvider regionProvider(BedrockAwsConnectionProperties properties) { + + if (StringUtils.hasText(properties.getRegion())) { + return new StaticRegionProvider(properties.getRegion()); + } + + return DefaultAwsRegionProviderChain.builder().build(); + } + + /** + * @author Wei Jiang + */ + static class StaticRegionProvider implements AwsRegionProvider { + + private final Region region; + + public StaticRegionProvider(String region) { + try { + this.region = Region.of(region); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("The region '" + region + "' is not a valid region!", e); + } + } + + @Override + public Region getRegion() { + return this.region; + } + + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java index 9a74161b050..7b3d4d2d545 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java @@ -21,6 +21,7 @@ import org.springframework.ai.bedrock.anthropic.BedrockAnthropicChatClient; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi; import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; @@ -28,6 +29,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Anthropic Chat Client. @@ -35,6 +37,7 @@ * Leverages the Spring Cloud AWS to resolve the {@link AwsCredentialsProvider}. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @AutoConfiguration @@ -46,16 +49,18 @@ public class BedrockAnthropicChatAutoConfiguration { @Bean @ConditionalOnMissingBean + @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) public AnthropicChatBedrockApi anthropicApi(AwsCredentialsProvider credentialsProvider, - BedrockAnthropicChatProperties properties, BedrockAwsConnectionProperties awsProperties) { - return new AnthropicChatBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(), + AwsRegionProvider regionProvider, BedrockAnthropicChatProperties properties, + BedrockAwsConnectionProperties awsProperties) { + return new AnthropicChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), new ObjectMapper(), awsProperties.getTimeout()); } @Bean + @ConditionalOnBean(AnthropicChatBedrockApi.class) public BedrockAnthropicChatClient anthropicChatClient(AnthropicChatBedrockApi anthropicApi, BedrockAnthropicChatProperties properties) { - return new BedrockAnthropicChatClient(anthropicApi, properties.getOptions()); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java index 31f18597ca7..60e5cdce69c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java @@ -21,6 +21,7 @@ import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatClient; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; @@ -28,6 +29,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Anthropic Chat Client. @@ -35,6 +37,7 @@ * Leverages the Spring Cloud AWS to resolve the {@link AwsCredentialsProvider}. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @AutoConfiguration @@ -46,14 +49,17 @@ public class BedrockAnthropic3ChatAutoConfiguration { @Bean @ConditionalOnMissingBean - public Anthropic3ChatBedrockApi anthropicApi(AwsCredentialsProvider credentialsProvider, - BedrockAnthropic3ChatProperties properties, BedrockAwsConnectionProperties awsProperties) { - return new Anthropic3ChatBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(), + @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) + public Anthropic3ChatBedrockApi anthropic3Api(AwsCredentialsProvider credentialsProvider, + AwsRegionProvider regionProvider, BedrockAnthropic3ChatProperties properties, + BedrockAwsConnectionProperties awsProperties) { + return new Anthropic3ChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), new ObjectMapper(), awsProperties.getTimeout()); } @Bean - public BedrockAnthropic3ChatClient anthropicChatClient(Anthropic3ChatBedrockApi anthropicApi, + @ConditionalOnBean(Anthropic3ChatBedrockApi.class) + public BedrockAnthropic3ChatClient anthropic3ChatClient(Anthropic3ChatBedrockApi anthropicApi, BedrockAnthropic3ChatProperties properties) { return new BedrockAnthropic3ChatClient(anthropicApi, properties.getOptions()); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java index 3ee34f90fa1..66b6d5e5e77 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java @@ -21,6 +21,7 @@ import org.springframework.ai.bedrock.cohere.BedrockCohereChatClient; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; @@ -28,11 +29,13 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Cohere Chat Client. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @AutoConfiguration @@ -44,13 +47,16 @@ public class BedrockCohereChatAutoConfiguration { @Bean @ConditionalOnMissingBean + @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) public CohereChatBedrockApi cohereChatApi(AwsCredentialsProvider credentialsProvider, - BedrockCohereChatProperties properties, BedrockAwsConnectionProperties awsProperties) { - return new CohereChatBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(), + AwsRegionProvider regionProvider, BedrockCohereChatProperties properties, + BedrockAwsConnectionProperties awsProperties) { + return new CohereChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), new ObjectMapper(), awsProperties.getTimeout()); } @Bean + @ConditionalOnBean(CohereChatBedrockApi.class) public BedrockCohereChatClient cohereChatClient(CohereChatBedrockApi cohereChatApi, BedrockCohereChatProperties properties) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java index 95ecba88858..76e3ea6033d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java @@ -17,12 +17,14 @@ import com.fasterxml.jackson.databind.ObjectMapper; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.bedrock.cohere.BedrockCohereEmbeddingClient; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi; import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; @@ -34,6 +36,7 @@ * {@link AutoConfiguration Auto-configuration} for Bedrock Cohere Embedding Client. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @AutoConfiguration @@ -45,14 +48,17 @@ public class BedrockCohereEmbeddingAutoConfiguration { @Bean @ConditionalOnMissingBean + @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) public CohereEmbeddingBedrockApi cohereEmbeddingApi(AwsCredentialsProvider credentialsProvider, - BedrockCohereEmbeddingProperties properties, BedrockAwsConnectionProperties awsProperties) { - return new CohereEmbeddingBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(), + AwsRegionProvider regionProvider, BedrockCohereEmbeddingProperties properties, + BedrockAwsConnectionProperties awsProperties) { + return new CohereEmbeddingBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), new ObjectMapper(), awsProperties.getTimeout()); } @Bean @ConditionalOnMissingBean + @ConditionalOnBean(CohereEmbeddingBedrockApi.class) public BedrockCohereEmbeddingClient cohereEmbeddingClient(CohereEmbeddingBedrockApi cohereEmbeddingApi, BedrockCohereEmbeddingProperties properties) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java index f7c50657238..e8266a0e417 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java @@ -22,6 +22,7 @@ import org.springframework.ai.bedrock.jurassic2.BedrockAi21Jurassic2ChatClient; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; @@ -29,11 +30,13 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Jurassic2 Chat Client. * * @author Ahmed Yousri + * @author Wei Jiang * @since 1.0.0 */ @AutoConfiguration @@ -46,13 +49,16 @@ public class BedrockAi21Jurassic2ChatAutoConfiguration { @Bean @ConditionalOnMissingBean + @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) public Ai21Jurassic2ChatBedrockApi ai21Jurassic2ChatBedrockApi(AwsCredentialsProvider credentialsProvider, - BedrockAi21Jurassic2ChatProperties properties, BedrockAwsConnectionProperties awsProperties) { - return new Ai21Jurassic2ChatBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(), + AwsRegionProvider regionProvider, BedrockAi21Jurassic2ChatProperties properties, + BedrockAwsConnectionProperties awsProperties) { + return new Ai21Jurassic2ChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), new ObjectMapper(), awsProperties.getTimeout()); } @Bean + @ConditionalOnBean(Ai21Jurassic2ChatBedrockApi.class) public BedrockAi21Jurassic2ChatClient jurassic2ChatClient(Ai21Jurassic2ChatBedrockApi ai21Jurassic2ChatBedrockApi, BedrockAi21Jurassic2ChatProperties properties) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java similarity index 57% rename from spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfiguration.java rename to spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java index 314e3671be2..9293acc84f2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.bedrock.llama2; +package org.springframework.ai.autoconfigure.bedrock.llama; import com.fasterxml.jackson.databind.ObjectMapper; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; -import org.springframework.ai.bedrock.llama2.BedrockLlama2ChatClient; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi; +import org.springframework.ai.bedrock.llama.BedrockLlamaChatClient; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; @@ -31,33 +33,35 @@ import org.springframework.context.annotation.Import; /** - * {@link AutoConfiguration Auto-configuration} for Bedrock Llama2 Chat Client. + * {@link AutoConfiguration Auto-configuration} for Bedrock Llama Chat Client. * * Leverages the Spring Cloud AWS to resolve the {@link AwsCredentialsProvider}. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @AutoConfiguration -@ConditionalOnClass(Llama2ChatBedrockApi.class) -@EnableConfigurationProperties({ BedrockLlama2ChatProperties.class, BedrockAwsConnectionProperties.class }) -@ConditionalOnProperty(prefix = BedrockLlama2ChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true") +@ConditionalOnClass(LlamaChatBedrockApi.class) +@EnableConfigurationProperties({ BedrockLlamaChatProperties.class, BedrockAwsConnectionProperties.class }) +@ConditionalOnProperty(prefix = BedrockLlamaChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true") @Import(BedrockAwsConnectionConfiguration.class) -public class BedrockLlama2ChatAutoConfiguration { +public class BedrockLlamaChatAutoConfiguration { @Bean @ConditionalOnMissingBean - public Llama2ChatBedrockApi llama2Api(AwsCredentialsProvider credentialsProvider, - BedrockLlama2ChatProperties properties, BedrockAwsConnectionProperties awsProperties) { - return new Llama2ChatBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(), + @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) + public LlamaChatBedrockApi llamaApi(AwsCredentialsProvider credentialsProvider, AwsRegionProvider regionProvider, + BedrockLlamaChatProperties properties, BedrockAwsConnectionProperties awsProperties) { + return new LlamaChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), new ObjectMapper(), awsProperties.getTimeout()); } @Bean - public BedrockLlama2ChatClient llama2ChatClient(Llama2ChatBedrockApi llama2Api, - BedrockLlama2ChatProperties properties) { + @ConditionalOnBean(LlamaChatBedrockApi.class) + public BedrockLlamaChatClient llamaChatClient(LlamaChatBedrockApi llamaApi, BedrockLlamaChatProperties properties) { - return new BedrockLlama2ChatClient(llama2Api, properties.getOptions()); + return new BedrockLlamaChatClient(llamaApi, properties.getOptions()); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java similarity index 63% rename from spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatProperties.java rename to spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java index a81f3a4af0c..f93b65aca32 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java @@ -13,36 +13,36 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.bedrock.llama2; +package org.springframework.ai.autoconfigure.bedrock.llama; -import org.springframework.ai.bedrock.llama2.BedrockLlama2ChatOptions; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatModel; +import org.springframework.ai.bedrock.llama.BedrockLlamaChatOptions; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; /** - * Configuration properties for Bedrock Llama2. + * Configuration properties for Bedrock Llama. * * @author Christian Tzolov * @since 0.8.0 */ -@ConfigurationProperties(BedrockLlama2ChatProperties.CONFIG_PREFIX) -public class BedrockLlama2ChatProperties { +@ConfigurationProperties(BedrockLlamaChatProperties.CONFIG_PREFIX) +public class BedrockLlamaChatProperties { - public static final String CONFIG_PREFIX = "spring.ai.bedrock.llama2.chat"; + public static final String CONFIG_PREFIX = "spring.ai.bedrock.llama.chat"; /** - * Enable Bedrock Llama2 chat client. Disabled by default. + * Enable Bedrock Llama chat client. Disabled by default. */ private boolean enabled = false; /** - * The generative id to use. See the {@link Llama2ChatModel} for the supported models. + * The generative id to use. See the {@link LlamaChatModel} for the supported models. */ - private String model = Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(); + private String model = LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(); @NestedConfigurationProperty - private BedrockLlama2ChatOptions options = BedrockLlama2ChatOptions.builder() + private BedrockLlamaChatOptions options = BedrockLlamaChatOptions.builder() .withTemperature(0.7f) .withMaxGenLen(300) .build(); @@ -63,11 +63,11 @@ public void setModel(String model) { this.model = model; } - public BedrockLlama2ChatOptions getOptions() { + public BedrockLlamaChatOptions getOptions() { return this.options; } - public void setOptions(BedrockLlama2ChatOptions options) { + public void setOptions(BedrockLlamaChatOptions options) { this.options = options; } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java index e24ec3696ab..67995b9e39c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java @@ -21,6 +21,7 @@ import org.springframework.ai.bedrock.titan.BedrockTitanChatClient; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; @@ -28,11 +29,13 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Titan Chat Client. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @AutoConfiguration @@ -44,14 +47,16 @@ public class BedrockTitanChatAutoConfiguration { @Bean @ConditionalOnMissingBean + @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) public TitanChatBedrockApi titanChatBedrockApi(AwsCredentialsProvider credentialsProvider, - BedrockTitanChatProperties properties, BedrockAwsConnectionProperties awsProperties) { - - return new TitanChatBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(), + AwsRegionProvider regionProvider, BedrockTitanChatProperties properties, + BedrockAwsConnectionProperties awsProperties) { + return new TitanChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), new ObjectMapper(), awsProperties.getTimeout()); } @Bean + @ConditionalOnBean(TitanChatBedrockApi.class) public BedrockTitanChatClient titanChatClient(TitanChatBedrockApi titanChatApi, BedrockTitanChatProperties properties) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java index f36c6e427fd..5ea79d4513a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java @@ -17,12 +17,14 @@ import com.fasterxml.jackson.databind.ObjectMapper; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingClient; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi; import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; @@ -34,6 +36,7 @@ * {@link AutoConfiguration Auto-configuration} for Bedrock Titan Embedding Client. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @AutoConfiguration @@ -45,14 +48,17 @@ public class BedrockTitanEmbeddingAutoConfiguration { @Bean @ConditionalOnMissingBean + @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) public TitanEmbeddingBedrockApi titanEmbeddingBedrockApi(AwsCredentialsProvider credentialsProvider, - BedrockTitanEmbeddingProperties properties, BedrockAwsConnectionProperties awsProperties) { - return new TitanEmbeddingBedrockApi(properties.getModel(), credentialsProvider, awsProperties.getRegion(), + AwsRegionProvider regionProvider, BedrockTitanEmbeddingProperties properties, + BedrockAwsConnectionProperties awsProperties) { + return new TitanEmbeddingBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), new ObjectMapper(), awsProperties.getTimeout()); } @Bean @ConditionalOnMissingBean + @ConditionalOnBean(TitanEmbeddingBedrockApi.class) public BedrockTitanEmbeddingClient titanEmbeddingClient(TitanEmbeddingBedrockApi titanEmbeddingApi, BedrockTitanEmbeddingProperties properties) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraConnectionDetails.java deleted file mode 100644 index b67f90f6ac9..00000000000 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraConnectionDetails.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.autoconfigure.vectorstore.cassandra; - -import java.net.InetSocketAddress; -import java.util.List; - -import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; - -/** - * @author Mick Semb Wever - * @since 1.0.0 - */ -public interface CassandraConnectionDetails extends ConnectionDetails { - - boolean hasCassandraContactPoints(); - - List getCassandraContactPoints(); - - boolean hasCassandraLocalDatacenter(); - - String getCassandraLocalDatacenter(); - -} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java index 1bb93f410a3..23a077ee340 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java @@ -15,16 +15,17 @@ */ package org.springframework.ai.autoconfigure.vectorstore.cassandra; -import java.net.InetSocketAddress; -import java.util.Arrays; -import java.util.List; +import java.time.Duration; -import com.google.common.base.Preconditions; +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import org.springframework.ai.embedding.EmbeddingClient; import org.springframework.ai.vectorstore.CassandraVectorStore; import org.springframework.ai.vectorstore.CassandraVectorStoreConfig; import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration; +import org.springframework.boot.autoconfigure.cassandra.DriverConfigLoaderBuilderCustomizer; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.context.properties.EnableConfigurationProperties; @@ -34,37 +35,24 @@ * @author Mick Semb Wever * @since 1.0.0 */ -@AutoConfiguration -@ConditionalOnClass({ CassandraVectorStore.class, EmbeddingClient.class }) +@AutoConfiguration(after = CassandraAutoConfiguration.class) +@ConditionalOnClass({ CassandraVectorStore.class, CqlSession.class }) @EnableConfigurationProperties(CassandraVectorStoreProperties.class) public class CassandraVectorStoreAutoConfiguration { - @Bean - @ConditionalOnMissingBean(CassandraConnectionDetails.class) - public PropertiesCassandraConnectionDetails cassandraConnectionDetails(CassandraVectorStoreProperties properties) { - return new PropertiesCassandraConnectionDetails(properties); - } - @Bean @ConditionalOnMissingBean public CassandraVectorStore vectorStore(EmbeddingClient embeddingClient, CassandraVectorStoreProperties properties, - CassandraConnectionDetails cassandraConnectionDetails) { + CqlSession cqlSession) { - var builder = CassandraVectorStoreConfig.builder(); - if (cassandraConnectionDetails.hasCassandraContactPoints()) { - for (InetSocketAddress contactPoint : cassandraConnectionDetails.getCassandraContactPoints()) { - builder = builder.addContactPoint(contactPoint); - } - } - if (cassandraConnectionDetails.hasCassandraLocalDatacenter()) { - builder = builder.withLocalDatacenter(cassandraConnectionDetails.getCassandraLocalDatacenter()); - } + var builder = CassandraVectorStoreConfig.builder().withCqlSession(cqlSession); builder = builder.withKeyspaceName(properties.getKeyspace()) .withTableName(properties.getTable()) - .withContentColumnName(properties.getContentFieldName()) - .withEmbeddingColumnName(properties.getEmbeddingFieldName()) - .withIndexName(properties.getIndexName()); + .withContentColumnName(properties.getContentColumnName()) + .withEmbeddingColumnName(properties.getEmbeddingColumnName()) + .withIndexName(properties.getIndexName()) + .withFixedThreadPoolExecutorSize(properties.getFixedThreadPoolExecutorSize()); if (properties.getDisallowSchemaCreation()) { builder = builder.disallowSchemaChanges(); @@ -73,46 +61,20 @@ public CassandraVectorStore vectorStore(EmbeddingClient embeddingClient, Cassand return new CassandraVectorStore(builder.build(), embeddingClient); } - private static class PropertiesCassandraConnectionDetails implements CassandraConnectionDetails { - - private final CassandraVectorStoreProperties properties; - - public PropertiesCassandraConnectionDetails(CassandraVectorStoreProperties properties) { - this.properties = properties; - } - - private String[] getCassandraContactPointHosts() { - return this.properties.getCassandraContactPointHosts().split("(,| )"); - } - - @Override - public List getCassandraContactPoints() { - - Preconditions.checkState(hasCassandraContactPoints(), "cassandraContactPointHosts has not been set"); - final int port = this.properties.getCassandraContactPointPort(); - - return Arrays.asList(getCassandraContactPointHosts()) - .stream() - .map((host) -> InetSocketAddress.createUnresolved(host, port)) - .toList(); - } - - @Override - public String getCassandraLocalDatacenter() { - Preconditions.checkState(hasCassandraLocalDatacenter(), "cassandraLocalDatacenter has not been set"); - return this.properties.getCassandraLocalDatacenter(); - } - - @Override - public boolean hasCassandraContactPoints() { - return null != this.properties.getCassandraContactPointHosts(); - } - - @Override - public boolean hasCassandraLocalDatacenter() { - return null != this.properties.getCassandraLocalDatacenter(); - } - + @Bean + public DriverConfigLoaderBuilderCustomizer driverConfigLoaderBuilderCustomizer() { + // this replaces spring-ai-cassandra-*.jar!application.conf + // as spring-boot autoconfigure will not resolve the default driver configs + return (builder) -> builder.startProfile(CassandraVectorStore.DRIVER_PROFILE_UPDATES) + .withString(DefaultDriverOption.REQUEST_CONSISTENCY, "LOCAL_QUORUM") + .withDuration(DefaultDriverOption.REQUEST_TIMEOUT, Duration.ofSeconds(1)) + .withBoolean(DefaultDriverOption.REQUEST_DEFAULT_IDEMPOTENCE, true) + .endProfile() + .startProfile(CassandraVectorStore.DRIVER_PROFILE_SEARCH) + .withString(DefaultDriverOption.REQUEST_CONSISTENCY, "LOCAL_ONE") + .withDuration(DefaultDriverOption.REQUEST_TIMEOUT, Duration.ofSeconds(10)) + .withBoolean(DefaultDriverOption.REQUEST_DEFAULT_IDEMPOTENCE, true) + .endProfile(); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java index 73b014ab6b0..27af7605e38 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java @@ -15,6 +15,8 @@ */ package org.springframework.ai.autoconfigure.vectorstore.cassandra; +import com.google.api.client.util.Preconditions; + import org.springframework.ai.vectorstore.CassandraVectorStoreConfig; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -27,17 +29,11 @@ public class CassandraVectorStoreProperties { public static final String CONFIG_PREFIX = "spring.ai.vectorstore.cassandra"; - private String cassandraContactPointHosts = null; - - private int cassandraContactPointPort = 9042; - - private String cassandraLocalDatacenter = null; - private String keyspace = CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME; private String table = CassandraVectorStoreConfig.DEFAULT_TABLE_NAME; - private String indexName = CassandraVectorStoreConfig.DEFAULT_INDEX_NAME; + private String indexName = null; private String contentColumnName = CassandraVectorStoreConfig.DEFAULT_CONTENT_COLUMN_NAME; @@ -45,30 +41,7 @@ public class CassandraVectorStoreProperties { private boolean disallowSchemaChanges = false; - public String getCassandraContactPointHosts() { - return this.cassandraContactPointHosts; - } - - /** comma or space separated */ - public void setCassandraContactPointHosts(String cassandraContactPointHosts) { - this.cassandraContactPointHosts = cassandraContactPointHosts; - } - - public int getCassandraContactPointPort() { - return this.cassandraContactPointPort; - } - - public void setCassandraContactPointPort(int cassandraContactPointPort) { - this.cassandraContactPointPort = cassandraContactPointPort; - } - - public String getCassandraLocalDatacenter() { - return this.cassandraLocalDatacenter; - } - - public void setCassandraLocalDatacenter(String cassandraLocalDatacenter) { - this.cassandraLocalDatacenter = cassandraLocalDatacenter; - } + private int fixedThreadPoolExecutorSize = CassandraVectorStoreConfig.DEFAULT_ADD_CONCURRENCY; public String getKeyspace() { return this.keyspace; @@ -94,20 +67,20 @@ public void setIndexName(String indexName) { this.indexName = indexName; } - public String getContentFieldName() { + public String getContentColumnName() { return this.contentColumnName; } - public void setContentFieldName(String contentFieldName) { - this.contentColumnName = contentFieldName; + public void setContentColumnName(String contentColumnName) { + this.contentColumnName = contentColumnName; } - public String getEmbeddingFieldName() { + public String getEmbeddingColumnName() { return this.embeddingColumnName; } - public void setEmbeddingFieldName(String embeddingFieldName) { - this.embeddingColumnName = embeddingFieldName; + public void setEmbeddingColumnName(String embeddingColumnName) { + this.embeddingColumnName = embeddingColumnName; } public Boolean getDisallowSchemaCreation() { @@ -118,4 +91,13 @@ public void setDisallowSchemaCreation(boolean disallowSchemaCreation) { this.disallowSchemaChanges = disallowSchemaCreation; } + public int getFixedThreadPoolExecutorSize() { + return this.fixedThreadPoolExecutorSize; + } + + public void setFixedThreadPoolExecutorSize(int fixedThreadPoolExecutorSize) { + Preconditions.checkArgument(0 < fixedThreadPoolExecutorSize); + this.fixedThreadPoolExecutorSize = fixedThreadPoolExecutorSize; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java new file mode 100644 index 00000000000..e5c24d4ddf0 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java @@ -0,0 +1,81 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.vectorstore.opensearch; + +import org.apache.hc.client5.http.auth.AuthScope; +import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; +import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; +import org.apache.hc.core5.http.HttpHost; +import org.opensearch.client.opensearch.OpenSearchClient; +import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder; +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.vectorstore.OpenSearchVectorStore; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; + +import java.net.URISyntaxException; +import java.util.Optional; + +@AutoConfiguration +@ConditionalOnClass({OpenSearchVectorStore.class, EmbeddingClient.class, OpenSearchClient.class}) +@EnableConfigurationProperties(OpenSearchVectorStoreProperties.class) +class OpenSearchVectorStoreAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + OpenSearchVectorStore vectorStore(OpenSearchVectorStoreProperties properties, OpenSearchClient openSearchClient, + EmbeddingClient embeddingClient) { + return new OpenSearchVectorStore( + Optional.ofNullable(properties.getIndexName()).orElse(OpenSearchVectorStore.DEFAULT_INDEX_NAME), + openSearchClient, embeddingClient, Optional.ofNullable(properties.getMappingJson()) + .orElse(OpenSearchVectorStore.DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION_1536)); + } + + @Bean + @ConditionalOnMissingBean + OpenSearchClient openSearchClient(OpenSearchVectorStoreProperties properties) { + HttpHost[] httpHosts = properties.getUris().stream().map(s -> createHttpHost(s)).toArray(HttpHost[]::new); + ApacheHttpClient5TransportBuilder transportBuilder = ApacheHttpClient5TransportBuilder.builder(httpHosts); + + Optional.ofNullable(properties.getUsername()) + .map(username -> createBasicCredentialsProvider(httpHosts[0], username, properties.getPassword())) + .ifPresent(basicCredentialsProvider -> transportBuilder.setHttpClientConfigCallback( + httpAsyncClientBuilder -> httpAsyncClientBuilder.setDefaultCredentialsProvider( + basicCredentialsProvider))); + + return new OpenSearchClient(transportBuilder.build()); + } + + private BasicCredentialsProvider createBasicCredentialsProvider(HttpHost httpHost, String username, + String password) { + BasicCredentialsProvider basicCredentialsProvider = new BasicCredentialsProvider(); + basicCredentialsProvider.setCredentials(new AuthScope(httpHost), + new UsernamePasswordCredentials(username, password.toCharArray())); + return basicCredentialsProvider; + } + + private HttpHost createHttpHost(String s) { + try { + return HttpHost.create(s); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreProperties.java new file mode 100644 index 00000000000..900cdbd233b --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreProperties.java @@ -0,0 +1,80 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.vectorstore.opensearch; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +import java.util.List; + +@ConfigurationProperties(prefix = OpenSearchVectorStoreProperties.CONFIG_PREFIX) +public class OpenSearchVectorStoreProperties { + + public static final String CONFIG_PREFIX = "spring.ai.vectorstore.opensearch"; + + /** + * Comma-separated list of the OpenSearch instances to use. + */ + private List uris; + + private String indexName; + + private String username; + + private String password; + + private String mappingJson; + + public List getUris() { + return uris; + } + + public void setUris(List uris) { + this.uris = uris; + } + + public String getIndexName() { + return this.indexName; + } + + public void setIndexName(String indexName) { + this.indexName = indexName; + } + + public String getUsername() { + return username; + } + + public void setUsername(String username) { + this.username = username; + } + + public String getPassword() { + return password; + } + + public void setPassword(String password) { + this.password = password; + } + + public String getMappingJson() { + return mappingJson; + } + + public void setMappingJson(String mappingJson) { + this.mappingJson = mappingJson; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index 59f104b6825..1d1532873d9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -6,7 +6,7 @@ org.springframework.ai.autoconfigure.huggingface.HuggingfaceChatAutoConfiguratio org.springframework.ai.autoconfigure.vertexai.palm2.VertexAiPalm2AutoConfiguration org.springframework.ai.autoconfigure.vertexai.gemini.VertexAiGeminiAutoConfiguration org.springframework.ai.autoconfigure.bedrock.jurrasic2.BedrockAi21Jurassic2ChatAutoConfiguration -org.springframework.ai.autoconfigure.bedrock.llama2.BedrockLlama2ChatAutoConfiguration +org.springframework.ai.autoconfigure.bedrock.llama.BedrockLlamaChatAutoConfiguration org.springframework.ai.autoconfigure.bedrock.cohere.BedrockCohereChatAutoConfiguration org.springframework.ai.autoconfigure.bedrock.cohere.BedrockCohereEmbeddingAutoConfiguration org.springframework.ai.autoconfigure.bedrock.anthropic.BedrockAnthropicChatAutoConfiguration @@ -32,3 +32,4 @@ org.springframework.ai.autoconfigure.anthropic.AnthropicAutoConfiguration org.springframework.ai.autoconfigure.watsonxai.WatsonxAiAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.elasticsearch.ElasticsearchVectorStoreAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.cassandra.CassandraVectorStoreAutoConfiguration +org.springframework.ai.autoconfigure.vectorstore.opensearch.OpenSearchVectorStoreAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java new file mode 100644 index 00000000000..bea58ce80e3 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java @@ -0,0 +1,139 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Import; + +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; + +/** + * @author Wei Jiang + * @since 0.8.1 + */ +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class BedrockAwsConnectionConfigurationIT { + + @Test + public void autoConfigureAWSCredentialAndRegionProvider() { + new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), + "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id()) + .withConfiguration(AutoConfigurations.of(TestAutoConfiguration.class)) + .run((context) -> { + var awsCredentialsProvider = context.getBean(AwsCredentialsProvider.class); + var awsRegionProvider = context.getBean(AwsRegionProvider.class); + + assertThat(awsCredentialsProvider).isNotNull(); + assertThat(awsRegionProvider).isNotNull(); + + var credentials = awsCredentialsProvider.resolveCredentials(); + assertThat(credentials).isNotNull(); + assertThat(credentials.accessKeyId()).isEqualTo(System.getenv("AWS_ACCESS_KEY_ID")); + assertThat(credentials.secretAccessKey()).isEqualTo(System.getenv("AWS_SECRET_ACCESS_KEY")); + + assertThat(awsRegionProvider.getRegion()).isEqualTo(Region.US_EAST_1); + }); + } + + @Test + public void autoConfigureWithCustomAWSCredentialAndRegionProvider() { + new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), + "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id()) + .withConfiguration(AutoConfigurations.of(TestAutoConfiguration.class, + CustomAwsCredentialsProviderAndAwsRegionProviderAutoConfiguration.class)) + .run((context) -> { + var awsCredentialsProvider = context.getBean(AwsCredentialsProvider.class); + var awsRegionProvider = context.getBean(AwsRegionProvider.class); + + assertThat(awsCredentialsProvider).isNotNull(); + assertThat(awsRegionProvider).isNotNull(); + + var credentials = awsCredentialsProvider.resolveCredentials(); + assertThat(credentials).isNotNull(); + assertThat(credentials.accessKeyId()).isEqualTo("CUSTOM_ACCESS_KEY"); + assertThat(credentials.secretAccessKey()).isEqualTo("CUSTOM_SECRET_ACCESS_KEY"); + + assertThat(awsRegionProvider.getRegion()).isEqualTo(Region.AWS_GLOBAL); + }); + } + + @EnableConfigurationProperties({ BedrockAwsConnectionProperties.class }) + @Import(BedrockAwsConnectionConfiguration.class) + static class TestAutoConfiguration { + + } + + @AutoConfiguration + static class CustomAwsCredentialsProviderAndAwsRegionProviderAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public AwsCredentialsProvider credentialsProvider() { + return new AwsCredentialsProvider() { + + @Override + public AwsCredentials resolveCredentials() { + return new AwsCredentials() { + + @Override + public String accessKeyId() { + return "CUSTOM_ACCESS_KEY"; + } + + @Override + public String secretAccessKey() { + return "CUSTOM_SECRET_ACCESS_KEY"; + } + + }; + } + + }; + } + + @Bean + @ConditionalOnMissingBean + public AwsRegionProvider regionProvider() { + return new AwsRegionProvider() { + + @Override + public Region getRegion() { + return Region.AWS_GLOBAL; + } + + }; + } + + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java similarity index 61% rename from spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java rename to spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java index 3c55877bcaf..6ca69b65599 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.bedrock.llama2; +package org.springframework.ai.autoconfigure.bedrock.llama; import java.util.List; import java.util.Map; @@ -27,8 +27,8 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; -import org.springframework.ai.bedrock.llama2.BedrockLlama2ChatClient; -import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatModel; +import org.springframework.ai.bedrock.llama.BedrockLlamaChatClient; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -41,21 +41,22 @@ /** * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") -public class BedrockLlama2ChatAutoConfigurationIT { +public class BedrockLlamaChatAutoConfigurationIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withPropertyValues("spring.ai.bedrock.llama2.chat.enabled=true", + .withPropertyValues("spring.ai.bedrock.llama.chat.enabled=true", "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), - "spring.ai.bedrock.llama2.chat.model=" + Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(), - "spring.ai.bedrock.llama2.chat.options.temperature=0.5", - "spring.ai.bedrock.llama2.chat.options.maxGenLen=500") - .withConfiguration(AutoConfigurations.of(BedrockLlama2ChatAutoConfiguration.class)); + "spring.ai.bedrock.llama.chat.model=" + LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(), + "spring.ai.bedrock.llama.chat.options.temperature=0.5", + "spring.ai.bedrock.llama.chat.options.maxGenLen=500") + .withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class)); private final Message systemMessage = new SystemPromptTemplate(""" You are a helpful AI assistant. Your name is {name}. @@ -70,8 +71,8 @@ public class BedrockLlama2ChatAutoConfigurationIT { @Test public void chatCompletion() { contextRunner.run(context -> { - BedrockLlama2ChatClient llama2ChatClient = context.getBean(BedrockLlama2ChatClient.class); - ChatResponse response = llama2ChatClient.call(new Prompt(List.of(userMessage, systemMessage))); + BedrockLlamaChatClient llamaChatClient = context.getBean(BedrockLlamaChatClient.class); + ChatResponse response = llamaChatClient.call(new Prompt(List.of(userMessage, systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @@ -80,9 +81,9 @@ public void chatCompletion() { public void chatCompletionStreaming() { contextRunner.run(context -> { - BedrockLlama2ChatClient llama2ChatClient = context.getBean(BedrockLlama2ChatClient.class); + BedrockLlamaChatClient llamaChatClient = context.getBean(BedrockLlamaChatClient.class); - Flux response = llama2ChatClient.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = llamaChatClient.stream(new Prompt(List.of(userMessage, systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(2); @@ -102,23 +103,23 @@ public void chatCompletionStreaming() { public void propertiesTest() { new ApplicationContextRunner() - .withPropertyValues("spring.ai.bedrock.llama2.chat.enabled=true", + .withPropertyValues("spring.ai.bedrock.llama.chat.enabled=true", "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", - "spring.ai.bedrock.llama2.chat.model=MODEL_XYZ", + "spring.ai.bedrock.llama.chat.model=MODEL_XYZ", "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), - "spring.ai.bedrock.llama2.chat.options.temperature=0.55", - "spring.ai.bedrock.llama2.chat.options.maxGenLen=123") - .withConfiguration(AutoConfigurations.of(BedrockLlama2ChatAutoConfiguration.class)) + "spring.ai.bedrock.llama.chat.options.temperature=0.55", + "spring.ai.bedrock.llama.chat.options.maxGenLen=123") + .withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class)) .run(context -> { - var llama2ChatProperties = context.getBean(BedrockLlama2ChatProperties.class); + var llamaChatProperties = context.getBean(BedrockLlamaChatProperties.class); var awsProperties = context.getBean(BedrockAwsConnectionProperties.class); - assertThat(llama2ChatProperties.isEnabled()).isTrue(); + assertThat(llamaChatProperties.isEnabled()).isTrue(); assertThat(awsProperties.getRegion()).isEqualTo(Region.EU_CENTRAL_1.id()); - assertThat(llama2ChatProperties.getOptions().getTemperature()).isEqualTo(0.55f); - assertThat(llama2ChatProperties.getOptions().getMaxGenLen()).isEqualTo(123); - assertThat(llama2ChatProperties.getModel()).isEqualTo("MODEL_XYZ"); + assertThat(llamaChatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + assertThat(llamaChatProperties.getOptions().getMaxGenLen()).isEqualTo(123); + assertThat(llamaChatProperties.getModel()).isEqualTo("MODEL_XYZ"); assertThat(awsProperties.getAccessKey()).isEqualTo("ACCESS_KEY"); assertThat(awsProperties.getSecretKey()).isEqualTo("SECRET_KEY"); @@ -129,27 +130,26 @@ public void propertiesTest() { public void chatCompletionDisabled() { // It is disabled by default - new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(BedrockLlama2ChatAutoConfiguration.class)) + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class)) .run(context -> { - assertThat(context.getBeansOfType(BedrockLlama2ChatProperties.class)).isEmpty(); - assertThat(context.getBeansOfType(BedrockLlama2ChatClient.class)).isEmpty(); + assertThat(context.getBeansOfType(BedrockLlamaChatProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(BedrockLlamaChatClient.class)).isEmpty(); }); // Explicitly enable the chat auto-configuration. - new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.llama2.chat.enabled=true") - .withConfiguration(AutoConfigurations.of(BedrockLlama2ChatAutoConfiguration.class)) + new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.llama.chat.enabled=true") + .withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class)) .run(context -> { - assertThat(context.getBeansOfType(BedrockLlama2ChatProperties.class)).isNotEmpty(); - assertThat(context.getBeansOfType(BedrockLlama2ChatClient.class)).isNotEmpty(); + assertThat(context.getBeansOfType(BedrockLlamaChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(BedrockLlamaChatClient.class)).isNotEmpty(); }); // Explicitly disable the chat auto-configuration. - new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.llama2.chat.enabled=false") - .withConfiguration(AutoConfigurations.of(BedrockLlama2ChatAutoConfiguration.class)) + new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.llama.chat.enabled=false") + .withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class)) .run(context -> { - assertThat(context.getBeansOfType(BedrockLlama2ChatProperties.class)).isEmpty(); - assertThat(context.getBeansOfType(BedrockLlama2ChatClient.class)).isEmpty(); + assertThat(context.getBeansOfType(BedrockLlamaChatProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(BedrockLlamaChatClient.class)).isEmpty(); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfigurationIT.java index 15530052f96..119bde94478 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfigurationIT.java @@ -31,6 +31,7 @@ import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -55,18 +56,18 @@ class CassandraVectorStoreAutoConfigurationIT { ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(CassandraVectorStoreAutoConfiguration.class)) + .withConfiguration( + AutoConfigurations.of(CassandraVectorStoreAutoConfiguration.class, CassandraAutoConfiguration.class)) .withUserConfiguration(Config.class) .withPropertyValues("spring.ai.vectorstore.cassandra.keyspace=test_autoconfigure") - .withPropertyValues("spring.ai.vectorstore.cassandra.contentFieldName=doc_chunk"); + .withPropertyValues("spring.ai.vectorstore.cassandra.contentColumnName=doc_chunk"); @Test void addAndSearch() { - contextRunner - .withPropertyValues("spring.ai.vectorstore.cassandra.cassandraContactPointHosts=" + getContactPointHost()) - .withPropertyValues("spring.ai.vectorstore.cassandra.cassandraContactPointPort=" + getContactPointPort()) - .withPropertyValues("spring.ai.vectorstore.cassandra.cassandraLocalDatacenter=" - + cassandraContainer.getLocalDatacenter()) + contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost()) + .withPropertyValues("spring.cassandra.port=" + getContactPointPort()) + .withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter()) + .withPropertyValues("spring.ai.vectorstore.cassandra.fixedThreadPoolExecutorSize=8") .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java index c5a3e4d0f7c..c0d5ad04aeb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java @@ -30,39 +30,34 @@ class CassandraVectorStorePropertiesTests { @Test void defaultValues() { var props = new CassandraVectorStoreProperties(); - assertThat(props.getCassandraContactPointHosts()).isNull(); - assertThat(props.getCassandraContactPointPort()).isEqualTo(9042); - assertThat(props.getCassandraLocalDatacenter()).isNull(); assertThat(props.getKeyspace()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME); assertThat(props.getTable()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_TABLE_NAME); - assertThat(props.getContentFieldName()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_CONTENT_COLUMN_NAME); - assertThat(props.getEmbeddingFieldName()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_EMBEDDING_COLUMN_NAME); - assertThat(props.getIndexName()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_INDEX_NAME); + assertThat(props.getContentColumnName()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_CONTENT_COLUMN_NAME); + assertThat(props.getEmbeddingColumnName()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_EMBEDDING_COLUMN_NAME); + assertThat(props.getIndexName()).isNull(); assertThat(props.getDisallowSchemaCreation()).isFalse(); + assertThat(props.getFixedThreadPoolExecutorSize()) + .isEqualTo(CassandraVectorStoreConfig.DEFAULT_ADD_CONCURRENCY); } @Test void customValues() { var props = new CassandraVectorStoreProperties(); - props.setCassandraContactPointHosts("127.0.0.1,127.0.0.2"); - props.setCassandraContactPointPort(9043); - props.setCassandraLocalDatacenter("dc1"); props.setKeyspace("my_keyspace"); props.setTable("my_table"); - props.setContentFieldName("my_content"); - props.setEmbeddingFieldName("my_vector"); + props.setContentColumnName("my_content"); + props.setEmbeddingColumnName("my_vector"); props.setIndexName("my_sai"); props.setDisallowSchemaCreation(true); + props.setFixedThreadPoolExecutorSize(10); - assertThat(props.getCassandraContactPointHosts()).isEqualTo("127.0.0.1,127.0.0.2"); - assertThat(props.getCassandraContactPointPort()).isEqualTo(9043); - assertThat(props.getCassandraLocalDatacenter()).isEqualTo("dc1"); assertThat(props.getKeyspace()).isEqualTo("my_keyspace"); assertThat(props.getTable()).isEqualTo("my_table"); - assertThat(props.getContentFieldName()).isEqualTo("my_content"); - assertThat(props.getEmbeddingFieldName()).isEqualTo("my_vector"); + assertThat(props.getContentColumnName()).isEqualTo("my_content"); + assertThat(props.getEmbeddingColumnName()).isEqualTo("my_vector"); assertThat(props.getIndexName()).isEqualTo("my_sai"); assertThat(props.getDisallowSchemaCreation()).isTrue(); + assertThat(props.getFixedThreadPoolExecutorSize()).isEqualTo(10); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java new file mode 100644 index 00000000000..bfa3dab48f2 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java @@ -0,0 +1,125 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.vectorstore.opensearch; + +import org.awaitility.Awaitility; +import org.junit.jupiter.api.Test; +import org.opensearch.testcontainers.OpensearchContainer; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.transformers.TransformersEmbeddingClient; +import org.springframework.ai.vectorstore.OpenSearchVectorStore; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.io.DefaultResourceLoader; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; + +@Testcontainers +class OpenSearchVectorStoreAutoConfigurationIT { + + @Container + private static final OpensearchContainer opensearchContainer = + new OpensearchContainer<>(DockerImageName.parse("opensearchproject/opensearch:2.12.0")); + + private static final String DOCUMENT_INDEX = "auto-spring-ai-document-index"; + + private List documents = List.of( + new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), + new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withConfiguration( + AutoConfigurations.of(OpenSearchVectorStoreAutoConfiguration.class, SpringAiRetryAutoConfiguration.class)) + .withUserConfiguration(Config.class) + .withPropertyValues( + OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".uris=" + opensearchContainer.getHttpHostAddress(), + OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".indexName=" + DOCUMENT_INDEX, + OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".mappingJson=" + """ + { + "properties":{ + "embedding":{ + "type":"knn_vector", + "dimension":384 + } + } + } + """); + + @Test + public void addAndSearchTest() { + + this.contextRunner.run(context -> { + OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class); + + vectorStore.add(documents); + + Awaitility.await().until(() -> vectorStore.similaritySearch( + SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)), + hasSize(1)); + + List results = vectorStore.similaritySearch( + SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); + assertThat(resultDoc.getMetadata()).hasSize(2); + assertThat(resultDoc.getMetadata()).containsKey("meta2"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); + + // Remove all documents from the store + vectorStore.delete(documents.stream().map(Document::getId).toList()); + + Awaitility.await().until(() -> vectorStore.similaritySearch( + SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)), hasSize(0)); + }); + } + + private String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Configuration(proxyBeanMethods = false) + static class Config { + + @Bean + public EmbeddingClient embeddingClient() { + return new TransformersEmbeddingClient(); + } + + } + +} diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store/pom.xml new file mode 100644 index 00000000000..c97eb81ad68 --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store/pom.xml @@ -0,0 +1,42 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-opensearch-store-spring-boot-starter + jar + Spring AI Starter - OpenSearch Store + Spring AI OpenSearch Store Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-spring-boot-autoconfigure + ${project.parent.version} + + + + org.springframework.ai + spring-ai-opensearch-store + ${project.parent.version} + + + + diff --git a/vector-stores/spring-ai-azure/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java b/vector-stores/spring-ai-azure/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java index ec29ce0bfcf..165b190b06d 100644 --- a/vector-stores/spring-ai-azure/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java +++ b/vector-stores/spring-ai-azure/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java @@ -320,7 +320,7 @@ private List toFloatList(List doubleList) { } /** - * Internal data structure for retrieving and and storing documents. + * Internal data structure for retrieving and storing documents. */ private record AzureSearchDocument(String id, String content, List embedding, String metadata) { } diff --git a/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java b/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java index 4a8b681a5f0..3efd440341b 100644 --- a/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java +++ b/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java @@ -79,7 +79,7 @@ private static void doOperand(ExpressionType type, StringBuilder context) { // TODO SAI supports collections // reach out to mck@apache.org if you'd like these implemented // case CONTAINS -> context.append(" CONTAINS "); - // case CONTAINS_KEY -> context.append(" CONTAINS KEY "); + // case CONTAINS_KEY -> context.append(" CONTAINS_KEY "); default -> throw new UnsupportedOperationException( String.format("Expression type %s not yet implemented. Patches welcome.", type)); } diff --git a/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java b/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java index 6f651c944df..c7422cc5b7e 100644 --- a/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java +++ b/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java @@ -17,16 +17,20 @@ import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import com.datastax.oss.driver.api.core.cql.BoundStatement; import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder; import com.datastax.oss.driver.api.core.cql.PreparedStatement; import com.datastax.oss.driver.api.core.cql.Row; +import com.datastax.oss.driver.api.core.cql.SimpleStatement; import com.datastax.oss.driver.api.core.data.CqlVector; import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata; import com.datastax.oss.driver.api.querybuilder.QueryBuilder; @@ -40,6 +44,7 @@ import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.vectorstore.CassandraVectorStoreConfig; import org.springframework.ai.vectorstore.CassandraVectorStoreConfig.SchemaColumn; import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; import org.springframework.beans.factory.InitializingBean; @@ -53,13 +58,13 @@ * fields in the documents to be stored alongside the vector and content data. * * This class requires a CassandraVectorStoreConfig configuration object for - * initialization, which includes settings like connection details, index name, field + * initialization, which includes settings like connection details, index name, column * names, etc. It also requires an EmbeddingClient to convert documents into embeddings * before storing them. * * A schema matching the configuration is automatically created if it doesn't exist. * Missing columns and indexes in existing tables will also be automatically created. - * Disable this with the disallowSchemaCreation. + * Disable this with the CassandraVectorStoreConfig#disallowSchemaChanges(). * * This class is designed to work with brand new tables that it creates for you, or on top * of existing Cassandra tables. The latter is appropriate when wanting to keep data in @@ -69,9 +74,20 @@ * Instances of this class are not dynamic against server-side schema changes. If you * change the schema server-side you need a new CassandraVectorStore instance. * + * When adding documents with the method {@link #add(List)} it first calls + * embeddingClient to create the embeddings. This is slow. Configure + * {@link CassandraVectorStoreConfig.Builder#withFixedThreadPoolExecutorSize(int)} + * accordingly to improve performance so embeddings are created and the documents are + * added concurrently. The default concurrency is 16 + * ({@link CassandraVectorStoreConfig#DEFAULT_ADD_CONCURRENCY}). Remote transformers + * probably want higher concurrency, and local transformers may need lower concurrency. + * This concurrency limit does not need to be higher than the max parallel calls made to + * the {@link #add(List)} method multiplied by the list size. This setting can + * also serve as a protecting throttle against your embedding model. + * * @author Mick Semb Wever * @see VectorStore - * @see CassandraVectorStoreConfig + * @see org.springframework.ai.vectorstore.CassandraVectorStoreConfig * @see EmbeddingClient * @since 1.0.0 */ @@ -87,10 +103,14 @@ public enum Similarity { } - private static final String QUERY_FORMAT = "select %s,%s,%s%s from %s.%s ? order by %s ann of ? limit ?"; - public static final String SIMILARITY_FIELD_NAME = "similarity_score"; + public static final String DRIVER_PROFILE_UPDATES = "spring-ai-updates"; + + public static final String DRIVER_PROFILE_SEARCH = "spring-ai-search"; + + private static final String QUERY_FORMAT = "select %s,%s,%s%s from %s.%s ? order by %s ann of ? limit ?"; + private static final Logger logger = LoggerFactory.getLogger(CassandraVectorStore.class); private final CassandraVectorStoreConfig conf; @@ -99,7 +119,7 @@ public enum Similarity { private final FilterExpressionConverter filterExpressionConverter; - private final Map, PreparedStatement> addStmts = new HashMap<>(); + private final ConcurrentMap, PreparedStatement> addStmts = new ConcurrentHashMap<>(); private final PreparedStatement deleteStmt; @@ -133,30 +153,39 @@ public CassandraVectorStore(CassandraVectorStoreConfig conf, EmbeddingClient emb @Override public void add(List documents) { - CompletableFuture[] futures = new CompletableFuture[documents.size()]; - short i = 0; - for (Document d : documents) { - List primaryKeyValues = this.conf.documentIdTranslator.apply(d.getId()); - var embedding = this.embeddingClient.embed(d).stream().map(Double::floatValue).toList(); + var futures = new CompletableFuture[documents.size()]; - BoundStatementBuilder builder = prepareAddStatement(d.getMetadata().keySet()).boundStatementBuilder(); - for (int k = 0; k < primaryKeyValues.size(); ++k) { - SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); - builder = builder.set(keyColumn.name(), primaryKeyValues.get(k), keyColumn.javaType()); - } + int i = 0; + for (Document d : documents) { + futures[i++] = CompletableFuture.runAsync(() -> { + List primaryKeyValues = this.conf.documentIdTranslator.apply(d.getId()); + + var embedding = (null != d.getEmbedding() && !d.getEmbedding().isEmpty() ? d.getEmbedding() + : this.embeddingClient.embed(d)) + .stream() + .map(Double::floatValue) + .toList(); + + BoundStatementBuilder builder = prepareAddStatement(d.getMetadata().keySet()).boundStatementBuilder(); + for (int k = 0; k < primaryKeyValues.size(); ++k) { + SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); + builder = builder.set(keyColumn.name(), primaryKeyValues.get(k), keyColumn.javaType()); + } - builder = builder.setString(this.conf.schema.content(), d.getContent()) - .setVector(this.conf.schema.embedding(), CqlVector.newInstance(embedding), Float.class); + builder = builder.setString(this.conf.schema.content(), d.getContent()) + .setVector(this.conf.schema.embedding(), CqlVector.newInstance(embedding), Float.class); - for (var metadataColumn : this.conf.schema.metadataColumns() - .stream() - .filter((mc) -> d.getMetadata().containsKey(mc.name())) - .toList()) { + for (var metadataColumn : this.conf.schema.metadataColumns() + .stream() + .filter((mc) -> d.getMetadata().containsKey(mc.name())) + .toList()) { - builder = builder.set(metadataColumn.name(), d.getMetadata().get(metadataColumn.name()), - metadataColumn.javaType()); - } - futures[i++] = this.conf.session.executeAsync(builder.build()).toCompletableFuture(); + builder = builder.set(metadataColumn.name(), d.getMetadata().get(metadataColumn.name()), + metadataColumn.javaType()); + } + BoundStatement s = builder.build().setExecutionProfileName(DRIVER_PROFILE_UPDATES); + this.conf.session.execute(s); + }, this.conf.executor); } CompletableFuture.allOf(futures).join(); } @@ -164,7 +193,7 @@ public void add(List documents) { @Override public Optional delete(List idList) { CompletableFuture[] futures = new CompletableFuture[idList.size()]; - short i = 0; + int i = 0; for (String id : idList) { List primaryKeyValues = this.conf.documentIdTranslator.apply(id); BoundStatement s = this.deleteStmt.bind(primaryKeyValues.toArray()); @@ -177,7 +206,7 @@ public Optional delete(List idList) { @Override public List similaritySearch(SearchRequest request) { Preconditions.checkArgument(request.getTopK() <= 1000); - var embedding = this.embeddingClient.embed(request.getQuery()).stream().map(Double::floatValue).toList(); + var embedding = toFloatArray(this.embeddingClient.embed(request.getQuery())); CqlVector cqlVector = CqlVector.newInstance(embedding); String whereClause = ""; @@ -191,8 +220,9 @@ public List similaritySearch(SearchRequest request) { String query = String.format(this.similarityStmt, cqlVector, whereClause, cqlVector, request.getTopK()); List documents = new ArrayList<>(); logger.trace("Executing {}", query); + SimpleStatement s = SimpleStatement.newInstance(query).setExecutionProfileName(DRIVER_PROFILE_SEARCH); - for (Row row : this.conf.session.execute(query)) { + for (Row row : this.conf.session.execute(s)) { float score = row.getFloat(0); if (score < request.getSimilarityThreshold()) { break; @@ -248,7 +278,16 @@ private PreparedStatement prepareDeleteStatement() { } private PreparedStatement prepareAddStatement(Set metadataFields) { - if (!this.addStmts.containsKey(metadataFields)) { + + // metadata fields that are not configured as metadata columns are not added + Set fieldsThatAreColumns = new HashSet<>(this.conf.schema.metadataColumns() + .stream() + .map((mc) -> mc.name()) + .filter((mc) -> metadataFields.contains(mc)) + .toList()); + + return this.addStmts.computeIfAbsent(fieldsThatAreColumns, (fields) -> { + RegularInsert stmt = null; InsertInto stmtStart = QueryBuilder.insertInto(this.conf.schema.keyspace(), this.conf.schema.table()); @@ -262,17 +301,11 @@ private PreparedStatement prepareAddStatement(Set metadataFields) { stmt = stmt.value(this.conf.schema.content(), QueryBuilder.bindMarker(this.conf.schema.content())) .value(this.conf.schema.embedding(), QueryBuilder.bindMarker(this.conf.schema.embedding())); - for (String metadataField : this.conf.schema.metadataColumns() - .stream() - .map((mc) -> mc.name()) - .filter((mc) -> metadataFields.contains(mc)) - .toList()) { - + for (String metadataField : fields) { stmt = stmt.value(metadataField, QueryBuilder.bindMarker(metadataField)); } - this.addStmts.putIfAbsent(metadataFields, this.conf.session.prepare(stmt.build())); - } - return this.addStmts.get(metadataFields); + return this.conf.session.prepare(stmt.build()); + }); } private String similaritySearchStatement() { @@ -317,4 +350,13 @@ private String getDocumentId(Row row) { return this.conf.primaryKeyTranslator.apply(primaryKeyValues); } + private static Float[] toFloatArray(List embeddingDouble) { + Float[] embeddingFloat = new Float[embeddingDouble.size()]; + int i = 0; + for (Double d : embeddingDouble) { + embeddingFloat[i++] = d.floatValue(); + } + return embeddingFloat; + } + } diff --git a/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java b/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java index d85d1ee4bc3..32a76d5087e 100644 --- a/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java +++ b/vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java @@ -22,6 +22,8 @@ import java.util.List; import java.util.Optional; import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import java.util.function.Function; import java.util.stream.Stream; @@ -42,21 +44,26 @@ import com.datastax.oss.driver.api.querybuilder.schema.CreateTableStart; import com.datastax.oss.driver.shaded.guava.common.annotations.VisibleForTesting; import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.lang.Nullable; + /** * Configuration for the Cassandra vector store. * - * All metadata fields configured to the store will be fetched and added to all queried + * All metadata columns configured to the store will be fetched and added to all queried * documents. * - * If you wish to metadata search against a field its 'searchable' argument must be true. + * To filter expression search against a metadata column configure it with + * SchemaColumnTags.INDEXED * * The Cassandra Java Driver is configured via the application.conf resource found in the * classpath. See * https://github.com/apache/cassandra-java-driver/tree/4.x/manual/core/configuration * + * @author Mick Semb Wever * @since 1.0.0 */ public final class CassandraVectorStoreConfig implements AutoCloseable { @@ -67,13 +74,15 @@ public final class CassandraVectorStoreConfig implements AutoCloseable { public static final String DEFAULT_ID_NAME = "id"; - public static final String DEFAULT_INDEX_NAME = "embedding_index"; + public static final String DEFAULT_INDEX_SUFFIX = "idx"; public static final String DEFAULT_CONTENT_COLUMN_NAME = "content"; public static final String DEFAULT_EMBEDDING_COLUMN_NAME = "embedding"; - private static final Logger logger = LoggerFactory.getLogger(CassandraVectorStore.class); + public static final int DEFAULT_ADD_CONCURRENCY = 16; + + private static final Logger logger = LoggerFactory.getLogger(CassandraVectorStoreConfig.class); record Schema(String keyspace, String table, List partitionKeys, List clusteringKeys, String content, String embedding, String index, Set metadataColumns) { @@ -127,6 +136,8 @@ public interface PrimaryKeyTranslator extends Function, String> { final PrimaryKeyTranslator primaryKeyTranslator; + final Executor executor; + private final boolean closeSessionOnClose; private CassandraVectorStoreConfig(Builder builder) { @@ -139,6 +150,7 @@ private CassandraVectorStoreConfig(Builder builder) { this.disallowSchemaChanges = builder.disallowSchemaCreation; this.documentIdTranslator = builder.documentIdTranslator; this.primaryKeyTranslator = builder.primaryKeyTranslator; + this.executor = Executors.newFixedThreadPool(builder.fixedThreadPoolExecutorSize); } public static Builder builder() { @@ -177,7 +189,7 @@ public static class Builder { private List clusteringKeys = List.of(); - private String indexName = DEFAULT_INDEX_NAME; + private String indexName = null; private String contentColumnName = DEFAULT_CONTENT_COLUMN_NAME; @@ -187,6 +199,8 @@ public static class Builder { private boolean disallowSchemaCreation = false; + private int fixedThreadPoolExecutorSize = DEFAULT_ADD_CONCURRENCY; + private DocumentIdTranslator documentIdTranslator = (String id) -> List.of(id); private PrimaryKeyTranslator primaryKeyTranslator = (List primaryKeyColumns) -> { @@ -246,6 +260,8 @@ public Builder withClusteringKeys(List clusteringKeys) { return this; } + /** defaults (if null) to '__idx' **/ + @Nullable public Builder withIndexName(String name) { this.indexName = name; return this; @@ -261,20 +277,27 @@ public Builder withEmbeddingColumnName(String name) { return this; } - public Builder addMetadataColumn(SchemaColumn... fields) { + public Builder addMetadataColumns(SchemaColumn... columns) { Builder builder = this; - for (SchemaColumn f : fields) { + for (SchemaColumn f : columns) { builder = builder.addMetadataColumn(f); } return builder; } - public Builder addMetadataColumn(SchemaColumn field) { + public Builder addMetadataColumns(List columns) { + Builder builder = this; + this.metadataColumns.addAll(columns); + return builder; + } + + public Builder addMetadataColumn(SchemaColumn column) { - Preconditions.checkArgument(this.metadataColumns.stream().noneMatch((sc) -> sc.name().equals(field.name())), - "A metadata field with name %s has already been added", field.name()); + Preconditions.checkArgument( + this.metadataColumns.stream().noneMatch((sc) -> sc.name().equals(column.name())), + "A metadata column with name %s has already been added", column.name()); - this.metadataColumns.add(field); + this.metadataColumns.add(column); return this; } @@ -283,6 +306,18 @@ public Builder disallowSchemaChanges() { return this; } + /** + * Executor to use when adding documents. The hotspot is the call to the + * embeddingClient. For remote transformers you probably want a higher value to + * utilize network. For local transformers you probably want a lower value to + * avoid saturation. + **/ + public Builder withFixedThreadPoolExecutorSize(int threads) { + Preconditions.checkArgument(0 < threads); + this.fixedThreadPoolExecutorSize = threads; + return this; + } + public Builder withDocumentIdTranslator(DocumentIdTranslator documentIdTranslator) { this.documentIdTranslator = documentIdTranslator; return this; @@ -294,6 +329,9 @@ public Builder withPrimaryKeyTranslator(PrimaryKeyTranslator primaryKeyTranslato } public CassandraVectorStoreConfig build() { + if (null == this.indexName) { + this.indexName = String.format("%s_%s_%s", this.table, this.embeddingColumnName, DEFAULT_INDEX_SUFFIX); + } for (SchemaColumn metadata : this.metadataColumns) { Preconditions.checkArgument( @@ -480,7 +518,7 @@ private void ensureTableColumnsExist(int vectorDimension) { if (column.isPresent()) { Preconditions.checkArgument(column.get().getType().equals(metadata.type()), - "Cannot change type on metadata field %s from %s to %s", metadata.name(), + "Cannot change type on metadata column %s from %s to %s", metadata.name(), column.get().getType(), metadata.type()); } else { @@ -500,7 +538,7 @@ private void ensureTableColumnsExist(int vectorDimension) { // special case for embedding column, bc JAVA-3118, as above StringBuilder alterTableStmt = new StringBuilder(((BuildableQuery) alterTable).asCql()); if (newColumns.isEmpty() && !addContent) { - alterTableStmt.append(" ADD "); + alterTableStmt.append(" ADD ("); } else { alterTableStmt.setLength(alterTableStmt.length() - 1); @@ -509,7 +547,7 @@ private void ensureTableColumnsExist(int vectorDimension) { alterTableStmt.append(this.schema.embedding) .append(" vector"); + .append(">)"); logger.debug("Executing {}", alterTableStmt.toString()); this.session.execute(alterTableStmt.toString()); diff --git a/vector-stores/spring-ai-cassandra/src/main/resources/application.conf b/vector-stores/spring-ai-cassandra/src/main/resources/application.conf new file mode 100644 index 00000000000..91b1c800e7d --- /dev/null +++ b/vector-stores/spring-ai-cassandra/src/main/resources/application.conf @@ -0,0 +1,24 @@ +# Reference configuration for the DataStax Java driver for Apache Cassandra® +# see https://github.com/apache/cassandra-java-driver/tree/4.x/manual/core/configuration +# +# +# when using spring-boot autoconfigure this will not be used +# instead CassandraVectorStoreAutoConfiguration.driverConfigLoaderBuilderCustomizer() is used +datastax-java-driver { + profiles { + spring-ai-updates { + basic.request { + consistency = LOCAL_QUORUM + timeout = 1 seconds + default-idempotence = true + } + } + spring-ai-search { + basic.request { + consistency = LOCAL_ONE + timeout = 10 seconds + default-idempotence = true + } + } + } +} \ No newline at end of file diff --git a/vector-stores/spring-ai-cassandra/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java b/vector-stores/spring-ai-cassandra/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java index 91ec53e503e..868c79fbb3f 100644 --- a/vector-stores/spring-ai-cassandra/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java +++ b/vector-stores/spring-ai-cassandra/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java @@ -17,23 +17,29 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; -import org.junit.jupiter.api.Assertions; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.CqlSessionBuilder; import com.datastax.oss.driver.api.core.servererrors.InvalidQueryException; import com.datastax.oss.driver.api.core.servererrors.SyntaxError; import com.datastax.oss.driver.api.core.type.DataTypes; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testcontainers.containers.CassandraContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.shaded.org.apache.commons.lang3.RandomStringUtils; import org.testcontainers.utility.DockerImageName; import org.springframework.ai.document.Document; @@ -128,7 +134,8 @@ void ensureSchemaNoCreation() { @Test void ensureSchemaPartialCreation() { this.contextRunner.run(context -> { - for (int i = 0; i < 4; ++i) { + int PARTIAL_FILES = 5; + for (int i = 0; i < PARTIAL_FILES; ++i) { executeCqlFile(context, format("test_wiki_partial_%d_schema.cql", i)); var wrapper = createStore(context, List.of(), false, false); try { @@ -142,6 +149,10 @@ void ensureSchemaPartialCreation() { wrapper.store().close(); } } + // make sure there's not more files to test + Assertions.assertThrows(IOException.class, () -> { + executeCqlFile(context, format("test_wiki_partial_%d_schema.cql", PARTIAL_FILES)); + }); }); } @@ -174,6 +185,51 @@ void addAndSearch() { }); } + @Test + void addAndSearchPoormansBench() { + // todo – replace with JMH (parameters: nThreads, rounds, runs, docsPerAdd) + int nThreads = CassandraVectorStoreConfig.DEFAULT_ADD_CONCURRENCY; + int runs = 10; // 100; + int docsPerAdd = 12; // 128; + int rounds = 3; + + contextRunner.run(context -> { + + try (CassandraVectorStore store = new CassandraVectorStore( + storeBuilder(context, List.of()).withFixedThreadPoolExecutorSize(nThreads).build(), + context.getBean(EmbeddingClient.class))) { + + var executor = Executors.newFixedThreadPool((int) (nThreads * 1.2)); + for (int k = 0; k < rounds; ++k) { + long start = System.nanoTime(); + var futures = new CompletableFuture[runs]; + for (int j = 0; j < runs; ++j) { + futures[j] = CompletableFuture.runAsync(() -> { + List documents = new ArrayList<>(); + for (int i = docsPerAdd; i >= 0; --i) { + + documents.add(new Document( + RandomStringUtils.randomAlphanumeric(4) + "§¶" + + ThreadLocalRandom.current().nextInt(1, 10), + RandomStringUtils.randomAlphanumeric(1024), Map.of("revision", + ThreadLocalRandom.current().nextInt(1, 100000), "id", 1000))); + } + store.add(documents); + + var results = store.similaritySearch( + SearchRequest.query(RandomStringUtils.randomAlphanumeric(20)).withTopK(10)); + + assertThat(results).hasSize(10); + }, executor); + } + CompletableFuture.allOf(futures).join(); + long time = System.nanoTime() - start; + logger.info("add+search took an average of {} ms", Duration.ofNanos(time / runs).toMillis()); + } + } + }); + } + @Test void searchWithPartitionFilter() throws InterruptedException { contextRunner.run(context -> { @@ -456,22 +512,37 @@ private StoreWrapper createSto } private StoreWrapper createStore(ApplicationContext context, - List extraMetadataFields, boolean disallowSchemaCreation, boolean dropKeyspaceFirst) + List columnOverrides, boolean disallowSchemaCreation, boolean dropKeyspaceFirst) throws IOException { - Optional wikiOverride = extraMetadataFields.stream() + CassandraVectorStoreConfig.Builder builder = storeBuilder(context, columnOverrides); + if (disallowSchemaCreation) { + builder = builder.disallowSchemaChanges(); + } + + CassandraVectorStoreConfig conf = builder.build(); + if (dropKeyspaceFirst) { + conf.dropKeyspace(); + } + return new StoreWrapper(new CassandraVectorStore(conf, context.getBean(EmbeddingClient.class)), conf); + } + + static CassandraVectorStoreConfig.Builder storeBuilder(ApplicationContext context, + List columnOverrides) throws IOException { + + Optional wikiOverride = columnOverrides.stream() .filter((f) -> "wiki".equals(f.name())) .findFirst(); - Optional langOverride = extraMetadataFields.stream() + Optional langOverride = columnOverrides.stream() .filter((f) -> "language".equals(f.name())) .findFirst(); - Optional titleOverride = extraMetadataFields.stream() + Optional titleOverride = columnOverrides.stream() .filter((f) -> "title".equals(f.name())) .findFirst(); - Optional chunkNoOverride = extraMetadataFields.stream() + Optional chunkNoOverride = columnOverrides.stream() .filter((f) -> "chunk_no".equals(f.name())) .findFirst(); @@ -493,7 +564,7 @@ private StoreWrapper createSto .withEmbeddingColumnName("all_minilm_l6_v2_embedding") .withIndexName("all_minilm_l6_v2_ann") - .addMetadataColumn(new SchemaColumn("revision", DataTypes.INT), + .addMetadataColumns(new SchemaColumn("revision", DataTypes.INT), new SchemaColumn("id", DataTypes.INT, CassandraVectorStoreConfig.SchemaColumnTags.INDEXED)) // this store uses '§¶' as a deliminator in the document id between db columns @@ -511,21 +582,7 @@ private StoreWrapper createSto return List.of("simplewiki", "en", title, chunk_no); }); - for (SchemaColumn cf : extraMetadataFields) { - if (!partitionKeys.contains(cf) && !clusteringKeys.contains(cf)) { - builder = builder.addMetadataColumn(cf); - } - } - - if (disallowSchemaCreation) { - builder = builder.disallowSchemaChanges(); - } - - CassandraVectorStoreConfig conf = builder.build(); - if (dropKeyspaceFirst) { - conf.dropKeyspace(); - } - return new StoreWrapper(new CassandraVectorStore(conf, context.getBean(EmbeddingClient.class)), conf); + return builder; } private void executeCqlFile(ApplicationContext context, String filename) throws IOException { diff --git a/vector-stores/spring-ai-cassandra/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java b/vector-stores/spring-ai-cassandra/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java index 0984a067ff0..d1fac3901b3 100644 --- a/vector-stores/spring-ai-cassandra/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java +++ b/vector-stores/spring-ai-cassandra/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java @@ -345,8 +345,9 @@ public static class TestApplication { public CassandraVectorStore store(CqlSession cqlSession, EmbeddingClient embeddingClient) { CassandraVectorStoreConfig conf = storeBuilder(cqlSession) - .addMetadataColumn(new SchemaColumn("meta1", DataTypes.TEXT), new SchemaColumn("meta2", DataTypes.TEXT), - new SchemaColumn("country", DataTypes.TEXT), new SchemaColumn("year", DataTypes.SMALLINT)) + .addMetadataColumns(new SchemaColumn("meta1", DataTypes.TEXT), + new SchemaColumn("meta2", DataTypes.TEXT), new SchemaColumn("country", DataTypes.TEXT), + new SchemaColumn("year", DataTypes.SMALLINT)) .build(); conf.dropKeyspace(); @@ -378,7 +379,7 @@ static CassandraVectorStoreConfig.Builder storeBuilder(CqlSession cqlSession) { private CassandraVectorStore createTestStore(ApplicationContext context, SchemaColumn... metadataFields) { CassandraVectorStoreConfig.Builder builder = storeBuilder(context.getBean(CqlSession.class)) - .addMetadataColumn(metadataFields); + .addMetadataColumns(metadataFields); CassandraVectorStoreConfig conf = builder.build(); conf.dropKeyspace(); diff --git a/vector-stores/spring-ai-cassandra/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java b/vector-stores/spring-ai-cassandra/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java index 76028e6502c..910dac85a4f 100644 --- a/vector-stores/spring-ai-cassandra/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java +++ b/vector-stores/spring-ai-cassandra/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java @@ -81,25 +81,30 @@ public static class TestApplication { @Bean public CassandraVectorStore store(CqlSession cqlSession, EmbeddingClient embeddingClient) { + List partitionColumns = List.of(new SchemaColumn("wiki", DataTypes.TEXT), + new SchemaColumn("language", DataTypes.TEXT), new SchemaColumn("title", DataTypes.TEXT)); + + List clusteringColumns = List.of(new SchemaColumn("chunk_no", DataTypes.INT), + new SchemaColumn("bert_embedding_no", DataTypes.INT)); + + List extraColumns = List.of(new SchemaColumn("revision", DataTypes.INT), + new SchemaColumn("id", DataTypes.INT)); + CassandraVectorStoreConfig conf = CassandraVectorStoreConfig.builder() .withCqlSession(cqlSession) .withKeyspaceName("wikidata") .withTableName("articles") - - .withPartitionKeys(List.of(new SchemaColumn("wiki", DataTypes.TEXT), - new SchemaColumn("language", DataTypes.TEXT), new SchemaColumn("title", DataTypes.TEXT))) - - .withClusteringKeys(List.of(new SchemaColumn("chunk_no", DataTypes.INT), - new SchemaColumn("bert_embedding_no", DataTypes.INT))) - + .withPartitionKeys(partitionColumns) + .withClusteringKeys(clusteringColumns) .withContentColumnName("body") .withEmbeddingColumnName("all_minilm_l6_v2_embedding") .withIndexName("all_minilm_l6_v2_ann") .disallowSchemaChanges() - - .addMetadataColumn(new SchemaColumn("revision", DataTypes.INT), new SchemaColumn("id", DataTypes.INT)) + .addMetadataColumns(extraColumns) .withPrimaryKeyTranslator((List primaryKeys) -> { + // the deliminator used to join fields together into the document's id + // is arbitary, here "§¶" is used if (primaryKeys.isEmpty()) { return "test§¶0"; } diff --git a/vector-stores/spring-ai-cassandra/src/test/resources/test_wiki_partial_4_schema.cql b/vector-stores/spring-ai-cassandra/src/test/resources/test_wiki_partial_4_schema.cql new file mode 100644 index 00000000000..68b4583c491 --- /dev/null +++ b/vector-stores/spring-ai-cassandra/src/test/resources/test_wiki_partial_4_schema.cql @@ -0,0 +1,10 @@ +CREATE KEYSPACE IF NOT EXISTS test_wikidata WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; + +CREATE TABLE IF NOT EXISTS test_wikidata.articles ( + wiki text, + language text, + title text, + chunk_no int, + messages text, + PRIMARY KEY ((wiki, language, title), chunk_no) +); \ No newline at end of file diff --git a/vector-stores/spring-ai-opensearch-store/pom.xml b/vector-stores/spring-ai-opensearch-store/pom.xml new file mode 100644 index 00000000000..4c11603369d --- /dev/null +++ b/vector-stores/spring-ai-opensearch-store/pom.xml @@ -0,0 +1,85 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-opensearch-store + jar + Spring AI Vector Store - OpenSearch + Spring AI OpenSearch Vector Store + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + 4.0.3 + + + + + org.springframework.ai + spring-ai-core + ${parent.version} + + + + org.opensearch.client + opensearch-java + ${opensearch-client.version} + + + + org.apache.httpcomponents.client5 + httpclient5 + ${httpclient5.version} + + + + + org.springframework.ai + spring-ai-openai + ${parent.version} + test + + + + + org.springframework.ai + spring-ai-test + ${parent.version} + test + + + + org.springframework.boot + spring-boot-starter-test + test + + + + org.opensearch + opensearch-testcontainers + 2.0.1 + test + + + + org.testcontainers + junit-jupiter + ${testcontainers.version} + test + + + + + diff --git a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverter.java b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverter.java new file mode 100644 index 00000000000..9035a86d299 --- /dev/null +++ b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverter.java @@ -0,0 +1,150 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vectorstore; + +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; + +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.List; +import java.util.TimeZone; +import java.util.regex.Pattern; + +/** + * @author Jemin Huh + * @since 1.0.0 + */ +public class OpenSearchAiSearchFilterExpressionConverter extends AbstractFilterExpressionConverter { + + private static final Pattern DATE_FORMAT_PATTERN = Pattern.compile("\\d{4}-\\d{2}-\\d{2}T\\d{2}:\\d{2}:\\d{2}Z"); + + private final SimpleDateFormat dateFormat; + + public OpenSearchAiSearchFilterExpressionConverter() { + this.dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss'Z'"); + this.dateFormat.setTimeZone(TimeZone.getTimeZone("UTC")); + } + + @Override + protected void doExpression(Expression expression, StringBuilder context) { + if (expression.type() == Filter.ExpressionType.IN || expression.type() == Filter.ExpressionType.NIN) { + context.append(getOperationSymbol(expression)); + context.append("("); + this.convertOperand(expression.left(), context); + this.convertOperand(expression.right(), context); + context.append(")"); + } + else { + this.convertOperand(expression.left(), context); + context.append(getOperationSymbol(expression)); + this.convertOperand(expression.right(), context); + } + } + + @Override + protected void doStartValueRange(Filter.Value listValue, StringBuilder context) { + } + + @Override + protected void doEndValueRange(Filter.Value listValue, StringBuilder context) { + } + + @Override + protected void doAddValueRangeSpitter(Filter.Value listValue, StringBuilder context) { + context.append(" OR "); + } + + private String getOperationSymbol(Expression exp) { + return switch (exp.type()) { + case AND -> " AND "; + case OR -> " OR "; + case EQ, IN -> ""; + case NE -> " NOT "; + case LT -> "<"; + case LTE -> "<="; + case GT -> ">"; + case GTE -> ">="; + case NIN -> "NOT "; + default -> throw new RuntimeException("Not supported expression type: " + exp.type()); + }; + } + + @Override + public void doKey(Key key, StringBuilder context) { + var identifier = hasOuterQuotes(key.key()) ? removeOuterQuotes(key.key()) : key.key(); + var prefixedIdentifier = withMetaPrefix(identifier); + context.append(prefixedIdentifier.trim()).append(":"); + } + + public String withMetaPrefix(String identifier) { + return "metadata." + identifier; + } + + @Override + protected void doValue(Filter.Value filterValue, StringBuilder context) { + if (filterValue.value() instanceof List list) { + int c = 0; + for (Object v : list) { + context.append(v); + if (c++ < list.size() - 1) { + this.doAddValueRangeSpitter(filterValue, context); + } + } + } + else { + this.doSingleValue(filterValue.value(), context); + } + } + + @Override + protected void doSingleValue(Object value, StringBuilder context) { + if (value instanceof Date date) { + context.append(this.dateFormat.format(date)); + } + else if (value instanceof String text) { + if (DATE_FORMAT_PATTERN.matcher(text).matches()) { + try { + Date date = this.dateFormat.parse(text); + context.append(this.dateFormat.format(date)); + } + catch (ParseException e) { + throw new IllegalArgumentException("Invalid date type:" + text, e); + } + } + else { + context.append(text); + } + } + else { + context.append(value); + } + } + + @Override + public void doStartGroup(Filter.Group group, StringBuilder context) { + context.append("("); + } + + @Override + public void doEndGroup(Filter.Group group, StringBuilder context) { + context.append(")"); + } + +} \ No newline at end of file diff --git a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java new file mode 100644 index 00000000000..608ce87cf24 --- /dev/null +++ b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java @@ -0,0 +1,225 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vectorstore; + + +import jakarta.json.stream.JsonParser; +import org.opensearch.client.json.JsonData; +import org.opensearch.client.json.JsonpMapper; +import org.opensearch.client.opensearch.OpenSearchClient; +import org.opensearch.client.opensearch._types.mapping.TypeMapping; +import org.opensearch.client.opensearch._types.query_dsl.Query; +import org.opensearch.client.opensearch.core.BulkRequest; +import org.opensearch.client.opensearch.core.BulkResponse; +import org.opensearch.client.opensearch.core.search.Hit; +import org.opensearch.client.opensearch.indices.CreateIndexRequest; +import org.opensearch.client.opensearch.indices.CreateIndexResponse; +import org.opensearch.client.transport.endpoints.BooleanResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.util.Assert; + +import java.io.IOException; +import java.io.StringReader; +import java.util.*; +import java.util.stream.Collectors; + +/** + * @author Jemin Huh + * @since 1.0.0 + */ +public class OpenSearchVectorStore implements VectorStore, InitializingBean { + + public static final String COSINE_SIMILARITY_FUNCTION = "cosinesimil"; + + private static final Logger logger = LoggerFactory.getLogger(OpenSearchVectorStore.class); + + public static final String DEFAULT_INDEX_NAME = "spring-ai-document-index"; + public static final String DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION_1536 = """ + { + "properties":{ + "embedding":{ + "type":"knn_vector", + "dimension":1536 + } + } + } + """; + + private final EmbeddingClient embeddingClient; + + private final OpenSearchClient openSearchClient; + + private final String index; + + private final FilterExpressionConverter filterExpressionConverter; + + private final String mappingJson; + + private String similarityFunction; + + public OpenSearchVectorStore(OpenSearchClient openSearchClient, EmbeddingClient embeddingClient) { + this(openSearchClient, embeddingClient, DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION_1536); + } + + public OpenSearchVectorStore(OpenSearchClient openSearchClient, EmbeddingClient embeddingClient, + String mappingJson) { + this(DEFAULT_INDEX_NAME, openSearchClient, embeddingClient, mappingJson); + } + + public OpenSearchVectorStore(String index, OpenSearchClient openSearchClient, + EmbeddingClient embeddingClient, String mappingJson) { + Objects.requireNonNull(embeddingClient, "RestClient must not be null"); + Objects.requireNonNull(embeddingClient, "EmbeddingClient must not be null"); + this.openSearchClient = openSearchClient; + this.embeddingClient = embeddingClient; + this.index = index; + this.mappingJson = mappingJson; + this.filterExpressionConverter = new OpenSearchAiSearchFilterExpressionConverter(); + // the potential functions for vector fields at + // https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces + this.similarityFunction = COSINE_SIMILARITY_FUNCTION; + } + + public OpenSearchVectorStore withSimilarityFunction(String similarityFunction) { + this.similarityFunction = similarityFunction; + return this; + } + + @Override + public void add(List documents) { + BulkRequest.Builder builkRequestBuilder = new BulkRequest.Builder(); + for (Document document : documents) { + if (Objects.isNull(document.getEmbedding()) || document.getEmbedding().isEmpty()) { + logger.debug("Calling EmbeddingClient for document id = " + document.getId()); + document.setEmbedding(this.embeddingClient.embed(document)); + } + builkRequestBuilder + .operations(op -> op.index(idx -> idx.index(this.index).id(document.getId()).document(document))); + } + bulkRequest(builkRequestBuilder.build()); + } + + @Override + public Optional delete(List idList) { + BulkRequest.Builder builkRequestBuilder = new BulkRequest.Builder(); + for (String id : idList) + builkRequestBuilder.operations(op -> op.delete(idx -> idx.index(this.index).id(id))); + return Optional.of(bulkRequest(builkRequestBuilder.build()).errors()); + } + + private BulkResponse bulkRequest(BulkRequest bulkRequest) { + try { + return this.openSearchClient.bulk(bulkRequest); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public List similaritySearch(SearchRequest searchRequest) { + Assert.notNull(searchRequest, "The search request must not be null."); + return similaritySearch(this.embeddingClient.embed(searchRequest.getQuery()), searchRequest.getTopK(), + searchRequest.getSimilarityThreshold(), searchRequest.getFilterExpression()); + } + + public List similaritySearch(List embedding, int topK, double similarityThreshold, + Filter.Expression filterExpression) { + return similaritySearch(new org.opensearch.client.opensearch.core.SearchRequest.Builder() + .query(getOpenSearchSimilarityQuery(embedding, filterExpression)) + .size(topK) + .minScore(similarityThreshold) + .build()); + } + + private Query getOpenSearchSimilarityQuery(List embedding, Filter.Expression filterExpression) { + return Query.of(queryBuilder -> queryBuilder.scriptScore(scriptScoreQueryBuilder -> { + scriptScoreQueryBuilder.query( + queryBuilder2 -> queryBuilder2.queryString(queryStringQuerybuilder -> queryStringQuerybuilder + .query(getOpenSearchQueryString(filterExpression)))) + .script(scriptBuilder -> scriptBuilder + .inline(inlineScriptBuilder -> inlineScriptBuilder.source("knn_score") + .lang("knn") + .params("field", JsonData.of("embedding")) + .params("query_value", JsonData.of(embedding)) + .params("space_type", JsonData.of(this.similarityFunction)))); + // https://opensearch.org/docs/latest/search-plugins/knn/knn-score-script + // k-NN ensures non-negative scores by adding 1 to cosine similarity, extending OpenSearch scores to 0-2. + // A 0.5 boost normalizes to 0-1. + return this.similarityFunction.equals(COSINE_SIMILARITY_FUNCTION) ? scriptScoreQueryBuilder.boost( + 0.5f) : scriptScoreQueryBuilder; + })); + } + + private String getOpenSearchQueryString(Filter.Expression filterExpression) { + return Objects.isNull(filterExpression) ? "*" + : this.filterExpressionConverter.convertExpression(filterExpression); + + } + + private List similaritySearch(org.opensearch.client.opensearch.core.SearchRequest searchRequest) { + try { + return this.openSearchClient.search(searchRequest, Document.class) + .hits() + .hits() + .stream() + .map(this::toDocument) + .collect(Collectors.toList()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private Document toDocument(Hit hit) { + Document document = hit.source(); + document.getMetadata().put("distance", 1 - hit.score().floatValue()); + return document; + } + + public boolean exists(String targetIndex) { + try { + BooleanResponse response = this.openSearchClient.indices() + .exists(existRequestBuilder -> existRequestBuilder.index(targetIndex)); + return response.value(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private CreateIndexResponse createIndexMapping(String index, String mappingJson) { + JsonpMapper mapper = openSearchClient._transport().jsonpMapper(); + JsonParser parser = mapper.jsonProvider().createParser(new StringReader(mappingJson)); + try { + return this.openSearchClient.indices().create(new CreateIndexRequest.Builder().index(index) + .settings(settingsBuilder -> settingsBuilder.knn(true)) + .mappings(TypeMapping._DESERIALIZER.deserialize(parser, mapper)).build()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void afterPropertiesSet() { + if (!exists(this.index)) { + createIndexMapping(this.index, mappingJson); + } + } +} \ No newline at end of file diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverterTest.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverterTest.java new file mode 100644 index 00000000000..274f132730e --- /dev/null +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverterTest.java @@ -0,0 +1,117 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vectorstore; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; + +import java.util.Date; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.*; + +class OpenSearchAiSearchFilterExpressionConverterTest { + + final FilterExpressionConverter converter = new OpenSearchAiSearchFilterExpressionConverter(); + + @Test + public void testDate() { + String vectorExpr = converter.convertExpression(new Filter.Expression(EQ, new Filter.Key("activationDate"), + new Filter.Value(new Date(1704637752148L)))); + assertThat(vectorExpr).isEqualTo("metadata.activationDate:2024-01-07T14:29:12Z"); + + vectorExpr = converter.convertExpression( + new Filter.Expression(EQ, new Filter.Key("activationDate"), new Filter.Value("1970-01-01T00:00:02Z"))); + assertThat(vectorExpr).isEqualTo("metadata.activationDate:1970-01-01T00:00:02Z"); + } + + @Test + public void testEQ() { + String vectorExpr = converter + .convertExpression(new Filter.Expression(EQ, new Filter.Key("country"), new Filter.Value("BG"))); + assertThat(vectorExpr).isEqualTo("metadata.country:BG"); + } + + @Test + public void tesEqAndGte() { + String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + new Filter.Expression(EQ, new Filter.Key("genre"), new Filter.Value("drama")), + new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020)))); + assertThat(vectorExpr).isEqualTo("metadata.genre:drama AND metadata.year:>=2020"); + } + + @Test + public void tesIn() { + String vectorExpr = converter.convertExpression(new Filter.Expression(IN, new Filter.Key("genre"), + new Filter.Value(List.of("comedy", "documentary", "drama")))); + assertThat(vectorExpr).isEqualTo("(metadata.genre:comedy OR documentary OR drama)"); + } + + @Test + public void testNe() { + String vectorExpr = converter.convertExpression( + new Filter.Expression(OR, new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020)), + new Filter.Expression(AND, + new Filter.Expression(EQ, new Filter.Key("country"), new Filter.Value("BG")), + new Filter.Expression(NE, new Filter.Key("city"), new Filter.Value("Sofia"))))); + assertThat(vectorExpr).isEqualTo("metadata.year:>=2020 OR metadata.country:BG AND metadata.city: NOT Sofia"); + } + + @Test + public void testGroup() { + String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + new Filter.Group(new Filter.Expression(OR, + new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020)), + new Filter.Expression(EQ, new Filter.Key("country"), new Filter.Value("BG")))), + new Filter.Expression(NIN, new Filter.Key("city"), new Filter.Value(List.of("Sofia", "Plovdiv"))))); + assertThat(vectorExpr) + .isEqualTo("(metadata.year:>=2020 OR metadata.country:BG) AND NOT (metadata.city:Sofia OR Plovdiv)"); + } + + @Test + public void tesBoolean() { + String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + new Filter.Expression(AND, new Filter.Expression(EQ, new Filter.Key("isOpen"), new Filter.Value(true)), + new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020))), + new Filter.Expression(IN, new Filter.Key("country"), new Filter.Value(List.of("BG", "NL", "US"))))); + + assertThat(vectorExpr) + .isEqualTo("metadata.isOpen:true AND metadata.year:>=2020 AND (metadata.country:BG OR NL OR US)"); + } + + @Test + public void testDecimal() { + String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + new Filter.Expression(GTE, new Filter.Key("temperature"), new Filter.Value(-15.6)), + new Filter.Expression(LTE, new Filter.Key("temperature"), new Filter.Value(20.13)))); + + assertThat(vectorExpr).isEqualTo("metadata.temperature:>=-15.6 AND metadata.temperature:<=20.13"); + } + + @Test + public void testComplexIdentifiers() { + String vectorExpr = converter + .convertExpression(new Filter.Expression(EQ, new Filter.Key("\"country 1 2 3\""), new Filter.Value("BG"))); + assertThat(vectorExpr).isEqualTo("metadata.country 1 2 3:BG"); + + vectorExpr = converter + .convertExpression(new Filter.Expression(EQ, new Filter.Key("'country 1 2 3'"), new Filter.Value("BG"))); + assertThat(vectorExpr).isEqualTo("metadata.country 1 2 3:BG"); + } + +} diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java new file mode 100644 index 00000000000..5632ea16454 --- /dev/null +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java @@ -0,0 +1,364 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vectorstore; + +import org.apache.hc.core5.http.HttpHost; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.opensearch.client.opensearch.OpenSearchClient; +import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder; +import org.opensearch.testcontainers.OpensearchContainer; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.openai.OpenAiEmbeddingClient; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.core.io.DefaultResourceLoader; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.ZonedDateTime; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +@Testcontainers +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") +class OpenSearchVectorStoreIT { + + @Container + private static final OpensearchContainer opensearchContainer = + new OpensearchContainer<>(DockerImageName.parse("opensearchproject/opensearch:2.13.0")); + + private static final String DEFAULT = "cosinesimil"; + + private List documents = List.of( + new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), + new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + + @BeforeAll + public static void beforeAll() { + Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); + Awaitility.setDefaultPollDelay(Duration.ZERO); + Awaitility.setDefaultTimeout(Duration.ofMinutes(1)); + } + + private String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private ApplicationContextRunner getContextRunner() { + return new ApplicationContextRunner().withUserConfiguration(TestApplication.class); + } + + @BeforeEach + void cleanDatabase() { + getContextRunner().run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + vectorStore.delete(List.of("_all")); + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = {DEFAULT, "l1", "l2", "linf"}) + public void addAndSearchTest(String similarityFunction) { + + getContextRunner().run(context -> { + OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class); + + if (!DEFAULT.equals(similarityFunction)) { + vectorStore.withSimilarityFunction(similarityFunction); + } + + vectorStore.add(documents); + + Awaitility.await() + .until(() -> vectorStore + .similaritySearch( + SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)), + hasSize(1)); + + List results = vectorStore + .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); + assertThat(resultDoc.getMetadata()).hasSize(2); + assertThat(resultDoc.getMetadata()).containsKey("meta2"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); + + // Remove all documents from the store + vectorStore.delete(documents.stream().map(Document::getId).toList()); + + Awaitility.await() + .until(() -> vectorStore + .similaritySearch( + SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)), + hasSize(0)); + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = {DEFAULT, "l1", "l2", "linf"}) + public void searchWithFilters(String similarityFunction) { + + getContextRunner().run(context -> { + OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class); + + if (!DEFAULT.equals(similarityFunction)) { + vectorStore.withSimilarityFunction(similarityFunction); + } + + var bgDocument = new Document("1", "The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "BG", "year", 2020, "activationDate", new Date(1000))); + var nlDocument = new Document("2", "The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "NL", "activationDate", new Date(2000))); + var bgDocument2 = new Document("3", "The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "BG", "year", 2023, "activationDate", new Date(3000))); + + vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2)); + + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(5)), + hasSize(3)); + + List results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country == 'NL'")); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); + + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country == 'BG'")); + + assertThat(results).hasSize(2); + assertThat(results.get(0).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + assertThat(results.get(1).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country == 'BG' && year == 2020")); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); + + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country in ['BG']")); + + assertThat(results).hasSize(2); + assertThat(results.get(0).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + assertThat(results.get(1).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country in ['BG','NL']")); + + assertThat(results).hasSize(3); + + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country not in ['BG']")); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); + + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("NOT(country not in ['BG'])")); + + assertThat(results).hasSize(2); + assertThat(results.get(0).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + assertThat(results.get(1).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression( + "activationDate > " + + ZonedDateTime.parse("1970-01-01T00:00:02Z").toInstant().toEpochMilli())); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument2.getId()); + + // Remove all documents from the store + vectorStore.delete(documents.stream().map(Document::getId).toList()); + + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(1)), + hasSize(0)); + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = {DEFAULT, "l1", "l2", "linf"}) + public void documentUpdateTest(String similarityFunction) { + + getContextRunner().run(context -> { + OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class); + if (!DEFAULT.equals(similarityFunction)) { + vectorStore.withSimilarityFunction(similarityFunction); + } + + Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", + Map.of("meta1", "meta1")); + vectorStore.add(List.of(document)); + + Awaitility.await().until(() -> vectorStore.similaritySearch( + SearchRequest.query("Spring").withSimilarityThreshold(0).withTopK(5)), hasSize(1)); + + List results = vectorStore + .similaritySearch(SearchRequest.query("Spring").withSimilarityThreshold(0).withTopK(5)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); + assertThat(resultDoc.getMetadata()).containsKey("meta1"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); + + Document sameIdDocument = new Document(document.getId(), + "The World is Big and Salvation Lurks Around the Corner", Map.of("meta2", "meta2")); + + vectorStore.add(List.of(sameIdDocument)); + SearchRequest fooBarSearchRequest = SearchRequest.query("FooBar").withTopK(5); + + Awaitility.await() + .until(() -> vectorStore.similaritySearch(fooBarSearchRequest).get(0).getContent(), + equalTo("The World is Big and Salvation Lurks Around the Corner")); + + results = vectorStore.similaritySearch(fooBarSearchRequest); + + assertThat(results).hasSize(1); + resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); + assertThat(resultDoc.getMetadata()).containsKey("meta2"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); + + // Remove all documents from the store + vectorStore.delete(List.of(document.getId())); + + Awaitility.await().until(() -> vectorStore.similaritySearch(fooBarSearchRequest), hasSize(0)); + + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = {DEFAULT, "l1", "l2", "linf"}) + public void searchThresholdTest(String similarityFunction) { + + getContextRunner().run(context -> { + OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class); + if (!DEFAULT.equals(similarityFunction)) { + vectorStore.withSimilarityFunction(similarityFunction); + } + + vectorStore.add(documents); + + SearchRequest query = SearchRequest.query("Great Depression") + .withTopK(50) + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL); + + Awaitility.await().until(() -> vectorStore.similaritySearch(query), hasSize(3)); + + List fullResult = vectorStore.similaritySearch(query); + + List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + + assertThat(distances).hasSize(3); + + float threshold = (distances.get(0) + distances.get(1)) / 2; + + List results = vectorStore.similaritySearch( + SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(1 - threshold)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); + assertThat(resultDoc.getMetadata()).containsKey("meta2"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); + + // Remove all documents from the store + vectorStore.delete(documents.stream().map(Document::getId).toList()); + + Awaitility.await() + .until(() -> vectorStore + .similaritySearch( + SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(0)), + hasSize(0)); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = {DataSourceAutoConfiguration.class}) + public static class TestApplication { + + @Bean + public OpenSearchVectorStore vectorStore(EmbeddingClient embeddingClient) { + try { + return new OpenSearchVectorStore(new OpenSearchClient(ApacheHttpClient5TransportBuilder.builder( + HttpHost.create(opensearchContainer.getHttpHostAddress())).build()), embeddingClient); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + @Bean + public EmbeddingClient embeddingClient() { + return new OpenAiEmbeddingClient(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + } + + } + +} diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java index 1d85ee361df..c2b7954d900 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java @@ -60,6 +60,8 @@ public class PgVectorStore implements VectorStore, InitializingBean { public static final String VECTOR_TABLE_NAME = "vector_store"; + public static final String VECTOR_INDEX_NAME = "spring_ai_vector_index"; + public final FilterExpressionConverter filterExpressionConverter = new PgVectorFilterExpressionConverter(); private final JdbcTemplate jdbcTemplate; @@ -352,8 +354,8 @@ embedding vector(%d) if (this.createIndexMethod != PgIndexType.NONE) { this.jdbcTemplate.execute(String.format(""" - CREATE INDEX ON %s USING %s (embedding %s) - """, VECTOR_TABLE_NAME, this.createIndexMethod, this.getDistanceType().index)); + CREATE INDEX IF NOT EXISTS %s ON %s USING %s (embedding %s) + """, VECTOR_INDEX_NAME, VECTOR_TABLE_NAME, this.createIndexMethod, this.getDistanceType().index)); } }