@@ -46,7 +46,7 @@ def check_num_checkpoints(self):
4646 """
4747
4848 checkpoint_files = os .listdir (self ._checkpoint_dir )
49- return len (checkpoint_files ) == len ( self . _profile_models )
49+ return len (checkpoint_files ) == 1
5050
5151 def check_loading_checkpoints (self ):
5252 """
@@ -74,7 +74,7 @@ def check_interrupt_handling(self):
7474 """
7575
7676 checkpoint_files = os .listdir (self ._checkpoint_dir )
77- if len (checkpoint_files ) != 2 :
77+ if len (checkpoint_files ) != 1 :
7878 return False
7979
8080 with open (self ._analyzer_log , 'r' ) as f :
@@ -85,8 +85,8 @@ def check_interrupt_handling(self):
8585 if log_contents .find (token ) == - 1 :
8686 return False
8787
88- # check that 2nd model is profiled once
89- token = f"Profiling { self ._profile_models [1 ]} "
88+ # check that 1st model is profiled twice
89+ token = f"Profiling { self ._profile_models [0 ]} "
9090 token_idx = 0
9191 found_count = 0
9292 while True :
@@ -95,7 +95,7 @@ def check_interrupt_handling(self):
9595 break
9696 found_count += 1
9797
98- return found_count == 1
98+ return found_count == 2
9999
100100 def check_early_exit (self ):
101101 """
@@ -117,7 +117,7 @@ def check_early_exit(self):
117117 return True
118118
119119 def check_continue_after_checkpoint (self ,
120- expected_resnet_count = 3 ,
120+ expected_resnet_count = 4 ,
121121 expected_vgg_count = 2 ):
122122 """
123123 Check that the 2nd model onwards have been run the correct
@@ -136,8 +136,6 @@ def check_continue_after_checkpoint(self,
136136
137137 # resnet50 libtorch normally has 4 runs:
138138 # ([2 models, one of which is default] x [2 concurrencies])
139- # but 1 was checkpointed from the previous interrupted run, so it
140- # will do the remaining 3
141139 #
142140 # vgg19 will have 2 runs:
143141 # ([2 models, one of which is default] x [1 concurrency])
0 commit comments