Skip to content

Commit 3e5da60

Browse files
committed
Explore returning Search Results.
1 parent 6f57f2f commit 3e5da60

File tree

3 files changed

+64
-74
lines changed

3 files changed

+64
-74
lines changed

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraTemplate.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,8 @@ public <T> List<T> select(Statement<?> statement, Class<T> entityClass) {
354354
Assert.notNull(statement, "Statement must not be null");
355355
Assert.notNull(entityClass, "Entity type must not be null");
356356

357-
return doSelect(statement, entityClass, getTableName(entityClass), entityClass, QueryResultConverter.entity());
357+
return doSelect(statement, entityClass, EntityQueryUtils.getTableName(statement), entityClass,
358+
QueryResultConverter.entity());
358359
}
359360

360361
<T, R> List<R> doSelect(Statement<?> statement, Class<?> entityClass, CqlIdentifier tableName, Class<T> returnType,

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/QueryStatementCreator.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,6 @@ SimpleStatement select(StatementFactory statementFactory, PartTree tree, Cassand
126126
LOG.debug(String.format("Created query [%s]", statement));
127127
}
128128

129-
System.out.println(statement.getQuery());
130-
131129
return statement;
132130
};
133131

spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/VectorSearchIntegrationTests.java

Lines changed: 62 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
import org.springframework.context.annotation.Configuration;
3131
import org.springframework.context.annotation.FilterType;
3232
import org.springframework.data.annotation.Id;
33+
import org.springframework.data.annotation.PersistenceCreator;
3334
import org.springframework.data.cassandra.config.SchemaAction;
34-
import org.springframework.data.cassandra.core.mapping.Indexed;
3535
import org.springframework.data.cassandra.core.mapping.SaiIndexed;
3636
import org.springframework.data.cassandra.core.mapping.Table;
3737
import org.springframework.data.cassandra.core.mapping.VectorType;
@@ -55,14 +55,16 @@
5555
@SpringJUnitConfig
5656
class VectorSearchIntegrationTests extends AbstractSpringDataEmbeddedCassandraIntegrationTest {
5757

58+
Vector VECTOR = Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f);
59+
5860
@Configuration
59-
@EnableCassandraRepositories(basePackageClasses = CommentsRepository.class, considerNestedRepositories = true,
60-
includeFilters = @ComponentScan.Filter(classes = CommentsRepository.class, type = FilterType.ASSIGNABLE_TYPE))
61+
@EnableCassandraRepositories(basePackageClasses = VectorSearchRepository.class, considerNestedRepositories = true,
62+
includeFilters = @ComponentScan.Filter(classes = VectorSearchRepository.class, type = FilterType.ASSIGNABLE_TYPE))
6163
public static class Config extends IntegrationTestConfig {
6264

6365
@Override
6466
protected Set<Class<?>> getInitialEntitySet() {
65-
return Collections.singleton(Comments.class);
67+
return Collections.singleton(WithVectorFields.class);
6668
}
6769

6870
@Override
@@ -71,132 +73,121 @@ public SchemaAction getSchemaAction() {
7173
}
7274
}
7375

74-
@Autowired CommentsRepository repository;
76+
@Autowired VectorSearchRepository repository;
7577

7678
@BeforeEach
7779
void setUp() {
7880

7981
repository.deleteAll();
8082

81-
Comments one = new Comments();
82-
one.setId(UUID.randomUUID());
83-
one.setLanguage("en");
84-
one.setEmbedding(Vector.of(0.45f, 0.09f, 0.01f, 0.2f, 0.11f));
85-
one.setComment("Raining too hard should have postponed");
86-
87-
Comments two = new Comments();
88-
two.setId(UUID.randomUUID());
89-
two.setLanguage("en");
90-
two.setEmbedding(Vector.of(0.99f, 0.5f, -10.99f, -100.1f, 0.34f));
91-
two.setComment("Second rest stop was out of water");
92-
93-
Comments three = new Comments();
94-
three.setId(UUID.randomUUID());
95-
three.setLanguage("en");
96-
three.setEmbedding(Vector.of(0.9f, 0.54f, 0.12f, 0.1f, 0.95f));
97-
three.setComment("LATE RIDERS SHOULD NOT DELAY THE START");
98-
99-
repository.saveAll(List.of(one, two, three));
83+
WithVectorFields w1 = new WithVectorFields("de", "one", Vector.of(0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f));
84+
WithVectorFields w2 = new WithVectorFields("de", "two", Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f));
85+
WithVectorFields w3 = new WithVectorFields("en", "three",
86+
Vector.of(0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f));
87+
WithVectorFields w4 = new WithVectorFields("de", "four",
88+
Vector.of(0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f));
10089

90+
repository.saveAll(List.of(w1, w2, w3, w4));
10191
}
10292

10393
@Test // GH-
10494
void shouldConsiderScoringFunction() {
10595

10696
Vector vector = Vector.of(0.9f, 0.54f, 0.12f, 0.1f, 0.95f);
10797

108-
SearchResults<Comments> result = repository.searchByEmbeddingNear(vector,
98+
SearchResults<WithVectorFields> results = repository.searchByEmbeddingNear(vector,
10999
VectorScoringFunctions.COSINE, Limit.of(100));
110100

111-
assertThat(result).hasSize(3);
112-
for (SearchResult<Comments> commentSearch : result) {
113-
assertThat(commentSearch.getScore().getValue()).isNotCloseTo(0d, offset(0.1d));
101+
assertThat(results).hasSize(4);
102+
for (SearchResult<WithVectorFields> result : results) {
103+
assertThat(result.getScore().getValue()).isNotCloseTo(0d, offset(0.1d));
114104
}
115105

116-
result = repository.searchByEmbeddingNear(vector, VectorScoringFunctions.EUCLIDEAN, Limit.of(100));
106+
results = repository.searchByEmbeddingNear(VECTOR, VectorScoringFunctions.EUCLIDEAN, Limit.of(100));
117107

118-
assertThat(result).hasSize(3);
119-
for (SearchResult<Comments> commentSearch : result) {
120-
assertThat(commentSearch.getScore().getValue()).isNotCloseTo(0.3d, offset(0.1d));
108+
assertThat(results).hasSize(4);
109+
for (SearchResult<WithVectorFields> result : results) {
110+
assertThat(result.getScore().getValue()).isNotCloseTo(0.3d, offset(0.1d));
121111
}
122112
}
123113

124114
@Test // GH-
125115
void shouldRunAnnotatedSearchByVector() {
126116

127-
Vector vector = Vector.of(0.9f, 0.54f, 0.12f, 0.1f, 0.95f);
117+
SearchResults<WithVectorFields> results = repository.searchAnnotatedByEmbeddingNear(VECTOR, Limit.of(100));
128118

129-
SearchResults<Comments> result = repository.searchAnnotatedByEmbeddingNear(vector, Limit.of(100));
130119

131-
assertThat(result).hasSize(3);
132-
for (SearchResult<Comments> commentSearch : result) {
133-
assertThat(commentSearch.getScore().getValue()).isNotCloseTo(0d, offset(0.1d));
120+
assertThat(results).hasSize(4);
121+
for (SearchResult<WithVectorFields> result : results) {
122+
assertThat(result.getScore().getValue()).isNotCloseTo(0d, offset(0.1d));
134123
}
135124
}
136125

137126
@Test // GH-
138127
void shouldFindByVector() {
139128

140-
Vector vector = Vector.of(0.9f, 0.54f, 0.12f, 0.1f, 0.95f);
129+
List<WithVectorFields> result = repository.findByEmbeddingNear(VECTOR, Limit.of(100));
141130

142-
List<Comments> result = repository.findByEmbeddingNear(vector, Limit.of(100));
131+
assertThat(result).hasSize(4);
132+
}
133+
134+
interface VectorSearchRepository extends CrudRepository<WithVectorFields, UUID> {
135+
136+
SearchResults<WithVectorFields> searchByEmbeddingNear(Vector embedding, ScoringFunction function, Limit limit);
137+
138+
List<WithVectorFields> findByEmbeddingNear(Vector embedding, Limit limit);
139+
140+
@Query("SELECT id,description,country,similarity_cosine(embedding,:embedding) AS score FROM withvectorfields ORDER BY embedding ANN OF :embedding LIMIT :limit")
141+
SearchResults<WithVectorFields> searchAnnotatedByEmbeddingNear(Vector embedding, Limit limit);
143142

144-
assertThat(result).hasSize(3);
145143
}
146144

147145
@Table
148-
static class Comments {
149-
150-
@Id UUID id;
151-
String comment;
146+
static class WithVectorFields {
152147

153-
@Indexed String language;
148+
@Id String id;
149+
String country;
150+
String description;
154151

155152
@VectorType(dimensions = 5)
156153
@SaiIndexed Vector embedding;
157154

158-
public UUID getId() {
159-
return id;
160-
}
161-
162-
public void setId(UUID id) {
155+
@PersistenceCreator
156+
public WithVectorFields(String id, String country, String description, Vector embedding) {
163157
this.id = id;
158+
this.country = country;
159+
this.description = description;
160+
this.embedding = embedding;
164161
}
165162

166-
public String getLanguage() {
167-
return language;
163+
public WithVectorFields(String country, String description, Vector embedding) {
164+
this.id = UUID.randomUUID().toString();
165+
this.country = country;
166+
this.description = description;
167+
this.embedding = embedding;
168168
}
169169

170-
public void setLanguage(String language) {
171-
this.language = language;
170+
public String getId() {
171+
return id;
172172
}
173173

174-
public String getComment() {
175-
return comment;
174+
public String getCountry() {
175+
return country;
176176
}
177177

178-
public void setComment(String comment) {
179-
this.comment = comment;
178+
public String getDescription() {
179+
return description;
180180
}
181181

182182
public Vector getEmbedding() {
183183
return embedding;
184184
}
185185

186-
public void setEmbedding(Vector embedding) {
187-
this.embedding = embedding;
186+
@Override
187+
public String toString() {
188+
return "WithVectorFields{" + "id='" + id + '\'' + ", country='" + country + '\'' + ", description='" + description
189+
+ '\'' + '}';
188190
}
189191
}
190192

191-
interface CommentsRepository extends CrudRepository<Comments, UUID> {
192-
193-
SearchResults<Comments> searchByEmbeddingNear(Vector embedding, ScoringFunction function, Limit limit);
194-
195-
List<Comments> findByEmbeddingNear(Vector embedding, Limit limit);
196-
197-
@Query("SELECT id,comment,language,similarity_cosine(embedding,:embedding) AS score FROM comments ORDER BY embedding ANN OF :embedding LIMIT :limit")
198-
SearchResults<Comments> searchAnnotatedByEmbeddingNear(Vector embedding, Limit limit);
199-
200-
}
201-
202193
}

0 commit comments

Comments
 (0)