55import static java .util .Collections .singletonList ;
66import static java .util .stream .Collectors .toList ;
77
8- import java .nio .ByteBuffer ;
9- import java .nio .ByteOrder ;
108import java .util .HashMap ;
119import java .util .List ;
1210import java .util .Map ;
1311import java .util .Set ;
1412import java .util .stream .Collectors ;
15- import java .util .stream .StreamSupport ;
1613
1714import org .jboss .logging .Logger ;
1815
3027import io .quarkus .redis .datasource .ReactiveRedisDataSource ;
3128import io .quarkus .redis .datasource .json .ReactiveJsonCommands ;
3229import io .quarkus .redis .datasource .keys .KeyScanArgs ;
30+ import io .quarkus .redis .datasource .search .CreateArgs ;
31+ import io .quarkus .redis .datasource .search .Document ;
32+ import io .quarkus .redis .datasource .search .QueryArgs ;
33+ import io .quarkus .redis .datasource .search .SearchQueryResponse ;
3334import io .smallrye .mutiny .Uni ;
3435import io .vertx .mutiny .redis .client .Command ;
3536import io .vertx .mutiny .redis .client .Request ;
36- import io .vertx .mutiny .redis .client .Response ;
3737
3838public class RedisEmbeddingStore implements EmbeddingStore <TextSegment > {
3939
@@ -65,19 +65,12 @@ private void createIndexIfDoesNotExist() {
6565 }
6666 }).await ().indefinitely ();
6767 if (!indexes .contains (schema .getIndexName ())) {
68- // TODO: rewrite to use the typesafe data source API
69- Request request = Request .cmd (Command .FT_CREATE )
70- .arg (schema .getIndexName ())
71- .arg ("ON" )
72- .arg ("JSON" )
73- .arg ("PREFIX" )
74- .arg ("1" )
75- .arg (schema .getPrefix ())
76- .arg ("SCHEMA" );
77- schema .defineFields (request );
78- LOG .debug (
79- "Creating index with command: " + request .toString ().replaceAll ("\r \n " , " " ));
80- ds .getRedis ().send (request ).await ().indefinitely ();
68+ CreateArgs indexCreateArgs = new CreateArgs ()
69+ .onJson ()
70+ .prefixes (schema .getPrefix ());
71+ schema .defineFields (indexCreateArgs );
72+ LOG .debug ("Creating Redis index " + schema .getIndexName ());
73+ ds .search ().ftCreate (schema .getIndexName (), indexCreateArgs ).await ().indefinitely ();
8174 } else {
8275 LOG .debug ("Index in Redis already exists: " + schema .getIndexName ());
8376 }
@@ -152,63 +145,28 @@ public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbeddi
152145 double minScore ) {
153146 String queryTemplate = "*=>[ KNN %d @%s $BLOB AS %s ]" ;
154147 String query = format (queryTemplate , maxResults , schema .getVectorFieldName (), SCORE_FIELD_NAME );
155- // TODO: rewrite to the data source api, but we need a new
156- // method QueryArgs.param(String, byte[]) to get it working
157-
158- // QueryArgs args = new QueryArgs()
159- // .sortByAscending(SCORE_FIELD_NAME)
160- // .param("DIALECT", "2")
161- // .param("BLOB", toByteArray(referenceEmbedding.vector()));
162- // Uni<SearchQueryResponse> search = ds.search()
163- // .ftSearch(schema.getIndexName(), query, args);
164- // SearchQueryResponse response = search.await().indefinitely();
165- Request request = Request .cmd (Command .FT_SEARCH )
166- .arg (schema .getIndexName ())
167- .arg (query )
168- .arg ("PARAMS" )
169- .arg ("2" )
170- .arg ("BLOB" )
171- .arg (toByteArray (referenceEmbedding .vector ()))
172- .arg ("DIALECT" )
173- .arg ("2" );
174- Response response = ds .getRedis ().send (request ).await ().indefinitely ();
175- return StreamSupport .stream (response .get ("results" ).spliterator (), false )
176- .map (this ::toEmbeddingMatch )
148+ QueryArgs args = new QueryArgs ()
149+ .sortByAscending (SCORE_FIELD_NAME )
150+ .param ("DIALECT" , "2" )
151+ .param ("BLOB" , referenceEmbedding .vector ());
152+ Uni <SearchQueryResponse > search = ds .search ()
153+ .ftSearch (schema .getIndexName (), query , args );
154+ SearchQueryResponse response = search .await ().indefinitely ();
155+ return response .documents ().stream ().map (this ::extractEmbeddingMatch )
177156 .filter (embeddingMatch -> embeddingMatch .score () >= minScore )
178157 .collect (toList ());
179158 }
180159
181- /**
182- * Deletes all keys with the prefix that is used by this embedding store.
183- */
184- public void deleteAll () {
185- KeyScanArgs args = new KeyScanArgs ().match (schema .getPrefix () + "*" );
186- Set <String > keysToDelete = ds .key ().scan (args ).toMulti ().collect ().asSet ().await ().indefinitely ();
187- if (!keysToDelete .isEmpty ()) {
188- Request command = Request .cmd (Command .DEL );
189- keysToDelete .forEach (command ::arg );
190- ds .getRedis ().send (command ).await ().indefinitely ();
191- LOG .debug ("Deleted " + keysToDelete .size () + " keys" );
192- }
193- }
194-
195- public static byte [] toByteArray (float [] input ) {
196- byte [] bytes = new byte [Float .BYTES * input .length ];
197- ByteBuffer .wrap (bytes ).order (ByteOrder .LITTLE_ENDIAN ).asFloatBuffer ().put (input );
198- return bytes ;
199- }
200-
201- private EmbeddingMatch <TextSegment > toEmbeddingMatch (Response response ) {
202- String document = response .get (EXTRA_ATTRIBUTES ).get ("$" ).toString ();
160+ private EmbeddingMatch <TextSegment > extractEmbeddingMatch (Document document ) {
203161 try {
204- JsonNode jsonNode = QuarkusJsonCodecFactory .ObjectMapperHolder .MAPPER .readTree (document );
162+ JsonNode jsonNode = QuarkusJsonCodecFactory .ObjectMapperHolder .MAPPER
163+ .readTree (document .property ("$" ).asString ());
205164 JsonNode embedded = jsonNode .get (schema .getScalarFieldName ());
206165 Embedding embedding = new Embedding (
207166 Json .fromJson (jsonNode .get (schema .getVectorFieldName ()).toString (), float [].class ));
208- double score = (2 - response .get (EXTRA_ATTRIBUTES ).get (SCORE_FIELD_NAME ).toDouble ()) / 2 ;
209- String id = response .get (ID ).toString ().substring (schema .getPrefix ().length ());
210- List <String > metadataFields = schema .getMetadataFields ();
211- Map <String , String > metadata = metadataFields .stream ()
167+ double score = (2 - document .property (SCORE_FIELD_NAME ).asDouble ()) / 2 ;
168+ String id = document .key ().substring (schema .getPrefix ().length ());
169+ Map <String , String > metadata = schema .getMetadataFields ().stream ()
212170 .filter (jsonNode ::has )
213171 .collect (Collectors .toMap (metadataFieldName -> metadataFieldName ,
214172 (name ) -> jsonNode .get (name ).asText ()));
@@ -220,6 +178,20 @@ private EmbeddingMatch<TextSegment> toEmbeddingMatch(Response response) {
220178
221179 }
222180
181+ /**
182+ * Deletes all keys with the prefix that is used by this embedding store.
183+ */
184+ public void deleteAll () {
185+ KeyScanArgs args = new KeyScanArgs ().match (schema .getPrefix () + "*" );
186+ Set <String > keysToDelete = ds .key ().scan (args ).toMulti ().collect ().asSet ().await ().indefinitely ();
187+ if (!keysToDelete .isEmpty ()) {
188+ Request command = Request .cmd (Command .DEL );
189+ keysToDelete .forEach (command ::arg );
190+ ds .getRedis ().send (command ).await ().indefinitely ();
191+ LOG .debug ("Deleted " + keysToDelete .size () + " keys" );
192+ }
193+ }
194+
223195 public static class Builder {
224196
225197 private ReactiveRedisDataSource redisClient ;
0 commit comments