Skip to content

Commit 1247c3d

Browse files
committed
feat: add minimum_should_match operator to bm25 and hybrid queries
1 parent f85b97e commit 1247c3d

File tree

2 files changed

+114
-93
lines changed

2 files changed

+114
-93
lines changed

src/collections/query/types.ts

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,15 @@ export type Bm25QueryProperty<T> = {
8484
weight: number;
8585
};
8686

87+
export type Bm25OperatorOptions = {
88+
operator: 'and' | 'or';
89+
minimumMatch: number;
90+
}
91+
8792
export type Bm25SearchOptions<T> = {
8893
/** Which properties of the collection to perform the keyword search on. */
8994
queryProperties?: (PrimitiveKeys<T> | Bm25QueryProperty<T>)[];
95+
operator?: Bm25OperatorOptions;
9096
};
9197

9298
/** Base options available in the `query.bm25` method */
@@ -115,6 +121,7 @@ export type HybridSearchOptions<T> = {
115121
targetVector?: TargetVectorInputType;
116122
/** The specific vector to search for or a specific vector subsearch. If not specified, the query is vectorized and used in the similarity search. */
117123
vector?: NearVectorInputType | HybridNearTextSubSearch | HybridNearVectorSubSearch;
124+
bm25Operator?: Bm25OperatorOptions;
118125
};
119126

120127
/** Base options available in the `query.hybrid` method */
@@ -500,12 +507,12 @@ interface NearVector<T> {
500507
/** All the available methods on the `.query` namespace. */
501508
export interface Query<T>
502509
extends Bm25<T>,
503-
Hybrid<T>,
504-
NearImage<T>,
505-
NearMedia<T>,
506-
NearObject<T>,
507-
NearText<T>,
508-
NearVector<T> {
510+
Hybrid<T>,
511+
NearImage<T>,
512+
NearMedia<T>,
513+
NearObject<T>,
514+
NearText<T>,
515+
NearVector<T> {
509516
/**
510517
* Retrieve an object from the server by its UUID.
511518
*

src/collections/serialize/index.ts

Lines changed: 101 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ import {
1515
NearThermalSearch,
1616
NearVector,
1717
NearVideoSearch,
18+
SearchOperatorOptions,
19+
SearchOperatorOptions_Operator,
1820
Targets,
1921
VectorForTarget,
2022
WeightsForTarget,
@@ -115,6 +117,7 @@ import {
115117
import {
116118
BaseHybridOptions,
117119
BaseNearOptions,
120+
Bm25OperatorOptions,
118121
Bm25Options,
119122
Bm25QueryProperty,
120123
Bm25SearchOptions,
@@ -385,10 +388,10 @@ class Aggregate {
385388
text:
386389
metric.kind === 'text'
387390
? AggregateRequest_Aggregation_Text.fromPartial({
388-
count: metric.count,
389-
topOccurencesLimit: metric.minOccurrences,
390-
topOccurences: metric.topOccurrences != undefined,
391-
})
391+
count: metric.count,
392+
topOccurencesLimit: metric.minOccurrences,
393+
topOccurences: metric.topOccurrences != undefined,
394+
})
392395
: undefined,
393396
})
394397
);
@@ -522,27 +525,27 @@ class Search {
522525
returnAllNonrefProperties: nonRefProperties === undefined,
523526
refProperties: refProperties
524527
? refProperties.map((property) => {
525-
return {
526-
referenceProperty: property.linkOn,
527-
properties: Search.queryProperties(property.returnProperties as any),
528-
metadata: Search.metadata(property.includeVector, property.returnMetadata),
529-
targetCollection: property.targetCollection ? property.targetCollection : '',
530-
};
531-
})
528+
return {
529+
referenceProperty: property.linkOn,
530+
properties: Search.queryProperties(property.returnProperties as any),
531+
metadata: Search.metadata(property.includeVector, property.returnMetadata),
532+
targetCollection: property.targetCollection ? property.targetCollection : '',
533+
};
534+
})
532535
: [],
533536
objectProperties: objectProperties
534537
? objectProperties.map((property) => {
535-
const objProps = property.properties.filter(
536-
(property) => typeof property !== 'string'
537-
) as unknown; // cannot get types to work currently :(
538-
return {
539-
propName: property.name,
540-
primitiveProperties: property.properties.filter(
541-
(property) => typeof property === 'string'
542-
) as string[],
543-
objectProperties: (objProps as QueryNested<T>[]).map(resolveObjectProperty),
544-
};
545-
})
538+
const objProps = property.properties.filter(
539+
(property) => typeof property !== 'string'
540+
) as unknown; // cannot get types to work currently :(
541+
return {
542+
propName: property.name,
543+
primitiveProperties: property.properties.filter(
544+
(property) => typeof property === 'string'
545+
) as string[],
546+
objectProperties: (objProps as QueryNested<T>[]).map(resolveObjectProperty),
547+
};
548+
})
546549
: [],
547550
};
548551
};
@@ -914,30 +917,30 @@ export class Serialize {
914917

915918
return args.supportsSingleGrouped
916919
? GenerativeSearch.fromPartial({
917-
single: opts?.singlePrompt
918-
? GenerativeSearch_Single.fromPartial({
919-
prompt: singlePrompt,
920-
debug: singlePromptDebug,
921-
queries: opts.config ? [await Serialize.generativeQuery(opts.config, singleOpts)] : undefined,
922-
})
923-
: undefined,
924-
grouped: opts?.groupedTask
925-
? GenerativeSearch_Grouped.fromPartial({
926-
task: groupedTask,
927-
queries: opts.config
928-
? [await Serialize.generativeQuery(opts.config, groupedOpts)]
929-
: undefined,
930-
properties: groupedProperties
931-
? TextArray.fromPartial({ values: groupedProperties as string[] })
932-
: undefined,
933-
})
934-
: undefined,
935-
})
920+
single: opts?.singlePrompt
921+
? GenerativeSearch_Single.fromPartial({
922+
prompt: singlePrompt,
923+
debug: singlePromptDebug,
924+
queries: opts.config ? [await Serialize.generativeQuery(opts.config, singleOpts)] : undefined,
925+
})
926+
: undefined,
927+
grouped: opts?.groupedTask
928+
? GenerativeSearch_Grouped.fromPartial({
929+
task: groupedTask,
930+
queries: opts.config
931+
? [await Serialize.generativeQuery(opts.config, groupedOpts)]
932+
: undefined,
933+
properties: groupedProperties
934+
? TextArray.fromPartial({ values: groupedProperties as string[] })
935+
: undefined,
936+
})
937+
: undefined,
938+
})
936939
: GenerativeSearch.fromPartial({
937-
singleResponsePrompt: singlePrompt,
938-
groupedResponseTask: groupedTask,
939-
groupedProperties: groupedProperties as string[],
940-
});
940+
singleResponsePrompt: singlePrompt,
941+
groupedResponseTask: groupedTask,
942+
groupedProperties: groupedProperties as string[],
943+
});
941944
};
942945

943946
public static isSinglePrompt(arg?: string | SinglePrompt): arg is SinglePrompt {
@@ -960,10 +963,20 @@ export class Serialize {
960963
});
961964
};
962965

966+
private static bm25SearchOperator = (searchOperator?: Bm25OperatorOptions): SearchOperatorOptions | undefined => {
967+
if (searchOperator) {
968+
return SearchOperatorOptions.fromPartial({
969+
minimumOrTokensMatch: searchOperator.minimumMatch,
970+
operator: searchOperator.operator === 'and' as const ? SearchOperatorOptions_Operator.OPERATOR_AND : SearchOperatorOptions_Operator.OPERATOR_OR,
971+
});
972+
}
973+
}
974+
963975
public static bm25Search = <T>(args: { query: string } & Bm25SearchOptions<T>): BM25 => {
964976
return BM25.fromPartial({
965977
query: args.query,
966978
properties: this.bm25QueryProperties(args.queryProperties),
979+
searchOperator: this.bm25SearchOperator(args.operator),
967980
});
968981
};
969982

@@ -1005,13 +1018,13 @@ export class Serialize {
10051018
return vectorBytes !== undefined
10061019
? { vectorBytes, targetVectors, targets }
10071020
: {
1008-
targetVectors,
1009-
targets,
1010-
nearVector: NearVector.fromPartial({
1011-
vectorForTargets,
1012-
vectorPerTarget,
1013-
}),
1014-
};
1021+
targetVectors,
1022+
targets,
1023+
nearVector: NearVector.fromPartial({
1024+
vectorForTargets,
1025+
vectorPerTarget,
1026+
}),
1027+
};
10151028
} else if (Serialize.isHybridNearTextSearch(vector)) {
10161029
const { targetVectors, targets } = Serialize.targetVector(args);
10171030
return {
@@ -1074,6 +1087,7 @@ export class Serialize {
10741087
vectorBytes: vectorBytes,
10751088
vectorDistance: args.maxVectorDistance,
10761089
fusionType: fusionType(args.fusionType),
1090+
bm25SearchOperator: this.bm25SearchOperator(args.bm25Operator),
10771091
targetVectors,
10781092
targets,
10791093
nearText,
@@ -1165,17 +1179,17 @@ export class Serialize {
11651179
targetVectors,
11661180
moveAway: args.moveAway
11671181
? NearTextSearch_Move.fromPartial({
1168-
concepts: args.moveAway.concepts,
1169-
force: args.moveAway.force,
1170-
uuids: args.moveAway.objects,
1171-
})
1182+
concepts: args.moveAway.concepts,
1183+
force: args.moveAway.force,
1184+
uuids: args.moveAway.objects,
1185+
})
11721186
: undefined,
11731187
moveTo: args.moveTo
11741188
? NearTextSearch_Move.fromPartial({
1175-
concepts: args.moveTo.concepts,
1176-
force: args.moveTo.force,
1177-
uuids: args.moveTo.objects,
1178-
})
1189+
concepts: args.moveTo.concepts,
1190+
force: args.moveTo.force,
1191+
uuids: args.moveTo.objects,
1192+
})
11791193
: undefined,
11801194
});
11811195
};
@@ -1231,18 +1245,18 @@ export class Serialize {
12311245
} else if (TargetVectorInputGuards.isSingle(args.targetVector)) {
12321246
return args.supportsTargets
12331247
? {
1234-
targets: Targets.fromPartial({
1235-
targetVectors: [args.targetVector],
1236-
}),
1237-
}
1248+
targets: Targets.fromPartial({
1249+
targetVectors: [args.targetVector],
1250+
}),
1251+
}
12381252
: { targetVectors: [args.targetVector] };
12391253
} else if (TargetVectorInputGuards.isMulti(args.targetVector)) {
12401254
return args.supportsTargets
12411255
? {
1242-
targets: Targets.fromPartial({
1243-
targetVectors: args.targetVector,
1244-
}),
1245-
}
1256+
targets: Targets.fromPartial({
1257+
targetVectors: args.targetVector,
1258+
}),
1259+
}
12461260
: { targetVectors: args.targetVector };
12471261
} else {
12481262
return { targets: Serialize.targets(args.targetVector, args.supportsWeightsForTargets) };
@@ -1287,22 +1301,22 @@ export class Serialize {
12871301
.reduce((acc, { target, vector }) => {
12881302
return ArrayInputGuards.is2DArray(vector)
12891303
? acc.concat(
1290-
vector.map((v) => ({ name: target, vectorBytes: Serialize.vectorToBytes(v), vectors: [] }))
1291-
)
1304+
vector.map((v) => ({ name: target, vectorBytes: Serialize.vectorToBytes(v), vectors: [] }))
1305+
)
12921306
: acc.concat([{ name: target, vectorBytes: Serialize.vectorToBytes(vector), vectors: [] }]);
12931307
}, [] as VectorForTarget[]);
12941308
return args.targetVector !== undefined
12951309
? {
1296-
...Serialize.targetVector(args),
1297-
vectorForTargets,
1298-
}
1310+
...Serialize.targetVector(args),
1311+
vectorForTargets,
1312+
}
12991313
: {
1300-
targetVectors: undefined,
1301-
targets: Targets.fromPartial({
1302-
targetVectors: vectorForTargets.map((v) => v.name),
1303-
}),
1304-
vectorForTargets,
1305-
};
1314+
targetVectors: undefined,
1315+
targets: Targets.fromPartial({
1316+
targetVectors: vectorForTargets.map((v) => v.name),
1317+
}),
1318+
vectorForTargets,
1319+
};
13061320
} else {
13071321
const vectorPerTarget: Record<string, Uint8Array> = {};
13081322
Object.entries(args.vector).forEach(([k, v]) => {
@@ -1321,15 +1335,15 @@ export class Serialize {
13211335
} else {
13221336
return args.supportsTargets
13231337
? {
1324-
targets: Targets.fromPartial({
1325-
targetVectors: Object.keys(vectorPerTarget),
1326-
}),
1327-
vectorPerTarget,
1328-
}
1329-
: {
1338+
targets: Targets.fromPartial({
13301339
targetVectors: Object.keys(vectorPerTarget),
1331-
vectorPerTarget,
1332-
};
1340+
}),
1341+
vectorPerTarget,
1342+
}
1343+
: {
1344+
targetVectors: Object.keys(vectorPerTarget),
1345+
vectorPerTarget,
1346+
};
13331347
}
13341348
}
13351349
} else {

0 commit comments

Comments
 (0)