Skip to content

Commit 55f958c

Browse files
committed
force VSIM payloads as vectors, unless named
1 parent c310c6a commit 55f958c

File tree

5 files changed

+78
-17
lines changed

5 files changed

+78
-17
lines changed

src/NRedisStack/Search/HybridSearchQuery.Command.cs

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ internal int GetOwnArgsCount(IReadOnlyDictionary<string, object>? parameters)
2424
{
2525
int count = _search.GetOwnArgsCount() + _vsim.GetOwnArgsCount(); // note index is not included here
2626

27-
2827
if (_combiner is not null)
2928
{
3029
count += 1 + _combiner.GetOwnArgsCount();
@@ -126,9 +125,15 @@ internal int GetOwnArgsCount(IReadOnlyDictionary<string, object>? parameters)
126125

127126
if (_pagingOffset >= 0) count += 3;
128127

128+
bool forceVsimParam = _vsim.VectorData is { ForceParameter: true };
129129
if (parameters is not null)
130130
{
131131
count += (parameters.Count + 1) * 2;
132+
if (forceVsimParam) count += 2;
133+
}
134+
else if (forceVsimParam)
135+
{
136+
count += 4;
132137
}
133138

134139
if (_explainScore) count++;
@@ -144,10 +149,25 @@ internal int GetOwnArgsCount(IReadOnlyDictionary<string, object>? parameters)
144149
return count;
145150
}
146151

152+
private static string InventVectorParameterName(IReadOnlyDictionary<string, object>? parameters)
153+
{
154+
const string DEFAULT_NAME = "v";
155+
if (parameters is null || !parameters.ContainsKey(DEFAULT_NAME)) return DEFAULT_NAME;
156+
var max = parameters.Count;
157+
for (int i = 0; i <= max; i++)
158+
{
159+
var key = $"{DEFAULT_NAME}{i}";
160+
if (!parameters.ContainsKey(key)) return key;
161+
}
162+
// if we get here, the dictionary is lying to us
163+
throw new InvalidOperationException("Unable to create parameter for vector");
164+
}
165+
147166
internal void AddOwnArgs(List<object> args, IReadOnlyDictionary<string, object>? parameters)
148167
{
149168
_search.AddOwnArgs(args);
150-
_vsim.AddOwnArgs(args);
169+
string? forcedVsimName = _vsim.VectorData is { ForceParameter: true } ? InventVectorParameterName(parameters) : null;
170+
_vsim.AddOwnArgs(args, forcedVsimName);
151171

152172
if (_combiner is not null)
153173
{
@@ -313,7 +333,14 @@ static void AddApply(in ApplyExpression expr, List<object> args)
313333
if (parameters is not null)
314334
{
315335
args.Add("PARAMS");
316-
args.Add(parameters.Count * 2);
336+
var pairs = parameters.Count;
337+
if (forcedVsimName is not null) pairs++;
338+
args.Add(pairs * 2);
339+
if (forcedVsimName is not null)
340+
{
341+
args.Add(forcedVsimName);
342+
args.Add(_vsim.VectorData!.AsRedisValue());
343+
}
317344
if (parameters is Dictionary<string, object> typed)
318345
{
319346
foreach (var entry in typed) // avoid allocating enumerator
@@ -339,6 +366,13 @@ static object Wrap(object value)
339366
_ => value,
340367
};
341368
}
369+
else if (forcedVsimName is not null)
370+
{
371+
args.Add("PARAMS");
372+
args.Add(2);
373+
args.Add(forcedVsimName);
374+
args.Add(_vsim.VectorData!.AsRedisValue());
375+
}
342376

343377
if (_explainScore) args.Add("EXPLAINSCORE");
344378
if (_timeout > TimeSpan.Zero)

src/NRedisStack/Search/HybridSearchQuery.VectorSearchConfig.cs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System.Runtime.CompilerServices;
2+
using StackExchange.Redis;
23

34
namespace NRedisStack.Search;
45

@@ -103,13 +104,20 @@ internal int GetOwnArgsCount()
103104
return count;
104105
}
105106

106-
internal void AddOwnArgs(List<object> args)
107+
internal void AddOwnArgs(List<object> args, string? forcedParameterName)
107108
{
108109
if (HasValue)
109110
{
110111
args.Add("VSIM");
111112
args.Add(_fieldName);
112-
args.Add(_vectorData.GetSingleArg());
113+
if (forcedParameterName is not null)
114+
{
115+
args.Add("$" + forcedParameterName);
116+
}
117+
else
118+
{
119+
args.Add(_vectorData.GetSingleArg());
120+
}
113121

114122
_method?.AddOwnArgs(args);
115123
if (_filter != null)

src/NRedisStack/Search/VectorData.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ public override void Dispose()
2828
if (tmp is not null) ArrayPool<byte>.Shared.Return(tmp);
2929
}
3030
public override RedisValue AsRedisValue() => (RedisValue)new ReadOnlyMemory<byte>(Array, 0, byteLength);
31+
32+
// *always* force; in 8.4, this is not required, but this is likely to change, so: avoid the problem immediately
33+
internal override bool ForceParameter => true; // byteLength != 0 && Array[0] == (byte)'$';
3134
}
3235

3336
public abstract void Dispose();
@@ -102,6 +105,9 @@ private protected VectorData()
102105
private sealed class VectorDataRaw(ReadOnlyMemory<byte> bytes) : VectorData
103106
{
104107
public override RedisValue AsRedisValue() => (RedisValue)bytes;
108+
109+
// *always* force; in 8.4, this is not required, but this is likely to change, so: avoid the problem immediately
110+
internal override bool ForceParameter => true; // !bytes.IsEmpty && bytes.Span[0] == (byte)'$';
105111
}
106112

107113
private sealed class VectorParameter : VectorData
@@ -122,4 +128,6 @@ public VectorParameter(string name)
122128

123129
private protected static void ThrowBigEndian() =>
124130
throw new PlatformNotSupportedException("Big-endian CPUs are not currently supported for this operation");
131+
132+
internal virtual bool ForceParameter => false;
125133
}

tests/NRedisStack.Tests/Search/HybridSearchIntegrationTests.cs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ public async Task TestSearch(string endpointId)
137137
public enum Scenario
138138
{
139139
Simple,
140+
ForcedDollarValue,
141+
ForcedDollarValueWithParameters,
140142
NoSort,
141143
Apply,
142144
LinearNoScore,
@@ -232,6 +234,10 @@ public async Task TestSearchScenarios(string endpointId, Scenario scenario)
232234

233235
var hash = (await api.DB.HashGetAllAsync($"{api.Index}_entry2")).ToDictionary(k => k.Name, v => v.Value);
234236
var vec = (byte[])hash["vector1"]!;
237+
if (scenario is Scenario.ForcedDollarValue or Scenario.ForcedDollarValueWithParameters)
238+
{
239+
vec[0] = (byte)'$';
240+
}
235241
var text = (string)hash["text1"]!;
236242
string[] fields = ["@text1", HybridSearchQuery.Fields.Key, HybridSearchQuery.Fields.Score];
237243
var query = new HybridSearchQuery()
@@ -242,7 +248,7 @@ public async Task TestSearchScenarios(string endpointId, Scenario scenario)
242248
#pragma warning disable CS0612
243249
query = scenario switch
244250
{
245-
Scenario.Simple => query,
251+
Scenario.Simple or Scenario.ForcedDollarValue => query,
246252
Scenario.SearchWithAlias => query.Search(new(text, scoreAlias: "score_alias")),
247253
Scenario.SearchWithSimpleScorer => query.Search(new(text, scorer: Scorer.TfIdf)),
248254
Scenario.SearchWithComplexScorer => query.Search(new(text, scorer: Scorer.BM25StdTanh(7))),
@@ -293,7 +299,7 @@ public async Task TestSearchScenarios(string endpointId, Scenario scenario)
293299
Scenario.ReduceMulti => query.GroupBy("@tag1").Reduce(Reducers.Count().As("count"),
294300
Reducers.Min("@numeric1").As("min"), Reducers.Max("@numeric1").As("max")),
295301
Scenario.ParamVsim => query.VectorSearch("@vector1", VectorData.Parameter("$v")),
296-
Scenario.ParamSearch => query.Search("$q"),
302+
Scenario.ParamSearch or Scenario.ForcedDollarValueWithParameters => query.Search("$q"),
297303
Scenario.ParamPreFilter =>
298304
query.VectorSearch(new("@vector1", VectorData.Raw(vec), filter: "@numeric1!=$n")),
299305
Scenario.ParamPostFilter => query.ReturnFields([.. fields, "@numeric1"]).Filter("@numeric1!=$n"),
@@ -309,7 +315,7 @@ public async Task TestSearchScenarios(string endpointId, Scenario scenario)
309315
Scenario.ParamPostFilter or Scenario.ParamPreFilter => new Dictionary<string, object>() { ["n"] = 42 },
310316
Scenario.ParamMultiPostFilter or Scenario.ParamMultiPreFilter => new Dictionary<string, object>()
311317
{ ["n"] = 42, ["t"] = "foo" },
312-
Scenario.ParamSearch => new Dictionary<string, object>() { ["q"] = text },
318+
Scenario.ParamSearch or Scenario.ForcedDollarValueWithParameters => new Dictionary<string, object>() { ["q"] = text },
313319
Scenario.ParamVsim => new Dictionary<string, object>() { ["v"] = VectorData.Raw(vec) },
314320
_ => null,
315321
};

tests/NRedisStack.Tests/Search/HybridSearchUnitTests.cs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ public void BasicVectorSearch()
156156
byte[] blob = [];
157157
query.VectorSearch("vfield", VectorData.Raw(blob));
158158

159-
object[] expected = [Index, "VSIM", "vfield", ""];
159+
object[] expected = [Index, "VSIM", "vfield", "$v", "PARAMS", 2, "v", ""];
160160
Assert.Equivalent(expected, GetArgs(query));
161161
}
162162

@@ -170,7 +170,7 @@ public void BasicNonZeroLengthVectorSearch()
170170
HybridSearchQuery query = new();
171171
query.VectorSearch("vfield", SomeRandomDataHere);
172172

173-
object[] expected = [Index, "VSIM", "vfield", SomeRandomVectorValue];
173+
object[] expected = [Index, "VSIM", "vfield", "$v", "PARAMS", 2, "v", SomeRandomVectorValue];
174174
Assert.Equivalent(expected, GetArgs(query));
175175
}
176176

@@ -189,7 +189,7 @@ public void BasicVectorSearch_WithKNN(bool withScoreAlias, bool withDistanceAlia
189189
query.VectorSearch(searchConfig);
190190

191191
object[] expected =
192-
[Index, "VSIM", "vField", SomeRandomVectorValue, "KNN", withDistanceAlias ? 4 : 2, "K", 10];
192+
[Index, "VSIM", "vField", "$v", "KNN", withDistanceAlias ? 4 : 2, "K", 10];
193193
if (withDistanceAlias)
194194
{
195195
expected = [.. expected, "YIELD_DISTANCE_AS", "my_distance_alias"];
@@ -200,6 +200,7 @@ public void BasicVectorSearch_WithKNN(bool withScoreAlias, bool withDistanceAlia
200200
expected = [.. expected, "YIELD_SCORE_AS", "my_score_alias"];
201201
}
202202

203+
expected = [..expected, "PARAMS", 2, "v", SomeRandomVectorValue];
203204
Assert.Equivalent(expected, GetArgs(query));
204205
}
205206

@@ -221,7 +222,7 @@ public void BasicVectorSearch_WithKNN_WithEF(bool withScoreAlias, bool withDista
221222

222223
object[] expected =
223224
[
224-
Index, "VSIM", "vfield", SomeRandomVectorValue, "KNN", withDistanceAlias ? 6 : 4, "K", 16,
225+
Index, "VSIM", "vfield", "$v", "KNN", withDistanceAlias ? 6 : 4, "K", 16,
225226
"EF_RUNTIME", 100
226227
];
227228
if (withDistanceAlias)
@@ -234,6 +235,8 @@ public void BasicVectorSearch_WithKNN_WithEF(bool withScoreAlias, bool withDista
234235
expected = [.. expected, "YIELD_SCORE_AS", "my_score_alias"];
235236
}
236237

238+
expected = [.. expected, "PARAMS", 2, "v", SomeRandomVectorValue];
239+
237240
Assert.Equivalent(expected, GetArgs(query));
238241
}
239242

@@ -253,7 +256,7 @@ public void BasicVectorSearch_WithRange(bool withScoreAlias, bool withDistanceAl
253256

254257
object[] expected =
255258
[
256-
Index, "VSIM", "vfield", SomeRandomVectorValue, "RANGE", withDistanceAlias ? 4 : 2, "RADIUS",
259+
Index, "VSIM", "vfield", "$v", "RANGE", withDistanceAlias ? 4 : 2, "RADIUS",
257260
4.2
258261
];
259262
if (withDistanceAlias)
@@ -266,6 +269,7 @@ public void BasicVectorSearch_WithRange(bool withScoreAlias, bool withDistanceAl
266269
expected = [.. expected, "YIELD_SCORE_AS", "my_score_alias"];
267270
}
268271

272+
expected = [.. expected, "PARAMS", 2, "v", SomeRandomVectorValue];
269273
Assert.Equivalent(expected, GetArgs(query));
270274
}
271275

@@ -286,7 +290,7 @@ public void BasicVectorSearch_WithRange_WithEpsilon(bool withScoreAlias, bool wi
286290

287291
object[] expected =
288292
[
289-
Index, "VSIM", "vfield", SomeRandomVectorValue, "RANGE", withDistanceAlias ? 6 : 4, "RADIUS",
293+
Index, "VSIM", "vfield", "$v", "RANGE", withDistanceAlias ? 6 : 4, "RADIUS",
290294
4.2, "EPSILON", 0.06
291295
];
292296
if (withDistanceAlias)
@@ -299,6 +303,7 @@ public void BasicVectorSearch_WithRange_WithEpsilon(bool withScoreAlias, bool wi
299303
expected = [.. expected, "YIELD_SCORE_AS", "my_score_alias"];
300304
}
301305

306+
expected = [.. expected, "PARAMS", 2, "v", SomeRandomVectorValue];
302307
Assert.Equivalent(expected, GetArgs(query));
303308
}
304309

@@ -310,7 +315,7 @@ public void BasicVectorSearch_WithFilter_NoPolicy()
310315

311316
object[] expected =
312317
[
313-
Index, "VSIM", "vfield", SomeRandomVectorValue, "FILTER", "@foo:bar"
318+
Index, "VSIM", "vfield", "$v", "FILTER", "@foo:bar", "PARAMS", 2, "v", SomeRandomVectorValue
314319
];
315320

316321
Assert.Equivalent(expected, GetArgs(query));
@@ -702,12 +707,12 @@ public void MakeMeOneWithEverything()
702707
[
703708
Index, "SEARCH", "foo", "SCORER", "BM25STD.TANH", "BM25STD_TANH_FACTOR", 5, "YIELD_SCORE_AS",
704709
"text_score_alias", "VSIM", "bar",
705-
"AACAPwAAAEAAAEBA", "KNN", 6, "K", 10, "EF_RUNTIME", 100, "YIELD_DISTANCE_AS", "vector_distance_alias", "FILTER",
710+
"$v", "KNN", 6, "K", 10, "EF_RUNTIME", 100, "YIELD_DISTANCE_AS", "vector_distance_alias", "FILTER",
706711
"@foo:bar", "YIELD_SCORE_AS", "vector_score_alias", "COMBINE", "RRF", 4, "WINDOW", 10, "CONSTANT", 0.5,
707712
"YIELD_SCORE_AS", "my_combined_alias", "LOAD", 2, "field1", "field2", "GROUPBY", 1, "field1", "REDUCE",
708713
"QUANTILE", 2, "@field3", 0.5, "AS", "reducer_alias", "APPLY", "@field1 + @field2", "AS", "apply_alias",
709714
"SORTBY", 3, "field1", "field2", "DESC", "FILTER", "@field1:bar", "LIMIT", 12, 54,
710-
"PARAMS", 4, "x", 42, "y", "abc",
715+
"PARAMS", 6, "v", "AACAPwAAAEAAAEBA", "x", 42, "y", "abc",
711716
"EXPLAINSCORE", "TIMEOUT", 1000,
712717
"WITHCURSOR", "COUNT", 10, "MAXIDLE", 10000
713718
];

0 commit comments

Comments
 (0)