3030import  com .datastax .oss .driver .api .core .cql .BoundStatement ;
3131import  com .datastax .oss .driver .api .core .cql .BoundStatementBuilder ;
3232import  com .datastax .oss .driver .api .core .cql .PreparedStatement ;
33+ import  com .datastax .oss .driver .api .core .cql .ResultSet ;
3334import  com .datastax .oss .driver .api .core .cql .Row ;
3435import  com .datastax .oss .driver .api .core .cql .SimpleStatement ;
3536import  com .datastax .oss .driver .api .core .data .CqlVector ;
3637import  com .datastax .oss .driver .api .core .metadata .schema .TableMetadata ;
3738import  com .datastax .oss .driver .api .querybuilder .QueryBuilder ;
39+ import  static  com .datastax .oss .driver .api .querybuilder .QueryBuilder .literal ;
3840import  com .datastax .oss .driver .api .querybuilder .delete .Delete ;
3941import  com .datastax .oss .driver .api .querybuilder .delete .DeleteSelection ;
4042import  com .datastax .oss .driver .api .querybuilder .insert .InsertInto ;
4143import  com .datastax .oss .driver .api .querybuilder .insert .RegularInsert ;
44+ import  com .datastax .oss .driver .api .querybuilder .select .Select ;
45+ import  com .datastax .oss .driver .api .querybuilder .select .Selector ;
4246import  com .datastax .oss .driver .shaded .guava .common .base .Preconditions ;
4347import  io .micrometer .observation .ObservationRegistry ;
4448import  org .slf4j .Logger ;
@@ -112,8 +116,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme
112116
113117	public  static  final  String  DRIVER_PROFILE_SEARCH  = "spring-ai-search" ;
114118
115- 	private  static  final  String  QUERY_FORMAT  = "select %s,%s,%s%s from %s.%s ? order by %s ann of ? limit ?" ;
116- 
117119	private  static  final  Logger  logger  = LoggerFactory .getLogger (CassandraVectorStore .class );
118120
119121	private  static  Map <Similarity , VectorStoreSimilarityMetric > SIMILARITY_TYPE_MAPPING  = Map .of (Similarity .COSINE ,
@@ -130,8 +132,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme
130132
131133	private  final  PreparedStatement  deleteStmt ;
132134
133- 	private  final  String  similarityStmt ;
134- 
135135	private  final  Similarity  similarity ;
136136
137137	private  final  BatchingStrategy  batchingStrategy ;
@@ -162,7 +162,6 @@ public CassandraVectorStore(CassandraVectorStoreConfig conf, EmbeddingModel embe
162162			.get ();
163163
164164		this .similarity  = getIndexSimilarity (cassandraMetadata );
165- 		this .similarityStmt  = similaritySearchStatement ();
166165
167166		this .filterExpressionConverter  = new  CassandraFilterExpressionConverter (
168167				cassandraMetadata .getColumns ().values ());
@@ -232,21 +231,14 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
232231		Preconditions .checkArgument (request .getTopK () <= 1000 );
233232		var  embedding  = toFloatArray (this .embeddingModel .embed (request .getQuery ()));
234233		CqlVector <Float > cqlVector  = CqlVector .newInstance (embedding );
235- 
236- 		String  whereClause  = "" ;
237- 		if  (request .hasFilterExpression ()) {
238- 			String  expression  = this .filterExpressionConverter .convertExpression (request .getFilterExpression ());
239- 			if  (!expression .isBlank ()) {
240- 				whereClause  = String .format ("where %s" , expression );
241- 			}
242- 		}
243- 
244- 		String  query  = String .format (this .similarityStmt , cqlVector , whereClause , cqlVector , request .getTopK ());
234+ 		String  cql  = createSimilaritySearchCql (request , cqlVector , request .getTopK ());
245235		List <Document > documents  = new  ArrayList <>();
246- 		logger .trace ("Executing {}" , query );
247- 		SimpleStatement  s  = SimpleStatement .newInstance (query ).setExecutionProfileName (DRIVER_PROFILE_SEARCH );
236+ 		logger .trace ("Executing {}" , cql );
248237
249- 		for  (Row  row  : this .conf .session .execute (s )) {
238+ 		ResultSet  result  = this .conf .session 
239+ 			.execute (SimpleStatement .newInstance (cql ).setExecutionProfileName (DRIVER_PROFILE_SEARCH ));
240+ 
241+ 		for  (Row  row  : result ) {
250242			float  score  = row .getFloat (0 );
251243			if  (score  < request .getSimilarityThreshold ()) {
252244				break ;
@@ -333,38 +325,36 @@ private PreparedStatement prepareAddStatement(Set<String> metadataFields) {
333325		});
334326	}
335327
336- 	private  String  similaritySearchStatement () {
337- 		StringBuilder  ids  = new  StringBuilder ();
338- 		for  (var  m  : this .conf .schema .partitionKeys ()) {
339- 			ids .append (m .name ()).append (',' );
340- 		}
341- 		for  (var  m  : this .conf .schema .clusteringKeys ()) {
342- 			ids .append (m .name ()).append (',' );
343- 		}
344- 		ids .deleteCharAt (ids .length () - 1 );
328+ 	private  String  createSimilaritySearchCql (SearchRequest  request , CqlVector <Float > cqlVector , int  topK ) {
345329
346- 		String  similarityFunction  = new  StringBuilder ("similarity_" ).append (this .similarity .toString ().toLowerCase ())
347- 			.append ('(' )
348- 			.append (this .conf .schema .embedding ())
349- 			.append (",?)" )
350- 			.toString ();
330+ 		Select  stmt  = QueryBuilder .selectFrom (this .conf .schema .keyspace (), this .conf .schema .table ())
331+ 			.function ("similarity_"  + this .similarity .toString ().toLowerCase (),
332+ 					Selector .column (this .conf .schema .embedding ()), literal (cqlVector ));
351333
352- 		StringBuilder  extraSelectFields  = new  StringBuilder ();
334+ 		for  (var  c  : this .conf .schema .partitionKeys ()) {
335+ 			stmt  = stmt .column (c .name ());
336+ 		}
337+ 		for  (var  c  : this .conf .schema .clusteringKeys ()) {
338+ 			stmt  = stmt .column (c .name ());
339+ 		}
340+ 		stmt  = stmt .column (this .conf .schema .content ());
353341		for  (var  m  : this .conf .schema .metadataColumns ()) {
354- 			extraSelectFields . append ( ',' ). append (m .name ());
342+ 			stmt  =  stmt . column (m .name ());
355343		}
356344		if  (this .conf .returnEmbeddings ) {
357- 			extraSelectFields . append ( ',' ). append (this .conf .schema .embedding ());
345+ 			stmt  =  stmt . column (this .conf .schema .embedding ());
358346		}
359347
360- 		// java-driver-query-builder doesn't support orderByAnnOf yet 
361- 		String  query  = String .format (QUERY_FORMAT , similarityFunction , ids .toString (), this .conf .schema .content (),
362- 				extraSelectFields .toString (), this .conf .schema .keyspace (), this .conf .schema .table (),
363- 				this .conf .schema .embedding ());
364- 
365- 		query  = query .replace ("?" , "%s" );
366- 		logger .debug ("preparing {}" , query );
367- 		return  query ;
348+ 		// the filterExpression is a string so we go back to building a CQL string 
349+ 		String  whereClause  = "" ;
350+ 		if  (request .hasFilterExpression ()) {
351+ 			String  expression  = this .filterExpressionConverter .convertExpression (request .getFilterExpression ());
352+ 			if  (!expression .isBlank ()) {
353+ 				whereClause  = String .format ("WHERE %s" , expression );
354+ 			}
355+ 		}
356+ 		String  cql  = stmt .orderByAnnOf (this .conf .schema .embedding (), cqlVector ).limit (topK ).asCql ();
357+ 		return  cql .replace (" ORDER " , whereClause  + " ORDER " );
368358	}
369359
370360	private  String  getDocumentId (Row  row ) {
0 commit comments