3030import org .springframework .context .annotation .Configuration ;
3131import org .springframework .context .annotation .FilterType ;
3232import org .springframework .data .annotation .Id ;
33+ import org .springframework .data .annotation .PersistenceCreator ;
3334import org .springframework .data .cassandra .config .SchemaAction ;
34- import org .springframework .data .cassandra .core .mapping .Indexed ;
3535import org .springframework .data .cassandra .core .mapping .SaiIndexed ;
3636import org .springframework .data .cassandra .core .mapping .Table ;
3737import org .springframework .data .cassandra .core .mapping .VectorType ;
5555@ SpringJUnitConfig
5656class VectorSearchIntegrationTests extends AbstractSpringDataEmbeddedCassandraIntegrationTest {
5757
58+ Vector VECTOR = Vector .of (0.2001f , 0.32345f , 0.43456f , 0.54567f , 0.65678f );
59+
5860 @ Configuration
59- @ EnableCassandraRepositories (basePackageClasses = CommentsRepository .class , considerNestedRepositories = true ,
60- includeFilters = @ ComponentScan .Filter (classes = CommentsRepository .class , type = FilterType .ASSIGNABLE_TYPE ))
61+ @ EnableCassandraRepositories (basePackageClasses = VectorSearchRepository .class , considerNestedRepositories = true ,
62+ includeFilters = @ ComponentScan .Filter (classes = VectorSearchRepository .class , type = FilterType .ASSIGNABLE_TYPE ))
6163 public static class Config extends IntegrationTestConfig {
6264
6365 @ Override
6466 protected Set <Class <?>> getInitialEntitySet () {
65- return Collections .singleton (Comments .class );
67+ return Collections .singleton (WithVectorFields .class );
6668 }
6769
6870 @ Override
@@ -71,132 +73,121 @@ public SchemaAction getSchemaAction() {
7173 }
7274 }
7375
74- @ Autowired CommentsRepository repository ;
76+ @ Autowired VectorSearchRepository repository ;
7577
7678 @ BeforeEach
7779 void setUp () {
7880
7981 repository .deleteAll ();
8082
81- Comments one = new Comments ();
82- one .setId (UUID .randomUUID ());
83- one .setLanguage ("en" );
84- one .setEmbedding (Vector .of (0.45f , 0.09f , 0.01f , 0.2f , 0.11f ));
85- one .setComment ("Raining too hard should have postponed" );
86-
87- Comments two = new Comments ();
88- two .setId (UUID .randomUUID ());
89- two .setLanguage ("en" );
90- two .setEmbedding (Vector .of (0.99f , 0.5f , -10.99f , -100.1f , 0.34f ));
91- two .setComment ("Second rest stop was out of water" );
92-
93- Comments three = new Comments ();
94- three .setId (UUID .randomUUID ());
95- three .setLanguage ("en" );
96- three .setEmbedding (Vector .of (0.9f , 0.54f , 0.12f , 0.1f , 0.95f ));
97- three .setComment ("LATE RIDERS SHOULD NOT DELAY THE START" );
98-
99- repository .saveAll (List .of (one , two , three ));
83+ WithVectorFields w1 = new WithVectorFields ("de" , "one" , Vector .of (0.1001f , 0.22345f , 0.33456f , 0.44567f , 0.55678f ));
84+ WithVectorFields w2 = new WithVectorFields ("de" , "two" , Vector .of (0.2001f , 0.32345f , 0.43456f , 0.54567f , 0.65678f ));
85+ WithVectorFields w3 = new WithVectorFields ("en" , "three" ,
86+ Vector .of (0.9001f , 0.82345f , 0.73456f , 0.64567f , 0.55678f ));
87+ WithVectorFields w4 = new WithVectorFields ("de" , "four" ,
88+ Vector .of (0.9001f , 0.92345f , 0.93456f , 0.94567f , 0.95678f ));
10089
90+ repository .saveAll (List .of (w1 , w2 , w3 , w4 ));
10191 }
10292
10393 @ Test // GH-
10494 void shouldConsiderScoringFunction () {
10595
10696 Vector vector = Vector .of (0.9f , 0.54f , 0.12f , 0.1f , 0.95f );
10797
108- SearchResults <Comments > result = repository .searchByEmbeddingNear (vector ,
98+ SearchResults <WithVectorFields > results = repository .searchByEmbeddingNear (vector ,
10999 VectorScoringFunctions .COSINE , Limit .of (100 ));
110100
111- assertThat (result ).hasSize (3 );
112- for (SearchResult <Comments > commentSearch : result ) {
113- assertThat (commentSearch .getScore ().getValue ()).isNotCloseTo (0d , offset (0.1d ));
101+ assertThat (results ).hasSize (4 );
102+ for (SearchResult <WithVectorFields > result : results ) {
103+ assertThat (result .getScore ().getValue ()).isNotCloseTo (0d , offset (0.1d ));
114104 }
115105
116- result = repository .searchByEmbeddingNear (vector , VectorScoringFunctions .EUCLIDEAN , Limit .of (100 ));
106+ results = repository .searchByEmbeddingNear (VECTOR , VectorScoringFunctions .EUCLIDEAN , Limit .of (100 ));
117107
118- assertThat (result ).hasSize (3 );
119- for (SearchResult <Comments > commentSearch : result ) {
120- assertThat (commentSearch .getScore ().getValue ()).isNotCloseTo (0.3d , offset (0.1d ));
108+ assertThat (results ).hasSize (4 );
109+ for (SearchResult <WithVectorFields > result : results ) {
110+ assertThat (result .getScore ().getValue ()).isNotCloseTo (0.3d , offset (0.1d ));
121111 }
122112 }
123113
124114 @ Test // GH-
125115 void shouldRunAnnotatedSearchByVector () {
126116
127- Vector vector = Vector . of ( 0.9f , 0.54f , 0.12f , 0.1f , 0.95f );
117+ SearchResults < WithVectorFields > results = repository . searchAnnotatedByEmbeddingNear ( VECTOR , Limit . of ( 100 ) );
128118
129- SearchResults <Comments > result = repository .searchAnnotatedByEmbeddingNear (vector , Limit .of (100 ));
130119
131- assertThat (result ).hasSize (3 );
132- for (SearchResult <Comments > commentSearch : result ) {
133- assertThat (commentSearch .getScore ().getValue ()).isNotCloseTo (0d , offset (0.1d ));
120+ assertThat (results ).hasSize (4 );
121+ for (SearchResult <WithVectorFields > result : results ) {
122+ assertThat (result .getScore ().getValue ()).isNotCloseTo (0d , offset (0.1d ));
134123 }
135124 }
136125
137126 @ Test // GH-
138127 void shouldFindByVector () {
139128
140- Vector vector = Vector . of ( 0.9f , 0.54f , 0.12f , 0.1f , 0.95f );
129+ List < WithVectorFields > result = repository . findByEmbeddingNear ( VECTOR , Limit . of ( 100 ) );
141130
142- List <Comments > result = repository .findByEmbeddingNear (vector , Limit .of (100 ));
131+ assertThat (result ).hasSize (4 );
132+ }
133+
134+ interface VectorSearchRepository extends CrudRepository <WithVectorFields , UUID > {
135+
136+ SearchResults <WithVectorFields > searchByEmbeddingNear (Vector embedding , ScoringFunction function , Limit limit );
137+
138+ List <WithVectorFields > findByEmbeddingNear (Vector embedding , Limit limit );
139+
140+ @ Query ("SELECT id,description,country,similarity_cosine(embedding,:embedding) AS score FROM withvectorfields ORDER BY embedding ANN OF :embedding LIMIT :limit" )
141+ SearchResults <WithVectorFields > searchAnnotatedByEmbeddingNear (Vector embedding , Limit limit );
143142
144- assertThat (result ).hasSize (3 );
145143 }
146144
147145 @ Table
148- static class Comments {
149-
150- @ Id UUID id ;
151- String comment ;
146+ static class WithVectorFields {
152147
153- @ Indexed String language ;
148+ @ Id String id ;
149+ String country ;
150+ String description ;
154151
155152 @ VectorType (dimensions = 5 )
156153 @ SaiIndexed Vector embedding ;
157154
158- public UUID getId () {
159- return id ;
160- }
161-
162- public void setId (UUID id ) {
155+ @ PersistenceCreator
156+ public WithVectorFields (String id , String country , String description , Vector embedding ) {
163157 this .id = id ;
158+ this .country = country ;
159+ this .description = description ;
160+ this .embedding = embedding ;
164161 }
165162
166- public String getLanguage () {
167- return language ;
163+ public WithVectorFields (String country , String description , Vector embedding ) {
164+ this .id = UUID .randomUUID ().toString ();
165+ this .country = country ;
166+ this .description = description ;
167+ this .embedding = embedding ;
168168 }
169169
170- public void setLanguage ( String language ) {
171- this . language = language ;
170+ public String getId ( ) {
171+ return id ;
172172 }
173173
174- public String getComment () {
175- return comment ;
174+ public String getCountry () {
175+ return country ;
176176 }
177177
178- public void setComment ( String comment ) {
179- this . comment = comment ;
178+ public String getDescription ( ) {
179+ return description ;
180180 }
181181
182182 public Vector getEmbedding () {
183183 return embedding ;
184184 }
185185
186- public void setEmbedding (Vector embedding ) {
187- this .embedding = embedding ;
186+ @ Override
187+ public String toString () {
188+ return "WithVectorFields{" + "id='" + id + '\'' + ", country='" + country + '\'' + ", description='" + description
189+ + '\'' + '}' ;
188190 }
189191 }
190192
191- interface CommentsRepository extends CrudRepository <Comments , UUID > {
192-
193- SearchResults <Comments > searchByEmbeddingNear (Vector embedding , ScoringFunction function , Limit limit );
194-
195- List <Comments > findByEmbeddingNear (Vector embedding , Limit limit );
196-
197- @ Query ("SELECT id,comment,language,similarity_cosine(embedding,:embedding) AS score FROM comments ORDER BY embedding ANN OF :embedding LIMIT :limit" )
198- SearchResults <Comments > searchAnnotatedByEmbeddingNear (Vector embedding , Limit limit );
199-
200- }
201-
202193}
0 commit comments