Skip to content

Commit 074216a

Browse files
ilyas409facebook-github-bot
authored andcommitted
Better Tests for Better Engineering: metrics/throughput.py
Summary: Wrote better tests for the checkpoint restoration between: BSS job -> None-BSS job None-BSS job ->BSS job BSS job -> BSS job To better simulate the exhibited code behavior: - Create prev. modules, checkpoint it, create new module, restore from prev. checkpoint Also caught a bug with the new tests: Need to register the hooks regardless of the BSS being used for the cases when we run an online training job without BSS using an offline checkpoint Reviewed By: burak-turk Differential Revision: D72567969 fbshipit-source-id: 78b394596fb38a28ba8617ca1830503bab70199c
1 parent 35b14b0 commit 074216a

File tree

3 files changed

+90
-21
lines changed

3 files changed

+90
-21
lines changed

torchrec/metrics/tests/test_metric_module.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,3 +586,20 @@ def test_save_and_load_state_dict(self) -> None:
586586
self.assertEqual(throughput_metric._num_batch, 100)
587587
# Make sure num_batch is correctly synchronized
588588
self.assertEqual(throughput_metric._num_batch, 100)
589+
590+
# Load the same checkpoint into a module that doesn't use BSS
591+
592+
no_bss_metric_module = generate_metric_module(
593+
TestMetricModule,
594+
metrics_config=DefaultMetricsConfig,
595+
batch_size=128,
596+
world_size=1,
597+
my_rank=0,
598+
state_metrics_mapping={},
599+
device=torch.device("cpu"),
600+
batch_size_stages=None,
601+
)
602+
603+
no_bss_metric_module.load_state_dict(state_dict)
604+
# Make sure num_batch wasn't created on the throughput module (and no exception was thrown above)
605+
self.assertFalse(hasattr(no_bss_metric_module.throughput_metric, "_num_batch"))

torchrec/metrics/tests/test_throughput.py

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def test_state_dict_hook_no_batch_size_stages(self) -> None:
310310
Verifies that the state_dict_hook does not add the 'num_batch' key when
311311
batch_size_stages is None.
312312
"""
313+
# Hook-only test
313314
throughput_metric = ThroughputMetric(
314315
batch_size=32,
315316
world_size=4,
@@ -321,41 +322,84 @@ def test_state_dict_hook_no_batch_size_stages(self) -> None:
321322
ThroughputMetric.state_dict_hook(throughput_metric, state_dict, prefix, {})
322323
self.assertNotIn(f"{prefix}num_batch", state_dict)
323324

324-
def test_load_state_dict_hook_restores_value(self) -> None:
325+
# Lifecycle test
326+
327+
num_updates = 10
328+
prev_job_throughput_metric = ThroughputMetric(
329+
batch_size=32,
330+
world_size=4,
331+
window_seconds=100,
332+
batch_size_stages=None,
333+
)
334+
for _ in range(num_updates):
335+
prev_job_throughput_metric.update()
336+
prev_state_dict = prev_job_throughput_metric.state_dict()
337+
338+
curr_job_throughput_metric = ThroughputMetric(
339+
batch_size=32,
340+
world_size=4,
341+
window_seconds=100,
342+
batch_size_stages=None,
343+
)
344+
345+
curr_job_throughput_metric.load_state_dict(prev_state_dict)
346+
# Make sure _num_batch is not present as an argument of the class
347+
self.assertFalse(hasattr(curr_job_throughput_metric, "_num_batch"))
348+
349+
def test_load_state_dict_hook_resumes_from_checkpoint_with_bss_from_bss(
350+
self,
351+
) -> None:
325352
"""
326353
Checks that the load_state_dict_hook correctly restores the 'num_batch' value
327354
from the state_dict.
328355
"""
329-
throughput_metric = ThroughputMetric(
356+
num_updates = 10
357+
prev_job_throughput_metric = ThroughputMetric(
330358
batch_size=32,
331359
world_size=4,
332360
window_seconds=100,
333361
batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)],
334362
)
335-
state_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
336-
prefix: str = "test_prefix_"
337-
state_dict[f"{prefix}num_batch"] = torch.tensor(10, dtype=torch.long)
338-
throughput_metric.load_state_dict_hook(state_dict, prefix, {}, True, [], [], [])
339-
self.assertEqual(throughput_metric._num_batch, 10)
363+
for _ in range(num_updates):
364+
prev_job_throughput_metric.update()
365+
prev_state_dict = prev_job_throughput_metric.state_dict()
366+
367+
curr_job_throughput_metric = ThroughputMetric(
368+
batch_size=32,
369+
world_size=4,
370+
window_seconds=100,
371+
batch_size_stages=[BatchSizeStage(1024, 1), BatchSizeStage(2048, None)],
372+
)
373+
374+
curr_job_throughput_metric.load_state_dict(prev_state_dict)
375+
self.assertEqual(curr_job_throughput_metric._num_batch, num_updates)
340376

341377
def test_load_state_dict_hook_resumes_from_checkpoint_without_bss(self) -> None:
342378
"""
343379
Verifies that the load_state_dict_hook correctly handles the case where a
344380
previously checkpointed job used the batch_size_stages, but a subsequent job,
345381
restored from a checkpoint, isn't using them.
346382
"""
347-
throughput_metric = ThroughputMetric(
383+
384+
prev_job_throughput_metric = ThroughputMetric(
385+
batch_size=32,
386+
world_size=4,
387+
window_seconds=100,
388+
batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)],
389+
)
390+
391+
prev_state_dict = prev_job_throughput_metric.state_dict()
392+
393+
curr_job_throughput_metric = ThroughputMetric(
348394
batch_size=32,
349395
world_size=4,
350396
window_seconds=100,
351397
batch_size_stages=None, # No batch_size_stages
352398
)
353-
state_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
354-
prefix: str = "test_prefix_"
355-
state_dict[f"{prefix}num_batch"] = torch.tensor(10, dtype=torch.long)
356-
throughput_metric.load_state_dict_hook(state_dict, prefix, {}, True, [], [], [])
357399

358-
self.assertFalse(hasattr(throughput_metric, "_num_batch"))
400+
curr_job_throughput_metric.load_state_dict(prev_state_dict)
401+
402+
self.assertFalse(hasattr(curr_job_throughput_metric, "_num_batch"))
359403

360404
def test_load_state_dict_hook_resumes_from_checkpoint_with_bss_without_key(
361405
self,
@@ -365,15 +409,22 @@ def test_load_state_dict_hook_resumes_from_checkpoint_with_bss_without_key(
365409
previously checkpointed job didn't use batch_size_stages, but a subsequent job,
366410
restored from a checkpoint, is using them.
367411
"""
368-
throughput_metric = ThroughputMetric(
412+
prev_job_throughput_metric = ThroughputMetric(
413+
batch_size=32,
414+
world_size=4,
415+
window_seconds=100,
416+
batch_size_stages=None, # No batch_size_stages
417+
)
418+
prev_state_dict = prev_job_throughput_metric.state_dict()
419+
420+
curr_job_throughput_metric = ThroughputMetric(
369421
batch_size=32,
370422
world_size=4,
371423
window_seconds=100,
372424
batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)],
373425
)
374-
# Empty state_dict
375-
state_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
376-
prefix: str = "test_prefix_"
377-
throughput_metric.load_state_dict_hook(state_dict, prefix, {}, True, [], [], [])
426+
427+
curr_job_throughput_metric.load_state_dict(prev_state_dict)
428+
378429
# Expecting 0
379-
self.assertEqual(throughput_metric._num_batch, 0)
430+
self.assertEqual(curr_job_throughput_metric._num_batch, 0)

torchrec/metrics/throughput.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,9 @@ def __init__(
115115
if self._batch_size_stages is not None:
116116
# Keep track of the number of batches if using batch_size_stages
117117
self._num_batch: int = 0
118-
self._register_load_state_dict_pre_hook(self.load_state_dict_hook)
119-
self.register_state_dict_post_hook(self.state_dict_hook)
118+
119+
self._register_load_state_dict_pre_hook(self.load_state_dict_hook)
120+
self.register_state_dict_post_hook(self.state_dict_hook)
120121

121122
self.register_buffer("total_examples", torch.tensor(0, dtype=torch.long))
122123
self.register_buffer("warmup_examples", torch.tensor(0, dtype=torch.long))

0 commit comments

Comments
 (0)