Skip to content

Commit 3d10e48

Browse files
Use cassandra-java-driver's QueryBuidler vector support
ref: JAVA-3118 – apache/cassandra-java-driver#1931
1 parent 64c4567 commit 3d10e48

File tree

4 files changed

+44
-81
lines changed

4 files changed

+44
-81
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@
216216
<!-- also managed by boot bom -->
217217
<oracle.version>23.4.0.24.05</oracle.version>
218218
<postgresql.version>42.7.2</postgresql.version>
219-
<cassandra.java-driver.version>4.18.1</cassandra.java-driver.version>
219+
<cassandra.java-driver.version>4.18.2-SNAPSHOT</cassandra.java-driver.version>
220220
<elasticsearch-java.version>8.13.3</elasticsearch-java.version>
221221
<spring-retry.version>2.0.9</spring-retry.version>
222222
<jackson.version>2.16.1</jackson.version>

vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,19 @@
3030
import com.datastax.oss.driver.api.core.cql.BoundStatement;
3131
import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder;
3232
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
33+
import com.datastax.oss.driver.api.core.cql.ResultSet;
3334
import com.datastax.oss.driver.api.core.cql.Row;
3435
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
3536
import com.datastax.oss.driver.api.core.data.CqlVector;
3637
import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata;
3738
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
39+
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal;
3840
import com.datastax.oss.driver.api.querybuilder.delete.Delete;
3941
import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection;
4042
import com.datastax.oss.driver.api.querybuilder.insert.InsertInto;
4143
import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert;
44+
import com.datastax.oss.driver.api.querybuilder.select.Select;
45+
import com.datastax.oss.driver.api.querybuilder.select.Selector;
4246
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
4347
import io.micrometer.observation.ObservationRegistry;
4448
import org.slf4j.Logger;
@@ -112,8 +116,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme
112116

113117
public static final String DRIVER_PROFILE_SEARCH = "spring-ai-search";
114118

115-
private static final String QUERY_FORMAT = "select %s,%s,%s%s from %s.%s ? order by %s ann of ? limit ?";
116-
117119
private static final Logger logger = LoggerFactory.getLogger(CassandraVectorStore.class);
118120

119121
private static Map<Similarity, VectorStoreSimilarityMetric> SIMILARITY_TYPE_MAPPING = Map.of(Similarity.COSINE,
@@ -130,8 +132,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme
130132

131133
private final PreparedStatement deleteStmt;
132134

133-
private final String similarityStmt;
134-
135135
private final Similarity similarity;
136136

137137
private final BatchingStrategy batchingStrategy;
@@ -162,7 +162,6 @@ public CassandraVectorStore(CassandraVectorStoreConfig conf, EmbeddingModel embe
162162
.get();
163163

164164
this.similarity = getIndexSimilarity(cassandraMetadata);
165-
this.similarityStmt = similaritySearchStatement();
166165

167166
this.filterExpressionConverter = new CassandraFilterExpressionConverter(
168167
cassandraMetadata.getColumns().values());
@@ -232,21 +231,14 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
232231
Preconditions.checkArgument(request.getTopK() <= 1000);
233232
var embedding = toFloatArray(this.embeddingModel.embed(request.getQuery()));
234233
CqlVector<Float> cqlVector = CqlVector.newInstance(embedding);
235-
236-
String whereClause = "";
237-
if (request.hasFilterExpression()) {
238-
String expression = this.filterExpressionConverter.convertExpression(request.getFilterExpression());
239-
if (!expression.isBlank()) {
240-
whereClause = String.format("where %s", expression);
241-
}
242-
}
243-
244-
String query = String.format(this.similarityStmt, cqlVector, whereClause, cqlVector, request.getTopK());
234+
String cql = createSimilaritySearchCql(request, cqlVector, request.getTopK());
245235
List<Document> documents = new ArrayList<>();
246-
logger.trace("Executing {}", query);
247-
SimpleStatement s = SimpleStatement.newInstance(query).setExecutionProfileName(DRIVER_PROFILE_SEARCH);
236+
logger.trace("Executing {}", cql);
248237

249-
for (Row row : this.conf.session.execute(s)) {
238+
ResultSet result = this.conf.session
239+
.execute(SimpleStatement.newInstance(cql).setExecutionProfileName(DRIVER_PROFILE_SEARCH));
240+
241+
for (Row row : result) {
250242
float score = row.getFloat(0);
251243
if (score < request.getSimilarityThreshold()) {
252244
break;
@@ -333,38 +325,36 @@ private PreparedStatement prepareAddStatement(Set<String> metadataFields) {
333325
});
334326
}
335327

336-
private String similaritySearchStatement() {
337-
StringBuilder ids = new StringBuilder();
338-
for (var m : this.conf.schema.partitionKeys()) {
339-
ids.append(m.name()).append(',');
340-
}
341-
for (var m : this.conf.schema.clusteringKeys()) {
342-
ids.append(m.name()).append(',');
343-
}
344-
ids.deleteCharAt(ids.length() - 1);
328+
private String createSimilaritySearchCql(SearchRequest request, CqlVector<Float> cqlVector, int topK) {
345329

346-
String similarityFunction = new StringBuilder("similarity_").append(this.similarity.toString().toLowerCase())
347-
.append('(')
348-
.append(this.conf.schema.embedding())
349-
.append(",?)")
350-
.toString();
330+
Select stmt = QueryBuilder.selectFrom(this.conf.schema.keyspace(), this.conf.schema.table())
331+
.function("similarity_" + this.similarity.toString().toLowerCase(),
332+
Selector.column(this.conf.schema.embedding()), literal(cqlVector));
351333

352-
StringBuilder extraSelectFields = new StringBuilder();
334+
for (var c : this.conf.schema.partitionKeys()) {
335+
stmt = stmt.column(c.name());
336+
}
337+
for (var c : this.conf.schema.clusteringKeys()) {
338+
stmt = stmt.column(c.name());
339+
}
340+
stmt = stmt.column(this.conf.schema.content());
353341
for (var m : this.conf.schema.metadataColumns()) {
354-
extraSelectFields.append(',').append(m.name());
342+
stmt = stmt.column(m.name());
355343
}
356344
if (this.conf.returnEmbeddings) {
357-
extraSelectFields.append(',').append(this.conf.schema.embedding());
345+
stmt = stmt.column(this.conf.schema.embedding());
358346
}
359347

360-
// java-driver-query-builder doesn't support orderByAnnOf yet
361-
String query = String.format(QUERY_FORMAT, similarityFunction, ids.toString(), this.conf.schema.content(),
362-
extraSelectFields.toString(), this.conf.schema.keyspace(), this.conf.schema.table(),
363-
this.conf.schema.embedding());
364-
365-
query = query.replace("?", "%s");
366-
logger.debug("preparing {}", query);
367-
return query;
348+
// the filterExpression is a string so we go back to building a CQL string
349+
String whereClause = "";
350+
if (request.hasFilterExpression()) {
351+
String expression = this.filterExpressionConverter.convertExpression(request.getFilterExpression());
352+
if (!expression.isBlank()) {
353+
whereClause = String.format("WHERE %s", expression);
354+
}
355+
}
356+
String cql = stmt.orderByAnnOf(this.conf.schema.embedding(), cqlVector).limit(topK).asCql();
357+
return cql.replace(" ORDER ", whereClause + " ORDER ");
368358
}
369359

370360
private String getDocumentId(Row row) {

vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import com.datastax.oss.driver.api.core.type.DataTypes;
3737
import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry;
3838
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
39-
import com.datastax.oss.driver.api.querybuilder.BuildableQuery;
4039
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
4140
import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumn;
4241
import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumnEnd;
@@ -234,25 +233,15 @@ private void ensureTableExists(int vectorDimension) {
234233
createTable = createTable.withClusteringColumn(clusteringKey.name, clusteringKey.type);
235234
}
236235

237-
createTable = createTable.withColumn(this.schema.content, DataTypes.TEXT);
236+
createTable = createTable.withColumn(this.schema.content, DataTypes.TEXT)
237+
.withColumn(this.schema.embedding, DataTypes.vectorOf(DataTypes.FLOAT, vectorDimension));
238238

239239
for (SchemaColumn metadata : this.schema.metadataColumns) {
240240
createTable = createTable.withColumn(metadata.name(), metadata.type());
241241
}
242242

243-
// https://datastax-oss.atlassian.net/browse/JAVA-3118
244-
// .withColumn(config.embedding, new DefaultVectorType(DataTypes.FLOAT,
245-
// vectorDimension));
246-
247-
StringBuilder tableStmt = new StringBuilder(createTable.asCql());
248-
tableStmt.setLength(tableStmt.length() - 1);
249-
tableStmt.append(',')
250-
.append(this.schema.embedding)
251-
.append(" vector<float,")
252-
.append(vectorDimension)
253-
.append(">)");
254-
logger.debug("Executing {}", tableStmt.toString());
255-
this.session.execute(tableStmt.toString());
243+
logger.debug("Executing {}", createTable.asCql());
244+
this.session.execute(createTable.build());
256245
}
257246
}
258247

@@ -290,28 +279,12 @@ private void ensureTableColumnsExist(int vectorDimension) {
290279
alterTable = alterTable.addColumn(this.schema.content, DataTypes.TEXT);
291280
}
292281
if (addEmbedding) {
293-
// special case for embedding column, bc JAVA-3118, as above
294-
StringBuilder alterTableStmt = new StringBuilder(((BuildableQuery) alterTable).asCql());
295-
if (newColumns.isEmpty() && !addContent) {
296-
alterTableStmt.append(" ADD (");
297-
}
298-
else {
299-
alterTableStmt.setLength(alterTableStmt.length() - 1);
300-
alterTableStmt.append(',');
301-
}
302-
alterTableStmt.append(this.schema.embedding)
303-
.append(" vector<float,")
304-
.append(vectorDimension)
305-
.append(">)");
306-
307-
logger.debug("Executing {}", alterTableStmt.toString());
308-
this.session.execute(alterTableStmt.toString());
309-
}
310-
else {
311-
SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build();
312-
logger.debug("Executing {}", stmt.getQuery());
313-
this.session.execute(stmt);
282+
alterTable = alterTable.addColumn(this.schema.embedding,
283+
DataTypes.vectorOf(DataTypes.FLOAT, vectorDimension));
314284
}
285+
SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build();
286+
logger.debug("Executing {}", stmt.getQuery());
287+
this.session.execute(stmt);
315288
}
316289
}
317290

vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
*/
2424
public final class CassandraImage {
2525

26-
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("cassandra:5.0");
26+
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("cassandra:5.0.2");
2727

2828
private CassandraImage() {
2929

0 commit comments

Comments
 (0)