diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfiguration.java index 9f98736cef6..2a2ef509dd3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfiguration.java @@ -18,8 +18,8 @@ import com.datastax.oss.driver.api.core.CqlSession; -import org.springframework.ai.chat.memory.CassandraChatMemory; -import org.springframework.ai.chat.memory.CassandraChatMemoryConfig; +import org.springframework.ai.chat.memory.cassandra.CassandraChatMemory; +import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryConfig; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryProperties.java index c815131f39b..96a4c4ee325 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryProperties.java @@ -22,7 +22,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.autoconfigure.chat.memory.CommonChatMemoryProperties; -import org.springframework.ai.chat.memory.CassandraChatMemoryConfig; +import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryConfig; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.lang.Nullable; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java index 16104ab53f9..64890673b6f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java @@ -25,8 +25,7 @@ import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; -import org.springframework.ai.vectorstore.CassandraVectorStore; -import org.springframework.ai.vectorstore.CassandraVectorStoreConfig; +import org.springframework.ai.vectorstore.cassandra.CassandraVectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; @@ -63,25 +62,21 @@ public CassandraVectorStore vectorStore(EmbeddingModel embeddingModel, Cassandra ObjectProvider customObservationConvention, BatchingStrategy batchingStrategy) { - var builder = CassandraVectorStoreConfig.builder().withCqlSession(cqlSession); - - builder = builder.withKeyspaceName(properties.getKeyspace()) - .withTableName(properties.getTable()) - .withContentColumnName(properties.getContentColumnName()) - .withEmbeddingColumnName(properties.getEmbeddingColumnName()) - .withIndexName(properties.getIndexName()) - .withFixedThreadPoolExecutorSize(properties.getFixedThreadPoolExecutorSize()); - - if (!properties.isInitializeSchema()) { - builder = builder.disallowSchemaChanges(); - } - if (properties.getReturnEmbeddings()) { - builder = builder.returnEmbeddings(); - } - - return new CassandraVectorStore(builder.build(), embeddingModel, - observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), - customObservationConvention.getIfAvailable(() -> null), batchingStrategy); + return CassandraVectorStore.builder() + .session(cqlSession) + .keyspace(properties.getKeyspace()) + .table(properties.getTable()) + .contentColumnName(properties.getContentColumnName()) + .embeddingColumnName(properties.getEmbeddingColumnName()) + .indexName(properties.getIndexName()) + .fixedThreadPoolExecutorSize(properties.getFixedThreadPoolExecutorSize()) + .disallowSchemaChanges(!properties.isInitializeSchema()) + .returnEmbeddings(properties.getReturnEmbeddings()) + .embeddingModel(embeddingModel) + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .customObservationConvention(customObservationConvention.getIfAvailable(() -> null)) + .batchingStrategy(batchingStrategy) + .build(); } @Bean diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java index 5b80a9bc474..19174772f6b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java @@ -21,7 +21,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; -import org.springframework.ai.vectorstore.CassandraVectorStoreConfig; +import org.springframework.ai.vectorstore.cassandra.CassandraVectorStore; import org.springframework.boot.context.properties.ConfigurationProperties; /** @@ -37,19 +37,19 @@ public class CassandraVectorStoreProperties extends CommonVectorStoreProperties private static final Logger logger = LoggerFactory.getLogger(CassandraVectorStoreProperties.class); - private String keyspace = CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME; + private String keyspace = CassandraVectorStore.DEFAULT_KEYSPACE_NAME; - private String table = CassandraVectorStoreConfig.DEFAULT_TABLE_NAME; + private String table = CassandraVectorStore.DEFAULT_TABLE_NAME; private String indexName = null; - private String contentColumnName = CassandraVectorStoreConfig.DEFAULT_CONTENT_COLUMN_NAME; + private String contentColumnName = CassandraVectorStore.DEFAULT_CONTENT_COLUMN_NAME; - private String embeddingColumnName = CassandraVectorStoreConfig.DEFAULT_EMBEDDING_COLUMN_NAME; + private String embeddingColumnName = CassandraVectorStore.DEFAULT_EMBEDDING_COLUMN_NAME; private boolean returnEmbeddings = false; - private int fixedThreadPoolExecutorSize = CassandraVectorStoreConfig.DEFAULT_ADD_CONCURRENCY; + private int fixedThreadPoolExecutorSize = CassandraVectorStore.DEFAULT_ADD_CONCURRENCY; public String getKeyspace() { return this.keyspace; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfigurationIT.java index 8dbc9dc4ac2..6036c2bc921 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfigurationIT.java @@ -26,7 +26,7 @@ import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.utility.DockerImageName; -import org.springframework.ai.chat.memory.CassandraChatMemory; +import org.springframework.ai.chat.memory.cassandra.CassandraChatMemory; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.UserMessage; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryPropertiesTest.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryPropertiesTest.java index ee44f27b49a..9d779b9b558 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryPropertiesTest.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryPropertiesTest.java @@ -20,7 +20,7 @@ import org.junit.jupiter.api.Test; -import org.springframework.ai.chat.memory.CassandraChatMemoryConfig; +import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryConfig; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java index d2ed94ab7ac..98266e51928 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java @@ -18,7 +18,8 @@ import org.junit.jupiter.api.Test; -import org.springframework.ai.vectorstore.CassandraVectorStoreConfig; +import org.springframework.ai.vectorstore.cassandra.CassandraVectorStore; +import org.springframework.ai.vectorstore.cassandra.CassandraVectorStoreConfig; import static org.assertj.core.api.Assertions.assertThat; @@ -31,13 +32,12 @@ class CassandraVectorStorePropertiesTests { @Test void defaultValues() { var props = new CassandraVectorStoreProperties(); - assertThat(props.getKeyspace()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME); - assertThat(props.getTable()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_TABLE_NAME); - assertThat(props.getContentColumnName()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_CONTENT_COLUMN_NAME); - assertThat(props.getEmbeddingColumnName()).isEqualTo(CassandraVectorStoreConfig.DEFAULT_EMBEDDING_COLUMN_NAME); + assertThat(props.getKeyspace()).isEqualTo(CassandraVectorStore.DEFAULT_KEYSPACE_NAME); + assertThat(props.getTable()).isEqualTo(CassandraVectorStore.DEFAULT_TABLE_NAME); + assertThat(props.getContentColumnName()).isEqualTo(CassandraVectorStore.DEFAULT_CONTENT_COLUMN_NAME); + assertThat(props.getEmbeddingColumnName()).isEqualTo(CassandraVectorStore.DEFAULT_EMBEDDING_COLUMN_NAME); assertThat(props.getIndexName()).isNull(); - assertThat(props.getFixedThreadPoolExecutorSize()) - .isEqualTo(CassandraVectorStoreConfig.DEFAULT_ADD_CONCURRENCY); + assertThat(props.getFixedThreadPoolExecutorSize()).isEqualTo(CassandraVectorStore.DEFAULT_ADD_CONCURRENCY); } @Test diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemory.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java similarity index 96% rename from vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemory.java rename to vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java index c6453de5b27..035aed0ca5e 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemory.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.chat.memory; +package org.springframework.ai.chat.memory.cassandra; import java.time.Instant; import java.util.ArrayList; @@ -32,7 +32,8 @@ import com.datastax.oss.driver.api.querybuilder.select.Select; import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; -import org.springframework.ai.chat.memory.CassandraChatMemoryConfig.SchemaColumn; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryConfig.SchemaColumn; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -42,7 +43,7 @@ CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build()); * - * For example @see org.springframework.ai.chat.memory.CassandraChatMemory + * For example @see org.springframework.ai.chat.memory.cassandra.CassandraChatMemory * * @author Mick Semb Wever * @since 1.0.0 diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemoryConfig.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryConfig.java similarity index 99% rename from vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemoryConfig.java rename to vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryConfig.java index 2c25346f172..fa62afe5950 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemoryConfig.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryConfig.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.chat.memory; +package org.springframework.ai.chat.memory.cassandra; import java.net.InetSocketAddress; import java.time.Duration; diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java deleted file mode 100644 index c4799f9a488..00000000000 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java +++ /dev/null @@ -1,409 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.vectorstore; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; - -import com.datastax.oss.driver.api.core.cql.BoundStatement; -import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder; -import com.datastax.oss.driver.api.core.cql.PreparedStatement; -import com.datastax.oss.driver.api.core.cql.Row; -import com.datastax.oss.driver.api.core.cql.SimpleStatement; -import com.datastax.oss.driver.api.core.data.CqlVector; -import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata; -import com.datastax.oss.driver.api.querybuilder.QueryBuilder; -import com.datastax.oss.driver.api.querybuilder.delete.Delete; -import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection; -import com.datastax.oss.driver.api.querybuilder.insert.InsertInto; -import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert; -import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; -import io.micrometer.observation.ObservationRegistry; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.springframework.ai.document.Document; -import org.springframework.ai.document.DocumentMetadata; -import org.springframework.ai.embedding.BatchingStrategy; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; -import org.springframework.ai.embedding.TokenCountBatchingStrategy; -import org.springframework.ai.model.EmbeddingUtils; -import org.springframework.ai.observation.conventions.VectorStoreProvider; -import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; -import org.springframework.ai.vectorstore.CassandraVectorStoreConfig.SchemaColumn; -import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; -import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext.Builder; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; - -/** - * The CassandraVectorStore is for managing and querying vector data in an Apache - * Cassandra db. It offers functionalities like adding, deleting, and performing - * similarity searches on documents. - * - * The store utilizes CQL to index and search vector data. It allows for custom metadata - * fields in the documents to be stored alongside the vector and content data. - * - * This class requires a CassandraVectorStoreConfig configuration object for - * initialization, which includes settings like connection details, index name, column - * names, etc. It also requires an EmbeddingModel to convert documents into embeddings - * before storing them. - * - * A schema matching the configuration is automatically created if it doesn't exist. - * Missing columns and indexes in existing tables will also be automatically created. - * Disable this with the CassandraVectorStoreConfig#disallowSchemaChanges(). - * - * This class is designed to work with brand new tables that it creates for you, or on top - * of existing Cassandra tables. The latter is appropriate when wanting to keep data in - * place, creating embeddings next to it, and performing vector similarity searches - * in-situ. - * - * Instances of this class are not dynamic against server-side schema changes. If you - * change the schema server-side you need a new CassandraVectorStore instance. - * - * When adding documents with the method {@link #add(List)} it first calls - * embeddingModel to create the embeddings. This is slow. Configure - * {@link CassandraVectorStoreConfig.Builder#withFixedThreadPoolExecutorSize(int)} - * accordingly to improve performance so embeddings are created and the documents are - * added concurrently. The default concurrency is 16 - * ({@link CassandraVectorStoreConfig#DEFAULT_ADD_CONCURRENCY}). Remote transformers - * probably want higher concurrency, and local transformers may need lower concurrency. - * This concurrency limit does not need to be higher than the max parallel calls made to - * the {@link #add(List)} method multiplied by the list size. This setting can - * also serve as a protecting throttle against your embedding model. - * - * @author Mick Semb Wever - * @author Christian Tzolov - * @author Thomas Vitale - * @author Soby Chacko - * @see VectorStore - * @see org.springframework.ai.vectorstore.CassandraVectorStoreConfig - * @see EmbeddingModel - * @since 1.0.0 - */ -public class CassandraVectorStore extends AbstractObservationVectorStore implements AutoCloseable { - - public static final String DRIVER_PROFILE_UPDATES = "spring-ai-updates"; - - public static final String DRIVER_PROFILE_SEARCH = "spring-ai-search"; - - private static final String QUERY_FORMAT = "select %s,%s,%s%s from %s.%s ? order by %s ann of ? limit ?"; - - private static final Logger logger = LoggerFactory.getLogger(CassandraVectorStore.class); - - private static Map SIMILARITY_TYPE_MAPPING = Map.of(Similarity.COSINE, - VectorStoreSimilarityMetric.COSINE, Similarity.EUCLIDEAN, VectorStoreSimilarityMetric.EUCLIDEAN, - Similarity.DOT_PRODUCT, VectorStoreSimilarityMetric.DOT); - - private final CassandraVectorStoreConfig conf; - - private final EmbeddingModel embeddingModel; - - private final FilterExpressionConverter filterExpressionConverter; - - private final ConcurrentMap, PreparedStatement> addStmts = new ConcurrentHashMap<>(); - - private final PreparedStatement deleteStmt; - - private final String similarityStmt; - - private final Similarity similarity; - - private final BatchingStrategy batchingStrategy; - - public CassandraVectorStore(CassandraVectorStoreConfig conf, EmbeddingModel embeddingModel) { - this(conf, embeddingModel, ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy()); - } - - public CassandraVectorStore(CassandraVectorStoreConfig conf, EmbeddingModel embeddingModel, - ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, - BatchingStrategy batchingStrategy) { - - super(observationRegistry, customObservationConvention); - - Preconditions.checkArgument(null != conf, "Config must not be null"); - Preconditions.checkArgument(null != embeddingModel, "Embedding model must not be null"); - - this.conf = conf; - this.embeddingModel = embeddingModel; - conf.ensureSchemaExists(embeddingModel.dimensions()); - prepareAddStatement(Set.of()); - this.deleteStmt = prepareDeleteStatement(); - - TableMetadata cassandraMetadata = conf.session.getMetadata() - .getKeyspace(conf.schema.keyspace()) - .get() - .getTable(conf.schema.table()) - .get(); - - this.similarity = getIndexSimilarity(cassandraMetadata); - this.similarityStmt = similaritySearchStatement(); - - this.filterExpressionConverter = new CassandraFilterExpressionConverter( - cassandraMetadata.getColumns().values()); - this.batchingStrategy = batchingStrategy; - } - - private static Float[] toFloatArray(float[] embedding) { - Float[] embeddingFloat = new Float[embedding.length]; - int i = 0; - for (Float d : embedding) { - embeddingFloat[i++] = d.floatValue(); - } - return embeddingFloat; - } - - @Override - public void doAdd(List documents) { - var futures = new CompletableFuture[documents.size()]; - - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); - - int i = 0; - for (Document d : documents) { - futures[i++] = CompletableFuture.runAsync(() -> { - List primaryKeyValues = this.conf.documentIdTranslator.apply(d.getId()); - - BoundStatementBuilder builder = prepareAddStatement(d.getMetadata().keySet()).boundStatementBuilder(); - for (int k = 0; k < primaryKeyValues.size(); ++k) { - SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); - builder = builder.set(keyColumn.name(), primaryKeyValues.get(k), keyColumn.javaType()); - } - - builder = builder.setString(this.conf.schema.content(), d.getContent()) - .setVector(this.conf.schema.embedding(), - CqlVector.newInstance(EmbeddingUtils.toList(embeddings.get(documents.indexOf(d)))), - Float.class); - - for (var metadataColumn : this.conf.schema.metadataColumns() - .stream() - .filter(mc -> d.getMetadata().containsKey(mc.name())) - .toList()) { - - builder = builder.set(metadataColumn.name(), d.getMetadata().get(metadataColumn.name()), - metadataColumn.javaType()); - } - BoundStatement s = builder.build().setExecutionProfileName(DRIVER_PROFILE_UPDATES); - this.conf.session.execute(s); - }, this.conf.executor); - } - CompletableFuture.allOf(futures).join(); - } - - @Override - public Optional doDelete(List idList) { - CompletableFuture[] futures = new CompletableFuture[idList.size()]; - int i = 0; - for (String id : idList) { - List primaryKeyValues = this.conf.documentIdTranslator.apply(id); - BoundStatement s = this.deleteStmt.bind(primaryKeyValues.toArray()); - futures[i++] = this.conf.session.executeAsync(s).toCompletableFuture(); - } - CompletableFuture.allOf(futures).join(); - return Optional.of(Boolean.TRUE); - } - - @Override - public List doSimilaritySearch(SearchRequest request) { - Preconditions.checkArgument(request.getTopK() <= 1000); - var embedding = toFloatArray(this.embeddingModel.embed(request.getQuery())); - CqlVector cqlVector = CqlVector.newInstance(embedding); - - String whereClause = ""; - if (request.hasFilterExpression()) { - String expression = this.filterExpressionConverter.convertExpression(request.getFilterExpression()); - if (!expression.isBlank()) { - whereClause = String.format("where %s", expression); - } - } - - String query = String.format(this.similarityStmt, cqlVector, whereClause, cqlVector, request.getTopK()); - List documents = new ArrayList<>(); - logger.trace("Executing {}", query); - SimpleStatement s = SimpleStatement.newInstance(query).setExecutionProfileName(DRIVER_PROFILE_SEARCH); - - for (Row row : this.conf.session.execute(s)) { - float score = row.getFloat(0); - if (score < request.getSimilarityThreshold()) { - break; - } - Map docFields = new HashMap<>(); - docFields.put(DocumentMetadata.DISTANCE.value(), 1 - score); - for (var metadata : this.conf.schema.metadataColumns()) { - var value = row.get(metadata.name(), metadata.javaType()); - if (null != value) { - docFields.put(metadata.name(), value); - } - } - Document doc = Document.builder() - .id(getDocumentId(row)) - .text(row.getString(this.conf.schema.content())) - .metadata(docFields) - .score((double) score) - .build(); - - documents.add(doc); - } - return documents; - } - - @Override - public void close() throws Exception { - this.conf.close(); - } - - void checkSchemaValid() { - this.conf.checkSchemaValid(this.embeddingModel.dimensions()); - } - - private Similarity getIndexSimilarity(TableMetadata metadata) { - - return Similarity.valueOf(metadata.getIndex(this.conf.schema.index()) - .get() - .getOptions() - .getOrDefault("similarity_function", "COSINE") - .toUpperCase()); - } - - private PreparedStatement prepareDeleteStatement() { - Delete stmt = null; - DeleteSelection stmtStart = QueryBuilder.deleteFrom(this.conf.schema.keyspace(), this.conf.schema.table()); - - for (var c : this.conf.schema.partitionKeys()) { - stmt = (null != stmt ? stmt : stmtStart).whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name())); - } - for (var c : this.conf.schema.clusteringKeys()) { - stmt = stmt.whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name())); - } - - return this.conf.session.prepare(stmt.build()); - } - - private PreparedStatement prepareAddStatement(Set metadataFields) { - - // metadata fields that are not configured as metadata columns are not added - Set fieldsThatAreColumns = new HashSet<>(this.conf.schema.metadataColumns() - .stream() - .map(mc -> mc.name()) - .filter(mc -> metadataFields.contains(mc)) - .toList()); - - return this.addStmts.computeIfAbsent(fieldsThatAreColumns, fields -> { - - RegularInsert stmt = null; - InsertInto stmtStart = QueryBuilder.insertInto(this.conf.schema.keyspace(), this.conf.schema.table()); - - for (var c : this.conf.schema.partitionKeys()) { - stmt = (null != stmt ? stmt : stmtStart).value(c.name(), QueryBuilder.bindMarker(c.name())); - } - for (var c : this.conf.schema.clusteringKeys()) { - stmt = stmt.value(c.name(), QueryBuilder.bindMarker(c.name())); - } - - stmt = stmt.value(this.conf.schema.content(), QueryBuilder.bindMarker(this.conf.schema.content())) - .value(this.conf.schema.embedding(), QueryBuilder.bindMarker(this.conf.schema.embedding())); - - for (String metadataField : fields) { - stmt = stmt.value(metadataField, QueryBuilder.bindMarker(metadataField)); - } - return this.conf.session.prepare(stmt.build()); - }); - } - - private String similaritySearchStatement() { - StringBuilder ids = new StringBuilder(); - for (var m : this.conf.schema.partitionKeys()) { - ids.append(m.name()).append(','); - } - for (var m : this.conf.schema.clusteringKeys()) { - ids.append(m.name()).append(','); - } - ids.deleteCharAt(ids.length() - 1); - - String similarityFunction = new StringBuilder("similarity_").append(this.similarity.toString().toLowerCase()) - .append('(') - .append(this.conf.schema.embedding()) - .append(",?)") - .toString(); - - StringBuilder extraSelectFields = new StringBuilder(); - for (var m : this.conf.schema.metadataColumns()) { - extraSelectFields.append(',').append(m.name()); - } - if (this.conf.returnEmbeddings) { - extraSelectFields.append(',').append(this.conf.schema.embedding()); - } - - // java-driver-query-builder doesn't support orderByAnnOf yet - String query = String.format(QUERY_FORMAT, similarityFunction, ids.toString(), this.conf.schema.content(), - extraSelectFields.toString(), this.conf.schema.keyspace(), this.conf.schema.table(), - this.conf.schema.embedding()); - - query = query.replace("?", "%s"); - logger.debug("preparing {}", query); - return query; - } - - private String getDocumentId(Row row) { - List primaryKeyValues = new ArrayList<>(); - for (var m : this.conf.schema.partitionKeys()) { - primaryKeyValues.add(row.get(m.name(), m.javaType())); - } - for (var m : this.conf.schema.clusteringKeys()) { - primaryKeyValues.add(row.get(m.name(), m.javaType())); - } - return this.conf.primaryKeyTranslator.apply(primaryKeyValues); - } - - @Override - public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { - return VectorStoreObservationContext.builder(VectorStoreProvider.CASSANDRA.value(), operationName) - .withCollectionName(this.conf.schema.table()) - .withDimensions(this.embeddingModel.dimensions()) - .withNamespace(this.conf.schema.keyspace()) - .withSimilarityMetric(getSimilarityMetric()); - } - - private String getSimilarityMetric() { - if (!SIMILARITY_TYPE_MAPPING.containsKey(this.similarity)) { - return this.similarity.name(); - } - return SIMILARITY_TYPE_MAPPING.get(this.similarity).value(); - } - - /** - * Indexes are automatically created with COSINE. This can be changed manually via - * cqlsh - */ - public enum Similarity { - - COSINE, DOT_PRODUCT, EUCLIDEAN - - } - -} diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/cassandra/CassandraFilterExpressionConverter.java similarity index 98% rename from vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java rename to vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/cassandra/CassandraFilterExpressionConverter.java index f1f0e8b5b61..4437a5fc865 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/cassandra/CassandraFilterExpressionConverter.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vectorstore; +package org.springframework.ai.vectorstore.cassandra; import java.util.Collection; import java.util.Map; diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStore.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStore.java new file mode 100644 index 00000000000..a5a6b670bcf --- /dev/null +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStore.java @@ -0,0 +1,1038 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore.cassandra; + +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.function.Function; +import java.util.stream.Stream; + +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.CqlSessionBuilder; +import com.datastax.oss.driver.api.core.cql.BoundStatement; +import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder; +import com.datastax.oss.driver.api.core.cql.PreparedStatement; +import com.datastax.oss.driver.api.core.cql.Row; +import com.datastax.oss.driver.api.core.cql.SimpleStatement; +import com.datastax.oss.driver.api.core.data.CqlVector; +import com.datastax.oss.driver.api.core.metadata.schema.ColumnMetadata; +import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata; +import com.datastax.oss.driver.api.core.type.DataType; +import com.datastax.oss.driver.api.core.type.DataTypes; +import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry; +import com.datastax.oss.driver.api.core.type.reflect.GenericType; +import com.datastax.oss.driver.api.querybuilder.BuildableQuery; +import com.datastax.oss.driver.api.querybuilder.QueryBuilder; +import com.datastax.oss.driver.api.querybuilder.SchemaBuilder; +import com.datastax.oss.driver.api.querybuilder.delete.Delete; +import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection; +import com.datastax.oss.driver.api.querybuilder.insert.InsertInto; +import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert; +import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumn; +import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumnEnd; +import com.datastax.oss.driver.api.querybuilder.schema.CreateTable; +import com.datastax.oss.driver.api.querybuilder.schema.CreateTableStart; +import com.datastax.oss.driver.shaded.guava.common.annotations.VisibleForTesting; +import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; +import io.micrometer.observation.ObservationRegistry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.cassandra.SchemaUtil; +import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; +import org.springframework.ai.embedding.BatchingStrategy; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; +import org.springframework.ai.model.EmbeddingUtils; +import org.springframework.ai.observation.conventions.VectorStoreProvider; +import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; +import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; +import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; +import org.springframework.util.Assert; + +/** + * The CassandraVectorStore is for managing and querying vector data in an Apache + * Cassandra db. It offers functionalities like adding, deleting, and performing + * similarity searches on documents. + * + * The store utilizes CQL to index and search vector data. It allows for custom metadata + * fields in the documents to be stored alongside the vector and content data. + * + * This class requires a CassandraVectorStore#CassandraBuilder configuration object for + * initialization, which includes settings like connection details, index name, column + * names, etc. It also requires an EmbeddingModel to convert documents into embeddings + * before storing them. + * + * A schema matching the configuration is automatically created if it doesn't exist. + * Missing columns and indexes in existing tables will also be automatically created. + * Disable this with the CassandraBuilder#disallowSchemaChanges(). + * + * This class is designed to work with brand new tables that it creates for you, or on top + * of existing Cassandra tables. The latter is appropriate when wanting to keep data in + * place, creating embeddings next to it, and performing vector similarity searches + * in-situ. + * + * Instances of this class are not dynamic against server-side schema changes. If you + * change the schema server-side you need a new CassandraVectorStore instance. + * + * When adding documents with the method {@link #add(List)} it first calls + * embeddingModel to create the embeddings. This is slow. Configure + * {@link CassandraVectorStore.CassandraBuilder#fixedThreadPoolExecutorSize(int)} + * accordingly to improve performance so embeddings are created and the documents are + * added concurrently. The default concurrency is 16 + * ({@link CassandraVectorStore.CassandraBuilder#DEFAULT_ADD_CONCURRENCY}). Remote + * transformers probably want higher concurrency, and local transformers may need lower + * concurrency. This concurrency limit does not need to be higher than the max parallel + * calls made to the {@link #add(List)} method multiplied by the list size. This + * setting can also serve as a protecting throttle against your embedding model. + * + * @author Mick Semb Wever + * @author Christian Tzolov + * @author Thomas Vitale + * @author Soby Chacko + * @see VectorStore + * @see EmbeddingModel + * @since 1.0.0 + */ +public class CassandraVectorStore extends AbstractObservationVectorStore implements AutoCloseable { + + public static final String DEFAULT_KEYSPACE_NAME = "springframework"; + + public static final String DEFAULT_TABLE_NAME = "ai_vector_store"; + + public static final String DEFAULT_ID_NAME = "id"; + + public static final String DEFAULT_INDEX_SUFFIX = "idx"; + + public static final String DEFAULT_CONTENT_COLUMN_NAME = "content"; + + public static final String DEFAULT_EMBEDDING_COLUMN_NAME = "embedding"; + + public static final int DEFAULT_ADD_CONCURRENCY = 16; + + public static final String DRIVER_PROFILE_UPDATES = "spring-ai-updates"; + + public static final String DRIVER_PROFILE_SEARCH = "spring-ai-search"; + + private static final String QUERY_FORMAT = "select %s,%s,%s%s from %s.%s ? order by %s ann of ? limit ?"; + + private static final Logger logger = LoggerFactory.getLogger(CassandraVectorStore.class); + + private static final Map SIMILARITY_TYPE_MAPPING = Map.of( + Similarity.COSINE, VectorStoreSimilarityMetric.COSINE, Similarity.EUCLIDEAN, + VectorStoreSimilarityMetric.EUCLIDEAN, Similarity.DOT_PRODUCT, VectorStoreSimilarityMetric.DOT); + + private final CqlSession session; + + private final Schema schema; + + private final boolean disallowSchemaChanges; + + private final FilterExpressionConverter filterExpressionConverter; + + private final DocumentIdTranslator documentIdTranslator; + + private final PrimaryKeyTranslator primaryKeyTranslator; + + private final Executor executor; + + private final boolean closeSessionOnClose; + + private final BatchingStrategy batchingStrategy; + + private final ConcurrentMap, PreparedStatement> addStmts = new ConcurrentHashMap<>(); + + private final PreparedStatement deleteStmt; + + private final String similarityStmt; + + private final Similarity similarity; + + // TODO: Remove this flag as the document no longer holds embeddings. + @Deprecated(since = "1.0.0-M5", forRemoval = true) + private final boolean returnEmbeddings; + + /** + * @deprecated since 1.0.0-M5, use {@link #builder()} instead + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) + public CassandraVectorStore(CassandraVectorStoreConfig conf, EmbeddingModel embeddingModel) { + this(conf, embeddingModel, ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy()); + } + + /** + * @deprecated since 1.0.0-M5, use {@link #builder()} instead + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) + public CassandraVectorStore(CassandraVectorStoreConfig conf, EmbeddingModel embeddingModel, + ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, + BatchingStrategy batchingStrategy) { + this(builder().session(conf.session) + .embeddingModel(embeddingModel) + .observationRegistry(observationRegistry) + .customObservationConvention(customObservationConvention) + .batchingStrategy(batchingStrategy)); + } + + protected CassandraVectorStore(CassandraBuilder builder) { + super(builder); + + Assert.notNull(builder.session, "Session must not be null"); + + this.session = builder.session; + this.schema = builder.buildSchema(); + this.disallowSchemaChanges = builder.disallowSchemaChanges; + this.documentIdTranslator = builder.documentIdTranslator; + this.primaryKeyTranslator = builder.primaryKeyTranslator; + this.executor = Executors.newFixedThreadPool(builder.fixedThreadPoolExecutorSize); + this.closeSessionOnClose = builder.closeSessionOnClose; + this.batchingStrategy = builder.batchingStrategy; + + ensureSchemaExists(embeddingModel.dimensions()); + prepareAddStatement(Set.of()); + this.deleteStmt = prepareDeleteStatement(); + + TableMetadata cassandraMetadata = session.getMetadata() + .getKeyspace(schema.keyspace()) + .get() + .getTable(schema.table()) + .get(); + + this.similarity = getIndexSimilarity(cassandraMetadata); + this.similarityStmt = similaritySearchStatement(); + + this.filterExpressionConverter = builder.filterExpressionConverter != null ? builder.filterExpressionConverter + : new CassandraFilterExpressionConverter(cassandraMetadata.getColumns().values()); + + this.returnEmbeddings = builder.returnEmbeddings; + } + + public static CassandraBuilder builder() { + return new CassandraBuilder(); + } + + private static Float[] toFloatArray(float[] embedding) { + Float[] embeddingFloat = new Float[embedding.length]; + int i = 0; + for (Float d : embedding) { + embeddingFloat[i++] = d.floatValue(); + } + return embeddingFloat; + } + + @Override + public void doAdd(List documents) { + var futures = new CompletableFuture[documents.size()]; + + List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy); + + int i = 0; + for (Document d : documents) { + futures[i++] = CompletableFuture.runAsync(() -> { + List primaryKeyValues = this.documentIdTranslator.apply(d.getId()); + + BoundStatementBuilder builder = prepareAddStatement(d.getMetadata().keySet()).boundStatementBuilder(); + for (int k = 0; k < primaryKeyValues.size(); ++k) { + SchemaColumn keyColumn = this.getPrimaryKeyColumn(k); + builder = builder.set(keyColumn.name(), primaryKeyValues.get(k), keyColumn.javaType()); + } + + builder = builder.setString(this.schema.content(), d.getContent()) + .setVector(this.schema.embedding(), + CqlVector.newInstance(EmbeddingUtils.toList(embeddings.get(documents.indexOf(d)))), + Float.class); + + for (var metadataColumn : this.schema.metadataColumns() + .stream() + .filter(mc -> d.getMetadata().containsKey(mc.name())) + .toList()) { + + builder = builder.set(metadataColumn.name(), d.getMetadata().get(metadataColumn.name()), + metadataColumn.javaType()); + } + BoundStatement s = builder.build().setExecutionProfileName(DRIVER_PROFILE_UPDATES); + this.session.execute(s); + }, this.executor); + } + CompletableFuture.allOf(futures).join(); + } + + @Override + public Optional doDelete(List idList) { + CompletableFuture[] futures = new CompletableFuture[idList.size()]; + int i = 0; + for (String id : idList) { + List primaryKeyValues = this.documentIdTranslator.apply(id); + BoundStatement s = this.deleteStmt.bind(primaryKeyValues.toArray()); + futures[i++] = this.session.executeAsync(s).toCompletableFuture(); + } + CompletableFuture.allOf(futures).join(); + return Optional.of(Boolean.TRUE); + } + + @Override + public List doSimilaritySearch(SearchRequest request) { + Preconditions.checkArgument(request.getTopK() <= 1000); + var embedding = toFloatArray(this.embeddingModel.embed(request.getQuery())); + CqlVector cqlVector = CqlVector.newInstance(embedding); + + String whereClause = ""; + if (request.hasFilterExpression()) { + String expression = this.filterExpressionConverter.convertExpression(request.getFilterExpression()); + if (!expression.isBlank()) { + whereClause = String.format("where %s", expression); + } + } + + String query = String.format(this.similarityStmt, cqlVector, whereClause, cqlVector, request.getTopK()); + List documents = new ArrayList<>(); + logger.trace("Executing {}", query); + SimpleStatement s = SimpleStatement.newInstance(query).setExecutionProfileName(DRIVER_PROFILE_SEARCH); + + for (Row row : this.session.execute(s)) { + float score = row.getFloat(0); + if (score < request.getSimilarityThreshold()) { + break; + } + Map docFields = new HashMap<>(); + docFields.put(DocumentMetadata.DISTANCE.value(), 1 - score); + for (var metadata : this.schema.metadataColumns()) { + var value = row.get(metadata.name(), metadata.javaType()); + if (null != value) { + docFields.put(metadata.name(), value); + } + } + Document doc = Document.builder() + .id(getDocumentId(row)) + .text(row.getString(this.schema.content())) + .metadata(docFields) + .score((double) score) + .build(); + + documents.add(doc); + } + return documents; + } + + void checkSchemaValid() { + this.checkSchemaValid(this.embeddingModel.dimensions()); + } + + private Similarity getIndexSimilarity(TableMetadata metadata) { + + return Similarity.valueOf(metadata.getIndex(this.schema.index()) + .get() + .getOptions() + .getOrDefault("similarity_function", "COSINE") + .toUpperCase()); + } + + private PreparedStatement prepareDeleteStatement() { + Delete stmt = null; + DeleteSelection stmtStart = QueryBuilder.deleteFrom(this.schema.keyspace(), this.schema.table()); + + for (var c : this.schema.partitionKeys()) { + stmt = (null != stmt ? stmt : stmtStart).whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name())); + } + for (var c : this.schema.clusteringKeys()) { + stmt = stmt.whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name())); + } + + return this.session.prepare(stmt.build()); + } + + private PreparedStatement prepareAddStatement(Set metadataFields) { + + // metadata fields that are not configured as metadata columns are not added + Set fieldsThatAreColumns = new HashSet<>(this.schema.metadataColumns() + .stream() + .map(mc -> mc.name()) + .filter(mc -> metadataFields.contains(mc)) + .toList()); + + return this.addStmts.computeIfAbsent(fieldsThatAreColumns, fields -> { + + RegularInsert stmt = null; + InsertInto stmtStart = QueryBuilder.insertInto(this.schema.keyspace(), this.schema.table()); + + for (var c : this.schema.partitionKeys()) { + stmt = (null != stmt ? stmt : stmtStart).value(c.name(), QueryBuilder.bindMarker(c.name())); + } + for (var c : this.schema.clusteringKeys()) { + stmt = stmt.value(c.name(), QueryBuilder.bindMarker(c.name())); + } + + stmt = stmt.value(this.schema.content(), QueryBuilder.bindMarker(this.schema.content())) + .value(this.schema.embedding(), QueryBuilder.bindMarker(this.schema.embedding())); + + for (String metadataField : fields) { + stmt = stmt.value(metadataField, QueryBuilder.bindMarker(metadataField)); + } + return this.session.prepare(stmt.build()); + }); + } + + private String similaritySearchStatement() { + StringBuilder ids = new StringBuilder(); + for (var m : this.schema.partitionKeys()) { + ids.append(m.name()).append(','); + } + for (var m : this.schema.clusteringKeys()) { + ids.append(m.name()).append(','); + } + ids.deleteCharAt(ids.length() - 1); + + String similarityFunction = new StringBuilder("similarity_").append(this.similarity.toString().toLowerCase()) + .append('(') + .append(this.schema.embedding()) + .append(",?)") + .toString(); + + StringBuilder extraSelectFields = new StringBuilder(); + for (var m : this.schema.metadataColumns()) { + extraSelectFields.append(',').append(m.name()); + } + if (this.returnEmbeddings) { + extraSelectFields.append(',').append(this.schema.embedding()); + } + + // java-driver-query-builder doesn't support orderByAnnOf yet + String query = String.format(QUERY_FORMAT, similarityFunction, ids.toString(), this.schema.content(), + extraSelectFields.toString(), this.schema.keyspace(), this.schema.table(), this.schema.embedding()); + + query = query.replace("?", "%s"); + logger.debug("preparing {}", query); + return query; + } + + private String getDocumentId(Row row) { + List primaryKeyValues = new ArrayList<>(); + for (var m : this.schema.partitionKeys()) { + primaryKeyValues.add(row.get(m.name(), m.javaType())); + } + for (var m : this.schema.clusteringKeys()) { + primaryKeyValues.add(row.get(m.name(), m.javaType())); + } + return this.primaryKeyTranslator.apply(primaryKeyValues); + } + + @Override + public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { + return VectorStoreObservationContext.builder(VectorStoreProvider.CASSANDRA.value(), operationName) + .withCollectionName(this.schema.table()) + .withDimensions(this.embeddingModel.dimensions()) + .withNamespace(this.schema.keyspace()) + .withSimilarityMetric(getSimilarityMetric()); + } + + private String getSimilarityMetric() { + if (!SIMILARITY_TYPE_MAPPING.containsKey(this.similarity)) { + return this.similarity.name(); + } + return SIMILARITY_TYPE_MAPPING.get(this.similarity).value(); + } + + @Override + public void close() throws Exception { + if (this.closeSessionOnClose) { + this.session.close(); + } + } + + SchemaColumn getPrimaryKeyColumn(int index) { + return index < this.schema.partitionKeys().size() ? this.schema.partitionKeys().get(index) + : this.schema.clusteringKeys().get(index - this.schema.partitionKeys().size()); + } + + @VisibleForTesting + static void dropKeyspace(CassandraBuilder builder) { + Preconditions.checkState(builder.keyspace.startsWith("test_"), "Only test keyspaces can be dropped"); + builder.session.execute(SchemaBuilder.dropKeyspace(builder.keyspace).ifExists().build()); + } + + void ensureSchemaExists(int vectorDimension) { + if (!this.disallowSchemaChanges) { + SchemaUtil.ensureKeyspaceExists(this.session, this.schema.keyspace); + ensureTableExists(vectorDimension); + ensureTableColumnsExist(vectorDimension); + ensureIndexesExists(); + SchemaUtil.checkSchemaAgreement(this.session); + } + else { + checkSchemaValid(vectorDimension); + } + } + + void checkSchemaValid(int vectorDimension) { + + Preconditions.checkState(this.session.getMetadata().getKeyspace(this.schema.keyspace).isPresent(), + "keyspace %s does not exist", this.schema.keyspace); + + Preconditions.checkState(this.session.getMetadata() + .getKeyspace(this.schema.keyspace) + .get() + .getTable(this.schema.table) + .isPresent(), "table %s does not exist"); + + TableMetadata tableMetadata = this.session.getMetadata() + .getKeyspace(this.schema.keyspace) + .get() + .getTable(this.schema.table) + .get(); + + Preconditions.checkState(tableMetadata.getColumn(this.schema.content).isPresent(), "column %s does not exist", + this.schema.content); + + Preconditions.checkState(tableMetadata.getColumn(this.schema.embedding).isPresent(), "column %s does not exist", + this.schema.embedding); + + for (SchemaColumn m : this.schema.metadataColumns) { + Optional column = tableMetadata.getColumn(m.name()); + Preconditions.checkState(column.isPresent(), "column %s does not exist", m.name()); + + Preconditions.checkArgument(column.get().getType().equals(m.type()), + "Mismatching type on metadata column %s of %s vs %s", m.name(), column.get().getType(), m.type()); + + if (m.indexed()) { + Preconditions.checkState( + tableMetadata.getIndexes().values().stream().anyMatch(i -> i.getTarget().equals(m.name())), + "index %s does not exist", m.name()); + } + } + + } + + private void ensureIndexesExists() { + + SimpleStatement indexStmt = SchemaBuilder.createIndex(this.schema.index) + .ifNotExists() + .custom("StorageAttachedIndex") + .onTable(this.schema.keyspace, this.schema.table) + .andColumn(this.schema.embedding) + .build(); + + logger.debug("Executing {}", indexStmt.getQuery()); + this.session.execute(indexStmt); + + Stream + .concat(this.schema.partitionKeys.stream(), + Stream.concat(this.schema.clusteringKeys.stream(), this.schema.metadataColumns.stream())) + .filter(cs -> cs.indexed()) + .forEach(metadata -> { + + SimpleStatement indexStatement = SchemaBuilder.createIndex(String.format("%s_idx", metadata.name())) + .ifNotExists() + .custom("StorageAttachedIndex") + .onTable(this.schema.keyspace, this.schema.table) + .andColumn(metadata.name()) + .build(); + + logger.debug("Executing {}", indexStatement.getQuery()); + this.session.execute(indexStatement); + }); + } + + private void ensureTableExists(int vectorDimension) { + if (this.session.getMetadata().getKeyspace(this.schema.keyspace).get().getTable(this.schema.table).isEmpty()) { + + CreateTable createTable = null; + + CreateTableStart createTableStart = SchemaBuilder.createTable(this.schema.keyspace, this.schema.table) + .ifNotExists(); + + for (SchemaColumn partitionKey : this.schema.partitionKeys) { + createTable = (null != createTable ? createTable : createTableStart).withPartitionKey(partitionKey.name, + partitionKey.type); + } + for (SchemaColumn clusteringKey : this.schema.clusteringKeys) { + createTable = createTable.withClusteringColumn(clusteringKey.name, clusteringKey.type); + } + + createTable = createTable.withColumn(this.schema.content, DataTypes.TEXT); + + for (SchemaColumn metadata : this.schema.metadataColumns) { + createTable = createTable.withColumn(metadata.name(), metadata.type()); + } + + // https://datastax-oss.atlassian.net/browse/JAVA-3118 + // .withColumn(config.embedding, new DefaultVectorType(DataTypes.FLOAT, + // vectorDimension)); + + StringBuilder tableStmt = new StringBuilder(createTable.asCql()); + tableStmt.setLength(tableStmt.length() - 1); + tableStmt.append(',') + .append(this.schema.embedding) + .append(" vector)"); + logger.debug("Executing {}", tableStmt.toString()); + this.session.execute(tableStmt.toString()); + } + } + + private void ensureTableColumnsExist(int vectorDimension) { + + TableMetadata tableMetadata = this.session.getMetadata() + .getKeyspace(this.schema.keyspace) + .get() + .getTable(this.schema.table) + .get(); + + Set newColumns = new HashSet<>(); + boolean addContent = tableMetadata.getColumn(this.schema.content).isEmpty(); + boolean addEmbedding = tableMetadata.getColumn(this.schema.embedding).isEmpty(); + + for (SchemaColumn metadata : this.schema.metadataColumns) { + Optional column = tableMetadata.getColumn(metadata.name()); + if (column.isPresent()) { + + Preconditions.checkArgument(column.get().getType().equals(metadata.type()), + "Cannot change type on metadata column %s from %s to %s", metadata.name(), + column.get().getType(), metadata.type()); + } + else { + newColumns.add(metadata); + } + } + + if (!newColumns.isEmpty() || addContent || addEmbedding) { + AlterTableAddColumn alterTable = SchemaBuilder.alterTable(this.schema.keyspace, this.schema.table); + for (SchemaColumn metadata : newColumns) { + alterTable = alterTable.addColumn(metadata.name(), metadata.type()); + } + if (addContent) { + alterTable = alterTable.addColumn(this.schema.content, DataTypes.TEXT); + } + if (addEmbedding) { + // special case for embedding column, bc JAVA-3118, as above + StringBuilder alterTableStmt = new StringBuilder(((BuildableQuery) alterTable).asCql()); + if (newColumns.isEmpty() && !addContent) { + alterTableStmt.append(" ADD ("); + } + else { + alterTableStmt.setLength(alterTableStmt.length() - 1); + alterTableStmt.append(','); + } + alterTableStmt.append(this.schema.embedding) + .append(" vector)"); + + logger.debug("Executing {}", alterTableStmt.toString()); + this.session.execute(alterTableStmt.toString()); + } + else { + SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build(); + logger.debug("Executing {}", stmt.getQuery()); + this.session.execute(stmt); + } + } + } + + /** + * Indexes are automatically created with COSINE. This can be changed manually via + * cqlsh + */ + public enum Similarity { + + COSINE, DOT_PRODUCT, EUCLIDEAN + + } + + public enum SchemaColumnTags { + + INDEXED + + } + + /** + * Given a string document id, return the value for each primary key column. + * + * It is a requirement that an empty {@code List} returns an example formatted + * id + */ + public interface DocumentIdTranslator extends Function> { + + } + + /** Given a list of primary key column values, return the document id. */ + public interface PrimaryKeyTranslator extends Function, String> { + + } + + record Schema(String keyspace, String table, List partitionKeys, List clusteringKeys, + String content, String embedding, String index, Set metadataColumns) { + + } + + public record SchemaColumn(String name, DataType type, SchemaColumnTags... tags) { + + public SchemaColumn(String name, DataType type) { + this(name, type, new SchemaColumnTags[0]); + } + + public GenericType javaType() { + return CodecRegistry.DEFAULT.codecFor(this.type).getJavaType(); + } + + public boolean indexed() { + for (SchemaColumnTags t : this.tags) { + if (SchemaColumnTags.INDEXED == t) { + return true; + } + } + return false; + } + + } + + /** + * Builder for the Cassandra vector store. + * + * All metadata columns configured to the store will be fetched and added to all + * queried documents. + * + * To filter expression search against a metadata column configure it with + * SchemaColumnTags.INDEXED + * + * The Cassandra Java Driver is configured via the application.conf resource found in + * the classpath. See + * https://github.com/apache/cassandra-java-driver/tree/4.x/manual/core/configuration + * + */ + public static class CassandraBuilder extends AbstractVectorStoreBuilder { + + private CqlSession session; + + private CqlSessionBuilder sessionBuilder; + + private boolean closeSessionOnClose; + + private String keyspace = DEFAULT_KEYSPACE_NAME; + + private String table = DEFAULT_TABLE_NAME; + + private List partitionKeys = List.of(new SchemaColumn(DEFAULT_ID_NAME, DataTypes.TEXT)); + + private List clusteringKeys = List.of(); + + private String indexName; + + private String contentColumnName = DEFAULT_CONTENT_COLUMN_NAME; + + private String embeddingColumnName = DEFAULT_EMBEDDING_COLUMN_NAME; + + private Set metadataColumns = new HashSet<>(); + + private boolean disallowSchemaChanges = false; + + private int fixedThreadPoolExecutorSize = DEFAULT_ADD_CONCURRENCY; + + private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy(); + + private FilterExpressionConverter filterExpressionConverter; + + private DocumentIdTranslator documentIdTranslator = (String id) -> List.of(id); + + private PrimaryKeyTranslator primaryKeyTranslator = (List primaryKeyColumns) -> { + if (primaryKeyColumns.isEmpty()) { + return "test"; + } + Preconditions.checkArgument(1 == primaryKeyColumns.size()); + return (String) primaryKeyColumns.get(0); + }; + + private boolean returnEmbeddings = false; + + /** + * Executor to use when adding documents. The hotspot is the call to the + * embeddingModel. For remote transformers you probably want a higher value to + * utilize network. For local transformers you probably want a lower value to + * avoid saturation. + **/ + public CassandraBuilder fixedThreadPoolExecutorSize(int threads) { + Preconditions.checkArgument(0 < threads); + this.fixedThreadPoolExecutorSize = threads; + return this; + } + + /** + * Sets the CQL session. + * @param session the CQL session to use + * @return the builder instance + * @throws IllegalArgumentException if session is null + */ + public CassandraBuilder session(CqlSession session) { + Assert.notNull(session, "Session must not be null"); + this.session = session; + return this; + } + + /** + * Sets the keyspace name. + * @param keyspace the keyspace name + * @return the builder instance + * @throws IllegalArgumentException if keyspace is null or empty + */ + public CassandraBuilder keyspace(String keyspace) { + Assert.hasText(keyspace, "Keyspace must not be null or empty"); + this.keyspace = keyspace; + return this; + } + + /** + * Adds a contact point to the session builder. + * @param contactPoint the contact point to add + * @return the builder instance + * @throws IllegalStateException if session is already set + */ + public CassandraBuilder contactPoint(InetSocketAddress contactPoint) { + Assert.state(session == null, "Cannot call addContactPoint(..) when session is already set"); + if (sessionBuilder == null) { + sessionBuilder = new CqlSessionBuilder(); + } + sessionBuilder.addContactPoint(contactPoint); + return this; + } + + /** + * Sets the local datacenter for the session builder. + * @param localDatacenter the local datacenter name + * @return the builder instance + * @throws IllegalStateException if session is already set + */ + public CassandraBuilder localDatacenter(String localDatacenter) { + Assert.state(session == null, "Cannot call withLocalDatacenter(..) when session is already set"); + if (sessionBuilder == null) { + sessionBuilder = new CqlSessionBuilder(); + } + sessionBuilder.withLocalDatacenter(localDatacenter); + return this; + } + + /** + * Sets the table name. + * @param table the table name + * @return the builder instance + * @throws IllegalArgumentException if table is null or empty + */ + public CassandraBuilder table(String table) { + Assert.hasText(table, "Table must not be null or empty"); + this.table = table; + return this; + } + + /** + * Sets the partition keys. + * @param partitionKeys the partition keys + * @return the builder instance + * @throws IllegalArgumentException if partitionKeys is null or empty + */ + public CassandraBuilder partitionKeys(List partitionKeys) { + Assert.notEmpty(partitionKeys, "Partition keys must not be null or empty"); + this.partitionKeys = partitionKeys; + return this; + } + + /** + * Sets the clustering keys. + * @param clusteringKeys the clustering keys + * @return the builder instance + */ + public CassandraBuilder clusteringKeys(List clusteringKeys) { + this.clusteringKeys = clusteringKeys != null ? clusteringKeys : List.of(); + return this; + } + + /** + * Sets the index name. + * @param indexName the index name + * @return the builder instance + */ + public CassandraBuilder indexName(String indexName) { + this.indexName = indexName; + return this; + } + + /** + * Sets whether to disallow schema changes. + * @param disallowSchemaChanges true to disallow schema changes + * @return the builder instance + */ + public CassandraBuilder disallowSchemaChanges(boolean disallowSchemaChanges) { + this.disallowSchemaChanges = disallowSchemaChanges; + return this; + } + + /** + * Sets the batching strategy. + * @param batchingStrategy the batching strategy to use + * @return the builder instance + * @throws IllegalArgumentException if batchingStrategy is null + */ + public CassandraBuilder batchingStrategy(BatchingStrategy batchingStrategy) { + Assert.notNull(batchingStrategy, "BatchingStrategy must not be null"); + this.batchingStrategy = batchingStrategy; + return this; + } + + /** + * Sets the filter expression converter. + * @param converter the filter expression converter to use + * @return the builder instance + * @throws IllegalArgumentException if converter is null + */ + public CassandraBuilder filterExpressionConverter(FilterExpressionConverter converter) { + Assert.notNull(converter, "FilterExpressionConverter must not be null"); + this.filterExpressionConverter = converter; + return this; + } + + /** + * Sets the document ID translator. + * @param translator the document ID translator to use + * @return the builder instance + * @throws IllegalArgumentException if translator is null + */ + public CassandraBuilder documentIdTranslator(DocumentIdTranslator translator) { + Assert.notNull(translator, "DocumentIdTranslator must not be null"); + this.documentIdTranslator = translator; + return this; + } + + public CassandraBuilder contentColumnName(String contentColumnName) { + this.contentColumnName = contentColumnName; + return this; + } + + public CassandraBuilder embeddingColumnName(String embeddingColumnName) { + this.embeddingColumnName = embeddingColumnName; + return this; + } + + public CassandraBuilder addMetadataColumns(SchemaColumn... columns) { + CassandraBuilder builder = this; + for (SchemaColumn f : columns) { + builder = builder.addMetadataColumn(f); + } + return builder; + } + + public CassandraBuilder addMetadataColumns(List columns) { + CassandraBuilder builder = this; + this.metadataColumns.addAll(columns); + return builder; + } + + public CassandraBuilder addMetadataColumn(SchemaColumn column) { + + Preconditions.checkArgument(this.metadataColumns.stream().noneMatch(sc -> sc.name().equals(column.name())), + "A metadata column with name %s has already been added", column.name()); + + this.metadataColumns.add(column); + return this; + } + + /** + * Sets the primary key translator. + * @param translator the primary key translator to use + * @return the builder instance + * @throws IllegalArgumentException if translator is null + */ + public CassandraBuilder primaryKeyTranslator(PrimaryKeyTranslator translator) { + Assert.notNull(translator, "PrimaryKeyTranslator must not be null"); + this.primaryKeyTranslator = translator; + return this; + } + + public CassandraBuilder returnEmbeddings(boolean returnEmbeddings) { + this.returnEmbeddings = true; + return this; + } + + Schema buildSchema() { + if (this.indexName == null) { + this.indexName = String.format("%s_%s_%s", table, embeddingColumnName, DEFAULT_INDEX_SUFFIX); + } + + validateSchema(); + + return new Schema(keyspace, table, partitionKeys, clusteringKeys, contentColumnName, embeddingColumnName, + indexName, metadataColumns); + } + + private void validateSchema() { + for (SchemaColumn metadata : metadataColumns) { + Assert.isTrue(!partitionKeys.stream().anyMatch(c -> c.name().equals(metadata.name())), + "metadataColumn " + metadata.name() + " cannot have same name as a partition key"); + + Assert.isTrue(!clusteringKeys.stream().anyMatch(c -> c.name().equals(metadata.name())), + "metadataColumn " + metadata.name() + " cannot have same name as a clustering key"); + + Assert.isTrue(!metadata.name().equals(contentColumnName), + "metadataColumn " + metadata.name() + " cannot have same name as content column name"); + + Assert.isTrue(!metadata.name().equals(embeddingColumnName), + "metadataColumn " + metadata.name() + " cannot have same name as embedding column name"); + } + + int primaryKeyColumnsCount = partitionKeys.size() + clusteringKeys.size(); + String exampleId = primaryKeyTranslator.apply(Collections.emptyList()); + List testIdTranslation = documentIdTranslator.apply(exampleId); + + Assert.isTrue(testIdTranslation.size() == primaryKeyColumnsCount, + "documentIdTranslator results length " + testIdTranslation.size() + + " doesn't match number of primary key columns " + primaryKeyColumnsCount); + + Assert.isTrue(exampleId.equals(primaryKeyTranslator.apply(documentIdTranslator.apply(exampleId))), + "primaryKeyTranslator is not an inverse function to documentIdTranslator"); + } + + @Override + public CassandraVectorStore build() { + validate(); + if (session == null && sessionBuilder != null) { + session = sessionBuilder.build(); + closeSessionOnClose = true; + } + Assert.notNull(session, "Either session must be set directly or configured via sessionBuilder"); + return new CassandraVectorStore(this); + } + + } + +} diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStoreConfig.java similarity index 92% rename from vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java rename to vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStoreConfig.java index 518eb622ac4..725eacdefa7 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStoreConfig.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vectorstore; +package org.springframework.ai.vectorstore.cassandra; import java.net.InetSocketAddress; import java.util.Collections; @@ -65,7 +65,9 @@ * * @author Mick Semb Wever * @since 1.0.0 + * @deprecated since 1.0.0-M5, use {@link CassandraVectorStore#builder()} instead */ +@Deprecated(since = "1.0.0-M5", forRemoval = true) public final class CassandraVectorStoreConfig implements AutoCloseable { public static final String DEFAULT_KEYSPACE_NAME = "springframework"; @@ -317,6 +319,7 @@ private void ensureTableColumnsExist(int vectorDimension) { } } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public enum SchemaColumnTags { INDEXED @@ -329,20 +332,24 @@ public enum SchemaColumnTags { * It is a requirement that an empty {@code List} returns an example formatted * id */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public interface DocumentIdTranslator extends Function> { } + @Deprecated(since = "1.0.0-M5", forRemoval = true) /** Given a list of primary key column values, return the document id. */ public interface PrimaryKeyTranslator extends Function, String> { } + @Deprecated(since = "1.0.0-M5", forRemoval = true) record Schema(String keyspace, String table, List partitionKeys, List clusteringKeys, String content, String embedding, String index, Set metadataColumns) { } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public record SchemaColumn(String name, DataType type, SchemaColumnTags... tags) { public SchemaColumn(String name, DataType type) { @@ -364,6 +371,7 @@ public boolean indexed() { } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public static final class Builder { private CqlSession session = null; @@ -405,6 +413,7 @@ public static final class Builder { private Builder() { } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder withCqlSession(CqlSession session) { Preconditions.checkState(null == this.sessionBuilder, "Cannot call withContactPoint(..) or withLocalDatacenter(..) and this method"); @@ -413,6 +422,7 @@ public Builder withCqlSession(CqlSession session) { return this; } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder addContactPoint(InetSocketAddress contactPoint) { Preconditions.checkState(null == this.session, "Cannot call withCqlSession(..) and this method"); if (null == this.sessionBuilder) { @@ -422,6 +432,7 @@ public Builder addContactPoint(InetSocketAddress contactPoint) { return this; } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder withLocalDatacenter(String localDC) { Preconditions.checkState(null == this.session, "Cannot call withCqlSession(..) and this method"); if (null == this.sessionBuilder) { @@ -431,22 +442,26 @@ public Builder withLocalDatacenter(String localDC) { return this; } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder withKeyspaceName(String keyspace) { this.keyspace = keyspace; return this; } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder withTableName(String table) { this.table = table; return this; } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder withPartitionKeys(List partitionKeys) { Preconditions.checkArgument(!partitionKeys.isEmpty()); this.partitionKeys = partitionKeys; return this; } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder withClusteringKeys(List clusteringKeys) { this.clusteringKeys = clusteringKeys; return this; @@ -456,21 +471,25 @@ public Builder withClusteringKeys(List clusteringKeys) { * defaults (if null) to '<table_name>_<embedding_column_name>_idx' **/ @Nullable + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder withIndexName(String name) { this.indexName = name; return this; } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder withContentColumnName(String name) { this.contentColumnName = name; return this; } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder withEmbeddingColumnName(String name) { this.embeddingColumnName = name; return this; } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder addMetadataColumns(SchemaColumn... columns) { Builder builder = this; for (SchemaColumn f : columns) { @@ -479,12 +498,14 @@ public Builder addMetadataColumns(SchemaColumn... columns) { return builder; } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder addMetadataColumns(List columns) { Builder builder = this; this.metadataColumns.addAll(columns); return builder; } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder addMetadataColumn(SchemaColumn column) { Preconditions.checkArgument(this.metadataColumns.stream().noneMatch(sc -> sc.name().equals(column.name())), @@ -494,11 +515,13 @@ public Builder addMetadataColumn(SchemaColumn column) { return this; } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder disallowSchemaChanges() { this.disallowSchemaChanges = true; return this; } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder returnEmbeddings() { this.returnEmbeddings = true; return this; @@ -510,22 +533,26 @@ public Builder returnEmbeddings() { * utilize network. For local transformers you probably want a lower value to * avoid saturation. **/ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder withFixedThreadPoolExecutorSize(int threads) { Preconditions.checkArgument(0 < threads); this.fixedThreadPoolExecutorSize = threads; return this; } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder withDocumentIdTranslator(DocumentIdTranslator documentIdTranslator) { this.documentIdTranslator = documentIdTranslator; return this; } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Builder withPrimaryKeyTranslator(PrimaryKeyTranslator primaryKeyTranslator) { this.primaryKeyTranslator = primaryKeyTranslator; return this; } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public CassandraVectorStoreConfig build() { if (null == this.indexName) { this.indexName = String.format("%s_%s_%s", this.table, this.embeddingColumnName, DEFAULT_INDEX_SUFFIX); diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/cassandra/CassandraImage.java similarity index 95% rename from vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java rename to vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/cassandra/CassandraImage.java index 9bfd6eb2060..ae9d8ec0ea9 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/cassandra/CassandraImage.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai; +package org.springframework.ai.cassandra; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/chat/memory/CassandraChatMemoryIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java similarity index 96% rename from vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/chat/memory/CassandraChatMemoryIT.java rename to vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java index 802b046b536..50a2180c1b3 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/chat/memory/CassandraChatMemoryIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.chat.memory; +package org.springframework.ai.chat.memory.cassandra; import java.time.Duration; @@ -26,7 +26,7 @@ import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; -import org.springframework.ai.CassandraImage; +import org.springframework.ai.cassandra.CassandraImage; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverterTests.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/CassandraFilterExpressionConverterTests.java similarity index 99% rename from vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverterTests.java rename to vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/CassandraFilterExpressionConverterTests.java index 89db32064ea..18a5927d592 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/CassandraFilterExpressionConverterTests.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vectorstore; +package org.springframework.ai.vectorstore.cassandra; import java.util.Collection; import java.util.HashSet; diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/CassandraRichSchemaVectorStoreIT.java similarity index 83% rename from vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java rename to vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/CassandraRichSchemaVectorStoreIT.java index 4a31c8dbe13..a48cb29e975 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/CassandraRichSchemaVectorStoreIT.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vectorstore; +package org.springframework.ai.vectorstore.cassandra; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -43,11 +43,12 @@ import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.shaded.org.apache.commons.lang3.RandomStringUtils; -import org.springframework.ai.CassandraImage; +import org.springframework.ai.cassandra.CassandraImage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; -import org.springframework.ai.vectorstore.CassandraVectorStoreConfig.SchemaColumn; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.cassandra.CassandraVectorStore.SchemaColumn; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; @@ -92,66 +93,65 @@ class CassandraRichSchemaVectorStoreIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withUserConfiguration(TestApplication.class); - static CassandraVectorStoreConfig.Builder storeBuilder(ApplicationContext context, - List columnOverrides) throws IOException { + static CassandraVectorStore.CassandraBuilder storeBuilder(ApplicationContext context, + List columnOverrides) throws IOException { - Optional wikiOverride = columnOverrides.stream().filter(f -> "wiki".equals(f.name())).findFirst(); + Optional wikiOverride = columnOverrides.stream() + .filter(f -> "wiki".equals(f.name())) + .findFirst(); - Optional langOverride = columnOverrides.stream() + Optional langOverride = columnOverrides.stream() .filter(f -> "language".equals(f.name())) .findFirst(); - Optional titleOverride = columnOverrides.stream() + Optional titleOverride = columnOverrides.stream() .filter(f -> "title".equals(f.name())) .findFirst(); - Optional chunkNoOverride = columnOverrides.stream() + Optional chunkNoOverride = columnOverrides.stream() .filter(f -> "chunk_no".equals(f.name())) .findFirst(); - SchemaColumn wikiSC = wikiOverride.orElse(new SchemaColumn("wiki", DataTypes.TEXT)); - SchemaColumn langSC = langOverride.orElse(new SchemaColumn("language", DataTypes.TEXT)); - SchemaColumn titleSC = titleOverride.orElse(new SchemaColumn("title", DataTypes.TEXT)); - SchemaColumn chunkNoSC = chunkNoOverride.orElse(new SchemaColumn("chunk_no", DataTypes.INT)); - - List partitionKeys = List.of(wikiSC, langSC, titleSC); - List clusteringKeys = List.of(chunkNoSC); - - CassandraVectorStoreConfig.Builder builder = CassandraVectorStoreConfig.builder() - .withCqlSession(context.getBean(CqlSession.class)) - .withKeyspaceName("test_wikidata") - .withTableName("articles") - .withPartitionKeys(partitionKeys) - .withClusteringKeys(clusteringKeys) - .withContentColumnName("body") - .withEmbeddingColumnName("all_minilm_l6_v2_embedding") - .withIndexName("all_minilm_l6_v2_ann") - - .addMetadataColumns(new SchemaColumn("revision", DataTypes.INT), - new SchemaColumn("id", DataTypes.INT, CassandraVectorStoreConfig.SchemaColumnTags.INDEXED)) - + var wikiSC = wikiOverride.orElse(new CassandraVectorStore.SchemaColumn("wiki", DataTypes.TEXT)); + var langSC = langOverride.orElse(new CassandraVectorStore.SchemaColumn("language", DataTypes.TEXT)); + var titleSC = titleOverride.orElse(new CassandraVectorStore.SchemaColumn("title", DataTypes.TEXT)); + var chunkNoSC = chunkNoOverride.orElse(new CassandraVectorStore.SchemaColumn("chunk_no", DataTypes.INT)); + + List partitionKeys = List.of(wikiSC, langSC, titleSC); + List clusteringKeys = List.of(chunkNoSC); + + return CassandraVectorStore.builder() + .session(context.getBean(CqlSession.class)) + .keyspace("test_wikidata") + .table("articles") + .partitionKeys(partitionKeys) + .clusteringKeys(clusteringKeys) + .contentColumnName("body") + .embeddingColumnName("all_minilm_l6_v2_embedding") + .indexName("all_minilm_l6_v2_ann") + .addMetadataColumns(new CassandraVectorStore.SchemaColumn("revision", DataTypes.INT), + new CassandraVectorStore.SchemaColumn("id", DataTypes.INT, + CassandraVectorStore.SchemaColumnTags.INDEXED)) // this store uses '§¶' as a deliminator in the document id between db columns // 'title' and 'chunk_no' - .withPrimaryKeyTranslator((List primaryKeys) -> { + .primaryKeyTranslator((List primaryKeys) -> { if (primaryKeys.isEmpty()) { return "test§¶0"; } - return java.lang.String.format("%s§¶%s", primaryKeys.get(2), primaryKeys.get(3)); + return String.format("%s§¶%s", primaryKeys.get(2), primaryKeys.get(3)); }) - .withDocumentIdTranslator(id -> { + .documentIdTranslator(id -> { String[] parts = id.split("§¶"); String title = parts[0]; int chunk_no = 0 < parts.length ? Integer.parseInt(parts[1]) : 0; return List.of("simplewiki", "en", title, chunk_no); }); - - return builder; } @Test void ensureSchemaCreation() { this.contextRunner.run(context -> { - try (CassandraVectorStore store = createStore(context, false).store()) { + try (CassandraVectorStore store = createStore(context, false)) { Assertions.assertNotNull(store); store.checkSchemaValid(); store.similaritySearch(SearchRequest.query("1843").withTopK(1)); @@ -163,14 +163,16 @@ void ensureSchemaCreation() { void ensureSchemaNoCreation() { this.contextRunner.run(context -> { executeCqlFile(context, "test_wiki_full_schema.cql"); - var wrapper = createStore(context, List.of(), true, false); + var builder = createBuilder(context, List.of(), true, false); + Assertions.assertNotNull(builder); + var store = new CassandraVectorStore(builder); try { - Assertions.assertNotNull(wrapper.store()); - wrapper.store().checkSchemaValid(); - wrapper.store().similaritySearch(SearchRequest.query("1843").withTopK(1)); + store.checkSchemaValid(); + + store.similaritySearch(SearchRequest.query("1843").withTopK(1)); - wrapper.conf().dropKeyspace(); + CassandraVectorStore.dropKeyspace(builder); executeCqlFile(context, "test_wiki_partial_3_schema.cql"); // IllegalStateException: column all_minilm_l6_v2_embedding does not exist @@ -180,8 +182,8 @@ void ensureSchemaNoCreation() { Assertions.assertEquals("column all_minilm_l6_v2_embedding does not exist", ise.getMessage()); } finally { - wrapper.conf().dropKeyspace(); - wrapper.store().close(); + CassandraVectorStore.dropKeyspace(builder); + store.close(); } }); } @@ -192,16 +194,18 @@ void ensureSchemaPartialCreation() { int PARTIAL_FILES = 5; for (int i = 0; i < PARTIAL_FILES; ++i) { executeCqlFile(context, java.lang.String.format("test_wiki_partial_%d_schema.cql", i)); - var wrapper = createStore(context, List.of(), false, false); + var builder = createBuilder(context, List.of(), false, false); + Assertions.assertNotNull(builder); + CassandraVectorStore.dropKeyspace(builder); + var store = builder.build(); try { - Assertions.assertNotNull(wrapper.store()); - wrapper.store().checkSchemaValid(); + store.checkSchemaValid(); - wrapper.store().similaritySearch(SearchRequest.query("1843").withTopK(1)); - wrapper.conf().dropKeyspace(); + store.similaritySearch(SearchRequest.query("1843").withTopK(1)); } finally { - wrapper.store().close(); + CassandraVectorStore.dropKeyspace(builder); + store.close(); } } // make sure there's not more files to test @@ -213,7 +217,7 @@ void ensureSchemaPartialCreation() { @Test void addAndSearch() { this.contextRunner.run(context -> { - try (CassandraVectorStore store = createStore(context, false).store()) { + try (CassandraVectorStore store = createStore(context, false)) { store.add(documents); List results = store @@ -241,16 +245,16 @@ void addAndSearch() { @Test void addAndSearchPoormansBench() { // todo – replace with JMH (parameters: nThreads, rounds, runs, docsPerAdd) - int nThreads = CassandraVectorStoreConfig.DEFAULT_ADD_CONCURRENCY; + int nThreads = CassandraVectorStore.DEFAULT_ADD_CONCURRENCY; int runs = 10; // 100; int docsPerAdd = 12; // 128; int rounds = 3; this.contextRunner.run(context -> { - try (CassandraVectorStore store = new CassandraVectorStore( - storeBuilder(context, List.of()).withFixedThreadPoolExecutorSize(nThreads).build(), - context.getBean(EmbeddingModel.class))) { + try (CassandraVectorStore store = storeBuilder(context, List.of()).fixedThreadPoolExecutorSize(nThreads) + .embeddingModel(context.getBean(EmbeddingModel.class)) + .build()) { var executor = Executors.newFixedThreadPool((int) (nThreads * 1.2)); for (int k = 0; k < rounds; ++k) { @@ -286,7 +290,7 @@ void addAndSearchPoormansBench() { @Test void searchWithPartitionFilter() throws InterruptedException { this.contextRunner.run(context -> { - try (CassandraVectorStore store = createStore(context, false).store()) { + try (CassandraVectorStore store = createStore(context, false)) { store.add(documents); List results = store.similaritySearch(SearchRequest.query("Great Dark Spot").withTopK(5)); @@ -336,7 +340,7 @@ void searchWithPartitionFilter() throws InterruptedException { @Test void unsearchableFilters() throws InterruptedException { this.contextRunner.run(context -> { - try (CassandraVectorStore store = createStore(context, false).store()) { + try (CassandraVectorStore store = createStore(context, false)) { store.add(documents); List results = store.similaritySearch(SearchRequest.query("Great Dark Spot").withTopK(5)); @@ -354,7 +358,7 @@ void unsearchableFilters() throws InterruptedException { @Test void searchWithFilters() throws InterruptedException { this.contextRunner.run(context -> { - try (CassandraVectorStore store = createStore(context, false).store()) { + try (CassandraVectorStore store = createStore(context, false)) { store.add(documents); List results = store.similaritySearch(SearchRequest.query(URANUS_ORBIT_QUERY).withTopK(5)); @@ -418,10 +422,10 @@ void searchWithFilterOnPrimaryKeys() throws InterruptedException { this.contextRunner.run(context -> { List overrides = List.of( - new SchemaColumn("title", DataTypes.TEXT, CassandraVectorStoreConfig.SchemaColumnTags.INDEXED), - new SchemaColumn("chunk_no", DataTypes.INT, CassandraVectorStoreConfig.SchemaColumnTags.INDEXED)); + new SchemaColumn("title", DataTypes.TEXT, CassandraVectorStore.SchemaColumnTags.INDEXED), + new SchemaColumn("chunk_no", DataTypes.INT, CassandraVectorStore.SchemaColumnTags.INDEXED)); - try (CassandraVectorStore store = createStore(context, overrides, false, true).store()) { + try (CassandraVectorStore store = createStore(context, overrides, false, true)) { store.add(documents); @@ -452,7 +456,7 @@ void searchWithFilterOnPrimaryKeys() throws InterruptedException { @Test void documentUpdate() { this.contextRunner.run(context -> { - try (CassandraVectorStore store = createStore(context, false).store()) { + try (CassandraVectorStore store = createStore(context, false)) { store.add(documents); List results = store.similaritySearch(SearchRequest.query(URANUS_ORBIT_QUERY).withTopK(1)); @@ -502,7 +506,7 @@ void documentUpdate() { @Test void searchWithThreshold() { this.contextRunner.run(context -> { - try (CassandraVectorStore store = createStore(context, false).store()) { + try (CassandraVectorStore store = createStore(context, false)) { store.add(documents); List fullResult = store @@ -530,26 +534,43 @@ void searchWithThreshold() { }); } - private StoreWrapper createStore(ApplicationContext context, - boolean disallowSchemaCreation) throws IOException { + private CassandraVectorStore createStore(ApplicationContext context, boolean disallowSchemaCreation) + throws IOException { return createStore(context, List.of(), disallowSchemaCreation, true); } - private StoreWrapper createStore(ApplicationContext context, + private CassandraVectorStore createStore(ApplicationContext context, List columnOverrides, + boolean disallowSchemaCreation, boolean dropKeyspaceFirst) throws IOException { + + CassandraVectorStore.CassandraBuilder builder = storeBuilder(context, columnOverrides); + if (disallowSchemaCreation) { + builder = builder.disallowSchemaChanges(true); + } + + if (dropKeyspaceFirst) { + CassandraVectorStore.dropKeyspace(builder); + } + + builder.embeddingModel(context.getBean(EmbeddingModel.class)); + return new CassandraVectorStore(builder); + } + + private CassandraVectorStore.CassandraBuilder createBuilder(ApplicationContext context, List columnOverrides, boolean disallowSchemaCreation, boolean dropKeyspaceFirst) throws IOException { - CassandraVectorStoreConfig.Builder builder = storeBuilder(context, columnOverrides); + CassandraVectorStore.CassandraBuilder builder = storeBuilder(context, columnOverrides); if (disallowSchemaCreation) { - builder = builder.disallowSchemaChanges(); + builder = builder.disallowSchemaChanges(true); } - CassandraVectorStoreConfig conf = builder.build(); if (dropKeyspaceFirst) { - conf.dropKeyspace(); + CassandraVectorStore.dropKeyspace(builder); } - return new StoreWrapper(new CassandraVectorStore(conf, context.getBean(EmbeddingModel.class)), conf); + + builder.embeddingModel(context.getBean(EmbeddingModel.class)); + return builder; } private void executeCqlFile(ApplicationContext context, String filename) throws IOException { @@ -588,8 +609,4 @@ public CqlSession cqlSession() { } - public record StoreWrapper(K store, V conf) { - - } - } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStoreIT.java similarity index 89% rename from vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java rename to vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStoreIT.java index d50414bb02e..b86a95086e5 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStoreIT.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vectorstore; +package org.springframework.ai.vectorstore.cassandra; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -35,12 +35,13 @@ import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; -import org.springframework.ai.CassandraImage; +import org.springframework.ai.cassandra.CassandraImage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; -import org.springframework.ai.vectorstore.CassandraVectorStoreConfig.SchemaColumn; -import org.springframework.ai.vectorstore.CassandraVectorStoreConfig.SchemaColumnTags; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.cassandra.CassandraVectorStore.SchemaColumn; +import org.springframework.ai.vectorstore.cassandra.CassandraVectorStore.SchemaColumnTags; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; @@ -84,24 +85,25 @@ private static String getText(String uri) { } } - private static CassandraVectorStoreConfig.Builder storeBuilder(CqlSession cqlSession) { - return CassandraVectorStoreConfig.builder() - .withCqlSession(cqlSession) - .withKeyspaceName("test_" + CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME); + private static CassandraVectorStore.CassandraBuilder storeBuilder(CqlSession cqlSession) { + return CassandraVectorStore.builder() + .session(cqlSession) + .keyspace("test_" + CassandraVectorStore.DEFAULT_KEYSPACE_NAME); } private static CassandraVectorStore createTestStore(ApplicationContext context, SchemaColumn... metadataFields) { - CassandraVectorStoreConfig.Builder builder = storeBuilder(context.getBean(CqlSession.class)) + CassandraVectorStore.CassandraBuilder builder = storeBuilder(context.getBean(CqlSession.class)) .addMetadataColumns(metadataFields); return createTestStore(context, builder); } private static CassandraVectorStore createTestStore(ApplicationContext context, - CassandraVectorStoreConfig.Builder builder) { - CassandraVectorStoreConfig conf = builder.build(); - conf.dropKeyspace(); - return new CassandraVectorStore(conf, context.getBean(EmbeddingModel.class)); + CassandraVectorStore.CassandraBuilder builder) { + CassandraVectorStore.dropKeyspace(builder); + builder.embeddingModel(context.getBean(EmbeddingModel.class)); + CassandraVectorStore store = builder.build(); + return store; } @Test @@ -147,8 +149,8 @@ void addAndSearch() { @Test void addAndSearchReturnEmbeddings() { this.contextRunner.run(context -> { - CassandraVectorStoreConfig.Builder builder = storeBuilder(context.getBean(CqlSession.class)) - .returnEmbeddings(); + CassandraVectorStore.CassandraBuilder builder = storeBuilder(context.getBean(CqlSession.class)) + .returnEmbeddings(true); try (CassandraVectorStore store = createTestStore(context, builder)) { List documents = documents(); @@ -197,8 +199,7 @@ void searchWithPartitionFilter() throws InterruptedException { results = store.similaritySearch(SearchRequest.query("The World") .withTopK(5) .withSimilarityThresholdAll() - .withFilterExpression( - java.lang.String.format("%s == 'NL'", CassandraVectorStoreConfig.DEFAULT_ID_NAME))); + .withFilterExpression(java.lang.String.format("%s == 'NL'", CassandraVectorStore.DEFAULT_ID_NAME))); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); @@ -207,7 +208,7 @@ void searchWithPartitionFilter() throws InterruptedException { .withTopK(5) .withSimilarityThresholdAll() .withFilterExpression( - java.lang.String.format("%s == 'BG2'", CassandraVectorStoreConfig.DEFAULT_ID_NAME))); + java.lang.String.format("%s == 'BG2'", CassandraVectorStore.DEFAULT_ID_NAME))); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(bgDocument2.getId()); @@ -216,7 +217,7 @@ void searchWithPartitionFilter() throws InterruptedException { .withTopK(5) .withSimilarityThresholdAll() .withFilterExpression(java.lang.String.format("%s == 'BG' && year == 2020", - CassandraVectorStoreConfig.DEFAULT_ID_NAME))); + CassandraVectorStore.DEFAULT_ID_NAME))); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); @@ -227,7 +228,7 @@ void searchWithPartitionFilter() throws InterruptedException { .withTopK(5) .withSimilarityThresholdAll() .withFilterExpression(java.lang.String.format("NOT(%s == 'BG' && year == 2020)", - CassandraVectorStoreConfig.DEFAULT_ID_NAME)))); + CassandraVectorStore.DEFAULT_ID_NAME)))); } }); } @@ -394,14 +395,15 @@ public static class TestApplication { @Bean public CassandraVectorStore store(CqlSession cqlSession, EmbeddingModel embeddingModel) { - CassandraVectorStoreConfig conf = storeBuilder(cqlSession) - .addMetadataColumns(new SchemaColumn("meta1", DataTypes.TEXT), - new SchemaColumn("meta2", DataTypes.TEXT), new SchemaColumn("country", DataTypes.TEXT), - new SchemaColumn("year", DataTypes.SMALLINT)) - .build(); + CassandraVectorStore.CassandraBuilder builder = storeBuilder(cqlSession) + .addMetadataColumns(new CassandraVectorStore.SchemaColumn("meta1", DataTypes.TEXT), + new CassandraVectorStore.SchemaColumn("meta2", DataTypes.TEXT), + new CassandraVectorStore.SchemaColumn("country", DataTypes.TEXT), + new CassandraVectorStore.SchemaColumn("year", DataTypes.SMALLINT)) + .embeddingModel(embeddingModel); - conf.dropKeyspace(); - return new CassandraVectorStore(conf, embeddingModel); + CassandraVectorStore.dropKeyspace(builder); + return builder.build(); } @Bean diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreObservationIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStoreObservationIT.java similarity index 88% rename from vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreObservationIT.java rename to vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStoreObservationIT.java index e92349efaa4..6aba243e627 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStoreObservationIT.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vectorstore; +package org.springframework.ai.vectorstore.cassandra; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -32,7 +32,7 @@ import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; -import org.springframework.ai.CassandraImage; +import org.springframework.ai.cassandra.CassandraImage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -40,7 +40,8 @@ import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; import org.springframework.ai.transformers.TransformersEmbeddingModel; -import org.springframework.ai.vectorstore.CassandraVectorStoreConfig.SchemaColumn; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.observation.DefaultVectorStoreObservationConvention; import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.LowCardinalityKeyNames; @@ -80,12 +81,6 @@ public static String getText(String uri) { } } - private static CassandraVectorStoreConfig.Builder storeBuilder(CqlSession cqlSession) { - return CassandraVectorStoreConfig.builder() - .withCqlSession(cqlSession) - .withKeyspaceName("test_" + CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME); - } - @Test void observationVectorStoreAddAndQueryOperations() { @@ -110,7 +105,7 @@ void observationVectorStoreAddAndQueryOperations() { .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_VECTOR_QUERY_CONTENT.asString()) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_DIMENSION_COUNT.asString(), "384") .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_COLLECTION_NAME.asString(), - CassandraVectorStoreConfig.DEFAULT_TABLE_NAME) + CassandraVectorStore.DEFAULT_TABLE_NAME) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_NAMESPACE.asString(), "test_springframework") .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_VECTOR_FIELD_NAME.asString()) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_SEARCH_SIMILARITY_METRIC.asString(), @@ -144,7 +139,7 @@ void observationVectorStoreAddAndQueryOperations() { "What is Great Depression") .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_DIMENSION_COUNT.asString(), "384") .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_COLLECTION_NAME.asString(), - CassandraVectorStoreConfig.DEFAULT_TABLE_NAME) + CassandraVectorStore.DEFAULT_TABLE_NAME) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_NAMESPACE.asString(), "test_springframework") .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_VECTOR_FIELD_NAME.asString()) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_SEARCH_SIMILARITY_METRIC.asString(), @@ -177,15 +172,20 @@ public EmbeddingModel embeddingModel() { public CassandraVectorStore store(CqlSession cqlSession, EmbeddingModel embeddingModel, ObservationRegistry observationRegistry) { - CassandraVectorStoreConfig conf = storeBuilder(cqlSession) - .addMetadataColumns(new SchemaColumn("meta1", DataTypes.TEXT), - new SchemaColumn("meta2", DataTypes.TEXT), new SchemaColumn("country", DataTypes.TEXT), - new SchemaColumn("year", DataTypes.SMALLINT)) - .build(); - - conf.dropKeyspace(); - return new CassandraVectorStore(conf, embeddingModel, observationRegistry, null, - new TokenCountBatchingStrategy()); + CassandraVectorStore.CassandraBuilder builder = CassandraVectorStore.builder() + .session(cqlSession) + .session(cqlSession) + .keyspace("test_" + CassandraVectorStore.DEFAULT_KEYSPACE_NAME) + .addMetadataColumns(new CassandraVectorStore.SchemaColumn("meta1", DataTypes.TEXT), + new CassandraVectorStore.SchemaColumn("meta2", DataTypes.TEXT), + new CassandraVectorStore.SchemaColumn("country", DataTypes.TEXT), + new CassandraVectorStore.SchemaColumn("year", DataTypes.SMALLINT)) + .embeddingModel(embeddingModel) + .observationRegistry(observationRegistry) + .batchingStrategy(new TokenCountBatchingStrategy()); + + CassandraVectorStore.dropKeyspace(builder); + return builder.build(); } @Bean diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/WikiVectorStoreExample.java similarity index 83% rename from vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java rename to vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/WikiVectorStoreExample.java index 8330b3a443f..25d9cac2e8b 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/cassandra/WikiVectorStoreExample.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vectorstore; +package org.springframework.ai.vectorstore.cassandra; import java.util.List; @@ -27,7 +27,8 @@ import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; -import org.springframework.ai.vectorstore.CassandraVectorStoreConfig.SchemaColumn; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.cassandra.CassandraVectorStore.SchemaColumn; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; @@ -90,36 +91,33 @@ public CassandraVectorStore store(CqlSession cqlSession, EmbeddingModel embeddin List extraColumns = List.of(new SchemaColumn("revision", DataTypes.INT), new SchemaColumn("id", DataTypes.INT)); - CassandraVectorStoreConfig conf = CassandraVectorStoreConfig.builder() - .withCqlSession(cqlSession) - .withKeyspaceName("wikidata") - .withTableName("articles") - .withPartitionKeys(partitionColumns) - .withClusteringKeys(clusteringColumns) - .withContentColumnName("body") - .withEmbeddingColumnName("all_minilm_l6_v2_embedding") - .withIndexName("all_minilm_l6_v2_ann") - .disallowSchemaChanges() + return CassandraVectorStore.builder() + .session(cqlSession) + .keyspace("wikidata") + .table("articles") + .partitionKeys(partitionColumns) + .clusteringKeys(clusteringColumns) + .contentColumnName("body") + .embeddingColumnName("all_minilm_l6_v2_embedding") + .indexName("all_minilm_l6_v2_ann") + .disallowSchemaChanges(true) .addMetadataColumns(extraColumns) - - .withPrimaryKeyTranslator((List primaryKeys) -> { + .primaryKeyTranslator((List primaryKeys) -> { // the deliminator used to join fields together into the document's id // is arbitary, here "§¶" is used if (primaryKeys.isEmpty()) { return "test§¶0"; } - return java.lang.String.format("%s§¶%s", primaryKeys.get(2), primaryKeys.get(3)); + return String.format("%s§¶%s", primaryKeys.get(2), primaryKeys.get(3)); }) - - .withDocumentIdTranslator(id -> { + .documentIdTranslator(id -> { String[] parts = id.split("§¶"); String title = parts[0]; int chunk_no = 0 < parts.length ? Integer.parseInt(parts[1]) : 0; return List.of("simplewiki", "en", title, chunk_no, 0); }) + .embeddingModel(embeddingModel()) .build(); - - return new CassandraVectorStore(conf, embeddingModel()); } @Bean