Skip to content

Commit 18110ed

Browse files
fix issue with vectorizer default name
1 parent da13fb8 commit 18110ed

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

redisvl/utils/vectorize/text/voyageai.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,17 @@ class VoyageAITextVectorizer(BaseVectorizer):
4848
_aclient: Any = PrivateAttr()
4949

5050
def __init__(
51-
self, model: str, api_config: Optional[Dict] = None, dtype: str = "float32"
51+
self,
52+
model: str = "voyage-large-2",
53+
api_config: Optional[Dict] = None,
54+
dtype: str = "float32",
5255
):
5356
"""Initialize the VoyageAI vectorizer.
5457
5558
Visit https://docs.voyageai.com/docs/embeddings to learn about embeddings and check the available models.
5659
5760
Args:
58-
model (str): Model to use for embedding.
61+
model (str): Model to use for embedding. Defaults to "voyage-large-2".
5962
api_config (Optional[Dict], optional): Dictionary containing the API key.
6063
Defaults to None.
6164
dtype (str): the default datatype to use when embedding text as byte arrays.

tests/integration/test_vectorizers.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def test_dtypes(vector_class, skip_vectorizer):
273273
)
274274
else:
275275
vectorizer = vector_class()
276+
276277
assert vectorizer.dtype == "float32"
277278

278279
# test initializing dtype in constructor
@@ -288,6 +289,7 @@ def test_dtypes(vector_class, skip_vectorizer):
288289
)
289290
else:
290291
vectorizer = vector_class(dtype=dtype)
292+
291293
assert vectorizer.dtype == dtype
292294

293295
# test validation of dtype on init
@@ -317,13 +319,7 @@ def avectorizer(request, skip_vectorizer):
317319
if skip_vectorizer:
318320
pytest.skip("Skipping vectorizer instantiation...")
319321

320-
if request.param == OpenAITextVectorizer:
321-
return request.param()
322-
elif request.param == BedrockTextVectorizer:
323-
return request.param()
324-
elif request.param == MistralAITextVectorizer:
325-
return request.param()
326-
elif request.param == CustomTextVectorizer:
322+
if request.param == CustomTextVectorizer:
327323

328324
def embed_func(text):
329325
return [1.1, 2.2, 3.3, 4.4]
@@ -337,8 +333,8 @@ async def aembed_many_func(texts):
337333
return request.param(
338334
embed=embed_func, aembed=aembed_func, aembed_many=aembed_many_func
339335
)
340-
elif request.param == VoyageAITextVectorizer:
341-
return request.param(model="voyage-large-2")
336+
else:
337+
return request.param()
342338

343339

344340
@pytest.mark.asyncio

0 commit comments

Comments
 (0)