Skip to content

Commit 2daac44

Browse files
authored
Merge pull request #16 from jmartisk/redis-vectors
Make use of new stuff in the Redis DataSource API
2 parents a4f9b57 + bcbcc64 commit 2daac44

File tree

7 files changed

+69
-139
lines changed

7 files changed

+69
-139
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
<maven.compiler.release>17</maven.compiler.release>
3434
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
3535
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
36-
<quarkus.version>3.5.3</quarkus.version>
36+
<quarkus.version>3.6.0</quarkus.version>
3737
<langchain4j.version>0.24.0</langchain4j.version>
3838
<quarkus-poi.version>2.0.4</quarkus-poi.version>
3939
<assertj.version>3.24.2</assertj.version>

redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/MetricType.java

Lines changed: 0 additions & 22 deletions
This file was deleted.

redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/RedisEmbeddingStore.java

Lines changed: 38 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,11 @@
55
import static java.util.Collections.singletonList;
66
import static java.util.stream.Collectors.toList;
77

8-
import java.nio.ByteBuffer;
9-
import java.nio.ByteOrder;
108
import java.util.HashMap;
119
import java.util.List;
1210
import java.util.Map;
1311
import java.util.Set;
1412
import java.util.stream.Collectors;
15-
import java.util.stream.StreamSupport;
1613

1714
import org.jboss.logging.Logger;
1815

@@ -30,10 +27,13 @@
3027
import io.quarkus.redis.datasource.ReactiveRedisDataSource;
3128
import io.quarkus.redis.datasource.json.ReactiveJsonCommands;
3229
import io.quarkus.redis.datasource.keys.KeyScanArgs;
30+
import io.quarkus.redis.datasource.search.CreateArgs;
31+
import io.quarkus.redis.datasource.search.Document;
32+
import io.quarkus.redis.datasource.search.QueryArgs;
33+
import io.quarkus.redis.datasource.search.SearchQueryResponse;
3334
import io.smallrye.mutiny.Uni;
3435
import io.vertx.mutiny.redis.client.Command;
3536
import io.vertx.mutiny.redis.client.Request;
36-
import io.vertx.mutiny.redis.client.Response;
3737

3838
public class RedisEmbeddingStore implements EmbeddingStore<TextSegment> {
3939

@@ -65,19 +65,12 @@ private void createIndexIfDoesNotExist() {
6565
}
6666
}).await().indefinitely();
6767
if (!indexes.contains(schema.getIndexName())) {
68-
// TODO: rewrite to use the typesafe data source API
69-
Request request = Request.cmd(Command.FT_CREATE)
70-
.arg(schema.getIndexName())
71-
.arg("ON")
72-
.arg("JSON")
73-
.arg("PREFIX")
74-
.arg("1")
75-
.arg(schema.getPrefix())
76-
.arg("SCHEMA");
77-
schema.defineFields(request);
78-
LOG.debug(
79-
"Creating index with command: " + request.toString().replaceAll("\r\n", " "));
80-
ds.getRedis().send(request).await().indefinitely();
68+
CreateArgs indexCreateArgs = new CreateArgs()
69+
.onJson()
70+
.prefixes(schema.getPrefix());
71+
schema.defineFields(indexCreateArgs);
72+
LOG.debug("Creating Redis index " + schema.getIndexName());
73+
ds.search().ftCreate(schema.getIndexName(), indexCreateArgs).await().indefinitely();
8174
} else {
8275
LOG.debug("Index in Redis already exists: " + schema.getIndexName());
8376
}
@@ -152,63 +145,28 @@ public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbeddi
152145
double minScore) {
153146
String queryTemplate = "*=>[ KNN %d @%s $BLOB AS %s ]";
154147
String query = format(queryTemplate, maxResults, schema.getVectorFieldName(), SCORE_FIELD_NAME);
155-
// TODO: rewrite to the data source api, but we need a new
156-
// method QueryArgs.param(String, byte[]) to get it working
157-
158-
// QueryArgs args = new QueryArgs()
159-
// .sortByAscending(SCORE_FIELD_NAME)
160-
// .param("DIALECT", "2")
161-
// .param("BLOB", toByteArray(referenceEmbedding.vector()));
162-
// Uni<SearchQueryResponse> search = ds.search()
163-
// .ftSearch(schema.getIndexName(), query, args);
164-
// SearchQueryResponse response = search.await().indefinitely();
165-
Request request = Request.cmd(Command.FT_SEARCH)
166-
.arg(schema.getIndexName())
167-
.arg(query)
168-
.arg("PARAMS")
169-
.arg("2")
170-
.arg("BLOB")
171-
.arg(toByteArray(referenceEmbedding.vector()))
172-
.arg("DIALECT")
173-
.arg("2");
174-
Response response = ds.getRedis().send(request).await().indefinitely();
175-
return StreamSupport.stream(response.get("results").spliterator(), false)
176-
.map(this::toEmbeddingMatch)
148+
QueryArgs args = new QueryArgs()
149+
.sortByAscending(SCORE_FIELD_NAME)
150+
.param("DIALECT", "2")
151+
.param("BLOB", referenceEmbedding.vector());
152+
Uni<SearchQueryResponse> search = ds.search()
153+
.ftSearch(schema.getIndexName(), query, args);
154+
SearchQueryResponse response = search.await().indefinitely();
155+
return response.documents().stream().map(this::extractEmbeddingMatch)
177156
.filter(embeddingMatch -> embeddingMatch.score() >= minScore)
178157
.collect(toList());
179158
}
180159

181-
/**
182-
* Deletes all keys with the prefix that is used by this embedding store.
183-
*/
184-
public void deleteAll() {
185-
KeyScanArgs args = new KeyScanArgs().match(schema.getPrefix() + "*");
186-
Set<String> keysToDelete = ds.key().scan(args).toMulti().collect().asSet().await().indefinitely();
187-
if (!keysToDelete.isEmpty()) {
188-
Request command = Request.cmd(Command.DEL);
189-
keysToDelete.forEach(command::arg);
190-
ds.getRedis().send(command).await().indefinitely();
191-
LOG.debug("Deleted " + keysToDelete.size() + " keys");
192-
}
193-
}
194-
195-
public static byte[] toByteArray(float[] input) {
196-
byte[] bytes = new byte[Float.BYTES * input.length];
197-
ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().put(input);
198-
return bytes;
199-
}
200-
201-
private EmbeddingMatch<TextSegment> toEmbeddingMatch(Response response) {
202-
String document = response.get(EXTRA_ATTRIBUTES).get("$").toString();
160+
private EmbeddingMatch<TextSegment> extractEmbeddingMatch(Document document) {
203161
try {
204-
JsonNode jsonNode = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER.readTree(document);
162+
JsonNode jsonNode = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER
163+
.readTree(document.property("$").asString());
205164
JsonNode embedded = jsonNode.get(schema.getScalarFieldName());
206165
Embedding embedding = new Embedding(
207166
Json.fromJson(jsonNode.get(schema.getVectorFieldName()).toString(), float[].class));
208-
double score = (2 - response.get(EXTRA_ATTRIBUTES).get(SCORE_FIELD_NAME).toDouble()) / 2;
209-
String id = response.get(ID).toString().substring(schema.getPrefix().length());
210-
List<String> metadataFields = schema.getMetadataFields();
211-
Map<String, String> metadata = metadataFields.stream()
167+
double score = (2 - document.property(SCORE_FIELD_NAME).asDouble()) / 2;
168+
String id = document.key().substring(schema.getPrefix().length());
169+
Map<String, String> metadata = schema.getMetadataFields().stream()
212170
.filter(jsonNode::has)
213171
.collect(Collectors.toMap(metadataFieldName -> metadataFieldName,
214172
(name) -> jsonNode.get(name).asText()));
@@ -220,6 +178,20 @@ private EmbeddingMatch<TextSegment> toEmbeddingMatch(Response response) {
220178

221179
}
222180

181+
/**
182+
* Deletes all keys with the prefix that is used by this embedding store.
183+
*/
184+
public void deleteAll() {
185+
KeyScanArgs args = new KeyScanArgs().match(schema.getPrefix() + "*");
186+
Set<String> keysToDelete = ds.key().scan(args).toMulti().collect().asSet().await().indefinitely();
187+
if (!keysToDelete.isEmpty()) {
188+
Request command = Request.cmd(Command.DEL);
189+
keysToDelete.forEach(command::arg);
190+
ds.getRedis().send(command).await().indefinitely();
191+
LOG.debug("Deleted " + keysToDelete.size() + " keys");
192+
}
193+
}
194+
223195
public static class Builder {
224196

225197
private ReactiveRedisDataSource redisClient;

redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/VectorAlgorithm.java

Lines changed: 0 additions & 6 deletions
This file was deleted.

redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/runtime/RedisEmbeddingStoreConfig.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import java.util.List;
66
import java.util.Optional;
77

8-
import io.quarkiverse.langchain4j.redis.MetricType;
9-
import io.quarkiverse.langchain4j.redis.VectorAlgorithm;
8+
import io.quarkus.redis.datasource.search.DistanceMetric;
9+
import io.quarkus.redis.datasource.search.VectorAlgorithm;
1010
import io.quarkus.runtime.annotations.ConfigRoot;
1111
import io.smallrye.config.ConfigMapping;
1212
import io.smallrye.config.WithDefault;
@@ -51,7 +51,7 @@ public interface RedisEmbeddingStoreConfig {
5151
* Metric used to compute the distance between two vectors.
5252
*/
5353
@WithDefault("COSINE")
54-
MetricType metricType();
54+
DistanceMetric distanceMetric();
5555

5656
/**
5757
* Name of the key that will be used to store the embedding vector.

redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/runtime/RedisEmbeddingStoreRecorder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public RedisEmbeddingStore apply(SyntheticCreationalContext<RedisEmbeddingStore>
4040
.metadataFields(config.metadataFields().orElse(Collections.emptyList()))
4141
.vectorAlgorithm(config.vectorAlgorithm())
4242
.dimension(config.dimension())
43-
.metricType(config.metricType())
43+
.metricType(config.distanceMetric())
4444
.build();
4545
builder.schema(schema);
4646

redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/runtime/RedisSchema.java

Lines changed: 26 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22

33
import java.util.List;
44

5-
import io.quarkiverse.langchain4j.redis.MetricType;
6-
import io.quarkiverse.langchain4j.redis.VectorAlgorithm;
7-
import io.vertx.mutiny.redis.client.Request;
5+
import io.quarkus.redis.datasource.search.CreateArgs;
6+
import io.quarkus.redis.datasource.search.DistanceMetric;
7+
import io.quarkus.redis.datasource.search.FieldOptions;
8+
import io.quarkus.redis.datasource.search.FieldType;
9+
import io.quarkus.redis.datasource.search.VectorAlgorithm;
10+
import io.quarkus.redis.datasource.search.VectorType;
811

912
public class RedisSchema {
1013

@@ -15,7 +18,7 @@ public class RedisSchema {
1518
private List<String> metadataFields;
1619
private VectorAlgorithm vectorAlgorithm;
1720
private Long dimension;
18-
private MetricType metricType;
21+
private DistanceMetric distanceMetric;
1922
private static final String JSON_PATH_PREFIX = "$.";
2023

2124
public RedisSchema(String indexName,
@@ -25,15 +28,15 @@ public RedisSchema(String indexName,
2528
List<String> metadataFields,
2629
VectorAlgorithm vectorAlgorithm,
2730
Long dimension,
28-
MetricType metricType) {
31+
DistanceMetric distanceMetric) {
2932
this.indexName = indexName;
3033
this.prefix = prefix;
3134
this.vectorFieldName = vectorFieldName;
3235
this.scalarFieldName = scalarFieldName;
3336
this.metadataFields = metadataFields;
3437
this.vectorAlgorithm = vectorAlgorithm;
3538
this.dimension = dimension;
36-
this.metricType = metricType;
39+
this.distanceMetric = distanceMetric;
3740
}
3841

3942
public String getIndexName() {
@@ -64,51 +67,34 @@ public Long getDimension() {
6467
return dimension;
6568
}
6669

67-
public MetricType getMetricType() {
68-
return metricType;
70+
public DistanceMetric getDistanceMetric() {
71+
return distanceMetric;
6972
}
7073

71-
public void defineFields(Request args) {
74+
public void defineFields(CreateArgs args) {
7275
defineTextField(args);
7376
defineVectorField(args);
7477
defineMetadataFields(args);
7578
}
7679

77-
private void defineMetadataFields(Request args) {
80+
private void defineMetadataFields(CreateArgs args) {
7881
for (String metadataField : metadataFields) {
79-
args.arg(JSON_PATH_PREFIX + metadataField);
80-
args.arg("AS");
81-
args.arg(metadataField);
82-
args.arg("TEXT");
83-
args.arg("WEIGHT");
84-
args.arg("1.0");
82+
args.indexedField(JSON_PATH_PREFIX + metadataField, metadataField, FieldType.TEXT, new FieldOptions().weight(1.0));
8583
}
8684
}
8785

88-
private void defineTextField(Request args) {
89-
args.arg(JSON_PATH_PREFIX + scalarFieldName);
90-
args.arg("AS");
91-
args.arg(scalarFieldName);
92-
args.arg("TEXT");
93-
args.arg("WEIGHT");
94-
args.arg("1.0");
86+
private void defineTextField(CreateArgs args) {
87+
args.indexedField(JSON_PATH_PREFIX + scalarFieldName, scalarFieldName, FieldType.TEXT, new FieldOptions().weight(1.0));
9588
}
9689

97-
private void defineVectorField(Request args) {
98-
args.arg(JSON_PATH_PREFIX + vectorFieldName);
99-
args.arg("AS");
100-
args.arg(vectorFieldName);
101-
args.arg("VECTOR");
102-
args.arg(vectorAlgorithm.name());
103-
args.arg("8");
104-
args.arg("DIM");
105-
args.arg(dimension);
106-
args.arg("DISTANCE_METRIC");
107-
args.arg(metricType.name());
108-
args.arg("TYPE");
109-
args.arg("FLOAT32");
110-
args.arg("INITIAL_CAP");
111-
args.arg("5");
90+
private void defineVectorField(CreateArgs args) {
91+
args.indexedField(JSON_PATH_PREFIX + vectorFieldName,
92+
vectorFieldName,
93+
FieldType.VECTOR, new FieldOptions()
94+
.vectorAlgorithm(vectorAlgorithm)
95+
.vectorType(VectorType.FLOAT32)
96+
.dimension(dimension.intValue())
97+
.distanceMetric(distanceMetric));
11298
}
11399

114100
public static class Builder {
@@ -119,7 +105,7 @@ public static class Builder {
119105
private List<String> metadataFields;
120106
private VectorAlgorithm vectorAlgorithm;
121107
private Long dimension;
122-
private MetricType metricType;
108+
private DistanceMetric metricType;
123109

124110
public Builder indexName(String indexName) {
125111
this.indexName = indexName;
@@ -160,7 +146,7 @@ public Builder dimension(Long dimension) {
160146
return this;
161147
}
162148

163-
public Builder metricType(MetricType metricType) {
149+
public Builder metricType(DistanceMetric metricType) {
164150
this.metricType = metricType;
165151
return this;
166152
}

0 commit comments

Comments
 (0)