Skip to content

Commit 01b544f

Browse files
break out dtype tests
1 parent 1436e24 commit 01b544f

File tree

1 file changed

+37
-4
lines changed

1 file changed

+37
-4
lines changed

tests/integration/test_vectorizers.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def bad_return_type(text: str) -> str:
258258
VoyageAITextVectorizer,
259259
],
260260
)
261-
def test_dtypes(vectorizer_):
261+
def test_default_dtype(vectorizer_):
262262
# test dtype defaults to float32
263263
if issubclass(vectorizer_, CustomTextVectorizer):
264264
vectorizer = vectorizer_(embed=lambda x, input_type=None: [1.0, 2.0, 3.0])
@@ -271,6 +271,23 @@ def test_dtypes(vectorizer_):
271271

272272
assert vectorizer.dtype == "float32"
273273

274+
275+
@pytest.mark.requires_api_keys
276+
@pytest.mark.parametrize(
277+
"vectorizer_",
278+
[
279+
AzureOpenAITextVectorizer,
280+
BedrockTextVectorizer,
281+
CohereTextVectorizer,
282+
CustomTextVectorizer,
283+
HFTextVectorizer,
284+
MistralAITextVectorizer,
285+
OpenAITextVectorizer,
286+
VertexAITextVectorizer,
287+
VoyageAITextVectorizer,
288+
],
289+
)
290+
def test_other_dtypes(vectorizer_):
274291
# test initializing dtype in constructor
275292
for dtype in ["float16", "float32", "float64", "bfloat16"]:
276293
if issubclass(vectorizer_, CustomTextVectorizer):
@@ -287,14 +304,30 @@ def test_dtypes(vectorizer_):
287304

288305
assert vectorizer.dtype == dtype
289306

307+
308+
@pytest.mark.requires_api_keys
309+
@pytest.mark.parametrize(
310+
"vectorizer_",
311+
[
312+
AzureOpenAITextVectorizer,
313+
BedrockTextVectorizer,
314+
CohereTextVectorizer,
315+
HFTextVectorizer,
316+
MistralAITextVectorizer,
317+
OpenAITextVectorizer,
318+
VertexAITextVectorizer,
319+
VoyageAITextVectorizer,
320+
],
321+
)
322+
def test_bad_dtypes(vectorizer_):
290323
with pytest.raises(ValueError):
291-
vectorizer = vectorizer_(dtype="float25")
324+
vectorizer_(dtype="float25")
292325

293326
with pytest.raises(ValueError):
294-
vectorizer = vectorizer_(dtype=7)
327+
vectorizer_(dtype=7)
295328

296329
with pytest.raises(ValueError):
297-
vectorizer = vectorizer_(dtype=None)
330+
vectorizer_(dtype=None)
298331

299332

300333
@pytest.fixture(

0 commit comments

Comments
 (0)