@@ -310,6 +310,7 @@ def test_state_dict_hook_no_batch_size_stages(self) -> None:
310
310
Verifies that the state_dict_hook does not add the 'num_batch' key when
311
311
batch_size_stages is None.
312
312
"""
313
+ # Hook-only test
313
314
throughput_metric = ThroughputMetric (
314
315
batch_size = 32 ,
315
316
world_size = 4 ,
@@ -321,41 +322,84 @@ def test_state_dict_hook_no_batch_size_stages(self) -> None:
321
322
ThroughputMetric .state_dict_hook (throughput_metric , state_dict , prefix , {})
322
323
self .assertNotIn (f"{ prefix } num_batch" , state_dict )
323
324
324
- def test_load_state_dict_hook_restores_value (self ) -> None :
325
+ # Lifecycle test
326
+
327
+ num_updates = 10
328
+ prev_job_throughput_metric = ThroughputMetric (
329
+ batch_size = 32 ,
330
+ world_size = 4 ,
331
+ window_seconds = 100 ,
332
+ batch_size_stages = None ,
333
+ )
334
+ for _ in range (num_updates ):
335
+ prev_job_throughput_metric .update ()
336
+ prev_state_dict = prev_job_throughput_metric .state_dict ()
337
+
338
+ curr_job_throughput_metric = ThroughputMetric (
339
+ batch_size = 32 ,
340
+ world_size = 4 ,
341
+ window_seconds = 100 ,
342
+ batch_size_stages = None ,
343
+ )
344
+
345
+ curr_job_throughput_metric .load_state_dict (prev_state_dict )
346
+ # Make sure _num_batch is not present as an argument of the class
347
+ self .assertFalse (hasattr (curr_job_throughput_metric , "_num_batch" ))
348
+
349
+ def test_load_state_dict_hook_resumes_from_checkpoint_with_bss_from_bss (
350
+ self ,
351
+ ) -> None :
325
352
"""
326
353
Checks that the load_state_dict_hook correctly restores the 'num_batch' value
327
354
from the state_dict.
328
355
"""
329
- throughput_metric = ThroughputMetric (
356
+ num_updates = 10
357
+ prev_job_throughput_metric = ThroughputMetric (
330
358
batch_size = 32 ,
331
359
world_size = 4 ,
332
360
window_seconds = 100 ,
333
361
batch_size_stages = [BatchSizeStage (256 , 1 ), BatchSizeStage (512 , None )],
334
362
)
335
- state_dict : OrderedDict [str , torch .Tensor ] = OrderedDict ()
336
- prefix : str = "test_prefix_"
337
- state_dict [f"{ prefix } num_batch" ] = torch .tensor (10 , dtype = torch .long )
338
- throughput_metric .load_state_dict_hook (state_dict , prefix , {}, True , [], [], [])
339
- self .assertEqual (throughput_metric ._num_batch , 10 )
363
+ for _ in range (num_updates ):
364
+ prev_job_throughput_metric .update ()
365
+ prev_state_dict = prev_job_throughput_metric .state_dict ()
366
+
367
+ curr_job_throughput_metric = ThroughputMetric (
368
+ batch_size = 32 ,
369
+ world_size = 4 ,
370
+ window_seconds = 100 ,
371
+ batch_size_stages = [BatchSizeStage (1024 , 1 ), BatchSizeStage (2048 , None )],
372
+ )
373
+
374
+ curr_job_throughput_metric .load_state_dict (prev_state_dict )
375
+ self .assertEqual (curr_job_throughput_metric ._num_batch , num_updates )
340
376
341
377
def test_load_state_dict_hook_resumes_from_checkpoint_without_bss (self ) -> None :
342
378
"""
343
379
Verifies that the load_state_dict_hook correctly handles the case where a
344
380
previously checkpointed job used the batch_size_stages, but a subsequent job,
345
381
restored from a checkpoint, isn't using them.
346
382
"""
347
- throughput_metric = ThroughputMetric (
383
+
384
+ prev_job_throughput_metric = ThroughputMetric (
385
+ batch_size = 32 ,
386
+ world_size = 4 ,
387
+ window_seconds = 100 ,
388
+ batch_size_stages = [BatchSizeStage (256 , 1 ), BatchSizeStage (512 , None )],
389
+ )
390
+
391
+ prev_state_dict = prev_job_throughput_metric .state_dict ()
392
+
393
+ curr_job_throughput_metric = ThroughputMetric (
348
394
batch_size = 32 ,
349
395
world_size = 4 ,
350
396
window_seconds = 100 ,
351
397
batch_size_stages = None , # No batch_size_stages
352
398
)
353
- state_dict : OrderedDict [str , torch .Tensor ] = OrderedDict ()
354
- prefix : str = "test_prefix_"
355
- state_dict [f"{ prefix } num_batch" ] = torch .tensor (10 , dtype = torch .long )
356
- throughput_metric .load_state_dict_hook (state_dict , prefix , {}, True , [], [], [])
357
399
358
- self .assertFalse (hasattr (throughput_metric , "_num_batch" ))
400
+ curr_job_throughput_metric .load_state_dict (prev_state_dict )
401
+
402
+ self .assertFalse (hasattr (curr_job_throughput_metric , "_num_batch" ))
359
403
360
404
def test_load_state_dict_hook_resumes_from_checkpoint_with_bss_without_key (
361
405
self ,
@@ -365,15 +409,22 @@ def test_load_state_dict_hook_resumes_from_checkpoint_with_bss_without_key(
365
409
previously checkpointed job didn't use batch_size_stages, but a subsequent job,
366
410
restored from a checkpoint, is using them.
367
411
"""
368
- throughput_metric = ThroughputMetric (
412
+ prev_job_throughput_metric = ThroughputMetric (
413
+ batch_size = 32 ,
414
+ world_size = 4 ,
415
+ window_seconds = 100 ,
416
+ batch_size_stages = None , # No batch_size_stages
417
+ )
418
+ prev_state_dict = prev_job_throughput_metric .state_dict ()
419
+
420
+ curr_job_throughput_metric = ThroughputMetric (
369
421
batch_size = 32 ,
370
422
world_size = 4 ,
371
423
window_seconds = 100 ,
372
424
batch_size_stages = [BatchSizeStage (256 , 1 ), BatchSizeStage (512 , None )],
373
425
)
374
- # Empty state_dict
375
- state_dict : OrderedDict [str , torch .Tensor ] = OrderedDict ()
376
- prefix : str = "test_prefix_"
377
- throughput_metric .load_state_dict_hook (state_dict , prefix , {}, True , [], [], [])
426
+
427
+ curr_job_throughput_metric .load_state_dict (prev_state_dict )
428
+
378
429
# Expecting 0
379
- self .assertEqual (throughput_metric ._num_batch , 0 )
430
+ self .assertEqual (curr_job_throughput_metric ._num_batch , 0 )
0 commit comments