4141from dataclasses import dataclass
4242from typing import Any , Dict , Optional
4343
44+ import torch .distributed as dist
4445from loguru import logger
4546
46- __all__ = ["LoggerConfig" , "configure_logger" , "logger" ]
47+ __all__ = ["LoggerConfig" , "configure_logger" , "logger" , "configure_distributed_logger" ]
4748
4849
4950# used by `support_log_once``
5354@dataclass
5455class LoggerConfig :
5556 disabled : bool = False
56- clear_loggers : bool = True
5757 console_log_level : Optional [str ] = "INFO"
5858 log_file : Optional [str ] = None
5959 log_file_level : Optional [str ] = None
6060 metrics_disabled : bool = False
61+ rank : Optional [int ] = None
6162
6263
63- def configure_logger (config : Optional [LoggerConfig ] = None ) -> None :
64+ # global config
65+ LOGGER_CONFIG = LoggerConfig ()
66+
67+
68+ def configure_logger (
69+ logger_config : LoggerConfig = LOGGER_CONFIG , clear_loggers : bool = False
70+ ):
6471 """
6572 Configure the logger for LLM Compressor.
6673
@@ -72,13 +79,10 @@ def configure_logger(config: Optional[LoggerConfig] = None) -> None:
7279 :param config: The configuration for the logger to use.
7380 :type config: LoggerConfig
7481 """
75- logger_config = config or LoggerConfig ()
7682
7783 # env vars get priority
7884 if (disabled := os .getenv ("LLM_COMPRESSOR_LOG_DISABLED" )) is not None :
7985 logger_config .disabled = disabled .lower () == "true"
80- if (clear_loggers := os .getenv ("LLM_COMPRESSOR_CLEAR_LOGGERS" )) is not None :
81- logger_config .clear_loggers = clear_loggers .lower () == "true"
8286 if (console_log_level := os .getenv ("LLM_COMPRESSOR_LOG_LEVEL" )) is not None :
8387 logger_config .console_log_level = console_log_level .upper ()
8488 if (log_file := os .getenv ("LLM_COMPRESSOR_LOG_FILE" )) is not None :
@@ -92,15 +96,22 @@ def configure_logger(config: Optional[LoggerConfig] = None) -> None:
9296
9397 logger .enable ("llmcompressor" )
9498
95- if logger_config .clear_loggers :
99+ # reset logger configuration
100+ if clear_loggers :
96101 logger .remove ()
97102
103+ # set format (optionally adding rank)
104+ format = "{time:YYYY-MM-DDTHH:mm:ss.SSSS} | {function} | {level} - {message}"
105+ if logger_config .rank is not None :
106+ logger .configure (extra = {"rank" : dist .get_rank ()})
107+ format = "[Rank {extra[rank]}] " + format
108+
98109 if logger_config .console_log_level :
99110 # log as a human readable string with the time, function, level, and message
100111 logger .add (
101112 sys .stdout ,
102113 level = logger_config .console_log_level .upper (),
103- format = "{time} | {function} | {level} - {message}" ,
114+ format = format ,
104115 filter = support_log_once ,
105116 )
106117
@@ -112,6 +123,7 @@ def configure_logger(config: Optional[LoggerConfig] = None) -> None:
112123 log_file ,
113124 level = log_file_level .upper (),
114125 serialize = True ,
126+ format = format ,
115127 filter = support_log_once ,
116128 )
117129
@@ -121,6 +133,10 @@ def configure_logger(config: Optional[LoggerConfig] = None) -> None:
121133 # initialize metric logger on loguru
122134 logger .level ("METRIC" , no = 38 , color = "<yellow>" , icon = "📈" )
123135
136+ # set global value for later calls
137+ global LOGGER_CONFIG
138+ LOGGER_CONFIG = logger_config
139+
124140
125141def support_log_once (record : Dict [str , Any ]) -> bool :
126142 """
@@ -146,14 +162,11 @@ def support_log_once(record: Dict[str, Any]) -> bool:
146162 return True
147163
148164
165+ def configure_distributed_logger (logger_config : LoggerConfig = LOGGER_CONFIG ):
166+ logger_config .rank = dist .get_rank ()
167+ configure_logger (logger_config , clear_loggers = True )
168+
169+
149170# invoke logger setup on import with default values enabling console logging with INFO
150171# and disabling file logging
151- configure_logger (
152- config = LoggerConfig (
153- disabled = False ,
154- clear_loggers = True ,
155- console_log_level = "INFO" ,
156- log_file = None ,
157- log_file_level = None ,
158- )
159- )
172+ configure_logger ()
0 commit comments