@@ -143,25 +143,6 @@ def initialize_session(
143143
144144 train_data = self .get_train_dataloader ()
145145
146- # calculate total_steps_per_epoch
147- # n_gpu handled internally by dataloader
148- total_batch_size = (
149- self .args .per_device_train_batch_size
150- * self .args .gradient_accumulation_steps
151- )
152- if isinstance (self .train_dataset , IterableDataset ):
153- logger .warning (
154- "Training is being run with a streamed dataset, "
155- "steps_per_epoch cannot be determined and will default to "
156- "1. LLM Compressor modifiers utilizing this statistic may not "
157- "behave as expected. "
158- )
159- self .total_steps_per_epoch = 1
160- else :
161- self .total_steps_per_epoch = math .ceil (
162- len (self .train_dataset ) / total_batch_size
163- )
164-
165146 self .accelerator .wait_for_everyone ()
166147 with summon_full_params_context (self .model , offload_to_cpu = True ):
167148 active_session ().initialize (
@@ -175,7 +156,6 @@ def initialize_session(
175156 copy_data = False ,
176157 attach_optim_callbacks = True ,
177158 fsdp_active = self .is_fsdp_enabled ,
178- steps_per_epoch = self .total_steps_per_epoch ,
179159 metadata = self .metadata ,
180160 )
181161
@@ -219,6 +199,29 @@ def create_optimizer(self):
219199 self ._check_super_defined ("create_optimizer" )
220200 super ().create_optimizer ()
221201
202+ # n_gpu handled internally by dataloader
203+ total_batch_size = (
204+ self .args .per_device_train_batch_size
205+ * self .args .gradient_accumulation_steps
206+ )
207+
208+ if isinstance (self .train_dataset , IterableDataset ):
209+ logger .warning (
210+ "Training is being run with a streamed dataset, "
211+ "steps_per_epoch cannot be determined and will default to "
212+ "1. LLM Compressor modifiers utilizing this statistic may not "
213+ "behave as expected. "
214+ )
215+ self .total_steps_per_epoch = 1
216+ else :
217+ self .total_steps_per_epoch = math .ceil (
218+ len (self .train_dataset ) / total_batch_size
219+ )
220+
221+ active_session ().initialize (
222+ optimizer = self .optimizer , steps_per_epoch = self .total_steps_per_epoch
223+ )
224+
222225 return self .optimizer
223226
224227 def create_scheduler (
@@ -255,7 +258,7 @@ def training_step(
255258 """
256259 self ._check_super_defined ("training_step" )
257260
258- callbacks .batch_start (batch_data = inputs )
261+ callbacks .batch_start (batch_data = inputs , global_step = self . state . epoch )
259262 model_outputs = super ().training_step (
260263 model = model , inputs = inputs , num_items_in_batch = num_items_in_batch
261264 )
0 commit comments