Skip to content

Commit 25352b8

Browse files
committed
Add support for Cassandra Vector search.
Closes #1504
1 parent 92fa66b commit 25352b8

32 files changed

+1307
-147
lines changed

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/StatementFactory.java

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import org.springframework.data.cassandra.core.query.Update.SetAtIndexOp;
6060
import org.springframework.data.cassandra.core.query.Update.SetAtKeyOp;
6161
import org.springframework.data.cassandra.core.query.Update.SetOp;
62+
import org.springframework.data.cassandra.core.query.VectorSort;
6263
import org.springframework.data.convert.EntityWriter;
6364
import org.springframework.data.domain.Sort;
6465
import org.springframework.data.mapping.PersistentProperty;
@@ -72,6 +73,7 @@
7273
import org.springframework.util.ClassUtils;
7374

7475
import com.datastax.oss.driver.api.core.CqlIdentifier;
76+
import com.datastax.oss.driver.api.core.data.CqlVector;
7577
import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder;
7678
import com.datastax.oss.driver.api.querybuilder.BindMarker;
7779
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
@@ -688,38 +690,54 @@ private CqlIdentifier getKeyspace(CassandraPersistentEntity<?> entity, CqlIdenti
688690
private StatementBuilder<Select> createSelectAndOrder(List<Selector> selectors, CassandraPersistentEntity<?> entity,
689691
CqlIdentifier from, Filter filter, Sort sort) {
690692

691-
Select select;
693+
StatementBuilder<Select> builder = StatementBuilder.of((Select) QueryBuilder.selectFrom(from),
694+
cassandraConverter.getCodecRegistry());
692695

693-
if (selectors.isEmpty()) {
694-
select = QueryBuilder.selectFrom(getKeyspace(entity, from), from).all();
695-
} else {
696+
builder.bind((statement, factory) -> {
696697

697-
List<com.datastax.oss.driver.api.querybuilder.select.Selector> mappedSelectors = new ArrayList<>(
698-
selectors.size());
699-
for (Selector selector : selectors) {
700-
com.datastax.oss.driver.api.querybuilder.select.Selector orElseGet = selector.getAlias()
701-
.map(it -> getSelection(selector).as(it)).orElseGet(() -> getSelection(selector));
702-
mappedSelectors.add(orElseGet);
703-
}
698+
Select select;
704699

705-
select = QueryBuilder.selectFrom(getKeyspace(entity, from), from).selectors(mappedSelectors);
706-
}
700+
if (selectors.isEmpty()) {
701+
select = QueryBuilder.selectFrom(getKeyspace(entity, from), from).all();
702+
} else {
707703

708-
StatementBuilder<Select> builder = StatementBuilder.of(select, cassandraConverter.getCodecRegistry());
704+
List<com.datastax.oss.driver.api.querybuilder.select.Selector> mappedSelectors = new ArrayList<>(
705+
selectors.size());
706+
for (Selector selector : selectors) {
707+
com.datastax.oss.driver.api.querybuilder.select.Selector orElseGet = selector.getAlias()
708+
.map(it -> getSelection(selector, factory).as(it)).orElseGet(() -> getSelection(selector, factory));
709+
mappedSelectors.add(orElseGet);
710+
}
711+
712+
select = QueryBuilder.selectFrom(getKeyspace(entity, from), from).selectors(mappedSelectors);
713+
}
714+
715+
return select;
716+
});
709717

710718
builder.bind((statement, factory) -> {
711719
return statement.where(getRelations(filter, factory));
712720
});
713721

714722
if (sort.isSorted()) {
715723

716-
builder.apply((statement) -> {
724+
builder.bind((statement, factory) -> {
717725

718726
Select statementToUse = statement;
719727

720-
for (Sort.Order order : sort) {
721-
statementToUse = statementToUse.orderBy(order.getProperty(),
722-
order.isAscending() ? ClusteringOrder.ASC : ClusteringOrder.DESC);
728+
if (sort instanceof VectorSort vs) {
729+
730+
for (Sort.Order order : sort) {
731+
732+
Object vector = vs.getVector();
733+
statementToUse = statementToUse.orderByAnnOf(order.getProperty(), (CqlVector<?>) vector);
734+
}
735+
} else {
736+
737+
for (Sort.Order order : sort) {
738+
statementToUse = statementToUse.orderBy(order.getProperty(),
739+
order.isAscending() ? ClusteringOrder.ASC : ClusteringOrder.DESC);
740+
}
723741
}
724742

725743
return statementToUse;
@@ -730,25 +748,33 @@ private StatementBuilder<Select> createSelectAndOrder(List<Selector> selectors,
730748
}
731749

732750
private static List<Relation> getRelations(Filter filter, TermFactory factory) {
751+
733752
List<Relation> relations = new ArrayList<>();
753+
734754
for (CriteriaDefinition criteriaDefinition : filter) {
735755
relations.add(toClause(criteriaDefinition, factory));
736756
}
757+
737758
return relations;
738759
}
739760

740-
private static com.datastax.oss.driver.api.querybuilder.select.Selector getSelection(Selector selector) {
761+
private static com.datastax.oss.driver.api.querybuilder.select.Selector getSelection(Selector selector,
762+
TermFactory factory) {
741763

742764
if (selector instanceof FunctionCall) {
743765

744766
com.datastax.oss.driver.api.querybuilder.select.Selector[] arguments = ((FunctionCall) selector).getParameters()
745767
.stream().map(param -> {
746768

747-
if (param instanceof ColumnSelector) {
769+
if (param instanceof ColumnSelector s) {
748770

749-
return com.datastax.oss.driver.api.querybuilder.select.Selector
750-
.column(((ColumnSelector) param).getExpression());
771+
return com.datastax.oss.driver.api.querybuilder.select.Selector.column(s.getExpression());
751772
}
773+
774+
if (param instanceof CqlIdentifier i) {
775+
return com.datastax.oss.driver.api.querybuilder.select.Selector.column(i.toString());
776+
}
777+
752778
return new SimpleSelector(param.toString());
753779

754780
}).toArray(com.datastax.oss.driver.api.querybuilder.select.Selector[]::new);

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/convert/CassandraConverters.java

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.time.Instant;
2121
import java.time.LocalDate;
2222
import java.util.ArrayList;
23+
import java.util.Arrays;
2324
import java.util.Collection;
2425
import java.util.Date;
2526
import java.util.List;
@@ -30,18 +31,20 @@
3031
import org.springframework.data.cassandra.core.cql.converter.RowToListConverter;
3132
import org.springframework.data.cassandra.core.cql.converter.RowToMapConverter;
3233
import org.springframework.data.convert.ReadingConverter;
34+
import org.springframework.data.domain.Vector;
3335
import org.springframework.util.Assert;
3436
import org.springframework.util.NumberUtils;
3537

3638
import com.datastax.oss.driver.api.core.cql.Row;
39+
import com.datastax.oss.driver.api.core.data.CqlVector;
3740

3841
/**
3942
* Wrapper class to contain useful converters for the usage with Cassandra.
4043
*
4144
* @author Mark Paluch
4245
* @since 1.5
4346
*/
44-
abstract class CassandraConverters {
47+
public abstract class CassandraConverters {
4548

4649
/**
4750
* Private constructor to prevent instantiation.
@@ -66,6 +69,17 @@ static Collection<Object> getConvertersToRegister() {
6669
converters.add(RowToStringConverter.INSTANCE);
6770
converters.add(RowToUuidConverter.INSTANCE);
6871

72+
converters.add(VectorToFloatArrayConverter.INSTANCE);
73+
converters.add(VectorToDoubleArrayConverter.INSTANCE);
74+
converters.add(VectorToFloatListConverter.INSTANCE);
75+
76+
converters.add(FloatArrayToVectorConverter.INSTANCE);
77+
converters.add(DoubleArrayToVectorConverter.INSTANCE);
78+
converters.add(NumberListToVectorConverter.INSTANCE);
79+
80+
converters.add(VectorToCqlVectorConverter.INSTANCE);
81+
converters.add(CqlVectorToVectorConverter.INSTANCE);
82+
6983
return converters;
7084
}
7185

@@ -222,4 +236,128 @@ public LocalDate convert(Row row) {
222236
return row.getLocalDate(0);
223237
}
224238
}
239+
240+
@ReadingConverter
241+
public enum DoubleArrayToVectorConverter implements Converter<double[], CqlVector<Double>> {
242+
243+
INSTANCE;
244+
245+
@Override
246+
public CqlVector<Double> convert(double[] source) {
247+
248+
Double[] converted = new Double[source.length];
249+
for (int i = 0; i < converted.length; i++) {
250+
converted[i] = source[i];
251+
}
252+
return CqlVector.newInstance(converted);
253+
}
254+
}
255+
256+
public enum CqlVectorToVectorConverter implements Converter<CqlVector<?>, Vector> {
257+
258+
INSTANCE;
259+
260+
@Override
261+
public Vector convert(CqlVector<?> source) {
262+
return CassandraVector.of(source);
263+
}
264+
}
265+
266+
public enum VectorToCqlVectorConverter implements Converter<Vector, CqlVector<?>> {
267+
268+
INSTANCE;
269+
270+
@Override
271+
public CqlVector<?> convert(Vector source) {
272+
273+
if (source instanceof CassandraVector cv) {
274+
return cv.getSource();
275+
}
276+
277+
if (source.getType() == Float.class || source.getType() == Float.TYPE) {
278+
279+
float[] floatArray = source.toFloatArray();
280+
List<Float> boxed = new ArrayList<>(floatArray.length);
281+
282+
for (float v : floatArray) {
283+
boxed.add(v);
284+
}
285+
return CqlVector.newInstance(boxed);
286+
}
287+
288+
return CqlVector.newInstance(Arrays.stream(source.toDoubleArray()).boxed().toList());
289+
}
290+
}
291+
292+
@ReadingConverter
293+
public enum FloatArrayToVectorConverter implements Converter<float[], CqlVector<Float>> {
294+
295+
INSTANCE;
296+
297+
@Override
298+
public CqlVector<Float> convert(float[] source) {
299+
300+
Float[] converted = new Float[source.length];
301+
for (int i = 0; i < converted.length; i++) {
302+
converted[i] = source[i];
303+
}
304+
return CqlVector.newInstance(converted);
305+
}
306+
}
307+
308+
@ReadingConverter
309+
public enum NumberListToVectorConverter implements Converter<List<Number>, CqlVector<Number>> {
310+
311+
INSTANCE;
312+
313+
@Override
314+
public CqlVector<Number> convert(List<Number> source) {
315+
return CqlVector.newInstance(source);
316+
}
317+
}
318+
319+
@ReadingConverter
320+
public enum VectorToFloatArrayConverter implements Converter<CqlVector<Number>, float[]> {
321+
322+
INSTANCE;
323+
324+
@Override
325+
public float[] convert(CqlVector<Number> source) {
326+
float[] array = new float[source.size()];
327+
for (int i = 0; i < array.length; i++) {
328+
array[i] = source.get(i).floatValue();
329+
}
330+
return array;
331+
}
332+
}
333+
334+
@ReadingConverter
335+
public enum VectorToDoubleArrayConverter implements Converter<CqlVector<Number>, double[]> {
336+
337+
INSTANCE;
338+
339+
@Override
340+
public double[] convert(CqlVector<Number> source) {
341+
double[] array = new double[source.size()];
342+
for (int i = 0; i < array.length; i++) {
343+
array[i] = source.get(i).doubleValue();
344+
}
345+
return array;
346+
}
347+
}
348+
349+
@ReadingConverter
350+
public enum VectorToFloatListConverter implements Converter<CqlVector<Number>, List<Float>> {
351+
352+
INSTANCE;
353+
354+
@Override
355+
public List<Float> convert(CqlVector<Number> source) {
356+
List<Float> values = new ArrayList<>(source.size());
357+
for (int i = 0; i < source.size(); i++) {
358+
values.add(source.get(i).floatValue());
359+
}
360+
return values;
361+
}
362+
}
225363
}

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/convert/CassandraCustomConversions.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ static class CassandraConverterConfiguration extends ConverterConfiguration {
109109

110110
CassandraConverterConfiguration(List<?> converters) {
111111
super(STORE_CONVERSIONS, converters, getConverterFilter());
112-
113112
}
114113

115114
CassandraConverterConfiguration(List<?> userConverters, PropertyValueConversions propertyValueConversions) {

0 commit comments

Comments
 (0)