Skip to content

Commit a7c6582

Browse files
tests: Added task propagation to parallel
1 parent 16aebc0 commit a7c6582

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/test_text_multitask_embeddings.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,19 @@ def test_parallel_processing():
199199

200200
model = TextEmbedding(model_name=model_name)
201201

202-
embeddings_1 = list(model.embed(docs, batch_size=10, parallel=None))
202+
task_id = Task.SEPARATION
203+
embeddings_1 = list(model.embed(docs, batch_size=10, parallel=None, task_id=task_id))
203204
embeddings_1 = np.stack(embeddings_1, axis=0)
204205

205-
embeddings_2 = list(model.embed(docs, batch_size=10, parallel=1))
206+
embeddings_2 = list(model.embed(docs, batch_size=10, parallel=1, task_id=task_id))
206207
embeddings_2 = np.stack(embeddings_2, axis=0)
207208

208209
assert embeddings_1.shape[0] == len(docs) and embeddings_1.shape[-1] == dim
209210
assert np.allclose(embeddings_1, embeddings_2, atol=1e-4)
210211

212+
canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"]
213+
assert np.allclose(embeddings_2[:2, : canonical_vector.shape[1]], canonical_vector, atol=1e-4)
214+
211215
if is_ci:
212216
delete_model_cache(model.model._model_dir)
213217

0 commit comments

Comments
 (0)