Skip to content

Commit abaca04

Browse files
committed
distributed logging
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 234cd91 commit abaca04

File tree

5 files changed

+56
-32
lines changed

5 files changed

+56
-32
lines changed

src/llmcompressor/entrypoints/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import os
1111
from pathlib import PosixPath
1212

13-
from compressed_tensors.offload import from_accelerate
13+
from compressed_tensors.offload import from_accelerate, is_distributed
1414
from loguru import logger
1515
from transformers import (
1616
AutoConfig,
@@ -26,6 +26,7 @@
2626
RecipeArguments,
2727
)
2828
from llmcompressor.core import reset_session
29+
from llmcompressor.logger import configure_distributed_logger
2930
from llmcompressor.pytorch.model_load.helpers import parse_dtype
3031
from llmcompressor.transformers.compression.compressed_tensors_utils import (
3132
modify_save_pretrained,
@@ -52,6 +53,9 @@ def pre_process(
5253
Raises:
5354
FileNotFoundError: If the model or processor path is invalid.
5455
"""
56+
# Detect distributed, update logger
57+
if is_distributed():
58+
configure_distributed_logger()
5559

5660
# Initialize model
5761
if isinstance(model_args.model, (str, PosixPath)):

src/llmcompressor/logger.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@
4141
from dataclasses import dataclass
4242
from typing import Any, Dict, Optional
4343

44+
import torch.distributed as dist
4445
from loguru import logger
4546

46-
__all__ = ["LoggerConfig", "configure_logger", "logger"]
47+
__all__ = ["LoggerConfig", "configure_logger", "logger", "configure_distributed_logger"]
4748

4849

4950
# used by `support_log_once``
@@ -53,14 +54,20 @@
5354
@dataclass
5455
class LoggerConfig:
5556
disabled: bool = False
56-
clear_loggers: bool = True
5757
console_log_level: Optional[str] = "INFO"
5858
log_file: Optional[str] = None
5959
log_file_level: Optional[str] = None
6060
metrics_disabled: bool = False
61+
rank: Optional[int] = None
6162

6263

63-
def configure_logger(config: Optional[LoggerConfig] = None) -> None:
64+
# global config
65+
LOGGER_CONFIG = LoggerConfig()
66+
67+
68+
def configure_logger(
69+
logger_config: LoggerConfig = LOGGER_CONFIG, clear_loggers: bool = False
70+
):
6471
"""
6572
Configure the logger for LLM Compressor.
6673
@@ -72,13 +79,10 @@ def configure_logger(config: Optional[LoggerConfig] = None) -> None:
7279
:param config: The configuration for the logger to use.
7380
:type config: LoggerConfig
7481
"""
75-
logger_config = config or LoggerConfig()
7682

7783
# env vars get priority
7884
if (disabled := os.getenv("LLM_COMPRESSOR_LOG_DISABLED")) is not None:
7985
logger_config.disabled = disabled.lower() == "true"
80-
if (clear_loggers := os.getenv("LLM_COMPRESSOR_CLEAR_LOGGERS")) is not None:
81-
logger_config.clear_loggers = clear_loggers.lower() == "true"
8286
if (console_log_level := os.getenv("LLM_COMPRESSOR_LOG_LEVEL")) is not None:
8387
logger_config.console_log_level = console_log_level.upper()
8488
if (log_file := os.getenv("LLM_COMPRESSOR_LOG_FILE")) is not None:
@@ -92,15 +96,22 @@ def configure_logger(config: Optional[LoggerConfig] = None) -> None:
9296

9397
logger.enable("llmcompressor")
9498

95-
if logger_config.clear_loggers:
99+
# reset logger configuration
100+
if clear_loggers:
96101
logger.remove()
97102

103+
# set format (optionally adding rank)
104+
format = "{time:YYYY-MM-DDTHH:mm:ss.SSSS} | {function} | {level} - {message}"
105+
if logger_config.rank is not None:
106+
logger.configure(extra={"rank": dist.get_rank()})
107+
format = "[Rank {extra[rank]}] " + format
108+
98109
if logger_config.console_log_level:
99110
# log as a human readable string with the time, function, level, and message
100111
logger.add(
101112
sys.stdout,
102113
level=logger_config.console_log_level.upper(),
103-
format="{time} | {function} | {level} - {message}",
114+
format=format,
104115
filter=support_log_once,
105116
)
106117

@@ -112,6 +123,7 @@ def configure_logger(config: Optional[LoggerConfig] = None) -> None:
112123
log_file,
113124
level=log_file_level.upper(),
114125
serialize=True,
126+
format=format,
115127
filter=support_log_once,
116128
)
117129

@@ -121,6 +133,10 @@ def configure_logger(config: Optional[LoggerConfig] = None) -> None:
121133
# initialize metric logger on loguru
122134
logger.level("METRIC", no=38, color="<yellow>", icon="📈")
123135

136+
# set global value for later calls
137+
global LOGGER_CONFIG
138+
LOGGER_CONFIG = logger_config
139+
124140

125141
def support_log_once(record: Dict[str, Any]) -> bool:
126142
"""
@@ -146,14 +162,11 @@ def support_log_once(record: Dict[str, Any]) -> bool:
146162
return True
147163

148164

165+
def configure_distributed_logger(logger_config: LoggerConfig = LOGGER_CONFIG):
166+
logger_config.rank = dist.get_rank()
167+
configure_logger(logger_config, clear_loggers=True)
168+
169+
149170
# invoke logger setup on import with default values enabling console logging with INFO
150171
# and disabling file logging
151-
configure_logger(
152-
config=LoggerConfig(
153-
disabled=False,
154-
clear_loggers=True,
155-
console_log_level="INFO",
156-
log_file=None,
157-
log_file_level=None,
158-
)
159-
)
172+
configure_logger()

src/llmcompressor/modifiers/pruning/sparsegpt/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def compress_modules(self):
141141
dampening_frac=self.dampening_frac,
142142
preserve_sparsity_mask=self.preserve_sparsity_mask,
143143
)
144-
comp_logger.set_loss(loss)
144+
comp_logger.set_results(name="SGPT", loss=loss)
145145

146146
update_offload_parameter(module, "weight", sparsified_weight)
147147

src/llmcompressor/modifiers/quantization/gptq/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def compress_module_list(self, module_list):
302302
num_samples = self._num_samples[module]
303303
quant_args = getattr_chain(module, "quantization_scheme.weights")
304304

305-
logger.info(f"Quantizing {name} using {num_samples} samples")
305+
logger.info(f"Quantizing {name} using {int(num_samples)} samples")
306306
with (
307307
torch.no_grad(),
308308
align_module_device(module),
@@ -316,7 +316,7 @@ def compress_module_list(self, module_list):
316316
blocksize=self.block_size,
317317
percdamp=self.dampening_frac,
318318
)
319-
comp_logger.set_loss(loss)
319+
comp_logger.set_results(name="GPTQ", loss=loss)
320320

321321
for attr, val in q_param_dict.items():
322322
update_offload_parameter(module, attr, val)

src/llmcompressor/utils/metric_logging.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
import time
9-
from typing import Iterable
9+
from typing import Iterable, Optional
1010

1111
import torch
1212
from compressed_tensors.offload import is_distributed
@@ -28,27 +28,34 @@ class CompressionLogger:
2828
def __init__(self, module: torch.nn.Module):
2929
self.module = module
3030
self.start_tick = None
31-
self.loss = None
3231

33-
def set_loss(self, loss: float):
34-
self.loss = loss
32+
self._name = None
33+
self._loss = None
34+
35+
def set_results(
36+
self,
37+
name: Optional[str] = None,
38+
loss: Optional[float] = None,
39+
):
40+
self._name = name
41+
self._loss = loss
3542

3643
def __enter__(self) -> "CompressionLogger":
3744
self.start_tick = time.time()
3845
return self
3946

4047
def __exit__(self, _exc_type, _exc_val, _exc_tb):
4148
stop_tick = time.time()
42-
patch = logger.patch(lambda r: r.update(function="compress"))
4349

44-
if self.start_tick is not None:
45-
patch.log("METRIC", f"time {(stop_tick - self.start_tick):.2f}s")
46-
if self.loss is not None:
47-
patch.log("METRIC", f"error {self.loss:.2f}")
50+
patch = logger.patch(lambda r: r.update(function=(self._name or "compress")))
51+
52+
patch.log("METRIC", f"time {(stop_tick - self.start_tick):.2f}s")
53+
if self._loss is not None:
54+
patch.log("METRIC", f"error {self._loss:.2f}")
4855

4956
for device_id in _get_visible_devices():
50-
max_memory = torch.cuda.max_memory_allocated(device_id)
51-
used_memory = torch.cuda.get_device_properties(device_id).total_memory
57+
used_memory = torch.cuda.max_memory_allocated(device_id)
58+
max_memory = torch.cuda.get_device_properties(device_id).total_memory
5259
perc_used = 100 * used_memory / max_memory
5360
patch.log(
5461
"METRIC",

0 commit comments

Comments
 (0)