Skip to content

Commit 3b15ff9

Browse files
DOC-5557 added local copies of source files
1 parent 14c09c8 commit 3b15ff9

File tree

2 files changed

+682
-0
lines changed

2 files changed

+682
-0
lines changed
Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
// EXAMPLE: home_query_vec
2+
package io.redis.examples.async;
3+
4+
// STEP_START import
5+
// Lettuce client and query engine classes.
6+
import io.lettuce.core.*;
7+
import io.lettuce.core.api.StatefulRedisConnection;
8+
import io.lettuce.core.api.async.RedisAsyncCommands;
9+
import io.lettuce.core.search.arguments.*;
10+
import io.lettuce.core.search.SearchReply;
11+
import io.lettuce.core.json.JsonParser;
12+
import io.lettuce.core.json.JsonObject;
13+
import io.lettuce.core.json.JsonPath;
14+
15+
// Standard library classes for data manipulation and
16+
// asynchronous programming.
17+
import java.nio.ByteBuffer;
18+
import java.nio.ByteOrder;
19+
import java.nio.charset.StandardCharsets;
20+
import java.util.*;
21+
import java.util.concurrent.CompletableFuture;
22+
23+
// DJL classes for model loading and inference.
24+
import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory;
25+
import ai.djl.inference.Predictor;
26+
import ai.djl.repository.zoo.Criteria;
27+
import ai.djl.training.util.ProgressBar;
28+
// STEP_END
29+
// REMOVE_START
30+
import org.junit.jupiter.api.Test;
31+
import static org.assertj.core.api.Assertions.assertThat;
32+
// REMOVE_END
33+
34+
public class HomeQueryVecExample {
35+
36+
// STEP_START helper_method
37+
private ByteBuffer floatArrayToByteBuffer(float[] vector) {
38+
ByteBuffer buffer = ByteBuffer.allocate(vector.length * 4).order(ByteOrder.LITTLE_ENDIAN);
39+
for (float value : vector) {
40+
buffer.putFloat(value);
41+
}
42+
return (ByteBuffer) buffer.flip();
43+
}
44+
// STEP_END
45+
46+
// REMOVE_START
47+
@Test
48+
// REMOVE_END
49+
public void run() {
50+
// STEP_START model
51+
Predictor<String, float[]> predictor = null;
52+
53+
try {
54+
Criteria<String, float[]> criteria = Criteria.builder().setTypes(String.class, float[].class)
55+
.optModelUrls("djl://ai.djl.huggingface.pytorch/sentence-transformers/all-MiniLM-L6-v2")
56+
.optEngine("PyTorch").optTranslatorFactory(new TextEmbeddingTranslatorFactory())
57+
.optProgress(new ProgressBar()).build();
58+
59+
predictor = criteria.loadModel().newPredictor();
60+
} catch (Exception e) {
61+
// ...
62+
}
63+
// STEP_END
64+
65+
// STEP_START connect
66+
RedisClient redisClient = RedisClient.create("redis://localhost:6379");
67+
68+
try (StatefulRedisConnection<String, String> connection = redisClient.connect();
69+
StatefulRedisConnection<ByteBuffer, ByteBuffer> binConnection = redisClient.connect(new ByteBufferCodec())) {
70+
RedisAsyncCommands<String, String> asyncCommands = connection.async();
71+
RedisAsyncCommands<ByteBuffer, ByteBuffer> binAsyncCommands = binConnection.async();
72+
// ...
73+
// STEP_END
74+
// REMOVE_START
75+
asyncCommands.del("doc:1", "doc:2", "doc:3", "jdoc:1", "jdoc:2", "jdoc:3").toCompletableFuture().join();
76+
77+
asyncCommands.ftDropindex("vector_idx").exceptionally(ex -> null) // Ignore errors if the index doesn't exist.
78+
.toCompletableFuture().join();
79+
80+
asyncCommands.ftDropindex("vector_json_idx").exceptionally(ex -> null) // Ignore errors if the index doesn't exist.
81+
.toCompletableFuture().join();
82+
// REMOVE_END
83+
84+
// STEP_START create_index
85+
List<FieldArgs<String>> schema = Arrays.asList(TextFieldArgs.<String> builder().name("content").build(),
86+
TagFieldArgs.<String> builder().name("genre").build(),
87+
VectorFieldArgs.<String> builder().name("embedding").hnsw().type(VectorFieldArgs.VectorType.FLOAT32)
88+
.dimensions(384).distanceMetric(VectorFieldArgs.DistanceMetric.L2).build());
89+
90+
CreateArgs<String, String> createArgs = CreateArgs.<String, String> builder().on(CreateArgs.TargetType.HASH)
91+
.withPrefix("doc:").build();
92+
93+
CompletableFuture<Void> createIndex = asyncCommands.ftCreate("vector_idx", createArgs, schema)
94+
// REMOVE_START
95+
.thenApply(result -> {
96+
System.out.println(result); // >>> OK
97+
98+
assertThat(result).isEqualTo("OK");
99+
return result;
100+
})
101+
// REMOVE_END
102+
.thenAccept(System.out::println).toCompletableFuture();
103+
// STEP_END
104+
createIndex.join();
105+
106+
// STEP_START add_data
107+
String sentence1 = "That is a very happy person";
108+
109+
Map<ByteBuffer, ByteBuffer> doc1 = new HashMap<>();
110+
doc1.put(ByteBuffer.wrap("content".getBytes()), ByteBuffer.wrap(sentence1.getBytes()));
111+
doc1.put(ByteBuffer.wrap("genre".getBytes()), ByteBuffer.wrap("persons".getBytes()));
112+
113+
try {
114+
doc1.put(ByteBuffer.wrap("embedding".getBytes()), floatArrayToByteBuffer(predictor.predict(sentence1)));
115+
} catch (Exception e) {
116+
// ...
117+
}
118+
119+
CompletableFuture<Long> addDoc1 = binAsyncCommands.hset(ByteBuffer.wrap("doc:1".getBytes()), doc1)
120+
.thenApply(result -> {
121+
// REMOVE_START
122+
assertThat(result).isEqualTo(3L);
123+
// REMOVE_END
124+
System.out.println(result); // >>> 3
125+
return result;
126+
}).toCompletableFuture();
127+
128+
String sentence2 = "That is a happy dog";
129+
130+
Map<ByteBuffer, ByteBuffer> doc2 = new HashMap<>();
131+
doc2.put(ByteBuffer.wrap("content".getBytes()), ByteBuffer.wrap(sentence2.getBytes()));
132+
doc2.put(ByteBuffer.wrap("genre".getBytes()), ByteBuffer.wrap("pets".getBytes()));
133+
134+
try {
135+
doc2.put(ByteBuffer.wrap("embedding".getBytes()), floatArrayToByteBuffer(predictor.predict(sentence2)));
136+
} catch (Exception e) {
137+
// ...
138+
}
139+
140+
CompletableFuture<Long> addDoc2 = binAsyncCommands.hset(ByteBuffer.wrap("doc:2".getBytes()), doc2)
141+
.thenApply(result -> {
142+
// REMOVE_START
143+
assertThat(result).isEqualTo(3L);
144+
// REMOVE_END
145+
System.out.println(result); // >>> 3
146+
return result;
147+
}).toCompletableFuture();
148+
149+
String sentence3 = "Today is a sunny day";
150+
151+
Map<ByteBuffer, ByteBuffer> doc3 = new HashMap<>();
152+
doc3.put(ByteBuffer.wrap("content".getBytes()), ByteBuffer.wrap(sentence3.getBytes()));
153+
doc3.put(ByteBuffer.wrap("genre".getBytes()), ByteBuffer.wrap("weather".getBytes()));
154+
155+
try {
156+
doc3.put(ByteBuffer.wrap("embedding".getBytes()), floatArrayToByteBuffer(predictor.predict(sentence3)));
157+
} catch (Exception e) {
158+
// ...
159+
}
160+
161+
CompletableFuture<Long> addDoc3 = binAsyncCommands.hset(ByteBuffer.wrap("doc:3".getBytes()), doc3)
162+
.thenApply(result -> {
163+
// REMOVE_START
164+
assertThat(result).isEqualTo(3L);
165+
// REMOVE_END
166+
System.out.println(result); // >>> 3
167+
return result;
168+
}).toCompletableFuture();
169+
// STEP_END
170+
CompletableFuture.allOf(addDoc1, addDoc2, addDoc3).join();
171+
172+
// STEP_START query
173+
String query = "That is a happy person";
174+
float[] queryEmbedding = null;
175+
176+
try {
177+
queryEmbedding = predictor.predict(query);
178+
} catch (Exception e) {
179+
// ...
180+
}
181+
182+
SearchArgs<ByteBuffer, ByteBuffer> searchArgs = SearchArgs.<ByteBuffer, ByteBuffer> builder()
183+
.param(ByteBuffer.wrap("vec".getBytes()), floatArrayToByteBuffer(queryEmbedding))
184+
.returnField(ByteBuffer.wrap("content".getBytes()))
185+
.returnField(ByteBuffer.wrap("vector_distance".getBytes()))
186+
.sortBy(SortByArgs.<ByteBuffer> builder().attribute(ByteBuffer.wrap("vector_distance".getBytes())).build())
187+
.build();
188+
189+
CompletableFuture<SearchReply<ByteBuffer, ByteBuffer>> hashQuery = binAsyncCommands
190+
.ftSearch(ByteBuffer.wrap("vector_idx".getBytes()),
191+
ByteBuffer.wrap("*=>[KNN 3 @embedding $vec AS vector_distance]".getBytes()), searchArgs)
192+
.thenApply(result -> {
193+
List<SearchReply.SearchResult<ByteBuffer, ByteBuffer>> results = result.getResults();
194+
195+
results.forEach(r -> {
196+
String id = StandardCharsets.UTF_8.decode(r.getId()).toString();
197+
String content = StandardCharsets.UTF_8
198+
.decode(r.getFields().get(ByteBuffer.wrap("content".getBytes()))).toString();
199+
String distance = StandardCharsets.UTF_8
200+
.decode(r.getFields().get(ByteBuffer.wrap("vector_distance".getBytes()))).toString();
201+
202+
System.out.println("ID: " + id + ", Content: " + content + ", Distance: " + distance);
203+
});
204+
// >>> ID: doc:1, Content: That is a very happy person, Distance: 0.114169836044
205+
// >>> ID: doc:2, Content: That is a happy dog, Distance: 0.610845506191
206+
// >>> ID: doc:3, Content: Today is a sunny day, Distance: 1.48624765873
207+
208+
// REMOVE_START
209+
assertThat(result.getCount()).isEqualTo(3);
210+
// REMOVE_END
211+
return result;
212+
}).toCompletableFuture();
213+
// STEP_END
214+
hashQuery.join();
215+
216+
// STEP_START json_schema
217+
List<FieldArgs<String>> jsonSchema = Arrays.asList(
218+
TextFieldArgs.<String> builder().name("$.content").as("content").build(),
219+
TagFieldArgs.<String> builder().name("$.genre").as("genre").build(),
220+
VectorFieldArgs.<String> builder().name("$.embedding").as("embedding").hnsw()
221+
.type(VectorFieldArgs.VectorType.FLOAT32).dimensions(384)
222+
.distanceMetric(VectorFieldArgs.DistanceMetric.L2).build());
223+
224+
CreateArgs<String, String> jsonCreateArgs = CreateArgs.<String, String> builder().on(CreateArgs.TargetType.JSON)
225+
.withPrefix("jdoc:").build();
226+
227+
CompletableFuture<Void> jsonCreateIndex = asyncCommands.ftCreate("vector_json_idx", jsonCreateArgs, jsonSchema)
228+
// REMOVE_START
229+
.thenApply(result -> {
230+
System.out.println(result); // >>> OK
231+
232+
assertThat(result).isEqualTo("OK");
233+
return result;
234+
})
235+
// REMOVE_END
236+
.thenAccept(System.out::println).toCompletableFuture();
237+
// STEP_END
238+
jsonCreateIndex.join();
239+
240+
// STEP_START json_data
241+
JsonParser parser = asyncCommands.getJsonParser();
242+
243+
String jSentence1 = "\"That is a very happy person\"";
244+
245+
JsonObject jDoc1 = parser.createJsonObject();
246+
jDoc1.put("content", parser.createJsonValue(jSentence1));
247+
jDoc1.put("genre", parser.createJsonValue("\"persons\""));
248+
249+
try {
250+
jDoc1.put("embedding", parser.createJsonValue(Arrays.toString(predictor.predict(jSentence1))));
251+
} catch (Exception e) {
252+
// ...
253+
}
254+
255+
CompletableFuture<String> jsonAddDoc1 = asyncCommands.jsonSet("jdoc:1", JsonPath.ROOT_PATH, jDoc1)
256+
.thenApply(result -> {
257+
// REMOVE_START
258+
assertThat(result).isEqualTo("OK");
259+
// REMOVE_END
260+
System.out.println(result); // >>> OK
261+
return result;
262+
}).toCompletableFuture();
263+
264+
String jSentence2 = "\"That is a happy dog\"";
265+
266+
JsonObject jDoc2 = parser.createJsonObject();
267+
jDoc2.put("content", parser.createJsonValue(jSentence2));
268+
jDoc2.put("genre", parser.createJsonValue("\"pets\""));
269+
270+
try {
271+
jDoc2.put("embedding", parser.createJsonValue(Arrays.toString(predictor.predict(jSentence2))));
272+
} catch (Exception e) {
273+
// ...
274+
}
275+
276+
CompletableFuture<String> jsonAddDoc2 = asyncCommands.jsonSet("jdoc:2", JsonPath.ROOT_PATH, jDoc2)
277+
.thenApply(result -> {
278+
// REMOVE_START
279+
assertThat(result).isEqualTo("OK");
280+
// REMOVE_END
281+
System.out.println(result); // >>> OK
282+
return result;
283+
}).toCompletableFuture();
284+
285+
String jSentence3 = "\"Today is a sunny day\"";
286+
287+
JsonObject jDoc3 = parser.createJsonObject();
288+
jDoc3.put("content", parser.createJsonValue(jSentence3));
289+
jDoc3.put("genre", parser.createJsonValue("\"weather\""));
290+
291+
try {
292+
jDoc3.put("embedding", parser.createJsonValue(Arrays.toString(predictor.predict(jSentence3))));
293+
} catch (Exception e) {
294+
// ...
295+
}
296+
297+
CompletableFuture<String> jsonAddDoc3 = asyncCommands.jsonSet("jdoc:3", JsonPath.ROOT_PATH, jDoc3)
298+
.thenApply(result -> {
299+
// REMOVE_START
300+
assertThat(result).isEqualTo("OK");
301+
// REMOVE_END
302+
System.out.println(result); // >>> OK
303+
return result;
304+
}).toCompletableFuture();
305+
// STEP_END
306+
CompletableFuture.allOf(jsonAddDoc1, jsonAddDoc2, jsonAddDoc3).join();
307+
308+
// STEP_START json_query
309+
String jQuery = "That is a happy person";
310+
float[] jsonQueryEmbedding = null;
311+
312+
try {
313+
jsonQueryEmbedding = predictor.predict(jQuery);
314+
} catch (Exception e) {
315+
// ...
316+
}
317+
318+
SearchArgs<ByteBuffer, ByteBuffer> jsonSearchArgs = SearchArgs.<ByteBuffer, ByteBuffer> builder()
319+
.param(ByteBuffer.wrap("vec".getBytes()), floatArrayToByteBuffer(jsonQueryEmbedding))
320+
.returnField(ByteBuffer.wrap("content".getBytes()))
321+
.returnField(ByteBuffer.wrap("vector_distance".getBytes()))
322+
.sortBy(SortByArgs.<ByteBuffer> builder().attribute(ByteBuffer.wrap("vector_distance".getBytes())).build())
323+
.build();
324+
325+
CompletableFuture<SearchReply<ByteBuffer, ByteBuffer>> jsonQuery = binAsyncCommands
326+
.ftSearch(ByteBuffer.wrap("vector_json_idx".getBytes()),
327+
ByteBuffer.wrap("*=>[KNN 3 @embedding $vec AS vector_distance]".getBytes()), jsonSearchArgs)
328+
.thenApply(result -> {
329+
List<SearchReply.SearchResult<ByteBuffer, ByteBuffer>> results = result.getResults();
330+
331+
results.forEach(r -> {
332+
String id = StandardCharsets.UTF_8.decode(r.getId()).toString();
333+
String content = StandardCharsets.UTF_8
334+
.decode(r.getFields().get(ByteBuffer.wrap("content".getBytes()))).toString();
335+
String distance = StandardCharsets.UTF_8
336+
.decode(r.getFields().get(ByteBuffer.wrap("vector_distance".getBytes()))).toString();
337+
338+
System.out.println("ID: " + id + ", Content: " + content + ", Distance: " + distance);
339+
});
340+
// >>> ID: jdoc:1, Content: "That is a very happy person", Distance:0.628328084946
341+
// >>> ID: jdoc:2, Content: That is a happy dog, Distance: 0.895147025585
342+
// >>> ID: jdoc:3, Content: "Today is a sunny day", Distance: 1.49569523335
343+
344+
// REMOVE_START
345+
assertThat(result.getCount()).isEqualTo(3);
346+
// REMOVE_END
347+
return result;
348+
}).toCompletableFuture();
349+
// STEP_END
350+
jsonQuery.join();
351+
} finally {
352+
redisClient.shutdown();
353+
}
354+
}
355+
356+
}

0 commit comments

Comments
 (0)