@@ -170,6 +170,23 @@ def test_batch_embedding(model_name: str):
170170 delete_model_cache (model .model ._model_dir )
171171
172172
173+ @pytest .mark .parametrize ("model_name" , ["answerdotai/answerai-colbert-small-v1" ])
174+ def test_batch_inference_size_same_as_single_inference (model_name : str ):
175+ is_ci = os .getenv ("CI" )
176+
177+ model = LateInteractionTextEmbedding (model_name = model_name )
178+ docs_to_embed = [
179+ "short document" ,
180+ "A bit longer document, which should not affect the size"
181+ ]
182+ result = list (model .embed (docs_to_embed , batch_size = 1 ))
183+ result_2 = list (model .embed (docs_to_embed , batch_size = 2 ))
184+ assert len (result [0 ]) == len (result_2 [0 ])
185+
186+ if is_ci :
187+ delete_model_cache (model .model ._model_dir )
188+
189+
173190@pytest .mark .parametrize ("model_name" , ["answerdotai/answerai-colbert-small-v1" ])
174191def test_single_embedding (model_name : str ):
175192 is_ci = os .getenv ("CI" )
@@ -219,17 +236,16 @@ def test_parallel_processing(token_dim: int, model_name: str):
219236
220237 docs = ["hello world" , "flag embedding" ] * 100
221238 embeddings = list (model .embed (docs , batch_size = 10 , parallel = 2 ))
222- embeddings = np .stack (embeddings , axis = 0 )
223239
224240 embeddings_2 = list (model .embed (docs , batch_size = 10 , parallel = None ))
225- embeddings_2 = np .stack (embeddings_2 , axis = 0 )
226241
227242 embeddings_3 = list (model .embed (docs , batch_size = 10 , parallel = 0 ))
228- embeddings_3 = np .stack (embeddings_3 , axis = 0 )
229243
230- assert embeddings .shape [0 ] == len (docs ) and embeddings .shape [- 1 ] == token_dim
231- assert np .allclose (embeddings , embeddings_2 , atol = 1e-3 )
232- assert np .allclose (embeddings , embeddings_3 , atol = 1e-3 )
244+ assert len (embeddings ) == len (docs ) and embeddings [0 ].shape [- 1 ] == token_dim
245+
246+ for i in range (len (embeddings )):
247+ assert np .allclose (embeddings [i ], embeddings_2 [i ], atol = 1e-3 )
248+ assert np .allclose (embeddings [i ], embeddings_3 [i ], atol = 1e-3 )
233249
234250 if is_ci :
235251 delete_model_cache (model .model ._model_dir )
0 commit comments