Skip to content

Commit d9811bc

Browse files
authored
Merge pull request #130 from weaviate/feat/target-vectors
2 parents 427a3a6 + c246b5e commit d9811bc

File tree

14 files changed

+656
-275
lines changed

14 files changed

+656
-275
lines changed

src/Weaviate.Client.Tests/Integration/TestCollectionAggregate.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ int expectedLen
267267
Assert.True(obj.Vectors.ContainsKey("default"));
268268
await collectionClient.Data.Insert(new { text = text2 });
269269

270-
var nearVector = obj.Vectors["default"].Cast<float>().ToArray();
270+
var nearVector = obj.Vectors["default"];
271271
var metrics = new[]
272272
{
273273
Metrics

src/Weaviate.Client.Tests/Integration/TestMultiVector.cs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -155,20 +155,15 @@ public async Task Test_MultiVector_SelfProvided()
155155

156156
Assert.Equal(1UL, await collection.Count());
157157

158-
var objs = await collection.Query.NearVector(
159-
Vectors.Create(1f, 2f),
160-
targetVector: ["regular"]
161-
);
158+
var objs = await collection.Query.NearVector(new[] { 1f, 2f }, targetVector: ["regular"]);
162159
Assert.Single(objs);
163160

164161
objs = await collection.Query.NearVector(
165-
Vectors.Create(
166-
new[,]
167-
{
168-
{ 1f, 2f },
169-
{ 3f, 4f },
170-
}
171-
),
162+
new[,]
163+
{
164+
{ 1f, 2f },
165+
{ 3f, 4f },
166+
},
172167
targetVector: ["colbert"]
173168
);
174169
Assert.Single(objs);
@@ -181,7 +176,7 @@ public async Task Test_MultiVector_SelfProvided()
181176
// Assert.Single(objs);
182177

183178
objs = await collection.Query.NearVector(
184-
Vectors.Create(
179+
Vector.Create(
185180
"colbert",
186181
new[,]
187182
{
Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
namespace Weaviate.Client.Tests.Integration;
2+
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Linq;
6+
using System.Threading.Tasks;
7+
using Weaviate.Client.Models;
8+
using Xunit;
9+
10+
public class TestNamedVectorMultiTarget : IntegrationTests
11+
{
12+
[Fact]
13+
public async Task Test_NamedVector_MultiTargetVectorPerTarget()
14+
{
15+
var dummy = await CollectionFactory();
16+
if (dummy.WeaviateVersion < Version.Parse("1.26.0"))
17+
{
18+
Assert.Skip("Named vectors are not supported in versions lower than 1.26.0");
19+
}
20+
21+
var collection = await CollectionFactory(
22+
vectorConfig: new[]
23+
{
24+
Configure.Vectors.SelfProvided(name: "first"),
25+
Configure.Vectors.SelfProvided(name: "second"),
26+
}
27+
);
28+
29+
var uuid1 = await collection.Data.Insert(
30+
new { },
31+
vectors: new Vectors
32+
{
33+
{ "first", new[] { 1f, 0f } },
34+
{ "second", new[] { 0f, 1f, 0f } },
35+
}
36+
);
37+
var uuid2 = await collection.Data.Insert(
38+
new { },
39+
vectors: new Vectors
40+
{
41+
{ "first", new[] { 0f, 1f } },
42+
{ "second", new[] { 1f, 0f, 0f } },
43+
}
44+
);
45+
46+
var objs = await collection.Query.NearVector(
47+
new Vectors { { "first", new[] { 1f, 0f } }, { "second", new[] { 1f, 0f, 0f } } },
48+
targetVector: ["first", "second"]
49+
);
50+
var ids = objs.Select(o => o.ID!.Value).OrderBy(x => x).ToList();
51+
var expected = new[] { uuid1, uuid2 }.OrderBy(x => x).ToList();
52+
Assert.Equal(expected, ids);
53+
}
54+
55+
public static TheoryData<Vectors, string[]> MultiInputCombinations =>
56+
new(
57+
(
58+
new Vectors
59+
{
60+
{ "first", new[] { 0f, 1f } },
61+
{
62+
"second",
63+
new float[,]
64+
{
65+
{ 1f, 0f, 0f },
66+
{ 0f, 0f, 1f },
67+
}
68+
},
69+
},
70+
new[] { "first", "second" }
71+
),
72+
(
73+
new Vectors
74+
{
75+
{ "first", new[] { 0f, 1f } },
76+
{
77+
"second",
78+
new float[,]
79+
{
80+
{ 1f, 0f, 0f },
81+
{ 0f, 0f, 1f },
82+
}
83+
},
84+
},
85+
new[] { "first", "second" }
86+
),
87+
(
88+
new Vectors
89+
{
90+
{
91+
"first",
92+
new float[,]
93+
{
94+
{ 0f, 1f },
95+
{ 0f, 1f },
96+
}
97+
},
98+
{ "second", new[] { 1f, 0f, 0f } },
99+
},
100+
new[] { "first", "second" }
101+
),
102+
(
103+
new Vectors
104+
{
105+
{
106+
"first",
107+
new float[,]
108+
{
109+
{ 0f, 1f },
110+
{ 0f, 1f },
111+
}
112+
},
113+
{
114+
"second",
115+
new float[,]
116+
{
117+
{ 1f, 0f, 0f },
118+
{ 0f, 0f, 1f },
119+
}
120+
},
121+
},
122+
new[] { "first", "second" }
123+
),
124+
(
125+
new Vectors
126+
{
127+
{
128+
"first",
129+
new float[,]
130+
{
131+
{ 0f, 1f },
132+
{ 0f, 1f },
133+
}
134+
},
135+
{
136+
"second",
137+
new float[,]
138+
{
139+
{ 1f, 0f, 0f },
140+
{ 0f, 0f, 1f },
141+
}
142+
},
143+
},
144+
new[] { "second", "first" }
145+
)
146+
);
147+
148+
[Theory]
149+
[MemberData(nameof(MultiInputCombinations))]
150+
public async Task Test_SameTargetVector_MultipleInputCombinations(
151+
Vectors nearVector,
152+
string[] targetVector
153+
)
154+
{
155+
var dummy = await CollectionFactory();
156+
if (dummy.WeaviateVersion < Version.Parse("1.27.0"))
157+
{
158+
Assert.Skip("Multi vector per target is not supported in versions lower than 1.27.0");
159+
}
160+
161+
var collection = await CollectionFactory(
162+
properties: Array.Empty<Property>(),
163+
vectorConfig: new[]
164+
{
165+
Configure.Vectors.SelfProvided(name: "first"),
166+
Configure.Vectors.SelfProvided(name: "second"),
167+
}
168+
);
169+
170+
var uuid1 = await collection.Data.Insert(
171+
new { },
172+
vectors: new Vectors
173+
{
174+
{ "first", new[] { 1f, 0f } },
175+
{ "second", new[] { 0f, 1f, 0f } },
176+
}
177+
);
178+
var uuid2 = await collection.Data.Insert(
179+
new { },
180+
vectors: new Vectors
181+
{
182+
{ "first", new[] { 0f, 1f } },
183+
{ "second", new[] { 1f, 0f, 0f } },
184+
}
185+
);
186+
187+
var objs = await collection.Query.NearVector(nearVector, targetVector: targetVector);
188+
var ids = objs.Select(o => o.ID!.Value).OrderBy(x => x).ToList();
189+
var expected = new[] { uuid2, uuid1 }.OrderBy(x => x).ToList();
190+
Assert.Equal(expected, ids);
191+
}
192+
193+
public static IEnumerable<object[]> MultiTargetVectorsWithDistances =>
194+
new List<object[]>
195+
{
196+
new object[] { TargetVectors.Sum(["first", "second"]), new float[] { 1.0f, 3.0f } },
197+
new object[]
198+
{
199+
TargetVectors.ManualWeights(("first", 1.0f), ("second", [1.0f, 1.0f])),
200+
new float[] { 1.0f, 3.0f },
201+
},
202+
new object[]
203+
{
204+
TargetVectors.ManualWeights(("first", 1.0f), ("second", [1.0f, 2.0f])),
205+
new float[] { 2.0f, 4.0f },
206+
},
207+
new object[]
208+
{
209+
TargetVectors.ManualWeights(("second", [1.0f, 2.0f]), ("first", 1.0f)),
210+
new float[] { 2.0f, 4.0f },
211+
},
212+
};
213+
214+
[Theory]
215+
[MemberData(nameof(MultiTargetVectorsWithDistances))]
216+
public async Task Test_SameTargetVector_MultipleInput(
217+
TargetVectors targetVector,
218+
float[] expectedDistances
219+
)
220+
{
221+
var dummy = await CollectionFactory();
222+
if (dummy.WeaviateVersion < Version.Parse("1.26.0"))
223+
{
224+
Assert.Skip("Named vectors are not supported in versions lower than 1.26.0");
225+
}
226+
227+
var collection = await CollectionFactory(
228+
properties: Array.Empty<Property>(),
229+
vectorConfig: new[]
230+
{
231+
Configure.Vectors.SelfProvided(name: "first"),
232+
Configure.Vectors.SelfProvided(name: "second"),
233+
}
234+
);
235+
236+
var inserts = (
237+
await collection.Data.InsertMany(
238+
new BatchInsertRequest<object>(
239+
Data: new { },
240+
Vectors: new Vectors
241+
{
242+
{ "first", new[] { 1f, 0f } },
243+
{ "second", new[] { 0f, 1f, 0f } },
244+
}
245+
),
246+
new BatchInsertRequest<object>(
247+
Data: new { },
248+
Vectors: new Vectors
249+
{
250+
{ "first", new[] { 0f, 1f } },
251+
{ "second", new[] { 1f, 0f, 0f } },
252+
}
253+
)
254+
)
255+
).ToList();
256+
257+
var results = (
258+
await collection.Query.FetchObjects(returnMetadata: MetadataOptions.All)
259+
).ToList();
260+
261+
var uuid1 = results[0].ID!.Value;
262+
var uuid2 = results[1].ID!.Value;
263+
264+
var objs = await collection.Query.NearVector(
265+
new Vectors
266+
{
267+
{ "first", new[] { 0f, 1f } },
268+
{
269+
"second",
270+
new[,]
271+
{
272+
{ 1f, 0f, 0f },
273+
{ 0f, 0f, 1f },
274+
}
275+
},
276+
},
277+
targetVector: targetVector,
278+
returnMetadata: MetadataOptions.All
279+
);
280+
var ids = objs.Select(o => o.ID!.Value).OrderBy(x => x).ToList();
281+
var expected = new[] { uuid1, uuid2 }.OrderBy(x => x).ToList();
282+
Assert.Equal(expected, ids);
283+
Assert.Equal(expectedDistances.Length, objs.Count());
284+
Assert.Equal(expectedDistances[0], objs.ElementAt(0).Metadata.Distance);
285+
Assert.Equal(expectedDistances[1], objs.ElementAt(1).Metadata.Distance);
286+
}
287+
288+
public static IEnumerable<object[]> MultiTargetVectors =>
289+
new List<object[]>
290+
{
291+
new object[] { (TargetVectors)new[] { "first", "second" } },
292+
new object[] { TargetVectors.Sum(new[] { "first", "second" }) },
293+
new object[] { TargetVectors.Minimum(new[] { "first", "second" }) },
294+
new object[] { TargetVectors.Average(new[] { "first", "second" }) },
295+
new object[] { TargetVectors.ManualWeights(("first", 1.2), ("second", 0.7)) },
296+
new object[] { TargetVectors.RelativeScore(("first", 1.2), ("second", 0.7)) },
297+
};
298+
299+
[Theory]
300+
[MemberData(nameof(MultiTargetVectors))]
301+
public async Task Test_NamedVector_MultiTarget(string[] targetVector)
302+
{
303+
var dummy = await CollectionFactory();
304+
if (dummy.WeaviateVersion < Version.Parse("1.26.0"))
305+
{
306+
Assert.Skip("Named vectors are not supported in versions lower than 1.26.0");
307+
}
308+
309+
var collection = await CollectionFactory(
310+
properties: Array.Empty<Property>(),
311+
vectorConfig: new[]
312+
{
313+
Configure.Vectors.SelfProvided(name: "first"),
314+
Configure.Vectors.SelfProvided(name: "second"),
315+
}
316+
);
317+
318+
var uuid1 = await collection.Data.Insert(
319+
new { },
320+
vectors: new Vectors
321+
{
322+
{ "first", new[] { 1f, 0f, 0f } },
323+
{ "second", new[] { 0f, 1f, 0f } },
324+
}
325+
);
326+
var uuid2 = await collection.Data.Insert(
327+
new { },
328+
vectors: new Vectors
329+
{
330+
{ "first", new[] { 0f, 1f, 0f } },
331+
{ "second", new[] { 1f, 0f, 0f } },
332+
}
333+
);
334+
335+
var objs = await collection.Query.NearVector(
336+
new[] { 1f, 0f, 0f },
337+
targetVector: targetVector
338+
);
339+
var ids = objs.Select(o => o.ID!.Value).OrderBy(x => x).ToList();
340+
var expected = new[] { uuid1, uuid2 }.OrderBy(x => x).ToList();
341+
Assert.Equal(expected, ids);
342+
}
343+
}

0 commit comments

Comments
 (0)