|
6 | 6 | import static io.qdrant.client.QueryFactory.fusion; |
7 | 7 | import static io.qdrant.client.QueryFactory.nearest; |
8 | 8 | import static io.qdrant.client.QueryFactory.orderBy; |
| 9 | +import static io.qdrant.client.QueryFactory.sample; |
9 | 10 | import static io.qdrant.client.TargetVectorFactory.targetVector; |
10 | 11 | import static io.qdrant.client.ValueFactory.value; |
11 | 12 | import static io.qdrant.client.VectorFactory.vector; |
|
38 | 39 | import io.qdrant.client.grpc.Points.PointsUpdateOperation.ClearPayload; |
39 | 40 | import io.qdrant.client.grpc.Points.PointsUpdateOperation.UpdateVectors; |
40 | 41 | import io.qdrant.client.grpc.Points.PrefetchQuery; |
| 42 | +import io.qdrant.client.grpc.Points.QueryPointGroups; |
41 | 43 | import io.qdrant.client.grpc.Points.QueryPoints; |
42 | 44 | import io.qdrant.client.grpc.Points.RecommendPointGroups; |
43 | 45 | import io.qdrant.client.grpc.Points.RecommendPoints; |
44 | 46 | import io.qdrant.client.grpc.Points.RetrievedPoint; |
| 47 | +import io.qdrant.client.grpc.Points.Sample; |
45 | 48 | import io.qdrant.client.grpc.Points.ScoredPoint; |
46 | 49 | import io.qdrant.client.grpc.Points.ScrollPoints; |
47 | 50 | import io.qdrant.client.grpc.Points.ScrollResponse; |
|
50 | 53 | import io.qdrant.client.grpc.Points.UpdateResult; |
51 | 54 | import io.qdrant.client.grpc.Points.UpdateStatus; |
52 | 55 | import io.qdrant.client.grpc.Points.Vectors; |
| 56 | +import java.util.Arrays; |
53 | 57 | import java.util.List; |
54 | 58 | import java.util.concurrent.ExecutionException; |
55 | 59 | import java.util.concurrent.TimeUnit; |
@@ -596,7 +600,7 @@ public void batchPointUpdate() throws ExecutionException, InterruptedException { |
596 | 600 | createAndSeedCollection(testName); |
597 | 601 |
|
598 | 602 | List<PointsUpdateOperation> operations = |
599 | | - List.of( |
| 603 | + Arrays.asList( |
600 | 604 | PointsUpdateOperation.newBuilder() |
601 | 605 | .setClearPayload( |
602 | 606 | ClearPayload.newBuilder() |
@@ -757,6 +761,58 @@ public void queryWithPrefetchAndFusion() throws ExecutionException, InterruptedE |
757 | 761 | assertEquals(2, points.size()); |
758 | 762 | } |
759 | 763 |
|
| 764 | + @Test |
| 765 | + public void queryWithSampling() throws ExecutionException, InterruptedException { |
| 766 | + createAndSeedCollection(testName); |
| 767 | + |
| 768 | + List<ScoredPoint> points = |
| 769 | + client |
| 770 | + .queryAsync( |
| 771 | + QueryPoints.newBuilder() |
| 772 | + .setCollectionName(testName) |
| 773 | + .setQuery(sample(Sample.Random)) |
| 774 | + .setLimit(1) |
| 775 | + .build()) |
| 776 | + .get(); |
| 777 | + |
| 778 | + assertEquals(1, points.size()); |
| 779 | + } |
| 780 | + |
| 781 | + @Test |
| 782 | + public void queryGroups() throws ExecutionException, InterruptedException { |
| 783 | + createAndSeedCollection(testName); |
| 784 | + |
| 785 | + client |
| 786 | + .upsertAsync( |
| 787 | + testName, |
| 788 | + ImmutableList.of( |
| 789 | + PointStruct.newBuilder() |
| 790 | + .setId(id(10)) |
| 791 | + .setVectors(VectorsFactory.vectors(30f, 31f)) |
| 792 | + .putAllPayload(ImmutableMap.of("foo", value("hello"))) |
| 793 | + .build())) |
| 794 | + .get(); |
| 795 | + // 3 points in total, 2 with "foo" = "hello" and 1 with "foo" = "goodbye" |
| 796 | + |
| 797 | + List<PointGroup> groups = |
| 798 | + client |
| 799 | + .queryGroupsAsync( |
| 800 | + QueryPointGroups.newBuilder() |
| 801 | + .setCollectionName(testName) |
| 802 | + .setQuery(nearest(ImmutableList.of(10.4f, 11.4f))) |
| 803 | + .setGroupBy("foo") |
| 804 | + .setGroupSize(2) |
| 805 | + .setLimit(10) |
| 806 | + .build()) |
| 807 | + .get(); |
| 808 | + |
| 809 | + assertEquals(2, groups.size()); |
| 810 | + // A group with 2 hits because of 2 points with "foo" = "hello" |
| 811 | + assertEquals(1, groups.stream().filter(g -> g.getHitsCount() == 2).count()); |
| 812 | + // A group with 1 hit because of 1 point with "foo" = "goodbye" |
| 813 | + assertEquals(1, groups.stream().filter(g -> g.getHitsCount() == 1).count()); |
| 814 | + } |
| 815 | + |
760 | 816 | private void createAndSeedCollection(String collectionName) |
761 | 817 | throws ExecutionException, InterruptedException { |
762 | 818 | CreateCollection request = |
|
0 commit comments