|
16 | 16 | import torch.distributed as dist
|
17 | 17 | from torchrec.metrics.metrics_config import DefaultTaskInfo
|
18 | 18 | from torchrec.metrics.model_utils import parse_task_model_outputs
|
19 |
| -from torchrec.metrics.rec_metric import RecComputeMode, RecMetric, RecTaskInfo |
| 19 | +from torchrec.metrics.rec_metric import ( |
| 20 | + RecComputeMode, |
| 21 | + RecMetric, |
| 22 | + RecMetricException, |
| 23 | + RecTaskInfo, |
| 24 | +) |
20 | 25 | from torchrec.metrics.test_utils import (
|
21 | 26 | gen_test_batch,
|
22 | 27 | gen_test_tasks,
|
@@ -327,3 +332,48 @@ def test_mtml_empty_update(self) -> None:
|
327 | 332 | self.assertEqual(
|
328 | 333 | qps._metrics_computations[1].num_examples, (step + 2) // 2 * batch_size
|
329 | 334 | )
|
| 335 | + |
| 336 | + def test_tower_qps_update_with_invalid_tensors(self) -> None: |
| 337 | + warmup_steps = 2 |
| 338 | + batch_size = 128 |
| 339 | + task_names = ["t1", "t2"] |
| 340 | + tasks = gen_test_tasks(task_names) |
| 341 | + qps = TowerQPSMetric( |
| 342 | + world_size=1, |
| 343 | + my_rank=0, |
| 344 | + batch_size=batch_size, |
| 345 | + tasks=tasks, |
| 346 | + warmup_steps=warmup_steps, |
| 347 | + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, |
| 348 | + compute_on_all_ranks=False, |
| 349 | + should_validate_update=True, |
| 350 | + window_size=200, |
| 351 | + ) |
| 352 | + |
| 353 | + with self.assertRaisesRegex( |
| 354 | + RecMetricException, |
| 355 | + "Failed to convert labels to tensor for fused computation", |
| 356 | + ): |
| 357 | + qps.update( |
| 358 | + predictions=torch.ones(batch_size), |
| 359 | + labels={ |
| 360 | + "key_0": torch.rand(batch_size), |
| 361 | + "key_1": torch.rand(batch_size), |
| 362 | + "key_2": torch.rand(batch_size), |
| 363 | + }, |
| 364 | + weights=torch.rand(batch_size), |
| 365 | + ) |
| 366 | + |
| 367 | + with self.assertRaisesRegex( |
| 368 | + RecMetricException, |
| 369 | + "Failed to convert weights to tensor for fused computation", |
| 370 | + ): |
| 371 | + qps.update( |
| 372 | + predictions=torch.ones(batch_size), |
| 373 | + labels=torch.rand(batch_size), |
| 374 | + weights={ |
| 375 | + "key_0": torch.rand(batch_size), |
| 376 | + "key_1": torch.rand(batch_size), |
| 377 | + "key_2": torch.rand(batch_size), |
| 378 | + }, |
| 379 | + ) |
0 commit comments