Skip to content

Commit 141d437

Browse files
committed
feat: add rerank() argument to nearText / nearAudio
1 parent fd8705a commit 141d437

File tree

11 files changed

+147
-7
lines changed

11 files changed

+147
-7
lines changed

src/it/java/io/weaviate/integration/DataITest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
import io.weaviate.client6.v1.api.collections.data.BatchReference;
2525
import io.weaviate.client6.v1.api.collections.data.DeleteManyResponse;
2626
import io.weaviate.client6.v1.api.collections.data.Reference;
27+
import io.weaviate.client6.v1.api.collections.query.Filter;
2728
import io.weaviate.client6.v1.api.collections.query.Metadata;
2829
import io.weaviate.client6.v1.api.collections.query.Metadata.MetadataField;
2930
import io.weaviate.client6.v1.api.collections.query.QueryMetadata;
3031
import io.weaviate.client6.v1.api.collections.query.QueryReference;
31-
import io.weaviate.client6.v1.api.collections.query.Filter;
3232
import io.weaviate.containers.Container;
3333

3434
public class DataITest extends ConcurrentTest {

src/it/java/io/weaviate/integration/SearchITest.java

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
import io.weaviate.ConcurrentTest;
2222
import io.weaviate.client6.v1.api.WeaviateApiException;
2323
import io.weaviate.client6.v1.api.WeaviateClient;
24+
import io.weaviate.client6.v1.api.collections.Generative;
2425
import io.weaviate.client6.v1.api.collections.ObjectMetadata;
2526
import io.weaviate.client6.v1.api.collections.Property;
2627
import io.weaviate.client6.v1.api.collections.ReferenceProperty;
28+
import io.weaviate.client6.v1.api.collections.Reranker;
2729
import io.weaviate.client6.v1.api.collections.VectorConfig;
2830
import io.weaviate.client6.v1.api.collections.Vectors;
2931
import io.weaviate.client6.v1.api.collections.WeaviateMetadata;
@@ -37,8 +39,10 @@
3739
import io.weaviate.client6.v1.api.collections.query.Metadata;
3840
import io.weaviate.client6.v1.api.collections.query.QueryMetadata;
3941
import io.weaviate.client6.v1.api.collections.query.QueryResponseGroup;
42+
import io.weaviate.client6.v1.api.collections.query.Rerank;
4043
import io.weaviate.client6.v1.api.collections.query.SortBy;
4144
import io.weaviate.client6.v1.api.collections.query.Target;
45+
import io.weaviate.client6.v1.api.collections.rerankers.DummyReranker;
4246
import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw;
4347
import io.weaviate.client6.v1.api.collections.vectorindex.MultiVector;
4448
import io.weaviate.containers.Container;
@@ -52,7 +56,7 @@ public class SearchITest extends ConcurrentTest {
5256
Weaviate.custom()
5357
.withModel2VecUrl(Model2Vec.URL)
5458
.withImageInference(Img2VecNeural.URL, Img2VecNeural.MODULE)
55-
.addModules("generative-dummy")
59+
.addModules(Generative.Kind.DUMMY.jsonValue(), Reranker.Kind.DUMMY.jsonValue())
5660
.build(),
5761
Container.IMG2VEC_NEURAL,
5862
Container.MODEL2VEC);
@@ -741,7 +745,7 @@ public void teset_filterPropertyLength() throws IOException {
741745
// Assertions
742746
Assertions.assertThat(got.objects()).hasSize(2);
743747
}
744-
748+
745749
/**
746750
* Ensure the client respects server's configuration for max gRPC size:
747751
* we create a server with 1-byte message size and try to send a large payload
@@ -768,4 +772,30 @@ public void test_maxGrpcMessageSize() throws Exception {
768772
}).isInstanceOf(io.grpc.StatusRuntimeException.class);
769773
}
770774
}
775+
776+
@Test
777+
public void test_rerankQueries() throws IOException {
778+
// Arrange
779+
var nsThigns = ns("Things");
780+
781+
var things = client.collections.create(nsThigns,
782+
c -> c
783+
.properties(Property.text("title"), Property.integer("price"))
784+
.vectorConfig(VectorConfig.text2vecModel2Vec(
785+
t2v -> t2v.sourceProperties("title", "price")))
786+
.rerankerModules(new DummyReranker()));
787+
788+
things.data.insertMany(
789+
Map.of("title", "Ergonomic chair", "price", 269),
790+
Map.of("title", "Height-adjustable desk", "price", 349));
791+
792+
// Act
793+
var got = things.query.nearText(
794+
"office supplies",
795+
nt -> nt.rerank(Rerank.by("price",
796+
rank -> rank.query("cheaper first"))));
797+
798+
// Assert: ranking not important really, just that the request was valid.
799+
Assertions.assertThat(got.objects()).hasSize(2);
800+
}
771801
}

src/main/java/io/weaviate/client6/v1/api/collections/CollectionConfig.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,19 +276,23 @@ public void write(JsonWriter out, CollectionConfig value) throws IOException {
276276
// Reranker and Generative module configs belong to the "moduleConfig".
277277
var rerankerModules = jsonObject.remove("rerankerModules").getAsJsonArray();
278278
var generativeModule = jsonObject.remove("generativeModule");
279-
if (!rerankerModules.isEmpty() || !generativeModule.isJsonNull()) {
280-
var modules = new JsonObject();
281279

280+
var modules = new JsonObject();
281+
if (!rerankerModules.isEmpty()) {
282282
// Copy configuration for each reranker module.
283283
rerankerModules.forEach(reranker -> {
284284
reranker.getAsJsonObject().entrySet()
285285
.stream().forEach(entry -> modules.add(entry.getKey(), entry.getValue()));
286286
});
287+
}
287288

289+
if (!generativeModule.isJsonNull()) {
288290
// Copy configuration for each generative module.
289291
generativeModule.getAsJsonObject().entrySet()
290292
.stream().forEach(entry -> modules.add(entry.getKey(), entry.getValue()));
293+
}
291294

295+
if (!modules.isEmpty()) {
292296
jsonObject.add("moduleConfig", modules);
293297
}
294298

src/main/java/io/weaviate/client6/v1/api/collections/Reranker.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import com.google.gson.stream.JsonWriter;
1515

1616
import io.weaviate.client6.v1.api.collections.rerankers.CohereReranker;
17+
import io.weaviate.client6.v1.api.collections.rerankers.DummyReranker;
1718
import io.weaviate.client6.v1.api.collections.rerankers.JinaAiReranker;
1819
import io.weaviate.client6.v1.api.collections.rerankers.NvidiaReranker;
1920
import io.weaviate.client6.v1.api.collections.rerankers.TransformersReranker;
@@ -24,6 +25,7 @@
2425

2526
public interface Reranker extends TaggedUnion<Reranker.Kind, Object> {
2627
public enum Kind implements JsonEnum<Kind> {
28+
DUMMY("reranker-dummy"),
2729
JINAAI("reranker-jinaai"),
2830
VOYAGEAI("reranker-voyageai"),
2931
NVIDIA("reranker-nvidia"),
@@ -120,6 +122,11 @@ private final void addAdapter(Gson gson, Reranker.Kind kind, Class<? extends Rer
120122

121123
private final void init(Gson gson) {
122124
addAdapter(gson, Reranker.Kind.COHERE, CohereReranker.class);
125+
addAdapter(gson, Reranker.Kind.JINAAI, JinaAiReranker.class);
126+
addAdapter(gson, Reranker.Kind.NVIDIA, NvidiaReranker.class);
127+
addAdapter(gson, Reranker.Kind.TRANSFORMERS, TransformersReranker.class);
128+
addAdapter(gson, Reranker.Kind.VOYAGEAI, VoyageAiReranker.class);
129+
addAdapter(gson, Reranker.Kind.DUMMY, DummyReranker.class);
123130
}
124131

125132
@SuppressWarnings("unchecked")

src/main/java/io/weaviate/client6/v1/api/collections/query/BaseVectorSearchBuilder.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ abstract class BaseVectorSearchBuilder<SelfT extends BaseVectorSearchBuilder<Sel
66
// Optional query parameters.
77
Float distance;
88
Float certainty;
9+
Rerank rerank;
910

1011
/**
1112
* Discard objects whose vectors are further away
@@ -35,4 +36,13 @@ public SelfT certainty(float certainty) {
3536
this.certainty = certainty;
3637
return (SelfT) this;
3738
}
39+
40+
/**
41+
* Control the ranking of the query results.
42+
*/
43+
@SuppressWarnings("unchecked")
44+
public SelfT rerank(Rerank rerank) {
45+
this.rerank = rerank;
46+
return (SelfT) this;
47+
}
3848
}

src/main/java/io/weaviate/client6/v1/api/collections/query/NearAudio.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch;
1111
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;
1212

13-
public record NearAudio(Target searchTarget, Float distance, Float certainty, BaseQueryOptions common)
13+
public record NearAudio(Target searchTarget, Float distance, Float certainty, Rerank rerank, BaseQueryOptions common)
1414
implements QueryOperator, AggregateObjectFilter {
1515

1616
public static NearAudio of(String audio) {
@@ -34,6 +34,7 @@ public NearAudio(Builder builder) {
3434
builder.media,
3535
builder.distance,
3636
builder.certainty,
37+
builder.rerank,
3738
builder.baseOptions());
3839
}
3940

@@ -80,6 +81,7 @@ private WeaviateProtoBaseSearch.NearAudioSearch.Builder protoBuilder() {
8081
} else if (distance != null) {
8182
nearAudio.setDistance(distance);
8283
}
84+
8385
return nearAudio;
8486
}
8587
}

src/main/java/io/weaviate/client6/v1/api/collections/query/NearText.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch;
1414
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;
1515

16-
public record NearText(Target searchTarget, Float distance, Float certainty, Move moveTo,
16+
public record NearText(
17+
Target searchTarget,
18+
Float distance,
19+
Float certainty,
20+
Rerank rerank,
21+
Move moveTo,
1722
Move moveAway,
1823
BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter {
1924

@@ -38,6 +43,7 @@ public NearText(Builder builder) {
3843
builder.searchTarget,
3944
builder.distance,
4045
builder.certainty,
46+
builder.rerank,
4147
builder.moveTo,
4248
builder.moveAway,
4349
builder.baseOptions());

src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;
44

55
public interface QueryOperator {
6+
default BaseQueryOptions common() {
7+
return null;
8+
}
9+
10+
default Rerank rerank() {
11+
return null;
12+
}
13+
614
/** Append QueryOperator to the request message. */
715
void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req);
816
}

src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ public static <PropertiesT> WeaviateProtoSearchGet.SearchRequest marshal(
2828
message.setUses125Api(true);
2929
message.setUses123Api(true);
3030
message.setCollection(collection.collectionName());
31+
32+
if (request.operator.common() != null) {
33+
request.operator.common().appendTo(message);
34+
}
35+
if (request.operator.rerank() != null) {
36+
request.operator.rerank().appendTo(message);
37+
}
3138
request.operator.appendTo(message);
3239

3340
if (defaults.tenant() != null) {
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package io.weaviate.client6.v1.api.collections.query;
2+
3+
import java.util.function.Function;
4+
5+
import io.weaviate.client6.v1.internal.ObjectBuilder;
6+
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;
7+
8+
public record Rerank(String property, String query) {
9+
10+
public static Rerank by(String property) {
11+
return by(property, ObjectBuilder.identity());
12+
}
13+
14+
public static Rerank by(String property, Function<Builder, ObjectBuilder<Rerank>> fn) {
15+
return fn.apply(new Builder(property)).build();
16+
}
17+
18+
void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req) {
19+
var rerank = WeaviateProtoSearchGet.Rerank.newBuilder()
20+
.setProperty(property);
21+
22+
if (query != null) {
23+
rerank.setQuery(query);
24+
}
25+
req.setRerank(rerank);
26+
}
27+
28+
public Rerank(Builder builder) {
29+
this(builder.property, builder.query);
30+
}
31+
32+
public static class Builder implements ObjectBuilder<Rerank> {
33+
private final String property;
34+
private String query;
35+
36+
public Builder(String property) {
37+
this.property = property;
38+
}
39+
40+
public Builder query(String query) {
41+
this.query = query;
42+
return this;
43+
}
44+
45+
@Override
46+
public Rerank build() {
47+
return new Rerank(this);
48+
}
49+
}
50+
}

0 commit comments

Comments
 (0)