Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 0 additions & 100 deletions src/llmcompressor/core/helpers.py

This file was deleted.

50 changes: 0 additions & 50 deletions src/llmcompressor/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
from loguru import logger

from llmcompressor.core.events import EventType
from llmcompressor.core.helpers import log_model_info, should_log_model_info
from llmcompressor.core.lifecycle import CompressionLifecycle
from llmcompressor.core.state import ModifiedState, State
from llmcompressor.metrics import BaseLogger, LoggerManager
from llmcompressor.recipe import Recipe

__all__ = [
Expand Down Expand Up @@ -90,7 +88,6 @@ def initialize(
start: float | None = None,
steps_per_epoch: int | None = None,
batches_per_step: int | None = None,
loggers: LoggerManager | list[BaseLogger] | None = None,
**kwargs,
) -> ModifiedState:
"""
Expand Down Expand Up @@ -118,8 +115,6 @@ def initialize(
compression
:param batches_per_step: the number of batches per step to use for
compression
:param loggers: the metrics manager to setup logging important info
and milestones to, also accepts a list of BaseLogger(s)
:param kwargs: additional kwargs to pass to the lifecycle's initialize method
:return: the modified state of the session after initializing
"""
Expand All @@ -139,7 +134,6 @@ def initialize(
start=start,
steps_per_epoch=steps_per_epoch,
batches_per_step=batches_per_step,
loggers=loggers,
**kwargs,
)

Expand Down Expand Up @@ -194,16 +188,6 @@ def event(
modifier_data=mod_data,
)

def log(self, event_type: EventType, loss: Any | None = None):
"""
Log model and loss information for the current event type

:param event_type: the event type to log for
:param loss: the loss to log if any
"""
self._log_model_info()
self._log_loss(event_type=event_type, loss=loss)

def reset(self):
"""
Reset the session to its initial state
Expand All @@ -227,37 +211,3 @@ def get_serialized_recipe(self) -> str | None:
return recipe.yaml()

logger.warning("Recipe not found in session - it may have been reset")

def _log_model_info(self):
# Log model level logs if cadence reached
current_index = self._lifecycle.global_step

if (
should_log_model_info(
model=self.state.model,
loggers=self.state.loggers,
current_log_step=current_index,
last_log_step=self.state._last_log_step,
)
and self.state.loggers.frequency_manager.is_epoch_frequency_manager
):
log_model_info(
state=self.state,
current_log_step=current_index,
)
# update last log epoch
self.state.loggers.log_written(current_index)

def _log_loss(self, event_type: EventType, loss: Any):
if event_type != EventType.LOSS_CALCULATED:
# only log loss when loss is calculated
return

current_index = self._lifecycle.global_step

# always log loss if available
if loss is not None:
loss = loss if isinstance(loss, dict) else {"loss": loss}
self.state.loggers.metric.log_scalars(
tag="Loss", values=loss, step=current_index
)
6 changes: 3 additions & 3 deletions src/llmcompressor/core/session_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Generator, Optional

from loguru import logger

from llmcompressor.core.events import EventType
from llmcompressor.core.session import CompressionSession
from llmcompressor.core.state import ModifiedState
Expand Down Expand Up @@ -108,8 +110,7 @@ def loss_calculated(cls, loss: Optional[Any] = None, **kwargs) -> ModifiedState:
:param kwargs: additional kwargs to pass to the current session's event method
:return: the modified state of the active session after invoking the event
"""
# log loss if loss calculated
active_session()._log_loss(event_type=EventType.LOSS_CALCULATED, loss=loss)
logger.debug(f"Calculated loss: {loss}")
return cls.event(EventType.LOSS_CALCULATED, loss=loss, **kwargs)

@classmethod
Expand Down Expand Up @@ -140,7 +141,6 @@ def batch_end(cls, **kwargs) -> ModifiedState:
:param kwargs: additional kwargs to pass to the current session's event method
:return: the modified state of the active session after invoking the event
"""
active_session()._log_model_info()
return cls.event(EventType.BATCH_END, **kwargs)

@classmethod
Expand Down
28 changes: 0 additions & 28 deletions src/llmcompressor/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import torch
from loguru import logger

from llmcompressor.metrics import BaseLogger, LoggerManager

__all__ = ["State", "Data", "Hardware", "ModifiedState"]


Expand Down Expand Up @@ -94,11 +92,6 @@ class State:
:type data: Data
:param hardware: Hardware instance holding info about the target hardware being used
:type hardware: Hardware
:param loggers: LoggerManager instance holding all the loggers to log
:type loggers: Optional[LoggerManager]
:param model_log_cadence: The cadence to log model information w.r.t epochs.
If 1, logs every epoch. If 2, logs every other epoch, etc. Default is 1.
:type model_log_cadence: Optional[float]
"""

model: Any = None
Expand All @@ -109,9 +102,6 @@ class State:
batch_data: Any = None
data: Data = field(default_factory=Data)
hardware: Hardware = field(default_factory=Hardware)
loggers: LoggerManager | None = None
model_log_cadence: float | None = None
_last_log_step: float | int | None = None
loss_masks: list[torch.Tensor] | None = None
current_batch_idx: int = -1

Expand Down Expand Up @@ -141,8 +131,6 @@ def update(
start: float = None,
steps_per_epoch: int = None,
batches_per_step: int = None,
loggers: LoggerManager | list[BaseLogger] | None = None,
model_log_cadence: float | None = None,
**kwargs,
) -> dict:
"""
Expand Down Expand Up @@ -172,12 +160,6 @@ def update(
:type steps_per_epoch: int
:param batches_per_step: The batches per step to update the state with
:type batches_per_step: int
:param loggers: The metrics manager to setup logging important info and
milestones to, also accepts a list of BaseLogger(s)
:type loggers: Union[None, LoggerManager, List[BaseLogger]]
:param model_log_cadence: The cadence to log model information w.r.t epochs.
If 1, logs every epoch. If 2, logs every other epoch, etc. Default is 1.
:type model_log_cadence: Optional[float]
:param kwargs: Additional keyword arguments to update the state with
:return: The updated state as a dictionary
:rtype: Dict
Expand All @@ -197,8 +179,6 @@ def update(
"start": start,
"steps_per_epoch": steps_per_epoch,
"batches_per_step": batches_per_step,
"loggers": loggers,
"model_log_cadence": model_log_cadence,
"kwargs": kwargs,
},
)
Expand All @@ -222,14 +202,6 @@ def update(
if "device" in kwargs:
self.hardware.device = kwargs["device"]

loggers = loggers or []
if isinstance(loggers, list):
loggers = LoggerManager(loggers)
self.loggers = loggers

if model_log_cadence is not None:
self.model_log_cadence = model_log_cadence

return kwargs


Expand Down
12 changes: 0 additions & 12 deletions src/llmcompressor/metrics/__init__.py

This file was deleted.

Loading