Skip to content

Commit d51e4f3

Browse files
vectorizer typing changes
1 parent a354220 commit d51e4f3

File tree

14 files changed

+658
-429
lines changed

14 files changed

+658
-429
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,15 +310,17 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
310310
if not isinstance(prompt, str):
311311
raise TypeError("Prompt must be a string.")
312312

313-
return self._vectorizer.embed(prompt)
313+
result = self._vectorizer.embed(prompt)
314+
return result # type: ignore
314315

315316
async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
316317
"""Converts a text prompt to its vector representation using the
317318
configured vectorizer."""
318319
if not isinstance(prompt, str):
319320
raise TypeError("Prompt must be a string.")
320321

321-
return await self._vectorizer.aembed(prompt)
322+
result = await self._vectorizer.aembed(prompt)
323+
return result # type: ignore
322324

323325
def _check_vector_dims(self, vector: List[float]):
324326
"""Checks the size of the provided vector and raises an error if it

redisvl/extensions/router/semantic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,14 +366,14 @@ def __call__(
366366
if not vector:
367367
if not statement:
368368
raise ValueError("Must provide a vector or statement to the router")
369-
vector = self.vectorizer.embed(statement)
369+
vector = self.vectorizer.embed(statement) # type: ignore
370370

371371
aggregation_method = (
372372
aggregation_method or self.routing_config.aggregation_method
373373
)
374374

375375
# perform route classification
376-
top_route_match = self._classify_route(vector, aggregation_method)
376+
top_route_match = self._classify_route(vector, aggregation_method) # type: ignore
377377
return top_route_match
378378

379379
@deprecated_argument("distance_threshold")
@@ -400,7 +400,7 @@ def route_many(
400400
if not vector:
401401
if not statement:
402402
raise ValueError("Must provide a vector or statement to the router")
403-
vector = self.vectorizer.embed(statement)
403+
vector = self.vectorizer.embed(statement) # type: ignore
404404

405405
max_k = max_k or self.routing_config.max_k
406406
aggregation_method = (
@@ -409,7 +409,7 @@ def route_many(
409409

410410
# classify routes
411411
top_route_matches = self._classify_multi_route(
412-
vector, max_k, aggregation_method
412+
vector, max_k, aggregation_method # type: ignore
413413
)
414414

415415
return top_route_matches

redisvl/extensions/session_manager/semantic_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def add_messages(
349349
role=message[ROLE_FIELD_NAME],
350350
content=message[CONTENT_FIELD_NAME],
351351
session_tag=session_tag,
352-
vector_field=content_vector,
352+
vector_field=content_vector, # type: ignore
353353
)
354354

355355
if TOOL_FIELD_NAME in message:

redisvl/utils/vectorize/base.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from enum import Enum
3-
from typing import Callable, List, Optional
3+
from typing import Callable, List, Optional, Union
44

55
from pydantic import BaseModel, Field, field_validator
66

@@ -49,24 +49,47 @@ def check_dims(cls, value):
4949
return value
5050

5151
@abstractmethod
52-
def embed_many(
52+
def embed(
5353
self,
54-
texts: List[str],
54+
text: str,
5555
preprocess: Optional[Callable] = None,
56-
batch_size: int = 1000,
5756
as_buffer: bool = False,
5857
**kwargs,
59-
) -> List[List[float]]:
58+
) -> Union[List[float], bytes]:
59+
"""Embed a chunk of text.
60+
61+
Args:
62+
text: Text to embed
63+
preprocess: Optional function to preprocess text
64+
as_buffer: If True, returns a bytes object instead of a list
65+
66+
Returns:
67+
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
68+
object if as_buffer=True
69+
"""
6070
raise NotImplementedError
6171

6272
@abstractmethod
63-
def embed(
73+
def embed_many(
6474
self,
65-
text: str,
75+
texts: List[str],
6676
preprocess: Optional[Callable] = None,
77+
batch_size: int = 1000,
6778
as_buffer: bool = False,
6879
**kwargs,
69-
) -> List[float]:
80+
) -> Union[List[List[float]], List[bytes]]:
81+
"""Embed multiple chunks of text.
82+
83+
Args:
84+
texts: List of texts to embed
85+
preprocess: Optional function to preprocess text
86+
batch_size: Number of texts to process in each batch
87+
as_buffer: If True, returns each embedding as a bytes object
88+
89+
Returns:
90+
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
91+
or as bytes objects if as_buffer=True
92+
"""
7093
raise NotImplementedError
7194

7295
async def aembed_many(
@@ -76,7 +99,19 @@ async def aembed_many(
7699
batch_size: int = 1000,
77100
as_buffer: bool = False,
78101
**kwargs,
79-
) -> List[List[float]]:
102+
) -> Union[List[List[float]], List[bytes]]:
103+
"""Asynchronously embed multiple chunks of text.
104+
105+
Args:
106+
texts: List of texts to embed
107+
preprocess: Optional function to preprocess text
108+
batch_size: Number of texts to process in each batch
109+
as_buffer: If True, returns each embedding as a bytes object
110+
111+
Returns:
112+
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
113+
or as bytes objects if as_buffer=True
114+
"""
80115
# Fallback to standard embedding call if no async support
81116
return self.embed_many(texts, preprocess, batch_size, as_buffer, **kwargs)
82117

@@ -86,7 +121,18 @@ async def aembed(
86121
preprocess: Optional[Callable] = None,
87122
as_buffer: bool = False,
88123
**kwargs,
89-
) -> List[float]:
124+
) -> Union[List[float], bytes]:
125+
"""Asynchronously embed a chunk of text.
126+
127+
Args:
128+
text: Text to embed
129+
preprocess: Optional function to preprocess text
130+
as_buffer: If True, returns a bytes object instead of a list
131+
132+
Returns:
133+
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
134+
object if as_buffer=True
135+
"""
90136
# Fallback to standard embedding call if no async support
91137
return self.embed(text, preprocess, as_buffer, **kwargs)
92138

redisvl/utils/vectorize/text/azureopenai.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Any, Callable, Dict, List, Optional
2+
from typing import Any, Callable, Dict, List, Optional, Union
33

44
from pydantic import PrivateAttr
55
from tenacity import retry, stop_after_attempt, wait_random_exponential
@@ -178,7 +178,7 @@ def embed_many(
178178
batch_size: int = 10,
179179
as_buffer: bool = False,
180180
**kwargs,
181-
) -> List[List[float]]:
181+
) -> Union[List[List[float]], List[bytes]]:
182182
"""Embed many chunks of texts using the AzureOpenAI API.
183183
184184
Args:
@@ -191,7 +191,8 @@ def embed_many(
191191
to a byte string. Defaults to False.
192192
193193
Returns:
194-
List[List[float]]: List of embeddings.
194+
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
195+
or as bytes objects if as_buffer=True
195196
196197
Raises:
197198
TypeError: If the wrong input type is passed in for the test.
@@ -226,7 +227,7 @@ def embed(
226227
preprocess: Optional[Callable] = None,
227228
as_buffer: bool = False,
228229
**kwargs,
229-
) -> List[float]:
230+
) -> Union[List[float], bytes]:
230231
"""Embed a chunk of text using the AzureOpenAI API.
231232
232233
Args:
@@ -237,7 +238,8 @@ def embed(
237238
to a byte string. Defaults to False.
238239
239240
Returns:
240-
List[float]: Embedding.
241+
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
242+
object if as_buffer=True
241243
242244
Raises:
243245
TypeError: If the wrong input type is passed in for the test.
@@ -268,7 +270,7 @@ async def aembed_many(
268270
batch_size: int = 1000,
269271
as_buffer: bool = False,
270272
**kwargs,
271-
) -> List[List[float]]:
273+
) -> Union[List[List[float]], List[bytes]]:
272274
"""Asynchronously embed many chunks of texts using the AzureOpenAI API.
273275
274276
Args:
@@ -281,7 +283,8 @@ async def aembed_many(
281283
to a byte string. Defaults to False.
282284
283285
Returns:
284-
List[List[float]]: List of embeddings.
286+
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
287+
or as bytes objects if as_buffer=True
285288
286289
Raises:
287290
TypeError: If the wrong input type is passed in for the test.
@@ -316,7 +319,7 @@ async def aembed(
316319
preprocess: Optional[Callable] = None,
317320
as_buffer: bool = False,
318321
**kwargs,
319-
) -> List[float]:
322+
) -> Union[List[float], bytes]:
320323
"""Asynchronously embed a chunk of text using the OpenAI API.
321324
322325
Args:
@@ -327,7 +330,8 @@ async def aembed(
327330
to a byte string. Defaults to False.
328331
329332
Returns:
330-
List[float]: Embedding.
333+
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
334+
object if as_buffer=True
331335
332336
Raises:
333337
TypeError: If the wrong input type is passed in for the test.

redisvl/utils/vectorize/text/bedrock.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import os
3-
from typing import Any, Callable, Dict, List, Optional
3+
from typing import Any, Callable, Dict, List, Optional, Union
44

55
from pydantic import PrivateAttr
66
from tenacity import retry, stop_after_attempt, wait_random_exponential
@@ -135,16 +135,17 @@ def embed(
135135
preprocess: Optional[Callable] = None,
136136
as_buffer: bool = False,
137137
**kwargs,
138-
) -> List[float]:
139-
"""Embed a chunk of text using Amazon Bedrock.
138+
) -> Union[List[float], bytes]:
139+
"""Embed a chunk of text using the AWS Bedrock Embeddings API.
140140
141141
Args:
142142
text (str): Text to embed.
143143
preprocess (Optional[Callable]): Optional preprocessing function.
144144
as_buffer (bool): Whether to return as byte buffer.
145145
146146
Returns:
147-
List[float]: The embedding vector.
147+
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
148+
object if as_buffer=True
148149
149150
Raises:
150151
TypeError: If text is not a string.
@@ -177,17 +178,18 @@ def embed_many(
177178
batch_size: int = 10,
178179
as_buffer: bool = False,
179180
**kwargs,
180-
) -> List[List[float]]:
181-
"""Embed multiple texts using Amazon Bedrock.
181+
) -> Union[List[List[float]], List[bytes]]:
182+
"""Embed many chunks of text using the AWS Bedrock Embeddings API.
182183
183184
Args:
184185
texts (List[str]): List of texts to embed.
185186
preprocess (Optional[Callable]): Optional preprocessing function.
186-
batch_size (int): Size of batches for processing.
187+
batch_size (int): Size of batches for processing. Defaults to 10.
187188
as_buffer (bool): Whether to return as byte buffers.
188189
189190
Returns:
190-
List[List[float]]: List of embedding vectors.
191+
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
192+
or as bytes objects if as_buffer=True
191193
192194
Raises:
193195
TypeError: If texts is not a list of strings.

0 commit comments

Comments
 (0)