Skip to content

Commit 2b0684c

Browse files
omkar-334dhuangnmdsikkabrian-dellabetta
authored
Remove training loggers and all related code (#2414)
SUMMARY: Fixes #2409. cc @kylesayrs This PR removed training loggers and all related code. Replaces their functionality with `loguru`. It also removes other helper functions and `FrequencyManager` as well. TEST PLAN: most tests are passing, but getting stuck at gptq test --------- Signed-off-by: Dan Huang <dahuang@redhat.com> Signed-off-by: Omkar Kabde <omkarkabde@gmail.com> Co-authored-by: dhuangnm <74931910+dhuangnm@users.noreply.github.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com> Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
1 parent 7951987 commit 2b0684c

File tree

13 files changed

+3
-2094
lines changed

13 files changed

+3
-2094
lines changed

src/llmcompressor/core/helpers.py

Lines changed: 0 additions & 100 deletions
This file was deleted.

src/llmcompressor/core/session.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
from loguru import logger
1313

1414
from llmcompressor.core.events import EventType
15-
from llmcompressor.core.helpers import log_model_info, should_log_model_info
1615
from llmcompressor.core.lifecycle import CompressionLifecycle
1716
from llmcompressor.core.state import ModifiedState, State
18-
from llmcompressor.metrics import BaseLogger, LoggerManager
1917
from llmcompressor.recipe import Recipe
2018

2119
__all__ = [
@@ -90,7 +88,6 @@ def initialize(
9088
start: float | None = None,
9189
steps_per_epoch: int | None = None,
9290
batches_per_step: int | None = None,
93-
loggers: LoggerManager | list[BaseLogger] | None = None,
9491
**kwargs,
9592
) -> ModifiedState:
9693
"""
@@ -118,8 +115,6 @@ def initialize(
118115
compression
119116
:param batches_per_step: the number of batches per step to use for
120117
compression
121-
:param loggers: the metrics manager to setup logging important info
122-
and milestones to, also accepts a list of BaseLogger(s)
123118
:param kwargs: additional kwargs to pass to the lifecycle's initialize method
124119
:return: the modified state of the session after initializing
125120
"""
@@ -139,7 +134,6 @@ def initialize(
139134
start=start,
140135
steps_per_epoch=steps_per_epoch,
141136
batches_per_step=batches_per_step,
142-
loggers=loggers,
143137
**kwargs,
144138
)
145139

@@ -194,16 +188,6 @@ def event(
194188
modifier_data=mod_data,
195189
)
196190

197-
def log(self, event_type: EventType, loss: Any | None = None):
198-
"""
199-
Log model and loss information for the current event type
200-
201-
:param event_type: the event type to log for
202-
:param loss: the loss to log if any
203-
"""
204-
self._log_model_info()
205-
self._log_loss(event_type=event_type, loss=loss)
206-
207191
def reset(self):
208192
"""
209193
Reset the session to its initial state
@@ -227,37 +211,3 @@ def get_serialized_recipe(self) -> str | None:
227211
return recipe.yaml()
228212

229213
logger.warning("Recipe not found in session - it may have been reset")
230-
231-
def _log_model_info(self):
232-
# Log model level logs if cadence reached
233-
current_index = self._lifecycle.global_step
234-
235-
if (
236-
should_log_model_info(
237-
model=self.state.model,
238-
loggers=self.state.loggers,
239-
current_log_step=current_index,
240-
last_log_step=self.state._last_log_step,
241-
)
242-
and self.state.loggers.frequency_manager.is_epoch_frequency_manager
243-
):
244-
log_model_info(
245-
state=self.state,
246-
current_log_step=current_index,
247-
)
248-
# update last log epoch
249-
self.state.loggers.log_written(current_index)
250-
251-
def _log_loss(self, event_type: EventType, loss: Any):
252-
if event_type != EventType.LOSS_CALCULATED:
253-
# only log loss when loss is calculated
254-
return
255-
256-
current_index = self._lifecycle.global_step
257-
258-
# always log loss if available
259-
if loss is not None:
260-
loss = loss if isinstance(loss, dict) else {"loss": loss}
261-
self.state.loggers.metric.log_scalars(
262-
tag="Loss", values=loss, step=current_index
263-
)

src/llmcompressor/core/session_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from contextlib import contextmanager
1010
from typing import TYPE_CHECKING, Any, Generator, Optional
1111

12+
from loguru import logger
13+
1214
from llmcompressor.core.events import EventType
1315
from llmcompressor.core.session import CompressionSession
1416
from llmcompressor.core.state import ModifiedState
@@ -108,8 +110,7 @@ def loss_calculated(cls, loss: Optional[Any] = None, **kwargs) -> ModifiedState:
108110
:param kwargs: additional kwargs to pass to the current session's event method
109111
:return: the modified state of the active session after invoking the event
110112
"""
111-
# log loss if loss calculated
112-
active_session()._log_loss(event_type=EventType.LOSS_CALCULATED, loss=loss)
113+
logger.debug(f"Calculated loss: {loss}")
113114
return cls.event(EventType.LOSS_CALCULATED, loss=loss, **kwargs)
114115

115116
@classmethod
@@ -140,7 +141,6 @@ def batch_end(cls, **kwargs) -> ModifiedState:
140141
:param kwargs: additional kwargs to pass to the current session's event method
141142
:return: the modified state of the active session after invoking the event
142143
"""
143-
active_session()._log_model_info()
144144
return cls.event(EventType.BATCH_END, **kwargs)
145145

146146
@classmethod

src/llmcompressor/core/state.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
import torch
1313
from loguru import logger
1414

15-
from llmcompressor.metrics import BaseLogger, LoggerManager
16-
1715
__all__ = ["State", "Data", "Hardware", "ModifiedState"]
1816

1917

@@ -94,11 +92,6 @@ class State:
9492
:type data: Data
9593
:param hardware: Hardware instance holding info about the target hardware being used
9694
:type hardware: Hardware
97-
:param loggers: LoggerManager instance holding all the loggers to log
98-
:type loggers: Optional[LoggerManager]
99-
:param model_log_cadence: The cadence to log model information w.r.t epochs.
100-
If 1, logs every epoch. If 2, logs every other epoch, etc. Default is 1.
101-
:type model_log_cadence: Optional[float]
10295
"""
10396

10497
model: Any = None
@@ -109,9 +102,6 @@ class State:
109102
batch_data: Any = None
110103
data: Data = field(default_factory=Data)
111104
hardware: Hardware = field(default_factory=Hardware)
112-
loggers: LoggerManager | None = None
113-
model_log_cadence: float | None = None
114-
_last_log_step: float | int | None = None
115105
loss_masks: list[torch.Tensor] | None = None
116106
current_batch_idx: int = -1
117107

@@ -141,8 +131,6 @@ def update(
141131
start: float = None,
142132
steps_per_epoch: int = None,
143133
batches_per_step: int = None,
144-
loggers: LoggerManager | list[BaseLogger] | None = None,
145-
model_log_cadence: float | None = None,
146134
**kwargs,
147135
) -> dict:
148136
"""
@@ -172,12 +160,6 @@ def update(
172160
:type steps_per_epoch: int
173161
:param batches_per_step: The batches per step to update the state with
174162
:type batches_per_step: int
175-
:param loggers: The metrics manager to setup logging important info and
176-
milestones to, also accepts a list of BaseLogger(s)
177-
:type loggers: Union[None, LoggerManager, List[BaseLogger]]
178-
:param model_log_cadence: The cadence to log model information w.r.t epochs.
179-
If 1, logs every epoch. If 2, logs every other epoch, etc. Default is 1.
180-
:type model_log_cadence: Optional[float]
181163
:param kwargs: Additional keyword arguments to update the state with
182164
:return: The updated state as a dictionary
183165
:rtype: Dict
@@ -197,8 +179,6 @@ def update(
197179
"start": start,
198180
"steps_per_epoch": steps_per_epoch,
199181
"batches_per_step": batches_per_step,
200-
"loggers": loggers,
201-
"model_log_cadence": model_log_cadence,
202182
"kwargs": kwargs,
203183
},
204184
)
@@ -222,14 +202,6 @@ def update(
222202
if "device" in kwargs:
223203
self.hardware.device = kwargs["device"]
224204

225-
loggers = loggers or []
226-
if isinstance(loggers, list):
227-
loggers = LoggerManager(loggers)
228-
self.loggers = loggers
229-
230-
if model_log_cadence is not None:
231-
self.model_log_cadence = model_log_cadence
232-
233205
return kwargs
234206

235207

src/llmcompressor/metrics/__init__.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

0 commit comments

Comments
 (0)