Skip to content

Commit ee1da3a

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
test_tower_qps_needed_coverage fix (#3264)
Summary: Pull Request resolved: #3264 This patch adds additional unit test to increase test_tower_qps unit test coverage. Reviewed By: kausv Differential Revision: D79810524 fbshipit-source-id: eee0ab962d6d07764aaa2a875e81bd98830ceace
1 parent 8d8bccb commit ee1da3a

File tree

1 file changed

+51
-1
lines changed

1 file changed

+51
-1
lines changed

torchrec/metrics/tests/test_tower_qps.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
import torch.distributed as dist
1717
from torchrec.metrics.metrics_config import DefaultTaskInfo
1818
from torchrec.metrics.model_utils import parse_task_model_outputs
19-
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric, RecTaskInfo
19+
from torchrec.metrics.rec_metric import (
20+
RecComputeMode,
21+
RecMetric,
22+
RecMetricException,
23+
RecTaskInfo,
24+
)
2025
from torchrec.metrics.test_utils import (
2126
gen_test_batch,
2227
gen_test_tasks,
@@ -327,3 +332,48 @@ def test_mtml_empty_update(self) -> None:
327332
self.assertEqual(
328333
qps._metrics_computations[1].num_examples, (step + 2) // 2 * batch_size
329334
)
335+
336+
def test_tower_qps_update_with_invalid_tensors(self) -> None:
337+
warmup_steps = 2
338+
batch_size = 128
339+
task_names = ["t1", "t2"]
340+
tasks = gen_test_tasks(task_names)
341+
qps = TowerQPSMetric(
342+
world_size=1,
343+
my_rank=0,
344+
batch_size=batch_size,
345+
tasks=tasks,
346+
warmup_steps=warmup_steps,
347+
compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
348+
compute_on_all_ranks=False,
349+
should_validate_update=True,
350+
window_size=200,
351+
)
352+
353+
with self.assertRaisesRegex(
354+
RecMetricException,
355+
"Failed to convert labels to tensor for fused computation",
356+
):
357+
qps.update(
358+
predictions=torch.ones(batch_size),
359+
labels={
360+
"key_0": torch.rand(batch_size),
361+
"key_1": torch.rand(batch_size),
362+
"key_2": torch.rand(batch_size),
363+
},
364+
weights=torch.rand(batch_size),
365+
)
366+
367+
with self.assertRaisesRegex(
368+
RecMetricException,
369+
"Failed to convert weights to tensor for fused computation",
370+
):
371+
qps.update(
372+
predictions=torch.ones(batch_size),
373+
labels=torch.rand(batch_size),
374+
weights={
375+
"key_0": torch.rand(batch_size),
376+
"key_1": torch.rand(batch_size),
377+
"key_2": torch.rand(batch_size),
378+
},
379+
)

0 commit comments

Comments
 (0)