Skip to content
9 changes: 6 additions & 3 deletions src/llmcompressor/core/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,13 @@ def initialize(
:return: List of data returned from initialization of modifiers
:rtype: List[Any]
"""
self.state.update(**kwargs)
if self.initialized_: # TODO: do not initialize twice
return
if self.initialized_:
raise ValueError(
"Initialize was called twice. To update state values after "
"initialization, please use `active_session().state.update()`"
)

self.state.update(**kwargs)
logger.debug("Initializing compression lifecycle")
self.recipe_container.append(recipe, recipe_stage, recipe_args)
self.modifiers = self.recipe_container.get_modifiers()
Expand Down
4 changes: 3 additions & 1 deletion src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ def create_optimizer(self):
len(self.train_dataset) / total_batch_size
)

initialize(optimizer=self.optimizer, steps_per_epoch=self.total_steps_per_epoch)
active_session().state.update(
optimizer=self.optimizer, steps_per_epoch=self.total_steps_per_epoch
)

return self.optimizer

Expand Down