Skip to content

Commit 72aed1f

Browse files
committed
Add reactive search implementation.
1 parent 883a81d commit 72aed1f

File tree

6 files changed

+277
-10
lines changed

6 files changed

+277
-10
lines changed

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/AbstractReactiveCassandraQuery.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import reactor.core.publisher.Mono;
1919

2020
import org.reactivestreams.Publisher;
21+
2122
import org.springframework.core.convert.converter.Converter;
2223
import org.springframework.data.cassandra.ReactiveResultSet;
2324
import org.springframework.data.cassandra.core.CassandraOperations;
@@ -26,6 +27,7 @@
2627
import org.springframework.data.cassandra.repository.query.ReactiveCassandraQueryExecution.ExistsExecution;
2728
import org.springframework.data.cassandra.repository.query.ReactiveCassandraQueryExecution.ResultProcessingConverter;
2829
import org.springframework.data.cassandra.repository.query.ReactiveCassandraQueryExecution.ResultProcessingExecution;
30+
import org.springframework.data.cassandra.repository.query.ReactiveCassandraQueryExecution.SearchExecution;
2931
import org.springframework.data.cassandra.repository.query.ReactiveCassandraQueryExecution.SingleEntityExecution;
3032
import org.springframework.data.cassandra.repository.query.ReactiveCassandraQueryExecution.SlicedExecution;
3133
import org.springframework.data.cassandra.repository.query.ReactiveCassandraQueryExecution.WindowExecution;
@@ -126,6 +128,8 @@ private ReactiveCassandraQueryExecution getExecutionToWrap(CassandraParameterAcc
126128
} else if (getQueryMethod().isScrollQuery()) {
127129
return new WindowExecution(getReactiveCassandraOperations(), parameterAccessor.getScrollPosition(),
128130
parameterAccessor.getLimit());
131+
} else if (getQueryMethod().isSearchQuery()) {
132+
return new SearchExecution(getReactiveCassandraOperations(), parameterAccessor);
129133
} else if (getQueryMethod().isCollectionQuery()) {
130134
return new CollectionExecution(getReactiveCassandraOperations());
131135
} else if (isCountQuery()) {

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/CassandraQueryExecution.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.data.domain.ScoringFunction;
3434
import org.springframework.data.domain.SearchResult;
3535
import org.springframework.data.domain.SearchResults;
36+
import org.springframework.data.domain.Similarity;
3637
import org.springframework.data.domain.Slice;
3738
import org.springframework.data.domain.SliceImpl;
3839
import org.springframework.data.mapping.context.MappingContext;
@@ -204,8 +205,10 @@ public Object execute(Statement<?> statement, Class<?> type) {
204205
private Score getScore(Row row, String columnName, @Nullable ScoringFunction function) {
205206

206207
Object object = row.getObject(columnName);
207-
return Score.of(((Number) object).doubleValue(), function == null ? ScoringFunction.UNSPECIFIED : function);
208+
return Similarity.raw(((Number) object).doubleValue(),
209+
function == null ? ScoringFunction.unspecified() : function);
208210
}
211+
209212
}
210213

211214
/**

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/QueryStatementCreator.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@
6161
class QueryStatementCreator {
6262

6363
private static final Map<ScoringFunction, SimilarityFunction> SIMILARITY_FUNCTIONS = Map.of(
64-
VectorScoringFunctions.COSINE, SimilarityFunction.COSINE, VectorScoringFunctions.EUCLIDEAN,
65-
SimilarityFunction.EUCLIDEAN, VectorScoringFunctions.DOT, SimilarityFunction.DOT_PRODUCT,
66-
VectorScoringFunctions.INNER_PRODUCT, SimilarityFunction.DOT_PRODUCT);
64+
VectorScoringFunctions.COSINE, SimilarityFunction.COSINE, //
65+
VectorScoringFunctions.EUCLIDEAN, SimilarityFunction.EUCLIDEAN, //
66+
VectorScoringFunctions.DOT_PRODUCT, SimilarityFunction.DOT_PRODUCT);
6767

6868
private static final Log LOG = LogFactory.getLog(QueryStatementCreator.class);
6969

@@ -148,7 +148,7 @@ private SimilarityFunction getSimilarityFunction(@Nullable ScoringFunction funct
148148

149149
if (function == null) {
150150
throw new IllegalStateException(
151-
"Cannot determine ScoringFunction. No Score or bounded Score Range parameters provided.");
151+
"Cannot determine ScoringFunction. No ScoringFunction, Score/Similarity or bounded Score Range parameters provided.");
152152
}
153153

154154
SimilarityFunction similarityFunction = SIMILARITY_FUNCTIONS.get(function);

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/ReactiveCassandraQueryExecution.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020

2121
import java.util.List;
2222

23+
import org.jspecify.annotations.Nullable;
2324
import org.reactivestreams.Publisher;
25+
2426
import org.springframework.core.convert.converter.Converter;
2527
import org.springframework.dao.IncorrectResultSizeDataAccessException;
2628
import org.springframework.data.cassandra.core.ReactiveCassandraOperations;
@@ -31,6 +33,10 @@
3133
import org.springframework.data.convert.DtoInstantiatingConverter;
3234
import org.springframework.data.domain.Limit;
3335
import org.springframework.data.domain.Pageable;
36+
import org.springframework.data.domain.Score;
37+
import org.springframework.data.domain.ScoringFunction;
38+
import org.springframework.data.domain.SearchResult;
39+
import org.springframework.data.domain.Similarity;
3440
import org.springframework.data.domain.Slice;
3541
import org.springframework.data.domain.SliceImpl;
3642
import org.springframework.data.mapping.context.MappingContext;
@@ -152,6 +158,44 @@ public Publisher<? extends Object> execute(Statement<?> statement, Class<?> type
152158

153159
}
154160

161+
final class SearchExecution implements ReactiveCassandraQueryExecution {
162+
163+
private final ReactiveCassandraOperations operations;
164+
private final CassandraParameterAccessor accessor;
165+
166+
public SearchExecution(ReactiveCassandraOperations operations, CassandraParameterAccessor accessor) {
167+
168+
this.operations = operations;
169+
this.accessor = accessor;
170+
}
171+
172+
@Override
173+
public Publisher<? extends Object> execute(Statement<?> statement, Class<?> type) {
174+
175+
ScoringFunction function = accessor.getScoringFunction();
176+
177+
return operations.query(statement).as(type).map((row, reader) -> {
178+
179+
Object entity = reader.get();
180+
if (row.getColumnDefinitions().contains("__score__")) {
181+
return new SearchResult<>(entity, getScore(row, "__score__", function));
182+
}
183+
184+
if (row.getColumnDefinitions().contains("score")) {
185+
return new SearchResult<>(entity, getScore(row, "score", function));
186+
}
187+
return new SearchResult<>(entity, 0);
188+
}).all();
189+
}
190+
191+
private Score getScore(Row row, String columnName, @Nullable ScoringFunction function) {
192+
193+
Object object = row.getObject(columnName);
194+
return Similarity.raw(((Number) object).doubleValue(),
195+
function == null ? ScoringFunction.unspecified() : function);
196+
}
197+
}
198+
155199
/**
156200
* {@link ReactiveCassandraQueryExecution} to return a single entity.
157201
*
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
/*
2+
* Copyright 2016-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.cassandra.repository;
17+
18+
import static org.assertj.core.api.Assertions.*;
19+
20+
import reactor.core.publisher.Flux;
21+
import reactor.test.StepVerifier;
22+
23+
import java.util.Collections;
24+
import java.util.List;
25+
import java.util.Set;
26+
import java.util.UUID;
27+
28+
import org.junit.jupiter.api.BeforeEach;
29+
import org.junit.jupiter.api.Test;
30+
31+
import org.springframework.beans.factory.annotation.Autowired;
32+
import org.springframework.context.annotation.ComponentScan;
33+
import org.springframework.context.annotation.Configuration;
34+
import org.springframework.context.annotation.FilterType;
35+
import org.springframework.data.annotation.Id;
36+
import org.springframework.data.annotation.PersistenceCreator;
37+
import org.springframework.data.cassandra.config.SchemaAction;
38+
import org.springframework.data.cassandra.core.mapping.SaiIndexed;
39+
import org.springframework.data.cassandra.core.mapping.Table;
40+
import org.springframework.data.cassandra.core.mapping.VectorType;
41+
import org.springframework.data.cassandra.repository.config.EnableReactiveCassandraRepositories;
42+
import org.springframework.data.cassandra.repository.support.AbstractSpringDataEmbeddedCassandraIntegrationTest;
43+
import org.springframework.data.cassandra.repository.support.IntegrationTestConfig;
44+
import org.springframework.data.domain.Limit;
45+
import org.springframework.data.domain.ScoringFunction;
46+
import org.springframework.data.domain.SearchResult;
47+
import org.springframework.data.domain.Similarity;
48+
import org.springframework.data.domain.Vector;
49+
import org.springframework.data.domain.VectorScoringFunctions;
50+
import org.springframework.data.repository.reactive.ReactiveCrudRepository;
51+
import org.springframework.test.context.junit.jupiter.SpringJUnitConfig;
52+
53+
/**
54+
* Integration tests for Vector Search using reactive repositories.
55+
*
56+
* @author Mark Paluch
57+
*/
58+
@SpringJUnitConfig
59+
class ReactiveVectorSearchIntegrationTests extends AbstractSpringDataEmbeddedCassandraIntegrationTest {
60+
61+
Vector VECTOR = Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f);
62+
63+
@Configuration
64+
@EnableReactiveCassandraRepositories(basePackageClasses = ReactiveVectorSearchRepository.class,
65+
considerNestedRepositories = true,
66+
includeFilters = @ComponentScan.Filter(classes = ReactiveVectorSearchRepository.class,
67+
type = FilterType.ASSIGNABLE_TYPE))
68+
public static class Config extends IntegrationTestConfig {
69+
70+
@Override
71+
protected Set<Class<?>> getInitialEntitySet() {
72+
return Collections.singleton(WithVectorFields.class);
73+
}
74+
75+
@Override
76+
public SchemaAction getSchemaAction() {
77+
return SchemaAction.RECREATE_DROP_UNUSED;
78+
}
79+
}
80+
81+
@Autowired ReactiveVectorSearchRepository repository;
82+
83+
@BeforeEach
84+
void setUp() {
85+
86+
repository.deleteAll().as(StepVerifier::create).verifyComplete();
87+
88+
WithVectorFields w1 = new WithVectorFields("de", "one", Vector.of(0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f));
89+
WithVectorFields w2 = new WithVectorFields("de", "two", Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f));
90+
WithVectorFields w3 = new WithVectorFields("en", "three",
91+
Vector.of(0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f));
92+
WithVectorFields w4 = new WithVectorFields("de", "four",
93+
Vector.of(0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f));
94+
95+
repository.saveAll(List.of(w1, w2, w3, w4)).as(StepVerifier::create).expectNextCount(4).verifyComplete();
96+
}
97+
98+
@Test // GH-
99+
void shouldConsiderScoringFunction() {
100+
101+
Vector vector = Vector.of(0.9f, 0.54f, 0.12f, 0.1f, 0.95f);
102+
103+
List<SearchResult<WithVectorFields>> results = repository
104+
.searchByEmbeddingNear(vector, VectorScoringFunctions.COSINE, Limit.of(100)).collectList().block();
105+
106+
assertThat(results).hasSize(4);
107+
for (SearchResult<WithVectorFields> result : results) {
108+
assertThat(result.getScore()).isInstanceOf(Similarity.class);
109+
assertThat(result.getScore().getValue()).isNotCloseTo(0d, offset(0.1d));
110+
}
111+
112+
results = repository.searchByEmbeddingNear(VECTOR, VectorScoringFunctions.EUCLIDEAN, Limit.of(100)).collectList()
113+
.block();
114+
115+
assertThat(results).hasSize(4);
116+
for (SearchResult<WithVectorFields> result : results) {
117+
assertThat(result.getScore()).isInstanceOf(Similarity.class);
118+
assertThat(result.getScore().getValue()).isNotCloseTo(0.3d, offset(0.1d));
119+
}
120+
}
121+
122+
@Test // GH-
123+
void shouldRunAnnotatedSearchByVector() {
124+
125+
List<SearchResult<WithVectorFields>> results = repository.searchAnnotatedByEmbeddingNear(VECTOR, Limit.of(100))
126+
.collectList().block();
127+
128+
assertThat(results).hasSize(4);
129+
for (SearchResult<WithVectorFields> result : results) {
130+
assertThat(result.getScore()).isInstanceOf(Similarity.class);
131+
assertThat(result.getScore().getValue()).isNotCloseTo(0d, offset(0.1d));
132+
}
133+
}
134+
135+
@Test // GH-
136+
void shouldFindByVector() {
137+
138+
List<WithVectorFields> result = repository.findByEmbeddingNear(VECTOR, Limit.of(100)).collectList().block();
139+
140+
assertThat(result).hasSize(4);
141+
}
142+
143+
interface ReactiveVectorSearchRepository extends ReactiveCrudRepository<WithVectorFields, UUID> {
144+
145+
Flux<SearchResult<WithVectorFields>> searchByEmbeddingNear(Vector embedding, ScoringFunction function, Limit limit);
146+
147+
Flux<WithVectorFields> findByEmbeddingNear(Vector embedding, Limit limit);
148+
149+
@Query("SELECT id,description,country,similarity_cosine(embedding,:embedding) AS score FROM withvectorfields ORDER BY embedding ANN OF :embedding LIMIT :limit")
150+
Flux<SearchResult<WithVectorFields>> searchAnnotatedByEmbeddingNear(Vector embedding, Limit limit);
151+
152+
}
153+
154+
@Table
155+
static class WithVectorFields {
156+
157+
@Id String id;
158+
String country;
159+
String description;
160+
161+
@VectorType(dimensions = 5)
162+
@SaiIndexed Vector embedding;
163+
164+
@PersistenceCreator
165+
public WithVectorFields(String id, String country, String description, Vector embedding) {
166+
this.id = id;
167+
this.country = country;
168+
this.description = description;
169+
this.embedding = embedding;
170+
}
171+
172+
public WithVectorFields(String country, String description, Vector embedding) {
173+
this.id = UUID.randomUUID().toString();
174+
this.country = country;
175+
this.description = description;
176+
this.embedding = embedding;
177+
}
178+
179+
public String getId() {
180+
return id;
181+
}
182+
183+
public String getCountry() {
184+
return country;
185+
}
186+
187+
public String getDescription() {
188+
return description;
189+
}
190+
191+
public Vector getEmbedding() {
192+
return embedding;
193+
}
194+
195+
@Override
196+
public String toString() {
197+
return "WithVectorFields{" + "id='" + id + '\'' + ", country='" + country + '\'' + ", description='" + description
198+
+ '\'' + '}';
199+
}
200+
}
201+
202+
}

0 commit comments

Comments
 (0)