@@ -258,7 +258,7 @@ def bad_return_type(text: str) -> str:
258258 VoyageAITextVectorizer ,
259259 ],
260260)
261- def test_dtypes (vectorizer_ ):
261+ def test_default_dtype (vectorizer_ ):
262262 # test dtype defaults to float32
263263 if issubclass (vectorizer_ , CustomTextVectorizer ):
264264 vectorizer = vectorizer_ (embed = lambda x , input_type = None : [1.0 , 2.0 , 3.0 ])
@@ -271,6 +271,23 @@ def test_dtypes(vectorizer_):
271271
272272 assert vectorizer .dtype == "float32"
273273
274+
275+ @pytest .mark .requires_api_keys
276+ @pytest .mark .parametrize (
277+ "vectorizer_" ,
278+ [
279+ AzureOpenAITextVectorizer ,
280+ BedrockTextVectorizer ,
281+ CohereTextVectorizer ,
282+ CustomTextVectorizer ,
283+ HFTextVectorizer ,
284+ MistralAITextVectorizer ,
285+ OpenAITextVectorizer ,
286+ VertexAITextVectorizer ,
287+ VoyageAITextVectorizer ,
288+ ],
289+ )
290+ def test_other_dtypes (vectorizer_ ):
274291 # test initializing dtype in constructor
275292 for dtype in ["float16" , "float32" , "float64" , "bfloat16" ]:
276293 if issubclass (vectorizer_ , CustomTextVectorizer ):
@@ -287,14 +304,30 @@ def test_dtypes(vectorizer_):
287304
288305 assert vectorizer .dtype == dtype
289306
307+
308+ @pytest .mark .requires_api_keys
309+ @pytest .mark .parametrize (
310+ "vectorizer_" ,
311+ [
312+ AzureOpenAITextVectorizer ,
313+ BedrockTextVectorizer ,
314+ CohereTextVectorizer ,
315+ HFTextVectorizer ,
316+ MistralAITextVectorizer ,
317+ OpenAITextVectorizer ,
318+ VertexAITextVectorizer ,
319+ VoyageAITextVectorizer ,
320+ ],
321+ )
322+ def test_bad_dtypes (vectorizer_ ):
290323 with pytest .raises (ValueError ):
291- vectorizer = vectorizer_ (dtype = "float25" )
324+ vectorizer_ (dtype = "float25" )
292325
293326 with pytest .raises (ValueError ):
294- vectorizer = vectorizer_ (dtype = 7 )
327+ vectorizer_ (dtype = 7 )
295328
296329 with pytest .raises (ValueError ):
297- vectorizer = vectorizer_ (dtype = None )
330+ vectorizer_ (dtype = None )
298331
299332
300333@pytest .fixture (
0 commit comments