Skip to content

Commit 588fbf0

Browse files
committed
Multimodal embedding proposal
1 parent 28ae9d3 commit 588fbf0

File tree

2 files changed

+350
-0
lines changed

2 files changed

+350
-0
lines changed

redisvl/utils/vectorize/multimidal/__init__.py

Whitespace-only changes.
Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
import os
2+
from typing import Any, Callable, Dict, List, Optional, Union
3+
4+
from pydantic import PrivateAttr, HttpUrl
5+
from tenacity import retry, stop_after_attempt, wait_random_exponential
6+
from tenacity.retry import retry_if_not_exception_type
7+
8+
from redisvl.utils.utils import deprecated_argument
9+
from redisvl.utils.vectorize.base import BaseMultimodalVectorizer
10+
from PIL import Image
11+
12+
# ignore that voyageai isn't imported
13+
# mypy: disable-error-code="name-defined"
14+
15+
16+
class VoyageAIMultimodalVectorizer(BaseMultimodalVectorizer):
17+
"""The VoyageAIMultimodalVectorizer class utilizes VoyageAI's API to generate
18+
embeddings for text or image data.
19+
20+
This vectorizer is designed to interact with VoyageAI's /multimodalembeddings API,
21+
requiring an API key for authentication. The key can be provided
22+
directly in the `api_config` dictionary or through the `VOYAGE_API_KEY`
23+
environment variable. User must obtain an API key from VoyageAI's website
24+
(https://dash.voyageai.com/). Additionally, the `voyageai` python
25+
client must be installed with `pip install voyageai`.
26+
27+
The vectorizer supports both synchronous and asynchronous operations, allows for batch
28+
processing of texts and flexibility in handling preprocessing tasks.
29+
30+
.. code-block:: python
31+
32+
from redisvl.utils.vectorize import VoyageAITextVectorizer
33+
34+
vectorizer = VoyageAIMultimodalVectorizer(
35+
model="voyage-multimodal-3",
36+
api_config={"api_key": "your-voyageai-api-key"} # OR set VOYAGE_API_KEY in your env
37+
)
38+
query_embedding = vectorizer.embed(
39+
text="your input query text here",
40+
input_type="query"
41+
)
42+
doc_embeddings = vectorizer.embed_many(
43+
texts=["your document text", "more document text"],
44+
input_type="document"
45+
)
46+
47+
"""
48+
49+
_client: Any = PrivateAttr()
50+
_aclient: Any = PrivateAttr()
51+
52+
def __init__(
53+
self,
54+
model: str,
55+
api_config: Optional[Dict] = None,
56+
dtype: str = "float32",
57+
**kwargs,
58+
):
59+
"""Initialize the VoyageAI vectorizer.
60+
61+
Visit https://docs.voyageai.com/docs/multimodal-embeddings to learn about embeddings and check the available models.
62+
63+
Args:
64+
model (str): Model to use for embedding. Defaults to "voyage-large-2".
65+
api_config (Optional[Dict], optional): Dictionary containing the API key.
66+
Defaults to None.
67+
dtype (str): the default datatype to use when embedding text as byte arrays.
68+
Used when setting `as_buffer=True` in calls to embed() and embed_many().
69+
Defaults to 'float32'.
70+
71+
Raises:
72+
ImportError: If the voyageai library is not installed.
73+
ValueError: If the API key is not provided.
74+
75+
"""
76+
super().__init__(model=model, dtype=dtype)
77+
# Init client
78+
self._initialize_client(api_config, **kwargs)
79+
# Set model dimensions after init
80+
self.dims = self._set_model_dims()
81+
82+
def _initialize_client(self, api_config: Optional[Dict], **kwargs):
83+
"""
84+
Setup the VoyageAI clients using the provided API key or an
85+
environment variable.
86+
"""
87+
if api_config is None:
88+
api_config = {}
89+
90+
# Dynamic import of the voyageai module
91+
try:
92+
from voyageai import AsyncClient, Client
93+
except ImportError:
94+
raise ImportError(
95+
"VoyageAI vectorizer requires the voyageai library. \
96+
Please install with `pip install voyageai`"
97+
)
98+
99+
# Fetch the API key from api_config or environment variable
100+
api_key = (
101+
api_config.get("api_key") if api_config else os.getenv("VOYAGE_API_KEY")
102+
)
103+
if not api_key:
104+
raise ValueError(
105+
"VoyageAI API key is required. "
106+
"Provide it in api_config or set the VOYAGE_API_KEY environment variable."
107+
)
108+
self._client = Client(api_key=api_key, **kwargs)
109+
self._aclient = AsyncClient(api_key=api_key, **kwargs)
110+
111+
def _set_model_dims(self) -> int:
112+
try:
113+
embedding = self.embed(["dimension check"], input_type="document")
114+
except (KeyError, IndexError) as ke:
115+
raise ValueError(f"Unexpected response from the VoyageAI API: {str(ke)}")
116+
except Exception as e: # pylint: disable=broad-except
117+
# fall back (TODO get more specific)
118+
raise ValueError(f"Error setting embedding model dimensions: {str(e)}")
119+
return len(embedding)
120+
121+
@deprecated_argument("dtype")
122+
def embed(
123+
self,
124+
content: List[Union[str, HttpUrl, Image]],
125+
preprocess: Optional[Callable] = None,
126+
as_buffer: bool = False,
127+
**kwargs,
128+
) -> Union[List[float], bytes]:
129+
"""Embed a chunk of text using the VoyageAI Embeddings API.
130+
131+
Can provide the embedding `input_type` as a `kwarg` to this method
132+
that specifies the type of input you're giving to the model. For retrieval/search use cases,
133+
we recommend specifying this argument when encoding queries or documents to enhance retrieval quality.
134+
Embeddings generated with and without the input_type argument are compatible.
135+
136+
Supported input types are ``document`` and ``query``
137+
138+
When hydrating your Redis DB, the documents you want to search over
139+
should be embedded with input_type="document" and when you are
140+
querying the database, you should set the input_type="query".
141+
142+
Args:
143+
content (List[Union[str, HttpUrl, Image]]): The content to embed.
144+
preprocess (Optional[Callable], optional): Optional preprocessing callable to
145+
perform before vectorization. Defaults to None.
146+
as_buffer (bool, optional): Whether to convert the raw embedding
147+
to a byte string. Defaults to False.
148+
truncation (bool): Whether to truncate the input texts to fit within the context length.
149+
Check https://docs.voyageai.com/docs/multimodal-embeddings
150+
151+
Returns:
152+
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
153+
object if as_buffer=True
154+
155+
Raises:
156+
TypeError: If an invalid input_type is provided.
157+
"""
158+
return self.embed_many(
159+
contents=[content], preprocess=preprocess, as_buffer=as_buffer, **kwargs
160+
)[0]
161+
162+
@retry(
163+
wait=wait_random_exponential(min=1, max=60),
164+
stop=stop_after_attempt(6),
165+
retry=retry_if_not_exception_type(TypeError),
166+
)
167+
@deprecated_argument("dtype")
168+
def embed_many(
169+
self,
170+
contents: List[List[Union[str, HttpUrl, Image]]],
171+
preprocess: Optional[Callable] = None,
172+
batch_size: int = 10,
173+
as_buffer: bool = False,
174+
**kwargs,
175+
) -> Union[List[List[float]], List[bytes]]:
176+
"""Embed many chunks of text using the VoyageAI Embeddings API.
177+
178+
Can provide the embedding `input_type` as a `kwarg` to this method
179+
that specifies the type of input you're giving to the model. For retrieval/search use cases,
180+
we recommend specifying this argument when encoding queries or documents to enhance retrieval quality.
181+
Embeddings generated with and without the input_type argument are compatible.
182+
183+
Supported input types are ``document`` and ``query``
184+
185+
When hydrating your Redis DB, the documents you want to search over
186+
should be embedded with input_type="document" and when you are
187+
querying the database, you should set the input_type="query".
188+
189+
Args:
190+
contents (List[List[Union[str, HttpUrl, Image]]]): List of contents chunks to embed.
191+
preprocess (Optional[Callable], optional): Optional preprocessing callable to
192+
perform before vectorization. Defaults to None.
193+
batch_size (int, optional): Batch size of texts to use when creating
194+
embeddings. .
195+
as_buffer (bool, optional): Whether to convert the raw embedding
196+
to a byte string. Defaults to False.
197+
input_type (str): Specifies the type of input passed to the model.
198+
truncation (bool): Whether to truncate the input texts to fit within the context length.
199+
Check https://docs.voyageai.com/docs/embeddings
200+
201+
Returns:
202+
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
203+
or as bytes objects if as_buffer=True
204+
205+
Raises:
206+
TypeError: If an invalid input_type is provided.
207+
208+
"""
209+
input_type = kwargs.pop("input_type", None)
210+
truncation = kwargs.pop("truncation", None)
211+
dtype = kwargs.pop("dtype", self.dtype)
212+
213+
if not isinstance(contents, list):
214+
raise TypeError("Must pass in a list of str values to embed.")
215+
if input_type is not None and input_type not in ["document", "query"]:
216+
raise TypeError(
217+
"Must pass in a allowed value for voyageai embedding input_type. \
218+
See https://docs.voyageai.com/docs/embeddings."
219+
)
220+
221+
if truncation is not None and not isinstance(truncation, bool):
222+
raise TypeError("Truncation (optional) parameter is a bool.")
223+
224+
if batch_size is None:
225+
batch_size = 10
226+
227+
embeddings: List = []
228+
for batch in self.batchify(contents, batch_size, preprocess):
229+
response = self._client.multimodal_embed(
230+
inputs=batch, model=self.model, input_type=input_type, **kwargs
231+
)
232+
embeddings += [
233+
self._process_embedding(embedding, as_buffer, dtype)
234+
for embedding in response.embeddings
235+
]
236+
return embeddings
237+
238+
@deprecated_argument("dtype")
239+
async def aembed_many(
240+
self,
241+
contents: List[List[Union[str, HttpUrl, Image]]],
242+
preprocess: Optional[Callable] = None,
243+
batch_size: int = 10,
244+
as_buffer: bool = False,
245+
**kwargs,
246+
) -> Union[List[List[float]], List[bytes]]:
247+
"""Embed many chunks of text using the VoyageAI Embeddings API.
248+
249+
Can provide the embedding `input_type` as a `kwarg` to this method
250+
that specifies the type of input you're giving to the model. For retrieval/search use cases,
251+
we recommend specifying this argument when encoding queries or documents to enhance retrieval quality.
252+
Embeddings generated with and without the input_type argument are compatible.
253+
254+
Supported input types are ``document`` and ``query``
255+
256+
When hydrating your Redis DB, the documents you want to search over
257+
should be embedded with input_type="document" and when you are
258+
querying the database, you should set the input_type="query".
259+
260+
Args:
261+
contents (List[List[Union[str, HttpUrl, Image]]]): List of contents chunks to embed.
262+
preprocess (Optional[Callable], optional): Optional preprocessing callable to
263+
perform before vectorization. Defaults to None.
264+
batch_size (int, optional): Batch size of texts to use when creating
265+
embeddings. .
266+
as_buffer (bool, optional): Whether to convert the raw embedding
267+
to a byte string. Defaults to False.
268+
input_type (str): Specifies the type of input passed to the model.
269+
truncation (bool): Whether to truncate the input texts to fit within the context length.
270+
Check https://docs.voyageai.com/docs/embeddings
271+
272+
Returns:
273+
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
274+
or as bytes objects if as_buffer=True
275+
276+
Raises:
277+
TypeError: In an invalid input_type is provided.
278+
279+
"""
280+
input_type = kwargs.pop("input_type", None)
281+
truncation = kwargs.pop("truncation", None)
282+
dtype = kwargs.pop("dtype", self.dtype)
283+
284+
if not isinstance(contents, list):
285+
raise TypeError("Must pass in a list of str values to embed.")
286+
if input_type is not None and input_type not in ["document", "query"]:
287+
raise TypeError(
288+
"Must pass in a allowed value for voyageai embedding input_type. \
289+
See https://docs.voyageai.com/docs/embeddings."
290+
)
291+
292+
if truncation is not None and not isinstance(truncation, bool):
293+
raise TypeError("Truncation (optional) parameter is a bool.")
294+
295+
if batch_size is None:
296+
batch_size = 10
297+
298+
embeddings: List = []
299+
for batch in self.batchify(contents, batch_size, preprocess):
300+
response = await self._aclient.multimodal_embed(
301+
inputs=batch, model=self.model, input_type=input_type, **kwargs
302+
)
303+
embeddings += [
304+
self._process_embedding(embedding, as_buffer, dtype)
305+
for embedding in response.embeddings
306+
]
307+
return embeddings
308+
309+
@deprecated_argument("dtype")
310+
async def aembed(
311+
self,
312+
content: List[Union[str, HttpUrl, Image]],
313+
preprocess: Optional[Callable] = None,
314+
as_buffer: bool = False,
315+
**kwargs,
316+
) -> Union[List[float], bytes]:
317+
"""Embed a chunk of text using the VoyageAI Embeddings API.
318+
319+
Can provide the embedding `input_type` as a `kwarg` to this method
320+
that specifies the type of input you're giving to the model. For retrieval/search use cases,
321+
we recommend specifying this argument when encoding queries or documents to enhance retrieval quality.
322+
Embeddings generated with and without the input_type argument are compatible.
323+
324+
Supported input types are ``document`` and ``query``
325+
326+
When hydrating your Redis DB, the documents you want to search over
327+
should be embedded with input_type="document" and when you are
328+
querying the database, you should set the input_type="query".
329+
330+
Args:
331+
content (List[Union[str, HttpUrl, Image]]): The content to embed.
332+
preprocess (Optional[Callable], optional): Optional preprocessing callable to
333+
perform before vectorization. Defaults to None.
334+
as_buffer (bool, optional): Whether to convert the raw embedding
335+
to a byte string. Defaults to False.
336+
input_type (str): Specifies the type of input passed to the model.
337+
truncation (bool): Whether to truncate the input texts to fit within the context length.
338+
Check https://docs.voyageai.com/docs/embeddings
339+
340+
Returns:
341+
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
342+
object if as_buffer=True
343+
344+
Raises:
345+
TypeError: In an invalid input_type is provided.
346+
"""
347+
result = await self.aembed_many(
348+
contents=[content], preprocess=preprocess, as_buffer=as_buffer, **kwargs
349+
)
350+
return result[0]

0 commit comments

Comments
 (0)