diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/package-info.java new file mode 100644 index 00000000000..e28a215d457 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/package-info.java @@ -0,0 +1,25 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides the API for embedding observations. + */ +@NonNullApi +@NonNullFields +package org.springframework.ai.embedding; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/AbstractVectorStoreBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/AbstractVectorStoreBuilder.java index 0affbdcc8b0..8ee7f88b5c9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/AbstractVectorStoreBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/AbstractVectorStoreBuilder.java @@ -33,13 +33,18 @@ public abstract class AbstractVectorStoreBuilder> implements VectorStore.Builder { - protected EmbeddingModel embeddingModel; + protected final EmbeddingModel embeddingModel; protected ObservationRegistry observationRegistry = ObservationRegistry.NOOP; @Nullable protected VectorStoreObservationConvention customObservationConvention; + public AbstractVectorStoreBuilder(EmbeddingModel embeddingModel) { + Assert.notNull(embeddingModel, "EmbeddingModel must be configured"); + this.embeddingModel = embeddingModel; + } + public EmbeddingModel getEmbeddingModel() { return this.embeddingModel; } @@ -71,20 +76,9 @@ public T observationRegistry(ObservationRegistry observationRegistry) { } @Override - public T customObservationConvention(VectorStoreObservationConvention convention) { + public T customObservationConvention(@Nullable VectorStoreObservationConvention convention) { this.customObservationConvention = convention; return self(); } - @Override - public T embeddingModel(EmbeddingModel embeddingModel) { - Assert.notNull(embeddingModel, "EmbeddingModel must not be null"); - this.embeddingModel = embeddingModel; - return self(); - } - - protected void validate() { - Assert.notNull(this.embeddingModel, "EmbeddingModel must be configured"); - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java index 67b11fde579..8273a11188f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java @@ -92,8 +92,7 @@ public SimpleVectorStore(EmbeddingModel embeddingModel) { @Deprecated(forRemoval = true, since = "1.0.0-M5") public SimpleVectorStore(EmbeddingModel embeddingModel, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention) { - this(builder().embeddingModel(embeddingModel) - .observationRegistry(observationRegistry) + this(builder(embeddingModel).observationRegistry(observationRegistry) .customObservationConvention(customObservationConvention)); } @@ -106,8 +105,8 @@ protected SimpleVectorStore(SimpleVectorStoreBuilder builder) { * Creates an instance of SimpleVectorStore builder. * @return the SimpleVectorStore builder. */ - public static SimpleVectorStoreBuilder builder() { - return new SimpleVectorStoreBuilder(); + public static SimpleVectorStoreBuilder builder(EmbeddingModel embeddingModel) { + return new SimpleVectorStoreBuilder(embeddingModel); } @Override @@ -297,9 +296,12 @@ public static float norm(float[] vector) { public static final class SimpleVectorStoreBuilder extends AbstractVectorStoreBuilder { + private SimpleVectorStoreBuilder(EmbeddingModel embeddingModel) { + super(embeddingModel); + } + @Override public SimpleVectorStore build() { - validate(); return new SimpleVectorStore(this); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java index faea243a792..5c2dd00da1b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java @@ -23,7 +23,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentWriter; -import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.vectorstore.observation.DefaultVectorStoreObservationConvention; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.lang.Nullable; @@ -59,6 +58,7 @@ default void accept(List documents) { * @param idList list of document ids for which documents will be removed. * @return Returns true if the documents were successfully deleted. */ + @Nullable Optional delete(List idList); /** @@ -68,6 +68,7 @@ default void accept(List documents) { * topK, similarity threshold and metadata filter expressions. * @return Returns documents th match the query request conditions. */ + @Nullable List similaritySearch(SearchRequest request); /** @@ -77,6 +78,7 @@ default void accept(List documents) { * @return Returns a list of documents that have embeddings similar to the query text * embedding. */ + @Nullable default List similaritySearch(String query) { return this.similaritySearch(SearchRequest.query(query)); } @@ -90,8 +92,6 @@ default List similaritySearch(String query) { */ interface Builder> { - T embeddingModel(EmbeddingModel embeddingModel); - /** * Sets the registry for collecting observations and metrics. Defaults to * {@link ObservationRegistry#NOOP} if not specified. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java index 81e704d99b5..ec6c641bec6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java @@ -45,7 +45,6 @@ public abstract class AbstractObservationVectorStore implements VectorStore { @Nullable private final VectorStoreObservationConvention customObservationConvention; - @Nullable protected final EmbeddingModel embeddingModel; /** @@ -59,8 +58,7 @@ public AbstractObservationVectorStore(ObservationRegistry observationRegistry, this(null, observationRegistry, customObservationConvention); } - private AbstractObservationVectorStore(@Nullable EmbeddingModel embeddingModel, - ObservationRegistry observationRegistry, + private AbstractObservationVectorStore(EmbeddingModel embeddingModel, ObservationRegistry observationRegistry, @Nullable VectorStoreObservationConvention customObservationConvention) { this.embeddingModel = embeddingModel; this.observationRegistry = observationRegistry; @@ -94,6 +92,7 @@ public void add(List documents) { } @Override + @Nullable public Optional delete(List deleteDocIds) { VectorStoreObservationContext observationContext = this @@ -107,6 +106,7 @@ public Optional delete(List deleteDocIds) { } @Override + @Nullable public List similaritySearch(SearchRequest request) { VectorStoreObservationContext searchObservationContext = this diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreTests.java index 41bdbb457da..6e05264a662 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreTests.java @@ -57,7 +57,7 @@ void setUp() { when(this.mockEmbeddingModel.dimensions()).thenReturn(3); when(this.mockEmbeddingModel.embed(any(String.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f }); when(this.mockEmbeddingModel.embed(any(Document.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f }); - this.vectorStore = new SimpleVectorStore(SimpleVectorStore.builder().embeddingModel(this.mockEmbeddingModel)); + this.vectorStore = new SimpleVectorStore(SimpleVectorStore.builder(this.mockEmbeddingModel)); } @Test diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfiguration.java index 7db2a7936b8..d3f702a2a20 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfiguration.java @@ -84,9 +84,7 @@ public AzureVectorStore vectorStore(SearchIndexClient searchIndexClient, Embeddi ObjectProvider customObservationConvention, BatchingStrategy batchingStrategy) { - var builder = AzureVectorStore.builder() - .searchIndexClient(searchIndexClient) - .embeddingModel(embeddingModel) + var builder = AzureVectorStore.builder(searchIndexClient, embeddingModel) .initializeSchema(properties.isInitializeSchema()) .filterMetadataFields(List.of()) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java index 64890673b6f..2e63baf7aee 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java @@ -62,7 +62,7 @@ public CassandraVectorStore vectorStore(EmbeddingModel embeddingModel, Cassandra ObjectProvider customObservationConvention, BatchingStrategy batchingStrategy) { - return CassandraVectorStore.builder() + return CassandraVectorStore.builder(embeddingModel) .session(cqlSession) .keyspace(properties.getKeyspace()) .table(properties.getTable()) @@ -72,7 +72,6 @@ public CassandraVectorStore vectorStore(EmbeddingModel embeddingModel, Cassandra .fixedThreadPoolExecutorSize(properties.getFixedThreadPoolExecutorSize()) .disallowSchemaChanges(!properties.isInitializeSchema()) .returnEmbeddings(properties.getReturnEmbeddings()) - .embeddingModel(embeddingModel) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .customObservationConvention(customObservationConvention.getIfAvailable(() -> null)) .batchingStrategy(batchingStrategy) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java index 50f75891882..898a04444be 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java @@ -86,9 +86,7 @@ public ChromaVectorStore vectorStore(EmbeddingModel embeddingModel, ChromaApi ch ChromaVectorStoreProperties storeProperties, ObjectProvider observationRegistry, ObjectProvider customObservationConvention, BatchingStrategy chromaBatchingStrategy) { - return ChromaVectorStore.builder() - .chromaApi(chromaApi) - .embeddingModel(embeddingModel) + return ChromaVectorStore.builder(chromaApi, embeddingModel) .collectionName(storeProperties.getCollectionName()) .initializeSchema(storeProperties.isInitializeSchema()) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java index 33d0c2810a2..8541e8e833a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java @@ -73,10 +73,8 @@ ElasticsearchVectorStore vectorStore(ElasticsearchVectorStoreProperties properti elasticsearchVectorStoreOptions.setSimilarity(properties.getSimilarity()); } - return ElasticsearchVectorStore.builder() - .restClient(restClient) + return ElasticsearchVectorStore.builder(restClient, embeddingModel) .options(elasticsearchVectorStoreOptions) - .embeddingModel(embeddingModel) .initializeSchema(properties.isInitializeSchema()) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .customObservationConvention(customObservationConvention.getIfAvailable(() -> null)) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java index 5b2ae6d1c79..26eb432cd18 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java @@ -64,7 +64,7 @@ public GemFireVectorStore gemfireVectorStore(EmbeddingModel embeddingModel, GemF ObjectProvider customObservationConvention, BatchingStrategy batchingStrategy) { - return GemFireVectorStore.builder() + return GemFireVectorStore.builder(embeddingModel) .host(gemFireConnectionDetails.getHost()) .port(gemFireConnectionDetails.getPort()) .indexName(properties.getIndexName()) @@ -74,7 +74,6 @@ public GemFireVectorStore gemfireVectorStore(EmbeddingModel embeddingModel, GemF .vectorSimilarityFunction(properties.getVectorSimilarityFunction()) .fields(properties.getFields()) .sslEnabled(properties.isSslEnabled()) - .embeddingModel(embeddingModel) .initializeSchema(properties.isInitializeSchema()) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .customObservationConvention(customObservationConvention.getIfAvailable(() -> null)) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfiguration.java index 8c0c856285f..3870f8d60fe 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfiguration.java @@ -53,9 +53,7 @@ public HanaCloudVectorStore vectorStore(HanaVectorRepository observationRegistry, ObjectProvider customObservationConvention) { - return HanaCloudVectorStore.builder() - .repository(repository) - .embeddingModel(embeddingModel) + return HanaCloudVectorStore.builder(repository, embeddingModel) .tableName(properties.getTableName()) .topK(properties.getTopK()) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mariadb/MariaDbStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mariadb/MariaDbStoreAutoConfiguration.java index b5b69e53ed6..e3c72674e0d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mariadb/MariaDbStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mariadb/MariaDbStoreAutoConfiguration.java @@ -57,8 +57,7 @@ public MariaDBVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel var initializeSchema = properties.isInitializeSchema(); - return MariaDBVectorStore.builder(jdbcTemplate) - .embeddingModel(embeddingModel) + return MariaDBVectorStore.builder(jdbcTemplate, embeddingModel) .schemaName(properties.getSchemaName()) .vectorTableName(properties.getTableName()) .schemaValidation(properties.isSchemaValidation()) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfiguration.java index 1dc3eb5dfb0..029a1be61f6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfiguration.java @@ -71,9 +71,7 @@ public MilvusVectorStore vectorStore(MilvusServiceClient milvusClient, Embedding ObjectProvider observationRegistry, ObjectProvider customObservationConvention) { - return MilvusVectorStore.builder() - .milvusClient(milvusClient) - .embeddingModel(embeddingModel) + return MilvusVectorStore.builder(milvusClient, embeddingModel) .initializeSchema(properties.isInitializeSchema()) .batchingStrategy(batchingStrategy) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfiguration.java index ec3a9eb985a..8bb97634695 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfiguration.java @@ -66,9 +66,7 @@ MongoDBAtlasVectorStore vectorStore(MongoTemplate mongoTemplate, EmbeddingModel ObjectProvider customObservationConvention, BatchingStrategy batchingStrategy) { - MongoDBAtlasVectorStore.MongoDBBuilder builder = MongoDBAtlasVectorStore.builder() - .mongoTemplate(mongoTemplate) - .embeddingModel(embeddingModel) + MongoDBAtlasVectorStore.MongoDBBuilder builder = MongoDBAtlasVectorStore.builder(mongoTemplate, embeddingModel) .initializeSchema(properties.isInitializeSchema()) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .customObservationConvention(customObservationConvention.getIfAvailable(() -> null)) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java index aac5576d11e..5d3d114c145 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java @@ -58,9 +58,7 @@ public Neo4jVectorStore vectorStore(Driver driver, EmbeddingModel embeddingModel ObjectProvider customObservationConvention, BatchingStrategy batchingStrategy) { - return Neo4jVectorStore.builder() - .driver(driver) - .embeddingModel(embeddingModel) + return Neo4jVectorStore.builder(driver, embeddingModel) .initializeSchema(properties.isInitializeSchema()) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .customObservationConvention(customObservationConvention.getIfAvailable(() -> null)) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java index 28a4e4529ad..b844e9bbfa3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java @@ -79,10 +79,8 @@ OpenSearchVectorStore vectorStore(OpenSearchVectorStoreProperties properties, Op var mappingJson = Optional.ofNullable(properties.getMappingJson()) .orElse(OpenSearchVectorStore.DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION); - return OpenSearchVectorStore.builder() + return OpenSearchVectorStore.builder(openSearchClient, embeddingModel) .index(indexName) - .openSearchClient(openSearchClient) - .embeddingModel(embeddingModel) .mappingJson(mappingJson) .initializeSchema(properties.isInitializeSchema()) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfiguration.java index f8fcff20d6d..5e2d22b699e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfiguration.java @@ -60,9 +60,7 @@ public OracleVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel e ObjectProvider customObservationConvention, BatchingStrategy batchingStrategy) { - return OracleVectorStore.builder() - .jdbcTemplate(jdbcTemplate) - .embeddingModel(embeddingModel) + return OracleVectorStore.builder(jdbcTemplate, embeddingModel) .tableName(properties.getTableName()) .indexType(properties.getIndexType()) .distanceType(properties.getDistanceType()) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java index 441aca3345e..95075ed2064 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java @@ -62,9 +62,7 @@ public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embed var initializeSchema = properties.isInitializeSchema(); - return PgVectorStore.builder() - .jdbcTemplate(jdbcTemplate) - .embeddingModel(embeddingModel) + return PgVectorStore.builder(jdbcTemplate, embeddingModel) .schemaName(properties.getSchemaName()) .vectorTableName(properties.getTableName()) .vectorTableValidationsEnabled(properties.isSchemaValidation()) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfiguration.java index ba39ecec79b..8ab1f4c1b87 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfiguration.java @@ -55,12 +55,9 @@ public PineconeVectorStore vectorStore(EmbeddingModel embeddingModel, PineconeVe ObjectProvider customObservationConvention, BatchingStrategy batchingStrategy) { - return PineconeVectorStore.builder() - .embeddingModel(embeddingModel) - .apiKey(properties.getApiKey()) - .environment(properties.getEnvironment()) - .projectId(properties.getProjectId()) - .indexName(properties.getIndexName()) + return PineconeVectorStore + .builder(embeddingModel, properties.getApiKey(), properties.getProjectId(), properties.getEnvironment(), + properties.getIndexName()) .namespace(properties.getNamespace()) .contentFieldName(properties.getContentFieldName()) .distanceMetadataFieldName(properties.getDistanceMetadataFieldName()) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfiguration.java index 790b53c3aef..0521484be57 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfiguration.java @@ -77,9 +77,8 @@ public QdrantVectorStore vectorStore(EmbeddingModel embeddingModel, QdrantVector QdrantClient qdrantClient, ObjectProvider observationRegistry, ObjectProvider customObservationConvention, BatchingStrategy batchingStrategy) { - return QdrantVectorStore.builder(qdrantClient) + return QdrantVectorStore.builder(qdrantClient, embeddingModel) .collectionName(properties.getCollectionName()) - .embeddingModel(embeddingModel) .initializeSchema(properties.isInitializeSchema()) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .customObservationConvention(customObservationConvention.getIfAvailable(() -> null)) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java index 42516395c63..9b318767178 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java @@ -60,9 +60,9 @@ public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, RedisVectorSt ObjectProvider customObservationConvention, BatchingStrategy batchingStrategy) { - return RedisVectorStore.builder() - .jedis(new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort())) - .embeddingModel(embeddingModel) + JedisPooled jedisPooled = new JedisPooled(jedisConnectionFactory.getHostName(), + jedisConnectionFactory.getPort()); + return RedisVectorStore.builder(jedisPooled, embeddingModel) .initializeSchema(properties.isInitializeSchema()) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .customObservationConvention(customObservationConvention.getIfAvailable(() -> null)) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfiguration.java index 9d259a274d2..823c325660e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfiguration.java @@ -70,9 +70,7 @@ public TypesenseVectorStore vectorStore(Client typesenseClient, EmbeddingModel e ObjectProvider customObservationConvention, BatchingStrategy batchingStrategy) { - return TypesenseVectorStore.builder() - .client(typesenseClient) - .embeddingModel(embeddingModel) + return TypesenseVectorStore.builder(typesenseClient, embeddingModel) .collectionName(properties.getCollectionName()) .embeddingDimension(properties.getEmbeddingDimension()) .initializeSchema(properties.isInitializeSchema()) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java index 1e710170072..3fd8d7c518d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java @@ -81,9 +81,7 @@ public WeaviateVectorStore vectorStore(EmbeddingModel embeddingModel, WeaviateCl ObjectProvider customObservationConvention, BatchingStrategy batchingStrategy) { - return WeaviateVectorStore.builder() - .weaviateClient(weaviateClient) - .embeddingModel(embeddingModel) + return WeaviateVectorStore.builder(weaviateClient, embeddingModel) .objectClass(properties.getObjectClass()) .filterMetadataFields(properties.getFilterField() .entrySet() diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java index 8de1f30d362..6a5db8b20b7 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java @@ -69,6 +69,7 @@ import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -108,7 +109,8 @@ public class CosmosDBVectorStore extends AbstractObservationVectorStore implemen * @param cosmosClient the Cosmos DB client * @param properties the configuration properties * @param embeddingModel the embedding model - * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(CosmosAsyncClient, EmbeddingModel)} + * ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public CosmosDBVectorStore(ObservationRegistry observationRegistry, @@ -126,15 +128,14 @@ public CosmosDBVectorStore(ObservationRegistry observationRegistry, * @param properties the configuration properties * @param embeddingModel the embedding model * @param batchingStrategy the batching strategy - * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(CosmosAsyncClient, EmbeddingModel)} + * ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public CosmosDBVectorStore(ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, CosmosAsyncClient cosmosClient, CosmosDBVectorStoreConfig properties, EmbeddingModel embeddingModel, BatchingStrategy batchingStrategy) { - this(builder().cosmosClient(cosmosClient) - .embeddingModel(embeddingModel) - .containerName(properties.getContainerName()) + this(builder(cosmosClient, embeddingModel).containerName(properties.getContainerName()) .databaseName(properties.getDatabaseName()) .partitionKeyPath(properties.getPartitionKeyPath()) .vectorStoreThroughput(properties.getVectorStoreThroughput()) @@ -171,8 +172,8 @@ protected CosmosDBVectorStore(CosmosDBBuilder builder) { initializeContainer(containerName, databaseName, vectorStoreThroughput, vectorDimensions, partitionKeyPath); } - public static CosmosDBBuilder builder() { - return new CosmosDBBuilder(); + public static CosmosDBBuilder builder(CosmosAsyncClient cosmosClient, EmbeddingModel embeddingModel) { + return new CosmosDBBuilder(cosmosClient, embeddingModel); } private void initializeContainer(String containerName, String databaseName, int vectorStoreThroughput, @@ -243,7 +244,7 @@ private JsonNode mapCosmosDocument(Document document, float[] queryEmbedding) { ObjectMapper objectMapper = new ObjectMapper(); String id = document.getId(); - String content = document.getContent(); + String content = document.getText(); // Convert metadata and embedding directly to JsonNode JsonNode metadataNode = objectMapper.valueToTree(document.getMetadata()); @@ -430,12 +431,15 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str */ public static class CosmosDBBuilder extends AbstractVectorStoreBuilder { - private CosmosAsyncClient cosmosClient; + private final CosmosAsyncClient cosmosClient; + @Nullable private String containerName; + @Nullable private String databaseName; + @Nullable private String partitionKeyPath; private int vectorStoreThroughput = 400; @@ -446,16 +450,10 @@ public static class CosmosDBBuilder extends AbstractVectorStoreBuilder filterMetadataFields; + @Nullable private SearchClient searchClient; private int defaultTopK; @@ -135,7 +137,8 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements * @param searchIndexClient the Azure search index client * @param embeddingModel the embedding model to use * @param initializeSchema whether to initialize schema - * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(SearchIndexClient, EmbeddingModel)} + * ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public AzureVectorStore(SearchIndexClient searchIndexClient, EmbeddingModel embeddingModel, @@ -149,7 +152,8 @@ public AzureVectorStore(SearchIndexClient searchIndexClient, EmbeddingModel embe * @param embeddingModel the embedding model to use * @param initializeSchema whether to initialize schema * @param filterMetadataFields list of metadata fields for filtering - * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(SearchIndexClient, EmbeddingModel)} + * ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public AzureVectorStore(SearchIndexClient searchIndexClient, EmbeddingModel embeddingModel, @@ -166,16 +170,15 @@ public AzureVectorStore(SearchIndexClient searchIndexClient, EmbeddingModel embe * @param filterMetadataFields list of metadata fields for filtering * @param observationRegistry the observation registry * @param customObservationConvention the custom observation convention - * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(SearchIndexClient, EmbeddingModel)} + * ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public AzureVectorStore(SearchIndexClient searchIndexClient, EmbeddingModel embeddingModel, boolean initializeSchema, List filterMetadataFields, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { - this(builder().searchIndexClient(searchIndexClient) - .embeddingModel(embeddingModel) - .initializeSchema(initializeSchema) + this(builder(searchIndexClient, embeddingModel).initializeSchema(initializeSchema) .filterMetadataFields(filterMetadataFields) .observationRegistry(observationRegistry) .customObservationConvention(customObservationConvention) @@ -203,15 +206,15 @@ protected AzureVectorStore(AzureBuilder builder) { this.filterExpressionConverter = new AzureAiSearchFilterExpressionConverter(filterMetadataFields); } - public static AzureBuilder builder() { - return new AzureBuilder(); + public static AzureBuilder builder(SearchIndexClient searchIndexClient, EmbeddingModel embeddingModel) { + return new AzureBuilder(searchIndexClient, embeddingModel); } /** * Change the Index Name. * @param indexName The Azure VectorStore index name to use. - * @deprecated Since 1.0.0-M5, use {@link #builder()} with - * {@link AzureBuilder#indexName(String)} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(SearchIndexClient, EmbeddingModel)} + * ()} with {@link AzureBuilder#indexName(String)} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public void setIndexName(String indexName) { @@ -222,8 +225,8 @@ public void setIndexName(String indexName) { /** * Sets the a default maximum number of similar documents returned. * @param topK The default maximum number of similar documents returned. - * @deprecated Since 1.0.0-M5, use {@link #builder()} with - * {@link AzureBuilder#indexName(String)} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(SearchIndexClient, EmbeddingModel)} + * ()} with {@link AzureBuilder#indexName(String)} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public void setDefaultTopK(int topK) { @@ -235,8 +238,8 @@ public void setDefaultTopK(int topK) { * Sets the a default similarity threshold for returned documents. * @param similarityThreshold The a default similarity threshold for returned * documents. - * @deprecated Since 1.0.0-M5, use {@link #builder()} with - * {@link AzureBuilder#indexName(String)} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(SearchIndexClient, EmbeddingModel)} + * ()} with {@link AzureBuilder#indexName(String)} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public void setDefaultSimilarityThreshold(Double similarityThreshold) { @@ -260,7 +263,7 @@ public void doAdd(List documents) { SearchDocument searchDocument = new SearchDocument(); searchDocument.put(ID_FIELD_NAME, document.getId()); searchDocument.put(EMBEDDING_FIELD_NAME, embeddings.get(documents.indexOf(document))); - searchDocument.put(CONTENT_FIELD_NAME, document.getContent()); + searchDocument.put(CONTENT_FIELD_NAME, document.getText()); searchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString()); // Add the filterable metadata fields as top level fields, allowing filler @@ -469,7 +472,7 @@ private record AzureSearchDocument(String id, String content, List embedd */ public static class AzureBuilder extends AbstractVectorStoreBuilder { - private SearchIndexClient searchIndexClient; + private final SearchIndexClient searchIndexClient; private boolean initializeSchema = false; @@ -483,16 +486,10 @@ public static class AzureBuilder extends AbstractVectorStoreBuilder documents) { builder = builder.set(keyColumn.name(), primaryKeyValues.get(k), keyColumn.javaType()); } - builder = builder.setString(this.schema.content(), d.getContent()) + builder = builder.setString(this.schema.content(), d.getText()) .setVector(this.schema.embedding(), CqlVector.newInstance(EmbeddingUtils.toList(embeddings.get(documents.indexOf(d)))), Float.class); @@ -827,16 +826,9 @@ public static class CassandraBuilder extends AbstractVectorStoreBuilder partitionKeys = List.of(wikiSC, langSC, titleSC); List clusteringKeys = List.of(chunkNoSC); - return CassandraVectorStore.builder() + return CassandraVectorStore.builder(context.getBean(EmbeddingModel.class)) .session(context.getBean(CqlSession.class)) .keyspace("test_wikidata") .table("articles") @@ -253,7 +253,6 @@ void addAndSearchPoormansBench() { this.contextRunner.run(context -> { try (CassandraVectorStore store = storeBuilder(context, List.of()).fixedThreadPoolExecutorSize(nThreads) - .embeddingModel(context.getBean(EmbeddingModel.class)) .build()) { var executor = Executors.newFixedThreadPool((int) (nThreads * 1.2)); @@ -552,7 +551,6 @@ private CassandraVectorStore createStore(ApplicationContext context, List { - CassandraVectorStore.CassandraBuilder builder = storeBuilder(context.getBean(CqlSession.class)) + CassandraVectorStore.CassandraBuilder builder = storeBuilder(context.getBean(CqlSession.class), + context.getBean(EmbeddingModel.class)) .returnEmbeddings(true); try (CassandraVectorStore store = createTestStore(context, builder)) { @@ -395,12 +397,11 @@ public static class TestApplication { @Bean public CassandraVectorStore store(CqlSession cqlSession, EmbeddingModel embeddingModel) { - CassandraVectorStore.CassandraBuilder builder = storeBuilder(cqlSession) - .addMetadataColumns(new CassandraVectorStore.SchemaColumn("meta1", DataTypes.TEXT), - new CassandraVectorStore.SchemaColumn("meta2", DataTypes.TEXT), - new CassandraVectorStore.SchemaColumn("country", DataTypes.TEXT), - new CassandraVectorStore.SchemaColumn("year", DataTypes.SMALLINT)) - .embeddingModel(embeddingModel); + CassandraVectorStore.CassandraBuilder builder = storeBuilder(cqlSession, embeddingModel).addMetadataColumns( + new CassandraVectorStore.SchemaColumn("meta1", DataTypes.TEXT), + new CassandraVectorStore.SchemaColumn("meta2", DataTypes.TEXT), + new CassandraVectorStore.SchemaColumn("country", DataTypes.TEXT), + new CassandraVectorStore.SchemaColumn("year", DataTypes.SMALLINT)); CassandraVectorStore.dropKeyspace(builder); return builder.build(); diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStoreObservationIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStoreObservationIT.java index 6aba243e627..f65e7508ffb 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStoreObservationIT.java @@ -172,7 +172,7 @@ public EmbeddingModel embeddingModel() { public CassandraVectorStore store(CqlSession cqlSession, EmbeddingModel embeddingModel, ObservationRegistry observationRegistry) { - CassandraVectorStore.CassandraBuilder builder = CassandraVectorStore.builder() + CassandraVectorStore.CassandraBuilder builder = CassandraVectorStore.builder(embeddingModel) .session(cqlSession) .session(cqlSession) .keyspace("test_" + CassandraVectorStore.DEFAULT_KEYSPACE_NAME) @@ -180,7 +180,6 @@ public CassandraVectorStore store(CqlSession cqlSession, EmbeddingModel embeddin new CassandraVectorStore.SchemaColumn("meta2", DataTypes.TEXT), new CassandraVectorStore.SchemaColumn("country", DataTypes.TEXT), new CassandraVectorStore.SchemaColumn("year", DataTypes.SMALLINT)) - .embeddingModel(embeddingModel) .observationRegistry(observationRegistry) .batchingStrategy(new TokenCountBatchingStrategy()); diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/WikiVectorStoreExample.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/WikiVectorStoreExample.java index 82846a2d6a3..66a4261b38c 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/WikiVectorStoreExample.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/WikiVectorStoreExample.java @@ -93,7 +93,7 @@ public CassandraVectorStore store(CqlSession cqlSession, EmbeddingModel embeddin List extraColumns = List.of(new SchemaColumn("revision", DataTypes.INT), new SchemaColumn("id", DataTypes.INT)); - return CassandraVectorStore.builder() + return CassandraVectorStore.builder(embeddingModel) .session(cqlSession) .keyspace("wikidata") .table("articles") @@ -118,7 +118,6 @@ public CassandraVectorStore store(CqlSession cqlSession, EmbeddingModel embeddin int chunk_no = 0 < parts.length ? Integer.parseInt(parts[1]) : 0; return List.of("simplewiki", "en", title, chunk_no, 0); }) - .embeddingModel(embeddingModel()) .build(); } diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaApi.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaApi.java index c67a969c8c6..04c12129dc6 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaApi.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaApi.java @@ -33,6 +33,7 @@ import org.springframework.http.MediaType; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.http.client.support.BasicAuthenticationInterceptor; +import org.springframework.lang.Nullable; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.web.client.HttpClientErrorException; @@ -58,6 +59,7 @@ public class ChromaApi { private RestClient restClient; + @Nullable private String keyToken; public ChromaApi(String baseUrl) { @@ -99,7 +101,7 @@ public ChromaApi withBasicAuthCredentials(String username, String password) { return this; } - public List toEmbeddingResponseList(QueryResponse queryResponse) { + public List toEmbeddingResponseList(@Nullable QueryResponse queryResponse) { List result = new ArrayList<>(); if (queryResponse != null && !CollectionUtils.isEmpty(queryResponse.ids())) { @@ -113,6 +115,7 @@ public List toEmbeddingResponseList(QueryResponse queryResponse) { return result; } + @Nullable public Collection createCollection(CreateCollectionRequest createCollectionRequest) { return this.restClient.post() @@ -138,6 +141,7 @@ public void deleteCollection(String collectionName) { .toBodilessEntity(); } + @Nullable public Collection getCollection(String collectionName) { try { @@ -157,6 +161,7 @@ public Collection getCollection(String collectionName) { } } + @Nullable public List listCollections() { return this.restClient.get() @@ -167,7 +172,7 @@ public List listCollections() { .getBody(); } - public void upsertEmbeddings(String collectionId, AddEmbeddingsRequest embedding) { + public void upsertEmbeddings(@Nullable String collectionId, AddEmbeddingsRequest embedding) { this.restClient.post() .uri("/api/v1/collections/{collection_id}/upsert", collectionId) @@ -177,7 +182,7 @@ public void upsertEmbeddings(String collectionId, AddEmbeddingsRequest embedding .toBodilessEntity(); } - public int deleteEmbeddings(String collectionId, DeleteEmbeddingsRequest deleteRequest) { + public int deleteEmbeddings(@Nullable String collectionId, DeleteEmbeddingsRequest deleteRequest) { return this.restClient.post() .uri("/api/v1/collections/{collection_id}/delete", collectionId) .headers(this::httpHeaders) @@ -188,6 +193,7 @@ public int deleteEmbeddings(String collectionId, DeleteEmbeddingsRequest deleteR .value(); } + @Nullable public Long countEmbeddings(String collectionId) { return this.restClient.get() @@ -198,7 +204,8 @@ public Long countEmbeddings(String collectionId) { .getBody(); } - public QueryResponse queryCollection(String collectionId, QueryRequest queryRequest) { + @Nullable + public QueryResponse queryCollection(@Nullable String collectionId, QueryRequest queryRequest) { return this.restClient.post() .uri("/api/v1/collections/{collection_id}/query", collectionId) @@ -212,7 +219,7 @@ public QueryResponse queryCollection(String collectionId, QueryRequest queryRequ // // Chroma Client API (https://docs.trychroma.com/js_reference/Client) // - + @Nullable public GetEmbeddingResponse getEmbeddings(String collectionId, GetEmbeddingsRequest getEmbeddingsRequest) { return this.restClient.post() @@ -332,7 +339,7 @@ public AddEmbeddingsRequest(String id, float[] embedding, Map me @JsonInclude(JsonInclude.Include.NON_NULL) public record DeleteEmbeddingsRequest(// @formatter:off @JsonProperty("ids") List ids, - @JsonProperty("where") Map where) { // @formatter:on + @Nullable @JsonProperty("where") Map where) { // @formatter:on public DeleteEmbeddingsRequest(List ids) { this(ids, null); @@ -353,7 +360,7 @@ public DeleteEmbeddingsRequest(List ids) { @JsonInclude(JsonInclude.Include.NON_NULL) public record GetEmbeddingsRequest(// @formatter:off @JsonProperty("ids") List ids, - @JsonProperty("where") Map where, + @Nullable @JsonProperty("where") Map where, @JsonProperty("limit") Integer limit, @JsonProperty("offset") Integer offset, @JsonProperty("include") List include) { // @formatter:on @@ -404,7 +411,7 @@ public record GetEmbeddingResponse(// @formatter:off public record QueryRequest(// @formatter:off @JsonProperty("query_embeddings") List queryEmbeddings, @JsonProperty("n_results") Integer nResults, - @JsonProperty("where") Map where, + @Nullable @JsonProperty("where") Map where, @JsonProperty("include") List include) { // @formatter:on /** @@ -414,7 +421,7 @@ public QueryRequest(float[] queryEmbedding, Integer nResults) { this(List.of(queryEmbedding), nResults, null, Include.all); } - public QueryRequest(float[] queryEmbedding, Integer nResults, Map where) { + public QueryRequest(float[] queryEmbedding, Integer nResults, @Nullable Map where) { this(List.of(queryEmbedding), nResults, CollectionUtils.isEmpty(where) ? null : where, Include.all); } @@ -471,7 +478,7 @@ public record Embedding(// @formatter:off @JsonProperty("id") String id, @JsonProperty("embedding") float[] embedding, @JsonProperty("document") String document, - @JsonProperty("metadata") Map metadata, + @Nullable @JsonProperty("metadata") Map metadata, @JsonProperty("distances") Double distances) { // @formatter:on } diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java index 1773804d174..bed32ec933d 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java @@ -100,11 +100,9 @@ public ChromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, Str @Deprecated(since = "1.0.0-M5", forRemoval = true) public ChromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, String collectionName, boolean initializeSchema, ObservationRegistry observationRegistry, - VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { + @Nullable VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { - this(builder().chromaApi(chromaApi) - .embeddingModel(embeddingModel) - .collectionName(collectionName) + this(builder(chromaApi, embeddingModel).collectionName(collectionName) .initializeSchema(initializeSchema) .observationRegistry(observationRegistry) .customObservationConvention(customObservationConvention) @@ -117,8 +115,6 @@ public ChromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, Str protected ChromaVectorStore(ChromaBuilder builder) { super(builder); - Assert.notNull(builder.chromaApi, "ChromaApi must not be null"); - this.chromaApi = builder.chromaApi; this.collectionName = builder.collectionName; this.initializeSchema = builder.initializeSchema; @@ -136,6 +132,10 @@ protected ChromaVectorStore(ChromaBuilder builder) { } } + public static ChromaBuilder builder(ChromaApi chromaApi, EmbeddingModel embeddingModel) { + return new ChromaBuilder(chromaApi, embeddingModel); + } + @Override public void afterPropertiesSet() throws Exception { if (!this.initialized) { @@ -150,15 +150,13 @@ public void afterPropertiesSet() throws Exception { + " doesn't exist and won't be created as the initializeSchema is set to false."); } } - this.collectionId = collection.id(); + if (collection != null) { + this.collectionId = collection.id(); + } this.initialized = true; } } - public static ChromaBuilder builder() { - return new ChromaBuilder(); - } - @Override public void doAdd(@NonNull List documents) { Assert.notNull(documents, "Documents must not be null"); @@ -177,7 +175,7 @@ public void doAdd(@NonNull List documents) { for (Document document : documents) { ids.add(document.getId()); metadatas.add(document.getMetadata()); - contents.add(document.getContent()); + contents.add(document.getText()); embeddings.add(documentEmbeddings.get(documents.indexOf(document))); } @@ -278,7 +276,7 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str public static class ChromaBuilder extends AbstractVectorStoreBuilder { - private ChromaApi chromaApi; + private final ChromaApi chromaApi; private String collectionName = DEFAULT_COLLECTION_NAME; @@ -290,10 +288,10 @@ public static class ChromaBuilder extends AbstractVectorStoreBuilder new ChromaVectorStore.ChromaBuilder().chromaApi(this.chromaApi) - .embeddingModel(this.embeddingModel) + assertThatThrownBy(() -> ChromaVectorStore.builder(this.chromaApi, this.embeddingModel) .collectionName("non-existent") .initializeSchema(false) .initializeImmediately(true) diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreIT.java index c48c57a779b..4a396a52031 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreIT.java @@ -252,9 +252,7 @@ public ChromaApi chromaApi(RestClient.Builder builder) { @Bean public VectorStore chromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi) { - return ChromaVectorStore.builder() - .chromaApi(chromaApi) - .embeddingModel(embeddingModel) + return ChromaVectorStore.builder(chromaApi, embeddingModel) .collectionName("TestCollection") .initializeSchema(true) .build(); diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreObservationIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreObservationIT.java index 22157fee4fe..a7ebe191330 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreObservationIT.java @@ -176,9 +176,7 @@ public ChromaApi chromaApi(RestClient.Builder builder) { @Bean public VectorStore chromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, ObservationRegistry observationRegistry) { - return ChromaVectorStore.builder() - .chromaApi(chromaApi) - .embeddingModel(embeddingModel) + return ChromaVectorStore.builder(chromaApi, embeddingModel) .collectionName("TestCollection") .initializeSchema(true) .observationRegistry(observationRegistry) diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/TokenSecuredChromaWhereIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/TokenSecuredChromaWhereIT.java index 88b5f56b8d0..0b0414e001e 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/TokenSecuredChromaWhereIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/TokenSecuredChromaWhereIT.java @@ -144,9 +144,7 @@ public ChromaApi chromaApi(RestClient.Builder builder) { @Bean public VectorStore chromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi) { - return ChromaVectorStore.builder() - .chromaApi(chromaApi) - .embeddingModel(embeddingModel) + return ChromaVectorStore.builder(chromaApi, embeddingModel) .collectionName("TestCollection") .initializeSchema(true) .build(); diff --git a/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/coherence/CoherenceVectorStore.java b/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/coherence/CoherenceVectorStore.java index 208588bedf9..6ada1221059 100644 --- a/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/coherence/CoherenceVectorStore.java +++ b/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/coherence/CoherenceVectorStore.java @@ -143,11 +143,12 @@ public enum DistanceType { * Creates a new CoherenceVectorStore with minimal configuration. * @param embeddingModel the embedding model to use * @param session the Coherence session - * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(Session, EmbeddingModel)} ()} + * instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public CoherenceVectorStore(EmbeddingModel embeddingModel, Session session) { - this(builder().embeddingModel(embeddingModel).session(session)); + this(builder(session, embeddingModel)); } /** @@ -172,12 +173,13 @@ protected CoherenceVectorStore(CoherenceBuilder builder) { * Creates a new builder for configuring and creating CoherenceVectorStore instances. * @return a new builder instance */ - public static CoherenceBuilder builder() { - return new CoherenceBuilder(); + public static CoherenceBuilder builder(Session session, EmbeddingModel embeddingModel) { + return new CoherenceBuilder(session, embeddingModel); } /** - * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(Session, EmbeddingModel)} ()} + * instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public CoherenceVectorStore setMapName(String mapName) { @@ -186,7 +188,8 @@ public CoherenceVectorStore setMapName(String mapName) { } /** - * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(Session, EmbeddingModel)} ()} + * instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public CoherenceVectorStore setDistanceType(DistanceType distanceType) { @@ -195,7 +198,8 @@ public CoherenceVectorStore setDistanceType(DistanceType distanceType) { } /** - * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(Session, EmbeddingModel)} ()} + * instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public CoherenceVectorStore setIndexType(IndexType indexType) { @@ -204,7 +208,8 @@ public CoherenceVectorStore setIndexType(IndexType indexType) { } /** - * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(Session, EmbeddingModel)} ()} + * instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public CoherenceVectorStore setForcedNormalization(boolean forcedNormalization) { @@ -217,7 +222,7 @@ public void doAdd(final List documents) { Map chunks = new HashMap<>((int) Math.ceil(documents.size() / 0.75f)); for (Document doc : documents) { var id = toChunkId(doc.getId()); - var chunk = new DocumentChunk(doc.getContent(), doc.getMetadata(), + var chunk = new DocumentChunk(doc.getText(), doc.getMetadata(), toFloat32Vector(this.embeddingModel.embed(doc))); chunks.put(id, chunk); } @@ -332,7 +337,7 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str */ public static class CoherenceBuilder extends AbstractVectorStoreBuilder { - private Session session; + private final Session session; private String mapName = DEFAULT_MAP_NAME; @@ -342,16 +347,10 @@ public static class CoherenceBuilder extends AbstractVectorStoreBuilder SIMILARITY_TYPE_MAPPING = Map.of( + private static final Map SIMILARITY_TYPE_MAPPING = Map.of( SimilarityFunction.cosine, VectorStoreSimilarityMetric.COSINE, SimilarityFunction.l2_norm, VectorStoreSimilarityMetric.EUCLIDEAN, SimilarityFunction.dot_product, VectorStoreSimilarityMetric.DOT); @@ -187,9 +187,7 @@ public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestCli EmbeddingModel embeddingModel, boolean initializeSchema, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { - this(builder().restClient(restClient) - .options(options) - .embeddingModel(embeddingModel) + this(builder(restClient, embeddingModel).options(options) .initializeSchema(initializeSchema) .observationRegistry(observationRegistry) .customObservationConvention(customObservationConvention) @@ -226,7 +224,7 @@ public void doAdd(List documents) { this.batchingStrategy); for (Document document : documents) { - ElasticSearchDocument doc = new ElasticSearchDocument(document.getId(), document.getContent(), + ElasticSearchDocument doc = new ElasticSearchDocument(document.getId(), document.getText(), document.getMetadata(), embeddings.get(documents.indexOf(document))); bulkRequestBuilder.operations( op -> op.index(idx -> idx.index(this.options.getIndexName()).id(document.getId()).document(doc))); @@ -389,13 +387,13 @@ public record ElasticSearchDocument(String id, String content, Map { - private RestClient restClient; + private final RestClient restClient; private ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions(); @@ -411,10 +409,10 @@ public static class ElasticsearchBuilder extends AbstractVectorStoreBuilder documents) { this.batchingStrategy); UploadRequest upload = new UploadRequest(documents.stream() .map(document -> new UploadRequest.Embedding(document.getId(), embeddings.get(documents.indexOf(document)), - DOCUMENT_FIELD, document.getContent(), document.getMetadata())) + DOCUMENT_FIELD, document.getText(), document.getMetadata())) .toList()); String embeddingsJson = null; @@ -296,6 +296,7 @@ public Optional doDelete(List idList) { } @Override + @Nullable public List doSimilaritySearch(SearchRequest request) { if (request.hasFilterExpression()) { throw new UnsupportedOperationException("GemFire currently does not support metadata filter expressions."); @@ -515,7 +516,6 @@ public Map getMetadata() { private static final class QueryRequest { @JsonProperty("vector") - @NonNull private final float[] vector; @JsonProperty("top-k") @@ -650,7 +650,8 @@ public static final class GemFireVectorStoreConfig { boolean sslEnabled; /** - * @deprecated Since 1.0.0-M5, use {@link GemFireVectorStore#builder()} instead + * @deprecated Since 1.0.0-M5, use + * {@link GemFireVectorStore#builder(EmbeddingModel)} ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) private GemFireVectorStoreConfig(Builder builder) { @@ -668,7 +669,8 @@ private GemFireVectorStoreConfig(Builder builder) { /** * Start building a new configuration. * @return The entry point for creating a new configuration. - * @deprecated Since 1.0.0-M5, use {@link GemFireVectorStore#builder()} instead + * @deprecated Since 1.0.0-M5, use + * {@link GemFireVectorStore#builder(EmbeddingModel)} ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public static Builder builder() { @@ -676,7 +678,8 @@ public static Builder builder() { } /** - * @deprecated Since 1.0.0-M5, use {@link GemFireVectorStore#builder()} instead + * @deprecated Since 1.0.0-M5, use + * {@link GemFireVectorStore#builder(EmbeddingModel)} ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public static class Builder { @@ -701,8 +704,8 @@ public static class Builder { boolean sslEnabled = GemFireVectorStoreConfig.DEFAULT_SSL_ENABLED; /** - * @deprecated Since 1.0.0-M5, use {@link GemFireVectorStore#builder()} - * instead + * @deprecated Since 1.0.0-M5, use + * {@link GemFireVectorStore#builder(EmbeddingModel)} ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder setHost(String host) { @@ -712,8 +715,8 @@ public Builder setHost(String host) { } /** - * @deprecated Since 1.0.0-M5, use {@link GemFireVectorStore#builder()} - * instead + * @deprecated Since 1.0.0-M5, use + * {@link GemFireVectorStore#builder(EmbeddingModel)} ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder setPort(int port) { @@ -723,8 +726,8 @@ public Builder setPort(int port) { } /** - * @deprecated Since 1.0.0-M5, use {@link GemFireVectorStore#builder()} - * instead + * @deprecated Since 1.0.0-M5, use + * {@link GemFireVectorStore#builder(EmbeddingModel)} ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder setSslEnabled(boolean sslEnabled) { @@ -733,8 +736,8 @@ public Builder setSslEnabled(boolean sslEnabled) { } /** - * @deprecated Since 1.0.0-M5, use {@link GemFireVectorStore#builder()} - * instead + * @deprecated Since 1.0.0-M5, use + * {@link GemFireVectorStore#builder(EmbeddingModel)} ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder setIndexName(String indexName) { @@ -744,8 +747,8 @@ public Builder setIndexName(String indexName) { } /** - * @deprecated Since 1.0.0-M5, use {@link GemFireVectorStore#builder()} - * instead + * @deprecated Since 1.0.0-M5, use + * {@link GemFireVectorStore#builder(EmbeddingModel)} ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder setBeamWidth(int beamWidth) { @@ -757,8 +760,8 @@ public Builder setBeamWidth(int beamWidth) { } /** - * @deprecated Since 1.0.0-M5, use {@link GemFireVectorStore#builder()} - * instead + * @deprecated Since 1.0.0-M5, use + * {@link GemFireVectorStore#builder(EmbeddingModel)} ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder setMaxConnections(int maxConnections) { @@ -771,8 +774,8 @@ public Builder setMaxConnections(int maxConnections) { } /** - * @deprecated Since 1.0.0-M5, use {@link GemFireVectorStore#builder()} - * instead + * @deprecated Since 1.0.0-M5, use + * {@link GemFireVectorStore#builder(EmbeddingModel)} ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder setBuckets(int buckets) { @@ -782,8 +785,8 @@ public Builder setBuckets(int buckets) { } /** - * @deprecated Since 1.0.0-M5, use {@link GemFireVectorStore#builder()} - * instead + * @deprecated Since 1.0.0-M5, use + * {@link GemFireVectorStore#builder(EmbeddingModel)} ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder setVectorSimilarityFunction(String vectorSimilarityFunction) { @@ -793,8 +796,8 @@ public Builder setVectorSimilarityFunction(String vectorSimilarityFunction) { } /** - * @deprecated Since 1.0.0-M5, use {@link GemFireVectorStore#builder()} - * instead + * @deprecated Since 1.0.0-M5, use + * {@link GemFireVectorStore#builder(EmbeddingModel)} ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder setFields(String[] fields) { @@ -803,8 +806,8 @@ public Builder setFields(String[] fields) { } /** - * @deprecated Since 1.0.0-M5, use {@link GemFireVectorStore#builder()} - * instead + * @deprecated Since 1.0.0-M5, use + * {@link GemFireVectorStore#builder(EmbeddingModel)} ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public GemFireVectorStoreConfig build() { @@ -846,6 +849,10 @@ public static class GemFireBuilder extends AbstractVectorStoreBuilder repository, @@ -108,16 +110,15 @@ public HanaCloudVectorStore(HanaVectorRepository rep * @param config the vector store configuration * @param observationRegistry the observation registry * @param customObservationConvention the custom observation convention - * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + * @deprecated Since 1.0.0-M5, use + * {@link #builder(HanaVectorRepository, EmbeddingModel)} ()} instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public HanaCloudVectorStore(HanaVectorRepository repository, EmbeddingModel embeddingModel, HanaCloudVectorStoreConfig config, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention) { - this(builder().repository(repository) - .embeddingModel(embeddingModel) - .tableName(config.getTableName()) + this(builder(repository, embeddingModel).tableName(config.getTableName()) .topK(config.getTopK()) .observationRegistry(observationRegistry) .customObservationConvention(customObservationConvention)); @@ -143,8 +144,9 @@ protected HanaCloudVectorStore(HanaCloudBuilder builder) { * Creates a new builder for configuring and creating HanaCloudVectorStore instances. * @return a new builder instance */ - public static HanaCloudBuilder builder() { - return new HanaCloudBuilder(); + public static HanaCloudBuilder builder(HanaVectorRepository repository, + EmbeddingModel embeddingModel) { + return new HanaCloudBuilder(repository, embeddingModel); } @Override @@ -153,7 +155,7 @@ public void doAdd(List documents) { for (Document document : documents) { logger.info("[{}/{}] Calling EmbeddingModel for document id = {}", count++, documents.size(), document.getId()); - String content = document.getContent().replaceAll("\\s+", " "); + String content = document.getText().replaceAll("\\s+", " "); String embedding = getEmbedding(document); this.repository.save(this.tableName, document.getId(), embedding, content); } @@ -233,8 +235,9 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str */ public static class HanaCloudBuilder extends AbstractVectorStoreBuilder { - private HanaVectorRepository repository; + private final HanaVectorRepository repository; + @Nullable private String tableName; private int topK; @@ -245,10 +248,11 @@ public static class HanaCloudBuilder extends AbstractVectorStoreBuilder repository) { + private HanaCloudBuilder(HanaVectorRepository repository, + EmbeddingModel embeddingModel) { + super(embeddingModel); Assert.notNull(repository, "Repository must not be null"); this.repository = repository; - return this; } /** @@ -273,7 +277,6 @@ public HanaCloudBuilder topK(int topK) { @Override public HanaCloudVectorStore build() { - validate(); return new HanaCloudVectorStore(this); } diff --git a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/hanadb/package-info.java b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/hanadb/package-info.java new file mode 100644 index 00000000000..9ec356db0e3 --- /dev/null +++ b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/hanadb/package-info.java @@ -0,0 +1,25 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides the API for embedding observations. + */ +@NonNullApi +@NonNullFields +package org.springframework.ai.vectorstore.hanadb; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/hanadb/HanaCloudVectorStoreIT.java b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/hanadb/HanaCloudVectorStoreIT.java index 7015c31b6aa..6a4131f3475 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/hanadb/HanaCloudVectorStoreIT.java +++ b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/hanadb/HanaCloudVectorStoreIT.java @@ -91,9 +91,7 @@ public static class HanaTestApplication { public VectorStore hanaCloudVectorStore(CricketWorldCupRepository cricketWorldCupRepository, EmbeddingModel embeddingModel) { - return HanaCloudVectorStore.builder() - .repository(cricketWorldCupRepository) - .embeddingModel(embeddingModel) + return HanaCloudVectorStore.builder(cricketWorldCupRepository, embeddingModel) .tableName("CRICKET_WORLD_CUP") .topK(1) .build(); diff --git a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/hanadb/HanaVectorStoreObservationIT.java b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/hanadb/HanaVectorStoreObservationIT.java index d537aec4214..a37d5af13ba 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/hanadb/HanaVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/hanadb/HanaVectorStoreObservationIT.java @@ -167,9 +167,7 @@ public TestObservationRegistry observationRegistry() { public VectorStore hanaCloudVectorStore(CricketWorldCupRepository cricketWorldCupRepository, EmbeddingModel embeddingModel, ObservationRegistry observationRegistry) { - return HanaCloudVectorStore.builder() - .repository(cricketWorldCupRepository) - .embeddingModel(embeddingModel) + return HanaCloudVectorStore.builder(cricketWorldCupRepository, embeddingModel) .tableName(TEST_TABLE_NAME) .topK(1) .observationRegistry(observationRegistry) diff --git a/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStore.java b/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStore.java index e801c3ad5b3..b5ec600cac3 100644 --- a/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStore.java +++ b/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStore.java @@ -196,7 +196,8 @@ public class MariaDBVectorStore extends AbstractObservationVectorStore implement private final int maxDocumentBatchSize; /** - * @deprecated Use {@link #builder(JdbcTemplate)} instead + * @deprecated Use {@link #builder(JdbcTemplate, EmbeddingModel)} (JdbcTemplate)} + * instead */ @Deprecated(forRemoval = true, since = "1.0.0-M5") public MariaDBVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { @@ -204,7 +205,8 @@ public MariaDBVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingMod } /** - * @deprecated Use {@link #builder(JdbcTemplate)} instead + * @deprecated Use {@link #builder(JdbcTemplate, EmbeddingModel)} (JdbcTemplate)} + * instead */ @Deprecated(forRemoval = true, since = "1.0.0-M5") public MariaDBVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions) { @@ -212,7 +214,8 @@ public MariaDBVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingMod } /** - * @deprecated Use {@link #builder(JdbcTemplate)} instead + * @deprecated Use {@link #builder(JdbcTemplate, EmbeddingModel)} (JdbcTemplate)} + * instead */ @Deprecated(forRemoval = true, since = "1.0.0-M5") public MariaDBVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, @@ -222,7 +225,8 @@ public MariaDBVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingMod } /** - * @deprecated Use {@link #builder(JdbcTemplate)} instead + * @deprecated Use {@link #builder(JdbcTemplate, EmbeddingModel)} (JdbcTemplate)} + * instead */ @Deprecated(forRemoval = true, since = "1.0.0-M5") public MariaDBVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, @@ -233,7 +237,8 @@ public MariaDBVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, Emb } /** - * @deprecated Use {@link #builder(JdbcTemplate)} instead + * @deprecated Use {@link #builder(JdbcTemplate, EmbeddingModel)} (JdbcTemplate)} + * instead */ @Deprecated(forRemoval = true, since = "1.0.0-M5") private MariaDBVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled, @@ -246,7 +251,8 @@ private MariaDBVectorStore(String schemaName, String vectorTableName, boolean ve } /** - * @deprecated Use {@link #builder(JdbcTemplate)} instead + * @deprecated Use {@link #builder(JdbcTemplate, EmbeddingModel)} (JdbcTemplate)} + * instead */ @Deprecated(forRemoval = true, since = "1.0.0-M5") private MariaDBVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled, @@ -256,8 +262,7 @@ private MariaDBVectorStore(String schemaName, String vectorTableName, boolean ve int maxDocumentBatchSize, String contentFieldName, String embeddingFieldName, String idFieldName, String metadataFieldName) { - this(builder(jdbcTemplate).vectorTableName(vectorTableName) - .embeddingModel(embeddingModel) + this(builder(jdbcTemplate, embeddingModel).vectorTableName(vectorTableName) .dimensions(dimensions) .distanceType(distanceType) .removeExistingVectorStoreTable(removeExistingVectorStoreTable) @@ -287,12 +292,10 @@ protected MariaDBVectorStore(MariaDBBuilder builder) { this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build(); - this.vectorTableName = (null == builder.vectorTableName || builder.vectorTableName.isEmpty()) - ? DEFAULT_TABLE_NAME + this.vectorTableName = builder.vectorTableName.isEmpty() ? DEFAULT_TABLE_NAME : MariaDBSchemaValidator.validateAndEnquoteIdentifier(builder.vectorTableName.trim(), false); - logger.info("Using the vector table name: {}. Is empty: {}", this.vectorTableName, - (vectorTableName == null || vectorTableName.isEmpty())); + logger.info("Using the vector table name: {}. Is empty: {}", this.vectorTableName, vectorTableName.isEmpty()); this.schemaName = builder.schemaName == null ? null : MariaDBSchemaValidator.validateAndEnquoteIdentifier(builder.schemaName, false); @@ -319,8 +322,8 @@ protected MariaDBVectorStore(MariaDBBuilder builder) { * MariaDBVectorStore. * @return a new MariaDBBuilder instance */ - public static MariaDBBuilder builder(JdbcTemplate jdbcTemplate) { - return new MariaDBBuilder(jdbcTemplate); + public static MariaDBBuilder builder(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { + return new MariaDBBuilder(jdbcTemplate, embeddingModel); } public MariaDBDistanceType getDistanceType() { @@ -342,14 +345,14 @@ private List> batchDocuments(List documents, Lis List mariaDBDocuments = new ArrayList<>(documents.size()); if (embeddings.size() == documents.size()) { for (Document document : documents) { - mariaDBDocuments.add(new MariaDBDocument(document.getId(), document.getContent(), - document.getMetadata(), embeddings.get(documents.indexOf(document)))); + mariaDBDocuments.add(new MariaDBDocument(document.getId(), document.getText(), document.getMetadata(), + embeddings.get(documents.indexOf(document)))); } } else { for (Document document : documents) { mariaDBDocuments - .add(new MariaDBDocument(document.getId(), document.getContent(), document.getMetadata(), null)); + .add(new MariaDBDocument(document.getId(), document.getText(), document.getMetadata(), null)); } } @@ -572,6 +575,7 @@ public static final class MariaDBBuilder extends AbstractVectorStoreBuilder metadata, float[] embedding) { + public record MariaDBDocument(String id, @Nullable String content, Map metadata, + @Nullable float[] embedding) { } } diff --git a/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/package-info.java b/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/package-info.java new file mode 100644 index 00000000000..d2cb9365b88 --- /dev/null +++ b/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/package-info.java @@ -0,0 +1,25 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides the API for embedding observations. + */ +@NonNullApi +@NonNullFields +package org.springframework.ai.vectorstore.mariadb; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBEmbeddingDimensionsTests.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBEmbeddingDimensionsTests.java index c57f8e67668..2e7ec9abe54 100644 --- a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBEmbeddingDimensionsTests.java +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBEmbeddingDimensionsTests.java @@ -47,8 +47,7 @@ public void explicitlySetDimensions() { final int explicitDimensions = 696; - MariaDBVectorStore mariaDBVectorStore = MariaDBVectorStore.builder(this.jdbcTemplate) - .embeddingModel(this.embeddingModel) + MariaDBVectorStore mariaDBVectorStore = MariaDBVectorStore.builder(this.jdbcTemplate, this.embeddingModel) .dimensions(explicitDimensions) .build(); var dim = mariaDBVectorStore.embeddingDimensions(); @@ -61,8 +60,7 @@ public void explicitlySetDimensions() { public void embeddingModelDimensions() { when(this.embeddingModel.dimensions()).thenReturn(969); - MariaDBVectorStore mariaDBVectorStore = MariaDBVectorStore.builder(this.jdbcTemplate) - .embeddingModel(this.embeddingModel) + MariaDBVectorStore mariaDBVectorStore = MariaDBVectorStore.builder(this.jdbcTemplate, this.embeddingModel) .build(); var dim = mariaDBVectorStore.embeddingDimensions(); @@ -76,8 +74,7 @@ public void fallBackToDefaultDimensions() { when(this.embeddingModel.dimensions()).thenThrow(new RuntimeException()); - MariaDBVectorStore mariaDBVectorStore = MariaDBVectorStore.builder(this.jdbcTemplate) - .embeddingModel(this.embeddingModel) + MariaDBVectorStore mariaDBVectorStore = MariaDBVectorStore.builder(this.jdbcTemplate, this.embeddingModel) .build(); var dim = mariaDBVectorStore.embeddingDimensions(); diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreCustomNamesIT.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreCustomNamesIT.java index 3f60e501ec8..7c24948b2cf 100644 --- a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreCustomNamesIT.java +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreCustomNamesIT.java @@ -217,8 +217,7 @@ public static class TestApplication { @Bean public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { - return MariaDBVectorStore.builder(jdbcTemplate) - .embeddingModel(embeddingModel) + return MariaDBVectorStore.builder(jdbcTemplate, embeddingModel) .schemaName(this.schemaName) .vectorTableName(this.vectorTableName) .schemaValidation(this.schemaValidation) diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreIT.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreIT.java index 54ab39ab905..8a14311407b 100644 --- a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreIT.java +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreIT.java @@ -347,8 +347,7 @@ public static class TestApplication { @Bean public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { - return MariaDBVectorStore.builder(jdbcTemplate) - .embeddingModel(embeddingModel) + return MariaDBVectorStore.builder(jdbcTemplate, embeddingModel) .dimensions(MariaDBVectorStore.INVALID_EMBEDDING_DIMENSION) .distanceType(this.distanceType) .removeExistingVectorStoreTable(true) diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreObservationIT.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreObservationIT.java index 2e183bc525f..caca46eea76 100644 --- a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreObservationIT.java +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreObservationIT.java @@ -177,8 +177,7 @@ public TestObservationRegistry observationRegistry() { @Bean public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, ObservationRegistry observationRegistry) { - return MariaDBVectorStore.builder(jdbcTemplate) - .embeddingModel(embeddingModel) + return MariaDBVectorStore.builder(jdbcTemplate, embeddingModel) .schemaName(schemaName) .distanceType(MariaDBVectorStore.MariaDBDistanceType.COSINE) .observationRegistry(observationRegistry) diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreTests.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreTests.java index defe99548e0..be36e07920d 100644 --- a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreTests.java +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreTests.java @@ -70,8 +70,7 @@ void shouldAddDocumentsInBatchesAndEmbedOnce() { // Given var jdbcTemplate = mock(JdbcTemplate.class); var embeddingModel = mock(EmbeddingModel.class); - var mariadbVectorStore = MariaDBVectorStore.builder(jdbcTemplate) - .embeddingModel(embeddingModel) + var mariadbVectorStore = MariaDBVectorStore.builder(jdbcTemplate, embeddingModel) .maxDocumentBatchSize(1000) .build(); diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStoreBuilderTests.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStoreBuilderTests.java index 4e0e1eeb76e..ed28a7f25fd 100644 --- a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStoreBuilderTests.java +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStoreBuilderTests.java @@ -38,22 +38,21 @@ class MariaDBVectorStoreBuilderTests { @Test void shouldFailOnMissingEmbeddingModel() { - assertThatThrownBy(() -> MariaDBVectorStore.builder(jdbcTemplate).build()) + assertThatThrownBy(() -> MariaDBVectorStore.builder(jdbcTemplate, null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("EmbeddingModel must be configured"); } @Test void shouldFailOnMissingJdbcTemplate() { - assertThatThrownBy(() -> MariaDBVectorStore.builder(null).build()).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> MariaDBVectorStore.builder(null, embeddingModel).build()) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("JdbcTemplate must not be null"); } @Test void shouldUseDefaultValues() { - MariaDBVectorStore vectorStore = MariaDBVectorStore.builder(jdbcTemplate) - .embeddingModel(embeddingModel) - .build(); + MariaDBVectorStore vectorStore = MariaDBVectorStore.builder(jdbcTemplate, embeddingModel).build(); assertThat(vectorStore).hasFieldOrPropertyWithValue("vectorTableName", "vector_store") .hasFieldOrPropertyWithValue("schemaName", null) @@ -71,8 +70,7 @@ void shouldUseDefaultValues() { @Test void shouldConfigureCustomValues() { - MariaDBVectorStore vectorStore = MariaDBVectorStore.builder(jdbcTemplate) - .embeddingModel(embeddingModel) + MariaDBVectorStore vectorStore = MariaDBVectorStore.builder(jdbcTemplate, embeddingModel) .schemaName("custom_schema") .vectorTableName("custom_vectors") .schemaValidation(true) @@ -103,60 +101,49 @@ void shouldConfigureCustomValues() { @Test void shouldValidateFieldNames() { - assertThatThrownBy(() -> MariaDBVectorStore.builder(jdbcTemplate) - .embeddingModel(embeddingModel) - .contentFieldName("") - .build()).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> MariaDBVectorStore.builder(jdbcTemplate, embeddingModel).contentFieldName("").build()) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("ContentFieldName must not be empty"); - assertThatThrownBy(() -> MariaDBVectorStore.builder(jdbcTemplate) - .embeddingModel(embeddingModel) - .embeddingFieldName("") - .build()).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy( + () -> MariaDBVectorStore.builder(jdbcTemplate, embeddingModel).embeddingFieldName("").build()) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("EmbeddingFieldName must not be empty"); - assertThatThrownBy( - () -> MariaDBVectorStore.builder(jdbcTemplate).embeddingModel(embeddingModel).idFieldName("").build()) + assertThatThrownBy(() -> MariaDBVectorStore.builder(jdbcTemplate, embeddingModel).idFieldName("").build()) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("IdFieldName must not be empty"); - assertThatThrownBy(() -> MariaDBVectorStore.builder(jdbcTemplate) - .embeddingModel(embeddingModel) - .metadataFieldName("") - .build()).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> MariaDBVectorStore.builder(jdbcTemplate, embeddingModel).metadataFieldName("").build()) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("MetadataFieldName must not be empty"); } @Test void shouldValidateMaxDocumentBatchSize() { - assertThatThrownBy(() -> MariaDBVectorStore.builder(jdbcTemplate) - .embeddingModel(embeddingModel) - .maxDocumentBatchSize(0) - .build()).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy( + () -> MariaDBVectorStore.builder(jdbcTemplate, embeddingModel).maxDocumentBatchSize(0).build()) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("MaxDocumentBatchSize must be positive"); - assertThatThrownBy(() -> MariaDBVectorStore.builder(jdbcTemplate) - .embeddingModel(embeddingModel) - .maxDocumentBatchSize(-1) - .build()).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy( + () -> MariaDBVectorStore.builder(jdbcTemplate, embeddingModel).maxDocumentBatchSize(-1).build()) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("MaxDocumentBatchSize must be positive"); } @Test void shouldValidateDistanceType() { - assertThatThrownBy(() -> MariaDBVectorStore.builder(jdbcTemplate) - .embeddingModel(embeddingModel) - .distanceType(null) - .build()).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> MariaDBVectorStore.builder(jdbcTemplate, embeddingModel).distanceType(null).build()) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("DistanceType must not be null"); } @Test void shouldValidateBatchingStrategy() { - assertThatThrownBy(() -> MariaDBVectorStore.builder(jdbcTemplate) - .embeddingModel(embeddingModel) - .batchingStrategy(null) - .build()).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy( + () -> MariaDBVectorStore.builder(jdbcTemplate, embeddingModel).batchingStrategy(null).build()) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("BatchingStrategy must not be null"); } diff --git a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/milvus/vectorstore/MilvusVectorStore.java b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/milvus/vectorstore/MilvusVectorStore.java index 7976b31bb90..d8a5853f690 100644 --- a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/milvus/vectorstore/MilvusVectorStore.java +++ b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/milvus/vectorstore/MilvusVectorStore.java @@ -169,9 +169,9 @@ public class MilvusVectorStore extends AbstractObservationVectorStore implements private static final Logger logger = LoggerFactory.getLogger(MilvusVectorStore.class); - private static Map SIMILARITY_TYPE_MAPPING = Map.of(MetricType.COSINE, - VectorStoreSimilarityMetric.COSINE, MetricType.L2, VectorStoreSimilarityMetric.EUCLIDEAN, MetricType.IP, - VectorStoreSimilarityMetric.DOT); + private static final Map SIMILARITY_TYPE_MAPPING = Map.of( + MetricType.COSINE, VectorStoreSimilarityMetric.COSINE, MetricType.L2, VectorStoreSimilarityMetric.EUCLIDEAN, + MetricType.IP, VectorStoreSimilarityMetric.DOT); public final FilterExpressionConverter filterExpressionConverter = new MilvusFilterExpressionConverter(); @@ -230,9 +230,7 @@ public MilvusVectorStore(MilvusServiceClient milvusClient, EmbeddingModel embedd MilvusVectorStoreConfig config, boolean initializeSchema, BatchingStrategy batchingStrategy, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention) { - this(builder().milvusClient(milvusClient) - .embeddingModel(embeddingModel) - .observationRegistry(observationRegistry) + this(builder(milvusClient, embeddingModel).observationRegistry(observationRegistry) .customObservationConvention(customObservationConvention) .initializeSchema(initializeSchema) .batchingStrategy(batchingStrategy)); @@ -268,8 +266,8 @@ protected MilvusVectorStore(MilvusBuilder builder) { * recommended way to instantiate a MilvusBuilder. * @return a new MilvusBuilder instance */ - public static MilvusBuilder builder() { - return new MilvusBuilder(); + public static MilvusBuilder builder(MilvusServiceClient milvusClient, EmbeddingModel embeddingModel) { + return new MilvusBuilder(milvusClient, embeddingModel); } @Override @@ -290,7 +288,7 @@ public void doAdd(List documents) { docIdArray.add(document.getId()); // Use a (future) DocumentTextLayoutFormatter instance to extract // the content used to compute the embeddings - contentArray.add(document.getContent()); + contentArray.add(document.getText()); metadataArray.add(new JSONObject(document.getMetadata())); embeddingArray.add(EmbeddingUtils.toList(embeddings.get(documents.indexOf(document)))); } @@ -579,6 +577,8 @@ private String getSimilarityMetric() { public static final class MilvusBuilder extends AbstractVectorStoreBuilder { + private final MilvusServiceClient milvusClient; + private String databaseName = DEFAULT_DATABASE_NAME; private String collectionName = DEFAULT_COLLECTION_NAME; @@ -603,18 +603,16 @@ public static final class MilvusBuilder extends AbstractVectorStoreBuilder MilvusVectorStore.builder() - .milvusClient(this.milvusClient) + ThrowableAssert.ThrowingCallable actual = () -> MilvusVectorStore + .builder(this.milvusClient, this.embeddingModel) .embeddingDimension(explicitDimensions) .build(); diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/milvus/vectorstore/MilvusVectorStoreCustomFieldNamesIT.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/milvus/vectorstore/MilvusVectorStoreCustomFieldNamesIT.java index 0934fcc3c18..cc6eddbe151 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/milvus/vectorstore/MilvusVectorStoreCustomFieldNamesIT.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/milvus/vectorstore/MilvusVectorStoreCustomFieldNamesIT.java @@ -230,9 +230,7 @@ static class TestApplication { @Bean VectorStore vectorStore(MilvusServiceClient milvusClient, EmbeddingModel embeddingModel) { - return MilvusVectorStore.builder() - .milvusClient(milvusClient) - .embeddingModel(embeddingModel) + return MilvusVectorStore.builder(milvusClient, embeddingModel) .collectionName("test_vector_store_custom_fields") .databaseName("default") .indexType(IndexType.IVF_FLAT) diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/milvus/vectorstore/MilvusVectorStoreIT.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/milvus/vectorstore/MilvusVectorStoreIT.java index b3f08670aca..18cd23f5752 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/milvus/vectorstore/MilvusVectorStoreIT.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/milvus/vectorstore/MilvusVectorStoreIT.java @@ -269,9 +269,7 @@ public static class TestApplication { @Bean public VectorStore vectorStore(MilvusServiceClient milvusClient, EmbeddingModel embeddingModel) { - return MilvusVectorStore.builder() - .milvusClient(milvusClient) - .embeddingModel(embeddingModel) + return MilvusVectorStore.builder(milvusClient, embeddingModel) .collectionName("test_vector_store") .databaseName("default") .indexType(IndexType.IVF_FLAT) diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/milvus/vectorstore/MilvusVectorStoreObservationIT.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/milvus/vectorstore/MilvusVectorStoreObservationIT.java index 1df3ac1aedb..89e4df509b5 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/milvus/vectorstore/MilvusVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/milvus/vectorstore/MilvusVectorStoreObservationIT.java @@ -170,9 +170,7 @@ public TestObservationRegistry observationRegistry() { @Bean public VectorStore vectorStore(MilvusServiceClient milvusClient, EmbeddingModel embeddingModel, ObservationRegistry observationRegistry) { - return MilvusVectorStore.builder() - .milvusClient(milvusClient) - .embeddingModel(embeddingModel) + return MilvusVectorStore.builder(milvusClient, embeddingModel) .observationRegistry(observationRegistry) .collectionName(TEST_COLLECTION_NAME) .databaseName("default") diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/mongodb/atlas/MongoDBAtlasVectorStore.java b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/mongodb/atlas/MongoDBAtlasVectorStore.java index 067a7292d90..8a336599d4d 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/mongodb/atlas/MongoDBAtlasVectorStore.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/mongodb/atlas/MongoDBAtlasVectorStore.java @@ -187,9 +187,7 @@ public MongoDBAtlasVectorStore(MongoTemplate mongoTemplate, EmbeddingModel embed MongoDBVectorStoreConfig config, boolean initializeSchema, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { - this(builder().mongoTemplate(mongoTemplate) - .embeddingModel(embeddingModel) - .collectionName(config.collectionName) + this(builder(mongoTemplate, embeddingModel).collectionName(config.collectionName) .vectorIndexName(config.vectorIndexName) .pathName(config.pathName) .numCandidates(config.numCandidates) @@ -296,7 +294,7 @@ public void doAdd(List documents) { List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); for (Document document : documents) { - MongoDBDocument mdbDocument = new MongoDBDocument(document.getId(), document.getContent(), + MongoDBDocument mdbDocument = new MongoDBDocument(document.getId(), document.getText(), document.getMetadata(), embeddings.get(documents.indexOf(document))); this.mongoTemplate.save(mdbDocument, this.collectionName); } @@ -354,13 +352,13 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str * Creates a new builder instance for MongoDBAtlasVectorStore. * @return a new MongoDBBuilder instance */ - public static MongoDBBuilder builder() { - return new MongoDBBuilder(); + public static MongoDBBuilder builder(MongoTemplate mongoTemplate, EmbeddingModel embeddingModel) { + return new MongoDBBuilder(mongoTemplate, embeddingModel); } public static class MongoDBBuilder extends AbstractVectorStoreBuilder { - private MongoTemplate mongoTemplate; + private final MongoTemplate mongoTemplate; private String collectionName = DEFAULT_VECTOR_COLLECTION_NAME; @@ -381,10 +379,10 @@ public static class MongoDBBuilder extends AbstractVectorStoreBuilder SIMILARITY_TYPE_MAPPING = Map.of( + private static final Map SIMILARITY_TYPE_MAPPING = Map.of( Neo4jDistanceType.COSINE, VectorStoreSimilarityMetric.COSINE, Neo4jDistanceType.EUCLIDEAN, VectorStoreSimilarityMetric.EUCLIDEAN); @@ -196,9 +196,7 @@ public Neo4jVectorStore(Driver driver, EmbeddingModel embeddingModel, Neo4jVecto boolean initializeSchema, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { - this(builder().driver(driver) - .embeddingModel(embeddingModel) - .sessionConfig(config.sessionConfig) + this(builder(driver, embeddingModel).sessionConfig(config.sessionConfig) .embeddingDimension(config.embeddingDimension) .distanceType(config.distanceType) .embeddingProperty(config.embeddingProperty) @@ -341,7 +339,7 @@ private Map documentToRecord(Document document, float[] embeddin row.put("id", document.getId()); var properties = new HashMap(); - properties.put("text", document.getContent()); + properties.put("text", document.getText()); document.getMetadata().forEach((k, v) -> properties.put("metadata." + k, Values.value(v))); row.put("properties", properties); @@ -400,13 +398,13 @@ public enum Neo4jDistanceType { } - public static Neo4jBuilder builder() { - return new Neo4jBuilder(); + public static Neo4jBuilder builder(Driver driver, EmbeddingModel embeddingModel) { + return new Neo4jBuilder(driver, embeddingModel); } public static class Neo4jBuilder extends AbstractVectorStoreBuilder { - private Driver driver; + private final Driver driver; private SessionConfig sessionConfig = SessionConfig.defaultConfig(); @@ -428,10 +426,10 @@ public static class Neo4jBuilder extends AbstractVectorStoreBuilder documents) { this.batchingStrategy); BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder(); for (Document document : documents) { - OpenSearchDocument openSearchDocument = new OpenSearchDocument(document.getId(), document.getContent(), + OpenSearchDocument openSearchDocument = new OpenSearchDocument(document.getId(), document.getText(), document.getMetadata(), embedding.get(documents.indexOf(document))); bulkRequestBuilder.operations(op -> op .index(idx -> idx.index(this.index).id(openSearchDocument.id()).document(openSearchDocument))); @@ -450,7 +448,7 @@ public record OpenSearchDocument(String id, String content, Map */ public static class OpenSearchBuilder extends AbstractVectorStoreBuilder { - private OpenSearchClient openSearchClient; + private final OpenSearchClient openSearchClient; private String index = DEFAULT_INDEX_NAME; @@ -470,22 +468,10 @@ public static class OpenSearchBuilder extends AbstractVectorStoreBuilder SIMILARITY_TYPE_MAPPING = Map.of( - OracleVectorStoreDistanceType.COSINE, VectorStoreSimilarityMetric.COSINE, - OracleVectorStoreDistanceType.EUCLIDEAN, VectorStoreSimilarityMetric.EUCLIDEAN, - OracleVectorStoreDistanceType.DOT, VectorStoreSimilarityMetric.DOT); + private static final Map SIMILARITY_TYPE_MAPPING = Map + .of(OracleVectorStoreDistanceType.COSINE, VectorStoreSimilarityMetric.COSINE, + OracleVectorStoreDistanceType.EUCLIDEAN, VectorStoreSimilarityMetric.EUCLIDEAN, + OracleVectorStoreDistanceType.DOT, VectorStoreSimilarityMetric.DOT); public final FilterExpressionConverter filterExpressionConverter = new SqlJsonPathFilterExpressionConverter(); @@ -150,7 +150,8 @@ public class OracleVectorStore extends AbstractObservationVectorStore implements * Creates a new OracleVectorStore with default configuration. * @param jdbcTemplate the JDBC template to use * @param embeddingModel the embedding model to use - * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(JdbcTemplate, EmbeddingModel)} ()} + * instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { @@ -163,7 +164,8 @@ public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingMode * @param jdbcTemplate the JDBC template to use * @param embeddingModel the embedding model to use * @param initializeSchema whether to initialize the schema - * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(JdbcTemplate, EmbeddingModel)} ()} + * instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, boolean initializeSchema) { @@ -183,7 +185,8 @@ public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingMode * @param initializeSchema whether to initialize the schema * @param removeExistingVectorStoreTable whether to remove existing vector store table * @param forcedNormalization whether to force vector normalization - * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(JdbcTemplate, EmbeddingModel)} ()} + * instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, String tableName, @@ -211,7 +214,8 @@ public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingMode * @param observationRegistry the observation registry * @param customObservationConvention the custom observation convention * @param batchingStrategy the batching strategy - * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + * @deprecated Since 1.0.0-M5, use {@link #builder(JdbcTemplate, EmbeddingModel)} ()} + * instead */ @Deprecated(since = "1.0.0-M5", forRemoval = true) public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, String tableName, @@ -220,9 +224,7 @@ public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingMode boolean forcedNormalization, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { - this(builder().jdbcTemplate(jdbcTemplate) - .embeddingModel(embeddingModel) - .tableName(tableName) + this(builder(jdbcTemplate, embeddingModel).tableName(tableName) .indexType(indexType) .distanceType(distanceType) .dimensions(dimensions) @@ -257,8 +259,8 @@ protected OracleVectorStore(OracleBuilder builder) { this.batchingStrategy = builder.batchingStrategy; } - public static OracleBuilder builder() { - return new OracleBuilder(); + public static OracleBuilder builder(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { + return new OracleBuilder(jdbcTemplate, embeddingModel); } @Override @@ -270,7 +272,7 @@ public void doAdd(final List documents) { @Override public void setValues(PreparedStatement ps, int i) throws SQLException { final Document document = documents.get(i); - final String content = document.getContent(); + final String content = document.getText(); final byte[] json = toJson(document.getMetadata()); final VECTOR embeddingVector = toVECTOR(embeddings.get(documents.indexOf(document))); @@ -747,7 +749,7 @@ private List toFloatList(final float[] embeddings) { */ public static class OracleBuilder extends AbstractVectorStoreBuilder { - private JdbcTemplate jdbcTemplate; + private final JdbcTemplate jdbcTemplate; private String tableName = DEFAULT_TABLE_NAME; @@ -773,10 +775,10 @@ public static class OracleBuilder extends AbstractVectorStoreBuilder batch, List documents, public void setValues(PreparedStatement ps, int i) throws SQLException { var document = batch.get(i); - var content = document.getContent(); + var content = document.getText(); var json = toJson(document.getMetadata()); var embedding = embeddings.get(documents.indexOf(document)); var pGvector = new PGvector(embedding); @@ -389,7 +388,6 @@ public List embeddingDistance(String query) { new RowMapper() { @Override - @Nullable public Double mapRow(ResultSet rs, int rowNum) throws SQLException { return rs.getDouble(DocumentRowMapper.COLUMN_DISTANCE); } @@ -625,7 +623,7 @@ private Map toMap(PGobject pgObject) { public static class PgVectorStoreBuilder extends AbstractVectorStoreBuilder { - private JdbcTemplate jdbcTemplate; + private final JdbcTemplate jdbcTemplate; private String schemaName = PgVectorStore.DEFAULT_SCHEMA_NAME; @@ -647,10 +645,10 @@ public static class PgVectorStoreBuilder extends AbstractVectorStoreBuilder { - private String apiKey; + private final String apiKey; - private String projectId; + private final String projectId; - private String environment; + private final String environment; - private String indexName; + private final String indexName; private String namespace = ""; @@ -358,52 +358,19 @@ public static class PineconeBuilder extends AbstractVectorStoreBuilder toPayload(Document document) { try { var payload = QdrantValueFactory.toValueMap(document.getMetadata()); - payload.put(CONTENT_FIELD_NAME, io.qdrant.client.ValueFactory.value(document.getContent())); + payload.put(CONTENT_FIELD_NAME, io.qdrant.client.ValueFactory.value(document.getText())); return payload; } catch (Exception e) { @@ -386,7 +385,8 @@ public static final class QdrantBuilder extends AbstractVectorStoreBuilder QdrantVectorStore.builder(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> QdrantVectorStore.builder(null, null)).isInstanceOf(IllegalArgumentException.class) .hasMessage("QdrantClient must not be null"); } @Test void nullEmbeddingModelShouldThrowException() { - assertThatThrownBy(() -> QdrantVectorStore.builder(qdrantClient).embeddingModel(null).build()) + assertThatThrownBy(() -> QdrantVectorStore.builder(qdrantClient, null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("EmbeddingModel must not be null"); } @Test void emptyCollectionNameShouldThrowException() { - assertThatThrownBy( - () -> QdrantVectorStore.builder(qdrantClient).embeddingModel(embeddingModel).collectionName("").build()) + assertThatThrownBy(() -> QdrantVectorStore.builder(qdrantClient, embeddingModel).collectionName("").build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("collectionName must not be empty"); } @Test void nullBatchingStrategyShouldThrowException() { - assertThatThrownBy(() -> QdrantVectorStore.builder(qdrantClient) - .embeddingModel(embeddingModel) - .batchingStrategy(null) - .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("BatchingStrategy must not be null"); + assertThatThrownBy(() -> QdrantVectorStore.builder(qdrantClient, embeddingModel).batchingStrategy(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("BatchingStrategy must not be null"); } } diff --git a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java index 282fb5577bf..b5a89c8939b 100644 --- a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java +++ b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java @@ -254,9 +254,8 @@ public QdrantClient qdrantClient() { @Bean public VectorStore qdrantVectorStore(EmbeddingModel embeddingModel, QdrantClient qdrantClient) { - return QdrantVectorStore.builder(qdrantClient) + return QdrantVectorStore.builder(qdrantClient, embeddingModel) .collectionName(COLLECTION_NAME) - .embeddingModel(embeddingModel) .initializeSchema(true) .build(); } diff --git a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreObservationIT.java b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreObservationIT.java index f54a2b480a0..a227b7d5457 100644 --- a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreObservationIT.java @@ -195,9 +195,8 @@ public QdrantClient qdrantClient() { @Bean public VectorStore qdrantVectorStore(EmbeddingModel embeddingModel, QdrantClient qdrantClient, ObservationRegistry observationRegistry) { - return QdrantVectorStore.builder(qdrantClient) + return QdrantVectorStore.builder(qdrantClient, embeddingModel) .collectionName(COLLECTION_NAME) - .embeddingModel(embeddingModel) .initializeSchema(true) .observationRegistry(observationRegistry) .customObservationConvention(null) diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/RedisVectorStore.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/RedisVectorStore.java index ad801d089af..fecba59e7fe 100644 --- a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/RedisVectorStore.java +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/RedisVectorStore.java @@ -62,6 +62,7 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -248,9 +249,7 @@ public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingM boolean initializeSchema, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { - this(builder().jedis(jedis) - .embeddingModel(embeddingModel) - .indexName(config.indexName) + this(builder(jedis, embeddingModel).indexName(config.indexName) .prefix(config.prefix) .contentFieldName(config.contentFieldName) .embeddingFieldName(config.embeddingFieldName) @@ -293,7 +292,7 @@ public void doAdd(List documents) { for (Document document : documents) { var fields = new HashMap(); fields.put(this.embeddingFieldName, embeddings.get(documents.indexOf(document))); - fields.put(this.contentFieldName, document.getContent()); + fields.put(this.contentFieldName, document.getText()); fields.putAll(document.getMetadata()); pipeline.jsonSetWithEscape(key(document.getId()), JSON_SET_PATH, fields); } @@ -483,13 +482,13 @@ public static MetadataField tag(String name) { } - public static RedisBuilder builder() { - return new RedisBuilder(); + public static RedisBuilder builder(JedisPooled jedis, EmbeddingModel embeddingModel) { + return new RedisBuilder(jedis, embeddingModel); } public static class RedisBuilder extends AbstractVectorStoreBuilder { - private JedisPooled jedis; + private final JedisPooled jedis; private String indexName = DEFAULT_INDEX_NAME; @@ -507,10 +506,10 @@ public static class RedisBuilder extends AbstractVectorStoreBuilder fields) { + public RedisBuilder metadataFields(@Nullable List fields) { if (fields != null && !fields.isEmpty()) { this.metadataFields = new ArrayList<>(fields); } @@ -618,7 +617,6 @@ public RedisBuilder batchingStrategy(BatchingStrategy batchingStrategy) { @Override public RedisVectorStore build() { - validate(); return new RedisVectorStore(this); } diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/package-info.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/package-info.java new file mode 100644 index 00000000000..00ca265ff0e --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/package-info.java @@ -0,0 +1,25 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides the API for embedding observations. + */ +@NonNullApi +@NonNullFields +package org.springframework.ai.vectorstore.redis; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreIT.java index 250f07b2587..8973e56f287 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreIT.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreIT.java @@ -256,9 +256,9 @@ public static class TestApplication { @Bean public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, JedisConnectionFactory jedisConnectionFactory) { - return RedisVectorStore.builder() - .jedis(new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort())) - .embeddingModel(embeddingModel) + return RedisVectorStore + .builder(new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()), + embeddingModel) .metadataFields(MetadataField.tag("meta1"), MetadataField.tag("meta2"), MetadataField.tag("country"), MetadataField.numeric("year")) .initializeSchema(true) diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreObservationIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreObservationIT.java index b88928f04af..8b8ed213079 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreObservationIT.java @@ -176,9 +176,9 @@ public TestObservationRegistry observationRegistry() { @Bean public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, JedisConnectionFactory jedisConnectionFactory, ObservationRegistry observationRegistry) { - return RedisVectorStore.builder() - .jedis(new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort())) - .embeddingModel(embeddingModel) + return RedisVectorStore + .builder(new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()), + embeddingModel) .observationRegistry(observationRegistry) .customObservationConvention(null) .initializeSchema(true) diff --git a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/typesense/TypesenseVectorStore.java b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/typesense/TypesenseVectorStore.java index d15103680ca..590adddce44 100644 --- a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/typesense/TypesenseVectorStore.java +++ b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/typesense/TypesenseVectorStore.java @@ -52,6 +52,7 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -169,7 +170,7 @@ public class TypesenseVectorStore extends AbstractObservationVectorStore impleme private final int embeddingDimension; /** - * @deprecated Use {@link #builder()} instead + * @deprecated Use {@link #builder(Client, EmbeddingModel)} ()} instead */ @Deprecated(forRemoval = true, since = "1.0.0-M5") public TypesenseVectorStore(Client client, EmbeddingModel embeddingModel) { @@ -177,7 +178,7 @@ public TypesenseVectorStore(Client client, EmbeddingModel embeddingModel) { } /** - * @deprecated Use {@link #builder()} instead + * @deprecated Use {@link #builder(Client, EmbeddingModel)} ()} instead */ @Deprecated(forRemoval = true, since = "1.0.0-M5") public TypesenseVectorStore(Client client, EmbeddingModel embeddingModel, TypesenseVectorStoreConfig config, @@ -187,16 +188,14 @@ public TypesenseVectorStore(Client client, EmbeddingModel embeddingModel, Typese } /** - * @deprecated Use {@link #builder()} instead + * @deprecated Use {@link #builder(Client, EmbeddingModel)} ()} instead */ @Deprecated(forRemoval = true, since = "1.0.0-M5") public TypesenseVectorStore(Client client, EmbeddingModel embeddingModel, TypesenseVectorStoreConfig config, boolean initializeSchema, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { - this(builder().client(client) - .embeddingModel(embeddingModel) - .collectionName(config.collectionName) + this(builder(client, embeddingModel).collectionName(config.collectionName) .embeddingDimension(config.embeddingDimension) .initializeSchema(initializeSchema) .observationRegistry(observationRegistry) @@ -232,8 +231,8 @@ protected TypesenseVectorStore(TypesenseBuilder builder) { * a TypesenseVectorStore. * @return a new TypesenseBuilder instance */ - public static TypesenseBuilder builder() { - return new TypesenseBuilder(); + public static TypesenseBuilder builder(Client client, EmbeddingModel embeddingModel) { + return new TypesenseBuilder(client, embeddingModel); } @Override @@ -246,7 +245,7 @@ public void doAdd(List documents) { List> documentList = documents.stream().map(document -> { HashMap typesenseDoc = new HashMap<>(); typesenseDoc.put(DOC_ID_FIELD_NAME, document.getId()); - typesenseDoc.put(CONTENT_FIELD_NAME, document.getContent()); + typesenseDoc.put(CONTENT_FIELD_NAME, document.getText()); typesenseDoc.put(METADATA_FIELD_NAME, document.getMetadata()); typesenseDoc.put(EMBEDDING_FIELD_NAME, embeddings.get(documents.indexOf(document))); @@ -426,6 +425,7 @@ void dropCollection() { } } + @Nullable Map getCollectionInfo() { try { CollectionResponse retrievedCollection = this.client.collections(this.collectionName).retrieve(); @@ -455,22 +455,23 @@ public static final class TypesenseBuilder extends AbstractVectorStoreBuilder TypesenseVectorStore.builder().client(null).build()) + assertThatThrownBy(() -> TypesenseVectorStore.builder(null, embeddingModel).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("client must not be null"); } @Test void nullEmbeddingModelShouldThrowException() { - assertThatThrownBy(() -> TypesenseVectorStore.builder().client(client).embeddingModel(null).build()) + assertThatThrownBy(() -> TypesenseVectorStore.builder(client, null).build()) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("EmbeddingModel must not be null"); + .hasMessage("EmbeddingModel must be configured"); } @Test void invalidEmbeddingDimensionShouldThrowException() { - assertThatThrownBy(() -> TypesenseVectorStore.builder() - .client(client) - .embeddingModel(embeddingModel) - .embeddingDimension(0) - .build()).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> TypesenseVectorStore.builder(client, embeddingModel).embeddingDimension(0).build()) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Embedding dimension must be greater than 0"); } @Test void emptyCollectionNameShouldThrowException() { - assertThatThrownBy(() -> TypesenseVectorStore.builder() - .client(client) - .embeddingModel(embeddingModel) - .collectionName("") - .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("collectionName must not be empty"); + assertThatThrownBy(() -> TypesenseVectorStore.builder(client, embeddingModel).collectionName("").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("collectionName must not be empty"); } @Test void nullBatchingStrategyShouldThrowException() { - assertThatThrownBy(() -> TypesenseVectorStore.builder() - .client(client) - .embeddingModel(embeddingModel) - .batchingStrategy(null) - .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("batchingStrategy must not be null"); + assertThatThrownBy(() -> TypesenseVectorStore.builder(client, embeddingModel).batchingStrategy(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("batchingStrategy must not be null"); } } diff --git a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/typesense/TypesenseVectorStoreIT.java b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/typesense/TypesenseVectorStoreIT.java index 6cd76c8fd59..c44953bc778 100644 --- a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/typesense/TypesenseVectorStoreIT.java +++ b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/typesense/TypesenseVectorStoreIT.java @@ -242,9 +242,7 @@ public static class TestApplication { @Bean public VectorStore vectorStore(Client client, EmbeddingModel embeddingModel) { - return TypesenseVectorStore.builder() - .client(client) - .embeddingModel(embeddingModel) + return TypesenseVectorStore.builder(client, embeddingModel) .collectionName("test_vector_store") .embeddingDimension(embeddingModel.dimensions()) .initializeSchema(true) diff --git a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/typesense/TypesenseVectorStoreObservationIT.java b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/typesense/TypesenseVectorStoreObservationIT.java index 86223036cf4..e94d76d2a35 100644 --- a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/typesense/TypesenseVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/typesense/TypesenseVectorStoreObservationIT.java @@ -170,9 +170,7 @@ public TestObservationRegistry observationRegistry() { public VectorStore vectorStore(Client client, EmbeddingModel embeddingModel, ObservationRegistry observationRegistry) { - return TypesenseVectorStore.builder() - .client(client) - .embeddingModel(embeddingModel) + return TypesenseVectorStore.builder(client, embeddingModel) .collectionName(TEST_COLLECTION_NAME) .embeddingDimension(embeddingModel.dimensions()) .initializeSchema(true) diff --git a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStore.java b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStore.java index 8364500696b..ae1c50d012e 100644 --- a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStore.java +++ b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStore.java @@ -32,6 +32,7 @@ import io.weaviate.client.base.WeaviateErrorMessage; import io.weaviate.client.v1.batch.model.BatchDeleteResponse; import io.weaviate.client.v1.batch.model.ObjectGetResponse; +import io.weaviate.client.v1.batch.model.ObjectsGetResponseAO2Result; import io.weaviate.client.v1.data.model.WeaviateObject; import io.weaviate.client.v1.filters.Operator; import io.weaviate.client.v1.filters.WhereFilter; @@ -151,9 +152,10 @@ public class WeaviateVectorStore extends AbstractObservationVectorStore { * @param vectorStoreConfig The configuration for the store * @param embeddingModel The client for embedding operations * @param weaviateClient The client for Weaviate operations - * @deprecated Use {@link #builder()} instead to create instances of - * WeaviateVectorStore. This constructor will be removed in a future release. - * @see #builder() + * @deprecated Use {@link #builder(WeaviateClient, EmbeddingModel)} ()} instead to + * create instances of WeaviateVectorStore. This constructor will be removed in a + * future release. + * @see #builder(WeaviateClient, EmbeddingModel) () * @since 1.0.0 */ @Deprecated(forRemoval = true, since = "1.0.0-M5") @@ -171,9 +173,10 @@ public WeaviateVectorStore(WeaviateVectorStoreConfig vectorStoreConfig, Embeddin * @param observationRegistry The registry for observations * @param customObservationConvention The custom observation convention * @param batchingStrategy The strategy for batching operations - * @deprecated Use {@link #builder()} instead to create instances of - * WeaviateVectorStore. This constructor will be removed in a future release. - * @see #builder() + * @deprecated Use {@link #builder(WeaviateClient, EmbeddingModel)} ()} instead to + * create instances of WeaviateVectorStore. This constructor will be removed in a + * future release. + * @see #builder(WeaviateClient, EmbeddingModel) () * @since 1.0.0 */ @Deprecated(forRemoval = true, since = "1.0.0-M5") @@ -181,9 +184,7 @@ public WeaviateVectorStore(WeaviateVectorStoreConfig vectorStoreConfig, Embeddin WeaviateClient weaviateClient, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { - this(builder().embeddingModel(embeddingModel) - .weaviateClient(weaviateClient) - .observationRegistry(observationRegistry) + this(builder(weaviateClient, embeddingModel).observationRegistry(observationRegistry) .customObservationConvention(customObservationConvention) .batchingStrategy(batchingStrategy)); } @@ -217,8 +218,8 @@ protected WeaviateVectorStore(WeaviateBuilder builder) { * a WeaviateVectorStore. * @return a new WeaviateBuilder instance */ - public static WeaviateBuilder builder() { - return new WeaviateBuilder(); + public static WeaviateBuilder builder(WeaviateClient weaviateClient, EmbeddingModel embeddingModel) { + return new WeaviateBuilder(weaviateClient, embeddingModel); } private Field[] buildWeaviateSimilaritySearchFields() { @@ -267,7 +268,7 @@ public void doAdd(List documents) { errorMessages.add(response.getError() .getMessages() .stream() - .map(wm -> wm.getMessage()) + .map(WeaviateErrorMessage::getMessage) .collect(Collectors.joining(System.lineSeparator()))); throw new RuntimeException("Failed to add documents because: \n" + errorMessages); } @@ -278,7 +279,7 @@ public void doAdd(List documents) { var error = r.getResult().getErrors(); errorMessages.add(error.getError() .stream() - .map(e -> e.getMessage()) + .map(ObjectsGetResponseAO2Result.ErrorItem::getMessage) .collect(Collectors.joining(System.lineSeparator()))); } } @@ -293,13 +294,13 @@ private WeaviateObject toWeaviateObject(Document document, List docume // https://weaviate.io/developers/weaviate/config-refs/datatypes Map fields = new HashMap<>(); - fields.put(CONTENT_FIELD_NAME, document.getContent()); + fields.put(CONTENT_FIELD_NAME, document.getText()); try { String metadataString = this.objectMapper.writeValueAsString(document.getMetadata()); fields.put(METADATA_FIELD_NAME, metadataString); } catch (JsonProcessingException e) { - throw new RuntimeException("Failed to serialize the Document metadata: " + document.getContent()); + throw new RuntimeException("Failed to serialize the Document metadata: " + document.getText()); } // Add the filterable metadata fields as top level fields, allowing filler @@ -336,7 +337,7 @@ public Optional doDelete(List documentIds) { String errorMessages = result.getError() .getMessages() .stream() - .map(wm -> wm.getMessage()) + .map(WeaviateErrorMessage::getMessage) .collect(Collectors.joining(",")); throw new RuntimeException("Failed to delete documents because: \n" + errorMessages); } @@ -540,20 +541,21 @@ public static final class WeaviateBuilder extends AbstractVectorStoreBuilder filterMetadataFields = List.of(); - private WeaviateClient weaviateClient; + private final WeaviateClient weaviateClient; private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy(); /** - * Configures the Weaviate client. - * @param weaviateClient the client for Weaviate operations - * @return this builder instance + * Constructs a new WeaviateBuilder instance. + * @param weaviateClient The Weaviate client instance used for database + * operations. Must not be null. + * @param embeddingModel The embedding model used for vector transformations. * @throws IllegalArgumentException if weaviateClient is null */ - public WeaviateBuilder weaviateClient(WeaviateClient weaviateClient) { - Assert.notNull(weaviateClient, "weaviateClient must not be null"); + private WeaviateBuilder(WeaviateClient weaviateClient, EmbeddingModel embeddingModel) { + super(embeddingModel); + Assert.notNull(weaviateClient, "WeaviateClient must not be null"); this.weaviateClient = weaviateClient; - return this; } /** @@ -612,7 +614,6 @@ public WeaviateBuilder batchingStrategy(BatchingStrategy batchingStrategy) { */ @Override public WeaviateVectorStore build() { - validate(); return new WeaviateVectorStore(this); } @@ -621,9 +622,9 @@ public WeaviateVectorStore build() { /** * Configuration class for WeaviateVectorStore. * - * @deprecated Use {@link WeaviateVectorStore#builder()} instead to configure and - * create instances of WeaviateVectorStore. This class will be removed in a future - * release. Example migration:
{@code
+	 * @deprecated Use {@link WeaviateVectorStore#builder(WeaviateClient, EmbeddingModel)}
+	 * ()} instead to configure and create instances of WeaviateVectorStore. This class
+	 * will be removed in a future release. Example migration: 
{@code
 	 * // Old approach:
 	 * WeaviateVectorStoreConfig config = WeaviateVectorStoreConfig.builder()
 	 *     .withObjectClass("CustomClass")
@@ -636,7 +637,7 @@ public WeaviateVectorStore build() {
 	 *     .consistencyLevel(ConsistentLevel.QUORUM)
 	 *     .build();
 	 * }
- * @see WeaviateVectorStore#builder() + * @see WeaviateVectorStore#builder(WeaviateClient, EmbeddingModel) () * @since 1.0.0 */ @Deprecated(forRemoval = true, since = "1.0.0-M5") @@ -658,7 +659,8 @@ public static final class WeaviateVectorStoreConfig { /** * Constructor using the builder. * @param builder The configuration builder - * @deprecated Use {@link WeaviateVectorStore#builder()} instead + * @deprecated Use + * {@link WeaviateVectorStore#builder(WeaviateClient, EmbeddingModel)} ()} instead */ @Deprecated(forRemoval = true, since = "1.0.0-M5") public WeaviateVectorStoreConfig(Builder builder) { @@ -671,8 +673,9 @@ public WeaviateVectorStoreConfig(Builder builder) { /** * Start building a new configuration. * @return The entry point for creating a new configuration - * @deprecated Use {@link WeaviateVectorStore#builder()} instead to configure and - * create instances of WeaviateVectorStore + * @deprecated Use + * {@link WeaviateVectorStore#builder(WeaviateClient, EmbeddingModel)} ()} instead + * to configure and create instances of WeaviateVectorStore */ @Deprecated(forRemoval = true, since = "1.0.0-M5") public static Builder builder() { @@ -682,8 +685,9 @@ public static Builder builder() { /** * Returns the default configuration. * @return the default configuration - * @deprecated Use {@link WeaviateVectorStore#builder()} instead to configure and - * create instances of WeaviateVectorStore with default settings + * @deprecated Use + * {@link WeaviateVectorStore#builder(WeaviateClient, EmbeddingModel)} ()} instead + * to configure and create instances of WeaviateVectorStore with default settings */ @Deprecated(forRemoval = true, since = "1.0.0-M5") public static WeaviateVectorStoreConfig defaultConfig() { @@ -804,8 +808,9 @@ public enum Type { /** * Builder for WeaviateVectorStoreConfig. * - * @deprecated Use {@link WeaviateVectorStore#builder()} instead to configure and - * create instances of WeaviateVectorStore + * @deprecated Use + * {@link WeaviateVectorStore#builder(WeaviateClient, EmbeddingModel)} ()} instead + * to configure and create instances of WeaviateVectorStore * @since 1.0.0 */ @Deprecated(forRemoval = true, since = "1.0.0-M5") @@ -844,7 +849,7 @@ public Builder withFilterableMetadataFields(List filterMetadataFi * @return this builder * @throws IllegalArgumentException if headers is null * @deprecated Use the new builder API in - * {@link WeaviateVectorStore#builder()} + * {@link WeaviateVectorStore#builder(WeaviateClient, EmbeddingModel)} ()} */ @Deprecated(forRemoval = true, since = "1.0.0-M5") public Builder withHeaders(Map headers) { @@ -887,8 +892,9 @@ public Builder withConsistencyLevel(ConsistentLevel consistencyLevel) { /** * Builds and returns the immutable configuration. * @return the immutable configuration - * @deprecated Use {@link WeaviateVectorStore#builder()} instead to configure - * and create instances of WeaviateVectorStore + * @deprecated Use + * {@link WeaviateVectorStore#builder(WeaviateClient, EmbeddingModel)} ()} + * instead to configure and create instances of WeaviateVectorStore */ @Deprecated(forRemoval = true, since = "1.0.0-M5") public WeaviateVectorStoreConfig build() { diff --git a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/weaviate/package-info.java b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/weaviate/package-info.java new file mode 100644 index 00000000000..0226cc271ff --- /dev/null +++ b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/weaviate/package-info.java @@ -0,0 +1,25 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides the API for embedding observations. + */ +@NonNullApi +@NonNullFields +package org.springframework.ai.vectorstore.weaviate; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStoreBuilderTests.java b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStoreBuilderTests.java index 739ed525a2e..03986de31d2 100644 --- a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStoreBuilderTests.java +++ b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStoreBuilderTests.java @@ -47,10 +47,7 @@ class WeaviateVectorStoreBuilderTests { void shouldBuildWithMinimalConfiguration() { WeaviateClient weaviateClient = new WeaviateClient(new Config("http", "localhost:8080")); - WeaviateVectorStore vectorStore = WeaviateVectorStore.builder() - .weaviateClient(weaviateClient) - .embeddingModel(embeddingModel) - .build(); + WeaviateVectorStore vectorStore = WeaviateVectorStore.builder(weaviateClient, embeddingModel).build(); assertThat(vectorStore).isNotNull(); } @@ -59,9 +56,7 @@ void shouldBuildWithMinimalConfiguration() { void shouldBuildWithCustomConfiguration() { WeaviateClient weaviateClient = new WeaviateClient(new Config("http", "localhost:8080")); - WeaviateVectorStore vectorStore = WeaviateVectorStore.builder() - .weaviateClient(weaviateClient) - .embeddingModel(embeddingModel) + WeaviateVectorStore vectorStore = WeaviateVectorStore.builder(weaviateClient, embeddingModel) .objectClass("CustomClass") .consistencyLevel(ConsistentLevel.QUORUM) .filterMetadataFields(List.of(MetadataField.text("country"), MetadataField.number("year"))) @@ -72,7 +67,7 @@ void shouldBuildWithCustomConfiguration() { @Test void shouldFailWithoutWeaviateClient() { - assertThatThrownBy(() -> WeaviateVectorStore.builder().embeddingModel(embeddingModel).build()) + assertThatThrownBy(() -> WeaviateVectorStore.builder(null, embeddingModel).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("WeaviateClient must not be null"); } @@ -81,7 +76,7 @@ void shouldFailWithoutWeaviateClient() { void shouldFailWithoutEmbeddingModel() { WeaviateClient weaviateClient = new WeaviateClient(new Config("http", "localhost:8080")); - assertThatThrownBy(() -> WeaviateVectorStore.builder().weaviateClient(weaviateClient).build()) + assertThatThrownBy(() -> WeaviateVectorStore.builder(weaviateClient, null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("EmbeddingModel must be configured"); } @@ -90,33 +85,29 @@ void shouldFailWithoutEmbeddingModel() { void shouldFailWithInvalidObjectClass() { WeaviateClient weaviateClient = new WeaviateClient(new Config("http", "localhost:8080")); - assertThatThrownBy(() -> WeaviateVectorStore.builder() - .weaviateClient(weaviateClient) - .embeddingModel(embeddingModel) - .objectClass("") - .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("objectClass must not be empty"); + assertThatThrownBy(() -> WeaviateVectorStore.builder(weaviateClient, embeddingModel).objectClass("").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("objectClass must not be empty"); } @Test void shouldFailWithNullConsistencyLevel() { WeaviateClient weaviateClient = new WeaviateClient(new Config("http", "localhost:8080")); - assertThatThrownBy(() -> WeaviateVectorStore.builder() - .weaviateClient(weaviateClient) - .embeddingModel(embeddingModel) - .consistencyLevel(null) - .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("consistencyLevel must not be null"); + assertThatThrownBy( + () -> WeaviateVectorStore.builder(weaviateClient, embeddingModel).consistencyLevel(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("consistencyLevel must not be null"); } @Test void shouldFailWithNullFilterMetadataFields() { WeaviateClient weaviateClient = new WeaviateClient(new Config("http", "localhost:8080")); - assertThatThrownBy(() -> WeaviateVectorStore.builder() - .weaviateClient(weaviateClient) - .embeddingModel(embeddingModel) - .filterMetadataFields(null) - .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("filterMetadataFields must not be null"); + assertThatThrownBy( + () -> WeaviateVectorStore.builder(weaviateClient, embeddingModel).filterMetadataFields(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("filterMetadataFields must not be null"); } @Test diff --git a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStoreIT.java b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStoreIT.java index b675b5f9688..28c45b721eb 100644 --- a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStoreIT.java +++ b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStoreIT.java @@ -252,9 +252,7 @@ public VectorStore vectorStore(EmbeddingModel embeddingModel) { WeaviateClient weaviateClient = new WeaviateClient( new Config("http", weaviateContainer.getHttpHostAddress())); - return WeaviateVectorStore.builder() - .weaviateClient(weaviateClient) - .embeddingModel(embeddingModel) + return WeaviateVectorStore.builder(weaviateClient, embeddingModel) .filterMetadataFields(List.of(WeaviateVectorStore.MetadataField.text("country"), WeaviateVectorStore.MetadataField.number("year"))) .consistencyLevel(WeaviateVectorStore.ConsistentLevel.ONE) diff --git a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStoreObservationIT.java b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStoreObservationIT.java index b3b8ea8f7c3..23f0124918c 100644 --- a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStoreObservationIT.java @@ -166,9 +166,7 @@ public WeaviateVectorStore vectorStore(EmbeddingModel embeddingModel, Observatio WeaviateClient weaviateClient = new WeaviateClient( new io.weaviate.client.Config("http", weaviateContainer.getHttpHostAddress())); - return WeaviateVectorStore.builder() - .weaviateClient(weaviateClient) - .embeddingModel(embeddingModel) + return WeaviateVectorStore.builder(weaviateClient, embeddingModel) .consistencyLevel(WeaviateVectorStore.ConsistentLevel.ONE) .observationRegistry(observationRegistry) .batchingStrategy(new TokenCountBatchingStrategy())