Skip to content

Commit 977c2c3

Browse files
committed
feat: support multi2vec-nvidia vectorizer
1 parent bbac850 commit 977c2c3

File tree

4 files changed

+190
-53
lines changed

4 files changed

+190
-53
lines changed

src/collections/config/types/vectorizer.ts

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ type Text2VecPalmVectorizer = 'text2vec-palm';
1919

2020
export type Vectorizer =
2121
| 'img2vec-neural'
22+
| 'multi2vec-nvidia'
2223
| 'multi2vec-clip'
2324
| 'multi2vec-cohere'
2425
| 'multi2vec-bind'
@@ -65,6 +66,32 @@ export type Multi2VecField = {
6566
weight?: number;
6667
};
6768

69+
/** The configuration for multi-media vectorization using the NVIDIA module.
70+
*
71+
* See the [documentation](https://weaviate.io/developers/weaviate/model-providers/nvidia/embeddings-multimodal) for detailed usage.
72+
*/
73+
export type Multi2VecNvidiaConfig = {
74+
/** The model to use. Defaults to `None`, which uses the server-defined default. */
75+
model?: string;
76+
/** The base URL where API requests should go. */
77+
baseURL?: string;
78+
/** Whether to apply truncation. */
79+
truncation?: boolean;
80+
/** Format in which the embeddings are encoded. Defaults to `None`, so the embeddings are represented as a list of floating-point numbers. */
81+
output_encoding?: string;
82+
/** The image fields used when vectorizing. */
83+
imageFields?: string[];
84+
/** The text fields used when vectorizing. */
85+
textFields?: string[];
86+
/** The weights of the fields used for vectorization. */
87+
weights?: {
88+
/** The weights of the image fields. */
89+
imageFields?: number[];
90+
/** The weights of the text fields. */
91+
textFields?: number[];
92+
};
93+
};
94+
6895
/** The configuration for multi-media vectorization using the CLIP module.
6996
*
7097
* See the [documentation](https://weaviate.io/developers/weaviate/model-providers/transformers/embeddings-multimodal) for detailed usage.
@@ -569,6 +596,8 @@ export type VectorizerConfig =
569596

570597
export type VectorizerConfigType<V> = V extends 'img2vec-neural'
571598
? Img2VecNeuralConfig | undefined
599+
: V extends 'multi2vec-nvidia'
600+
? Multi2VecNvidiaConfig | undefined
572601
: V extends 'multi2vec-clip'
573602
? Multi2VecClipConfig | undefined
574603
: V extends 'multi2vec-cohere'

src/collections/configure/types/vectorizer.ts

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ export type VectorConfigUpdate<N extends string | undefined, I extends VectorInd
6464

6565
export type VectorizersConfigCreate<T, V> = V extends undefined
6666
?
67-
| VectorConfigCreate<PrimitiveKeys<T>, string | undefined, VectorIndexType, Vectorizer>
68-
| VectorConfigCreate<PrimitiveKeys<T>, string, VectorIndexType, Vectorizer>[]
67+
| VectorConfigCreate<PrimitiveKeys<T>, string | undefined, VectorIndexType, Vectorizer>
68+
| VectorConfigCreate<PrimitiveKeys<T>, string, VectorIndexType, Vectorizer>[]
6969
:
70-
| VectorConfigCreate<PrimitiveKeys<T>, (keyof V & string) | undefined, VectorIndexType, Vectorizer>
71-
| VectorConfigCreate<PrimitiveKeys<T>, keyof V & string, VectorIndexType, Vectorizer>[];
70+
| VectorConfigCreate<PrimitiveKeys<T>, (keyof V & string) | undefined, VectorIndexType, Vectorizer>
71+
| VectorConfigCreate<PrimitiveKeys<T>, keyof V & string, VectorIndexType, Vectorizer>[];
7272

7373
export type VectorizersConfigAdd<T> =
7474
| VectorConfigCreate<PrimitiveKeys<T>, string, VectorIndexType, Vectorizer>
@@ -112,6 +112,38 @@ export type ConfigureTextMultiVectorizerOptions<
112112

113113
export type Img2VecNeuralConfigCreate = Img2VecNeuralConfig;
114114

115+
// model: Optional[str] = None,
116+
// truncation: Optional[bool] = None,
117+
// output_encoding: Optional[str],
118+
// vectorize_collection_name: bool = True,
119+
// base_url: Optional[AnyHttpUrl] = None,
120+
// image_fields: Optional[Union[List[str], List[Multi2VecField]]] = None,
121+
// text_fields: Optional[Union[List[str], List[Multi2VecField]]] = None,
122+
123+
// model: The model to use. Defaults to `None`, which uses the server-defined default.
124+
// output_encoding: Format in which the embeddings are encoded. Defaults to `None`, so the embeddings are represented as a list of floating-point numbers.
125+
// vectorize_collection_name: Whether to vectorize the collection name. Defaults to `True`.
126+
// base_url: The base URL to use where API requests should go. Defaults to `None`, which uses the server-defined default.
127+
// image_fields: The image fields to use in vectorization.
128+
// text_fields: The text fields to use in vectorization.
129+
130+
131+
/** The configuration for the `multi2vec-nvidia` vectorizer. */
132+
export type Multi2VecNvidiaConfigCreate = {
133+
/** The model to use. Defaults to `None`, which uses the server-defined default. */
134+
model?: string;
135+
/** The base URL where API requests should go. */
136+
baseURL?: string;
137+
/** Whether to apply truncation. */
138+
truncation?: boolean;
139+
/** Format in which the embeddings are encoded. Defaults to `None`, so the embeddings are represented as a list of floating-point numbers. */
140+
outputEncoding?: string;
141+
/** The image fields to use in vectorization. Can be string of `Multi2VecField` type. If string, weight 0 will be assumed. */
142+
imageFields?: string[] | Multi2VecField[];
143+
/** The text fields to use in vectorization. Can be string of `Multi2VecField` type. If string, weight 0 will be assumed. */
144+
textFields?: string[] | Multi2VecField[];
145+
};
146+
115147
/** The configuration for the `multi2vec-clip` vectorizer. */
116148
export type Multi2VecClipConfigCreate = {
117149
/** The image fields to use in vectorization. Can be string of `Multi2VecField` type. If string, weight 0 will be assumed. */
@@ -261,6 +293,8 @@ export type Text2MultiVecJinaAIConfigCreate = Text2MultiVecJinaAIConfig;
261293

262294
export type VectorizerConfigCreateType<V> = V extends 'img2vec-neural'
263295
? Img2VecNeuralConfigCreate | undefined
296+
: V extends 'multi2vec-nvidia'
297+
? Multi2VecNvidiaConfigCreate | undefined
264298
: V extends 'multi2vec-clip'
265299
? Multi2VecClipConfigCreate | undefined
266300
: V extends 'multi2vec-cohere'

src/collections/configure/unit.test.ts

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,45 @@ describe('Unit testing of the vectorizer factory class', () => {
639639
},
640640
});
641641
});
642+
it('should create the correct Multi2VecNvidiaConfig type with all values and weights', () => {
643+
const config = configure.vectors.multi2VecNvidia({
644+
name: 'test',
645+
model: 'model-id',
646+
outputEncoding: 'base64',
647+
truncation: true,
648+
baseURL: 'example.com',
649+
imageFields: [
650+
{ name: 'field1', weight: 0.1 },
651+
{ name: 'field2', weight: 0.2 },
652+
],
653+
textFields: [
654+
{ name: 'field3', weight: 0.3 },
655+
{ name: 'field4', weight: 0.4 },
656+
],
657+
});
658+
expect(config).toEqual<VectorConfigCreate<never, 'test', 'hnsw', 'multi2vec-nvidia'>>({
659+
name: 'test',
660+
vectorIndex: {
661+
name: 'hnsw',
662+
config: undefined,
663+
},
664+
vectorizer: {
665+
name: 'multi2vec-nvidia',
666+
config: {
667+
output_encoding: 'base64',
668+
truncation: true,
669+
baseURL: 'example.com',
670+
imageFields: ['field1', 'field2'],
671+
textFields: ['field3', 'field4'],
672+
model: 'model-id',
673+
weights: {
674+
imageFields: [0.1, 0.2],
675+
textFields: [0.3, 0.4],
676+
},
677+
},
678+
},
679+
});
680+
});
642681

643682
it('should create the correct Multi2VecJinaAIConfig type with defaults', () => {
644683
const config = configure.vectors.multi2VecJinaAI();

src/collections/configure/vectorizer.ts

Lines changed: 84 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import {
44
Multi2VecBindConfig,
55
Multi2VecClipConfig,
66
Multi2VecField,
7+
Multi2VecNvidiaConfig,
78
Multi2VecPalmConfig,
89
Multi2VecVoyageAIConfig,
910
VectorIndexType,
@@ -37,20 +38,20 @@ const makeVectorIndex = (opts?: {
3738
}
3839
conf = conf
3940
? {
40-
...conf,
41-
multiVector: conf.multiVector
42-
? {
43-
...conf.multiVector,
44-
encoding: conf.multiVector.encoding
45-
? { ...conf.multiVector.encoding, ...opts.encoding }
46-
: opts.encoding,
47-
}
48-
: vectorIndex.multiVector.multiVector({ encoding: opts.encoding }),
49-
}
41+
...conf,
42+
multiVector: conf.multiVector
43+
? {
44+
...conf.multiVector,
45+
encoding: conf.multiVector.encoding
46+
? { ...conf.multiVector.encoding, ...opts.encoding }
47+
: opts.encoding,
48+
}
49+
: vectorIndex.multiVector.multiVector({ encoding: opts.encoding }),
50+
}
5051
: {
51-
multiVector: vectorIndex.multiVector.multiVector({ encoding: opts.encoding }),
52-
type: 'hnsw',
53-
};
52+
multiVector: vectorIndex.multiVector.multiVector({ encoding: opts.encoding }),
53+
type: 'hnsw',
54+
};
5455
}
5556
if (opts?.quantizer) {
5657
if (!conf) {
@@ -199,16 +200,16 @@ const legacyVectors = {
199200
Object.keys(config).length === 0
200201
? undefined
201202
: {
202-
...config,
203-
audioFields: audioFields?.map((f) => f.name),
204-
depthFields: depthFields?.map((f) => f.name),
205-
imageFields: imageFields?.map((f) => f.name),
206-
IMUFields: IMUFields?.map((f) => f.name),
207-
textFields: textFields?.map((f) => f.name),
208-
thermalFields: thermalFields?.map((f) => f.name),
209-
videoFields: videoFields?.map((f) => f.name),
210-
weights: Object.keys(weights).length === 0 ? undefined : weights,
211-
},
203+
...config,
204+
audioFields: audioFields?.map((f) => f.name),
205+
depthFields: depthFields?.map((f) => f.name),
206+
imageFields: imageFields?.map((f) => f.name),
207+
IMUFields: IMUFields?.map((f) => f.name),
208+
textFields: textFields?.map((f) => f.name),
209+
thermalFields: thermalFields?.map((f) => f.name),
210+
videoFields: videoFields?.map((f) => f.name),
211+
weights: Object.keys(weights).length === 0 ? undefined : weights,
212+
},
212213
},
213214
});
214215
},
@@ -238,11 +239,11 @@ const legacyVectors = {
238239
Object.keys(config).length === 0
239240
? undefined
240241
: {
241-
...config,
242-
imageFields: imageFields?.map((f) => f.name),
243-
textFields: textFields?.map((f) => f.name),
244-
weights: Object.keys(weights).length === 0 ? undefined : weights,
245-
},
242+
...config,
243+
imageFields: imageFields?.map((f) => f.name),
244+
textFields: textFields?.map((f) => f.name),
245+
weights: Object.keys(weights).length === 0 ? undefined : weights,
246+
},
246247
},
247248
});
248249
},
@@ -272,11 +273,11 @@ const legacyVectors = {
272273
Object.keys(config).length === 0
273274
? undefined
274275
: {
275-
...config,
276-
imageFields: imageFields?.map((f) => f.name),
277-
textFields: textFields?.map((f) => f.name),
278-
weights: Object.keys(weights).length === 0 ? undefined : weights,
279-
},
276+
...config,
277+
imageFields: imageFields?.map((f) => f.name),
278+
textFields: textFields?.map((f) => f.name),
279+
weights: Object.keys(weights).length === 0 ? undefined : weights,
280+
},
280281
},
281282
});
282283
},
@@ -307,11 +308,11 @@ const legacyVectors = {
307308
Object.keys(config).length === 0
308309
? undefined
309310
: {
310-
...config,
311-
imageFields: imageFields?.map((f) => f.name),
312-
textFields: textFields?.map((f) => f.name),
313-
weights: Object.keys(weights).length === 0 ? undefined : weights,
314-
},
311+
...config,
312+
imageFields: imageFields?.map((f) => f.name),
313+
textFields: textFields?.map((f) => f.name),
314+
weights: Object.keys(weights).length === 0 ? undefined : weights,
315+
},
315316
},
316317
});
317318
},
@@ -411,11 +412,11 @@ const legacyVectors = {
411412
Object.keys(config).length === 0
412413
? undefined
413414
: {
414-
...config,
415-
imageFields: imageFields?.map((f) => f.name),
416-
textFields: textFields?.map((f) => f.name),
417-
weights: Object.keys(weights).length === 0 ? undefined : weights,
418-
},
415+
...config,
416+
imageFields: imageFields?.map((f) => f.name),
417+
textFields: textFields?.map((f) => f.name),
418+
weights: Object.keys(weights).length === 0 ? undefined : weights,
419+
},
419420
},
420421
});
421422
},
@@ -842,9 +843,9 @@ const __vectors_shaded = {
842843
legacyVectors.text2VecGoogle(
843844
opts
844845
? {
845-
...opts,
846-
...(opts?.modelId || opts?.model ? { modelId: opts?.modelId || opts?.model } : undefined),
847-
}
846+
...opts,
847+
...(opts?.modelId || opts?.model ? { modelId: opts?.modelId || opts?.model } : undefined),
848+
}
848849
: undefined
849850
),
850851
text2VecOpenAI: <T, N extends string | undefined = undefined, I extends VectorIndexType = 'hnsw'>(
@@ -908,9 +909,43 @@ export const vectorizer = legacyVectors;
908909
// Remove deprecated vectorizers and module configuration parameters:
909910
// - PaLM vectorizers are called -Google now.
910911
// - __vectors_shaded hide/rename some parameters
911-
export const vectors = (({ text2VecPalm, multi2VecPalm, ...rest }) => ({ ...rest, ...__vectors_shaded }))(
912-
legacyVectors
913-
);
912+
export const vectors = (({ text2VecPalm, multi2VecPalm, ...rest }) => ({
913+
...rest,
914+
...__vectors_shaded,
915+
916+
/**
917+
* Create a `VectorConfigCreate` object with the vectorizer set to `'multi2vec-nvidia'`.
918+
*
919+
* See the [documentation](https://weaviate.io/developers/weaviate/model-providers/nvidia/embeddings-multimodal) for detailed usage.
920+
*
921+
* @param {ConfigureNonTextVectorizerOptions<N, I, 'multi2vec-nvidia'>} [opts] The configuration options for the `multi2vec-nvidia` vectorizer.
922+
* @returns {VectorConfigCreate<PrimitiveKeys<T>[], N, I, 'multi2vec-nvidia'>} The configuration object.
923+
*/
924+
multi2VecNvidia: <N extends string | undefined = undefined, I extends VectorIndexType = 'hnsw'>(
925+
opts?: ConfigureNonTextVectorizerOptions<N, I, 'multi2vec-nvidia'>,
926+
): VectorConfigCreate<never, N, I, 'multi2vec-nvidia'> => {
927+
const { name, quantizer, vectorIndexConfig, outputEncoding, ...config } = opts || {};
928+
const imageFields = config.imageFields?.map(mapMulti2VecField);
929+
const textFields = config.textFields?.map(mapMulti2VecField);
930+
let weights: Multi2VecNvidiaConfig['weights'] = {};
931+
weights = formatMulti2VecFields(weights, 'imageFields', imageFields);
932+
weights = formatMulti2VecFields(weights, 'textFields', textFields);
933+
return makeVectorizer(name, {
934+
quantizer,
935+
vectorIndexConfig,
936+
vectorizerConfig: {
937+
name: 'multi2vec-nvidia',
938+
config: {
939+
...config,
940+
output_encoding: outputEncoding,
941+
imageFields: imageFields?.map((f) => f.name),
942+
textFields: textFields?.map((f) => f.name),
943+
weights: Object.keys(weights).length === 0 ? undefined : weights,
944+
},
945+
},
946+
});
947+
},
948+
}))(legacyVectors);
914949

915950
export const multiVectors = {
916951
/**

0 commit comments

Comments
 (0)