diff --git a/pom.xml b/pom.xml index e4730037a..68630eefb 100644 --- a/pom.xml +++ b/pom.xml @@ -410,7 +410,7 @@ sign-artifacts - verify + deploy sign diff --git a/src/it/java/io/weaviate/integration/CollectionsITest.java b/src/it/java/io/weaviate/integration/CollectionsITest.java index e9d1afa87..4906d48d5 100644 --- a/src/it/java/io/weaviate/integration/CollectionsITest.java +++ b/src/it/java/io/weaviate/integration/CollectionsITest.java @@ -13,6 +13,7 @@ import io.weaviate.client6.v1.api.collections.DataType; import io.weaviate.client6.v1.api.collections.InvertedIndex; import io.weaviate.client6.v1.api.collections.Property; +import io.weaviate.client6.v1.api.collections.Quantization; import io.weaviate.client6.v1.api.collections.ReferenceProperty; import io.weaviate.client6.v1.api.collections.Replication; import io.weaviate.client6.v1.api.collections.VectorConfig; @@ -194,7 +195,7 @@ public void testInvalidCollectionName() throws IOException { } @Test - public void testNestedProperties() throws IOException, Exception { + public void testNestedProperties() throws IOException { var nsBuildings = ns("Buildings"); client.collections.create( @@ -227,4 +228,27 @@ public void testNestedProperties() throws IOException, Exception { .extracting(Property::dataTypes).extracting(types -> types.get(0)) .containsExactly(DataType.INT, DataType.NUMBER); } + + @Test + public void test_updateQuantization() throws IOException { + // Arrange + var nsThings = ns("Things"); + + var things = client.collections.create(nsThings, + c -> c.vectorConfig(VectorConfig.selfProvided( + self -> self.quantization(Quantization.uncompressed())))); + + // Act + things.config.update( + c -> c.vectorConfig(VectorConfig.selfProvided( + self -> self.quantization(Quantization.bq())))); + + // Assert + var config = things.config.get(); + Assertions.assertThat(config).get() + .extracting(CollectionConfig::vectors) + .extracting("default", InstanceOfAssertFactories.type(VectorConfig.class)) + .extracting(VectorConfig::quantization) + .returns(Quantization.Kind.BQ, Quantization::_kind); + } } diff --git a/src/it/java/io/weaviate/integration/DataITest.java b/src/it/java/io/weaviate/integration/DataITest.java index 246784cf5..d8ae6401c 100644 --- a/src/it/java/io/weaviate/integration/DataITest.java +++ b/src/it/java/io/weaviate/integration/DataITest.java @@ -24,11 +24,11 @@ import io.weaviate.client6.v1.api.collections.data.BatchReference; import io.weaviate.client6.v1.api.collections.data.DeleteManyResponse; import io.weaviate.client6.v1.api.collections.data.Reference; +import io.weaviate.client6.v1.api.collections.query.Filter; import io.weaviate.client6.v1.api.collections.query.Metadata; import io.weaviate.client6.v1.api.collections.query.Metadata.MetadataField; import io.weaviate.client6.v1.api.collections.query.QueryMetadata; import io.weaviate.client6.v1.api.collections.query.QueryReference; -import io.weaviate.client6.v1.api.collections.query.Filter; import io.weaviate.containers.Container; public class DataITest extends ConcurrentTest { diff --git a/src/it/java/io/weaviate/integration/SearchITest.java b/src/it/java/io/weaviate/integration/SearchITest.java index 7359057ae..3d3e73964 100644 --- a/src/it/java/io/weaviate/integration/SearchITest.java +++ b/src/it/java/io/weaviate/integration/SearchITest.java @@ -21,9 +21,11 @@ import io.weaviate.ConcurrentTest; import io.weaviate.client6.v1.api.WeaviateApiException; import io.weaviate.client6.v1.api.WeaviateClient; +import io.weaviate.client6.v1.api.collections.Generative; import io.weaviate.client6.v1.api.collections.ObjectMetadata; import io.weaviate.client6.v1.api.collections.Property; import io.weaviate.client6.v1.api.collections.ReferenceProperty; +import io.weaviate.client6.v1.api.collections.Reranker; import io.weaviate.client6.v1.api.collections.VectorConfig; import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.WeaviateMetadata; @@ -37,8 +39,10 @@ import io.weaviate.client6.v1.api.collections.query.Metadata; import io.weaviate.client6.v1.api.collections.query.QueryMetadata; import io.weaviate.client6.v1.api.collections.query.QueryResponseGroup; +import io.weaviate.client6.v1.api.collections.query.Rerank; import io.weaviate.client6.v1.api.collections.query.SortBy; import io.weaviate.client6.v1.api.collections.query.Target; +import io.weaviate.client6.v1.api.collections.rerankers.DummyReranker; import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; import io.weaviate.client6.v1.api.collections.vectorindex.MultiVector; import io.weaviate.containers.Container; @@ -52,7 +56,7 @@ public class SearchITest extends ConcurrentTest { Weaviate.custom() .withModel2VecUrl(Model2Vec.URL) .withImageInference(Img2VecNeural.URL, Img2VecNeural.MODULE) - .addModules("generative-dummy") + .addModules(Generative.Kind.DUMMY.jsonValue(), Reranker.Kind.DUMMY.jsonValue()) .build(), Container.IMG2VEC_NEURAL, Container.MODEL2VEC); @@ -741,7 +745,7 @@ public void teset_filterPropertyLength() throws IOException { // Assertions Assertions.assertThat(got.objects()).hasSize(2); } - + /** * Ensure the client respects server's configuration for max gRPC size: * we create a server with 1-byte message size and try to send a large payload @@ -768,4 +772,30 @@ public void test_maxGrpcMessageSize() throws Exception { }).isInstanceOf(io.grpc.StatusRuntimeException.class); } } + + @Test + public void test_rerankQueries() throws IOException { + // Arrange + var nsThigns = ns("Things"); + + var things = client.collections.create(nsThigns, + c -> c + .properties(Property.text("title"), Property.integer("price")) + .vectorConfig(VectorConfig.text2vecModel2Vec( + t2v -> t2v.sourceProperties("title", "price"))) + .rerankerModules(new DummyReranker())); + + things.data.insertMany( + Map.of("title", "Ergonomic chair", "price", 269), + Map.of("title", "Height-adjustable desk", "price", 349)); + + // Act + var got = things.query.nearText( + "office supplies", + nt -> nt.rerank(Rerank.by("price", + rank -> rank.query("cheaper first")))); + + // Assert: ranking not important really, just that the request was valid. + Assertions.assertThat(got.objects()).hasSize(2); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionConfig.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionConfig.java index 634bd9713..03c66841b 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionConfig.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionConfig.java @@ -276,19 +276,23 @@ public void write(JsonWriter out, CollectionConfig value) throws IOException { // Reranker and Generative module configs belong to the "moduleConfig". var rerankerModules = jsonObject.remove("rerankerModules").getAsJsonArray(); var generativeModule = jsonObject.remove("generativeModule"); - if (!rerankerModules.isEmpty() || !generativeModule.isJsonNull()) { - var modules = new JsonObject(); + var modules = new JsonObject(); + if (!rerankerModules.isEmpty()) { // Copy configuration for each reranker module. rerankerModules.forEach(reranker -> { reranker.getAsJsonObject().entrySet() .stream().forEach(entry -> modules.add(entry.getKey(), entry.getValue())); }); + } + if (!generativeModule.isJsonNull()) { // Copy configuration for each generative module. generativeModule.getAsJsonObject().entrySet() .stream().forEach(entry -> modules.add(entry.getKey(), entry.getValue())); + } + if (!modules.isEmpty()) { jsonObject.add("moduleConfig", modules); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Quantization.java b/src/main/java/io/weaviate/client6/v1/api/collections/Quantization.java index cd5fba0c6..c60a67c43 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Quantization.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Quantization.java @@ -19,9 +19,10 @@ import io.weaviate.client6.v1.api.collections.quantizers.SQ; import io.weaviate.client6.v1.api.collections.quantizers.Uncompressed; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.TaggedUnion; import io.weaviate.client6.v1.internal.json.JsonEnum; -public interface Quantization { +public interface Quantization extends TaggedUnion { public enum Kind implements JsonEnum { UNCOMPRESSED("skipDefaultQuantization"), @@ -112,6 +113,46 @@ public static Quantization rq(Function> fn) { return RQ.of(fn); } + default BQ asBQ() { + return _as(Quantization.Kind.BQ); + } + + default RQ asRQ() { + return _as(Quantization.Kind.RQ); + } + + default PQ asPQ() { + return _as(Quantization.Kind.PQ); + } + + default SQ asSQ() { + return _as(Quantization.Kind.SQ); + } + + default Uncompressed asUncompressed() { + return _as(Quantization.Kind.UNCOMPRESSED); + } + + default boolean isBQ() { + return _is(Quantization.Kind.BQ); + } + + default boolean isRQ() { + return _is(Quantization.Kind.RQ); + } + + default boolean isPQ() { + return _is(Quantization.Kind.PQ); + } + + default boolean isSQ() { + return _is(Quantization.Kind.SQ); + } + + default boolean isUncompressed() { + return _is(Quantization.Kind.UNCOMPRESSED); + } + public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { INSTANCE; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Reranker.java b/src/main/java/io/weaviate/client6/v1/api/collections/Reranker.java index 7358cdd21..5f38fc31e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Reranker.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Reranker.java @@ -14,6 +14,7 @@ import com.google.gson.stream.JsonWriter; import io.weaviate.client6.v1.api.collections.rerankers.CohereReranker; +import io.weaviate.client6.v1.api.collections.rerankers.DummyReranker; import io.weaviate.client6.v1.api.collections.rerankers.JinaAiReranker; import io.weaviate.client6.v1.api.collections.rerankers.NvidiaReranker; import io.weaviate.client6.v1.api.collections.rerankers.TransformersReranker; @@ -24,6 +25,7 @@ public interface Reranker extends TaggedUnion { public enum Kind implements JsonEnum { + DUMMY("reranker-dummy"), JINAAI("reranker-jinaai"), VOYAGEAI("reranker-voyageai"), NVIDIA("reranker-nvidia"), @@ -120,6 +122,11 @@ private final void addAdapter(Gson gson, Reranker.Kind kind, Class returnMetrics) { @@ -29,6 +32,7 @@ public static Aggregation of(AggregateObjectFilter objectFilter, Function metrics = new ArrayList<>(); private Integer objectLimit; private boolean includeTotalCount = false; @@ -55,6 +60,24 @@ public final Builder includeTotalCount(boolean include) { return this; } + /** + * Filter result set using traditional filtering operators: {@code eq}, + * {@code gte}, {@code like}, etc. + * Subsequent calls to {@link #filter} aggregate with an AND operator. + */ + public final Builder filters(Filter filter) { + this.whereFilter = this.whereFilter == null + ? filter + : Filter.and(this.whereFilter, filter); + return this; + } + + /** Combine several conditions using with an AND operator. */ + public final Builder filters(Filter... filters) { + Arrays.stream(filters).map(this::filters); + return this; + } + @SafeVarargs public final Builder metrics(PropertyAggregation... metrics) { this.metrics = Arrays.asList(metrics); @@ -80,6 +103,12 @@ public void appendTo(WeaviateProtoAggregate.AggregateRequest.Builder req) { req.setObjectLimit(objectLimit); } + if (whereFilter != null) { + var protoFilters = WeaviateProtoBase.Filters.newBuilder(); + whereFilter.appendTo(protoFilters); + req.setFilters(protoFilters); + } + for (final var metric : returnMetrics) { var aggregation = WeaviateProtoAggregate.AggregateRequest.Aggregation.newBuilder(); metric.appendTo(aggregation); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/UpdateCollectionRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/UpdateCollectionRequest.java index f0093492b..c706da0c1 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/config/UpdateCollectionRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/UpdateCollectionRequest.java @@ -9,6 +9,7 @@ import io.weaviate.client6.v1.api.collections.Generative; import io.weaviate.client6.v1.api.collections.InvertedIndex; import io.weaviate.client6.v1.api.collections.MultiTenancy; +import io.weaviate.client6.v1.api.collections.Quantization; import io.weaviate.client6.v1.api.collections.Replication; import io.weaviate.client6.v1.api.collections.Reranker; import io.weaviate.client6.v1.api.collections.VectorConfig; @@ -17,21 +18,49 @@ import io.weaviate.client6.v1.internal.rest.Endpoint; import io.weaviate.client6.v1.internal.rest.SimpleEndpoint; -public record UpdateCollectionRequest(CollectionConfig collection) { +public record UpdateCollectionRequest(CollectionConfig updated, CollectionConfig original) { public static final Endpoint _ENDPOINT = SimpleEndpoint.sideEffect( request -> "PUT", - request -> "/schema/" + request.collection.collectionName(), + request -> "/schema/" + request.updated.collectionName(), request -> Collections.emptyMap(), - request -> JSON.serialize(request.collection)); + request -> { + var json = JSON.serialize(request.updated); + + // Workaround: when doing updates, the server *insists* that + // "skipDefaultQuantization" property remains unchanged for each vector, + // even in cases when it is irrelevant [shrug]. + // To mitigate that we will set that field to its original value for all + // vectors which were present in the original configuration. + var jsonObject = JSON.toJsonElement(json).getAsJsonObject(); + var vectorsAny = jsonObject.get("vectorConfig"); + if (request.original.vectors() != null && !request.original.vectors().isEmpty() + && vectorsAny != null && vectorsAny.isJsonObject()) { + var vectors = vectorsAny.getAsJsonObject(); + for (var origVector : request.original.vectors().entrySet()) { + var vectorName = origVector.getKey(); + var origQuantization = origVector.getValue().quantization(); + if (vectors.has(vectorName)) { + vectors + .get(vectorName).getAsJsonObject() + .get("vectorIndexConfig").getAsJsonObject() + .addProperty(Quantization.Kind.UNCOMPRESSED.jsonValue(), origQuantization.isUncompressed()); + } + } + + json = jsonObject.toString(); + } + + return json; + }); - public static UpdateCollectionRequest of(CollectionConfig collection, + public static UpdateCollectionRequest of(CollectionConfig original, Function> fn) { - return fn.apply(new Builder(collection)).build(); + return fn.apply(new Builder(original)).build(); } public UpdateCollectionRequest(Builder builder) { - this(builder.newCollection.build()); + this(builder.newCollection.build(), builder.currentCollection); } public static class Builder implements ObjectBuilder { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java index 84138aa05..aff3548d3 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java @@ -58,7 +58,7 @@ void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { } } - public record Single(String prompt, boolean debug, List providers) { + public record Single(String prompt, boolean debug, boolean returnMetadata, List providers) { public static Single of(String prompt) { return of(prompt, ObjectBuilder.identity()); } @@ -68,13 +68,17 @@ public static Single of(String prompt, Function> } public Single(Builder builder) { - this(builder.prompt, builder.debug, builder.providers); + this(builder.prompt, + builder.debug, + builder.returnMetadata, + builder.providers); } public static class Builder implements ObjectBuilder { private final String prompt; private final List providers = new ArrayList<>(); private boolean debug = false; + private boolean returnMetadata = false; public Builder(String prompt) { this.prompt = prompt; @@ -85,6 +89,16 @@ public Builder debug(boolean enable) { return this; } + /** + * Return generative provider metadata alongside the query result. Metadata is + * only available if {@link #generativeProvider(GenerativeProvider)} is set + * explicitly.. + */ + public Builder metadata(boolean enable) { + this.returnMetadata = enable; + return this; + } + public Builder generativeProvider(GenerativeProvider provider) { providers.clear(); // Protobuf allows `repeated` but the server expects there to be 1. providers.add(provider); @@ -102,6 +116,7 @@ public void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { .map(provider -> { var proto = WeaviateProtoGenerative.GenerativeProvider.newBuilder(); provider.appendTo(proto); + proto.setReturnMetadata(returnMetadata); return proto.build(); }) .toList(); @@ -114,7 +129,8 @@ public void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { } } - public record Grouped(String prompt, boolean debug, List properties, List providers) { + public record Grouped(String prompt, boolean debug, boolean returnMetadata, List properties, + List providers) { public static Grouped of(String prompt) { return of(prompt, ObjectBuilder.identity()); } @@ -124,7 +140,12 @@ public static Grouped of(String prompt, Function } public Grouped(Builder builder) { - this(builder.prompt, builder.debug, builder.properties, builder.providers); + this( + builder.prompt, + builder.debug, + builder.returnMetadata, + builder.properties, + builder.providers); } public static class Builder implements ObjectBuilder { @@ -132,6 +153,7 @@ public static class Builder implements ObjectBuilder { private final List providers = new ArrayList<>(); private final List properties = new ArrayList<>(); private boolean debug = false; + private boolean returnMetadata = false; public Builder(String prompt) { this.prompt = prompt; @@ -152,6 +174,16 @@ public Builder generativeProvider(GenerativeProvider provider) { return this; } + /** + * Return generative provider metadata alongside the query result. Metadata is + * only available if {@link #generativeProvider(GenerativeProvider)} is set + * explicitly.. + */ + public Builder metadata(boolean enable) { + this.returnMetadata = enable; + return this; + } + public Builder debug(boolean enable) { this.debug = enable; return this; @@ -179,6 +211,7 @@ public void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { .map(provider -> { var proto = WeaviateProtoGenerative.GenerativeProvider.newBuilder(); provider.appendTo(proto); + proto.setReturnMetadata(returnMetadata); return proto.build(); }) .toList(); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/pagination/Paginator.java b/src/main/java/io/weaviate/client6/v1/api/collections/pagination/Paginator.java index 0d4d53413..23d86fb1a 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/pagination/Paginator.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/pagination/Paginator.java @@ -10,11 +10,11 @@ import io.weaviate.client6.v1.api.collections.WeaviateObject; import io.weaviate.client6.v1.api.collections.query.FetchObjects; +import io.weaviate.client6.v1.api.collections.query.Filter; import io.weaviate.client6.v1.api.collections.query.Metadata; import io.weaviate.client6.v1.api.collections.query.QueryMetadata; import io.weaviate.client6.v1.api.collections.query.QueryReference; import io.weaviate.client6.v1.api.collections.query.WeaviateQueryClient; -import io.weaviate.client6.v1.api.collections.query.Filter; import io.weaviate.client6.v1.internal.ObjectBuilder; public class Paginator implements Iterable> { @@ -94,6 +94,21 @@ public final Builder filters(Filter... filters) { return applyQueryOption(q -> q.filters(filters)); } + /** Include default vector. */ + public final Builder includeVector() { + return applyQueryOption(q -> q.includeVector()); + } + + /** Include one or more named vectors in the metadata response. */ + public final Builder includeVector(String... vectors) { + return applyQueryOption(q -> q.includeVector(vectors)); + } + + /** Include one or more named vectors in the metadata response. */ + public final Builder includeVector(List vectors) { + return applyQueryOption(q -> q.includeVector(vectors)); + } + public final Builder returnProperties(String... properties) { return applyQueryOption(q -> q.returnProperties(properties)); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseQueryOptions.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseQueryOptions.java index 1786ce74d..81d2c71fb 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseQueryOptions.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseQueryOptions.java @@ -131,7 +131,9 @@ protected SelfT generate(Function { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearText.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearText.java index 9878edb76..85509acd4 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearText.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearText.java @@ -13,7 +13,12 @@ import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; -public record NearText(Target searchTarget, Float distance, Float certainty, Move moveTo, +public record NearText( + Target searchTarget, + Float distance, + Float certainty, + Rerank rerank, + Move moveTo, Move moveAway, BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter { @@ -38,6 +43,7 @@ public NearText(Builder builder) { builder.searchTarget, builder.distance, builder.certainty, + builder.rerank, builder.moveTo, builder.moveAway, builder.baseOptions()); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearThermal.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearThermal.java index d85ef8b93..dc91ed03e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearThermal.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearThermal.java @@ -10,7 +10,11 @@ import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; -public record NearThermal(Target searchTarget, Float distance, Float certainty, BaseQueryOptions common) +public record NearThermal(Target searchTarget, + Float distance, + Float certainty, + Rerank rerank, + BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter { public static NearThermal of(String thermal) { @@ -34,6 +38,7 @@ public NearThermal(Builder builder) { builder.media, builder.distance, builder.certainty, + builder.rerank, builder.baseOptions()); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java index e8323a640..134c019c9 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java @@ -8,9 +8,29 @@ import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; -public record NearVector(NearVectorTarget searchTarget, Float distance, Float certainty, BaseQueryOptions common) +public record NearVector(NearVectorTarget searchTarget, + Float distance, + Float certainty, + Rerank rerank, + BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter { + public static final NearVector of(float[] vector) { + return of(vector, ObjectBuilder.identity()); + } + + public static final NearVector of(float[] vector, Function> fn) { + return fn.apply(new Builder(Target.vector(vector))).build(); + } + + public static final NearVector of(float[][] vector) { + return of(vector, ObjectBuilder.identity()); + } + + public static final NearVector of(float[][] vector, Function> fn) { + return fn.apply(new Builder(Target.vector(vector))).build(); + } + public static final NearVector of(NearVectorTarget searchTarget) { return of(searchTarget, ObjectBuilder.identity()); } @@ -20,7 +40,11 @@ public static final NearVector of(NearVectorTarget searchTarget, Function { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVideo.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVideo.java index f4a1b8922..fb8974216 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVideo.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVideo.java @@ -10,7 +10,12 @@ import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; -public record NearVideo(Target searchTarget, Float distance, Float certainty, BaseQueryOptions common) +public record NearVideo( + Target searchTarget, + Float distance, + Float certainty, + Rerank rerank, + BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter { public static NearVideo of(String video) { @@ -34,6 +39,7 @@ public NearVideo(Builder builder) { builder.media, builder.distance, builder.certainty, + builder.rerank, builder.baseOptions()); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java index a2fe89d11..886fd5b19 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java @@ -3,6 +3,14 @@ import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; public interface QueryOperator { + default BaseQueryOptions common() { + return null; + } + + default Rerank rerank() { + return null; + } + /** Append QueryOperator to the request message. */ void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java index fcc553f9e..625dde30d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java @@ -28,6 +28,13 @@ public static WeaviateProtoSearchGet.SearchRequest marshal( message.setUses125Api(true); message.setUses123Api(true); message.setCollection(collection.collectionName()); + + if (request.operator.common() != null) { + request.operator.common().appendTo(message); + } + if (request.operator.rerank() != null) { + request.operator.rerank().appendTo(message); + } request.operator.appendTo(message); if (defaults.tenant() != null) { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java index 98889c492..283045c9b 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java @@ -132,6 +132,12 @@ static WeaviateObject unmarsh } } metadataBuilder.vectors(vectors); + + if (metadataResult.getVectorBytes() != null && !metadataResult.getVectorBytes().isEmpty()) { + var unnamed = ByteStringUtil.decodeVectorSingle(metadataResult.getVectorBytes()); + metadataBuilder.vectors(Vectors.of(unnamed)); + } + metadata = metadataBuilder.build(); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/Rerank.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/Rerank.java new file mode 100644 index 000000000..76b3fb6ec --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/Rerank.java @@ -0,0 +1,50 @@ +package io.weaviate.client6.v1.api.collections.query; + +import java.util.function.Function; + +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; + +public record Rerank(String property, String query) { + + public static Rerank by(String property) { + return by(property, ObjectBuilder.identity()); + } + + public static Rerank by(String property, Function> fn) { + return fn.apply(new Builder(property)).build(); + } + + void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req) { + var rerank = WeaviateProtoSearchGet.Rerank.newBuilder() + .setProperty(property); + + if (query != null) { + rerank.setQuery(query); + } + req.setRerank(rerank); + } + + public Rerank(Builder builder) { + this(builder.property, builder.query); + } + + public static class Builder implements ObjectBuilder { + private final String property; + private String query; + + public Builder(String property) { + this.property = property; + } + + public Builder query(String query) { + this.query = query; + return this; + } + + @Override + public Rerank build() { + return new Rerank(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/Target.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/Target.java index aa62e7e8b..03d113197 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/Target.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/Target.java @@ -101,27 +101,27 @@ static VectorTarget vector(String vectorName, float weight, float[][] vector) { return new VectorTarget(vectorName, weight, vector); } - static Target combine(CombinationMethod combinationMethod, VectorTarget... vectorTargets) { + static NearVectorTarget combine(CombinationMethod combinationMethod, VectorTarget... vectorTargets) { return new CombinedVectorTarget(combinationMethod, Arrays.asList(vectorTargets)); } - static Target sum(VectorTarget... vectorTargets) { + static NearVectorTarget sum(VectorTarget... vectorTargets) { return combine(CombinationMethod.SUM, vectorTargets); } - static Target min(VectorTarget... vectorTargets) { + static NearVectorTarget min(VectorTarget... vectorTargets) { return combine(CombinationMethod.MIN, vectorTargets); } - static Target average(VectorTarget... vectorTargets) { + static NearVectorTarget average(VectorTarget... vectorTargets) { return combine(CombinationMethod.AVERAGE, vectorTargets); } - static Target relativeScore(VectorTarget... vectorTargets) { + static NearVectorTarget relativeScore(VectorTarget... vectorTargets) { return combine(CombinationMethod.RELATIVE_SCORE, vectorTargets); } - static Target manualWeights(VectorTarget... vectorTargets) { + static NearVectorTarget manualWeights(VectorTarget... vectorTargets) { return combine(CombinationMethod.MANUAL_WEIGHTS, vectorTargets); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/DummyReranker.java b/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/DummyReranker.java new file mode 100644 index 000000000..ce692ed3a --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/DummyReranker.java @@ -0,0 +1,16 @@ +package io.weaviate.client6.v1.api.collections.rerankers; + +import io.weaviate.client6.v1.api.collections.Reranker; + +public record DummyReranker() implements Reranker { + + @Override + public Kind _kind() { + return Reranker.Kind.DUMMY; + } + + @Override + public Object _self() { + return this; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java b/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java index 344409495..5702250b9 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java +++ b/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java @@ -4,6 +4,8 @@ import com.google.gson.Gson; import com.google.gson.GsonBuilder; +import com.google.gson.JsonElement; +import com.google.gson.JsonParser; import com.google.gson.ToNumberPolicy; import com.google.gson.reflect.TypeToken; @@ -85,6 +87,10 @@ public static final String serialize(Object value) { return serialize(value, TypeToken.get(value.getClass())); } + public static final JsonElement toJsonElement(String json) { + return JsonParser.parseString(json); + } + public static final String serialize(Object value, TypeToken typeToken) { return gson.toJson(value, typeToken.getType()); }