Skip to content

Commit 8afcc87

Browse files
committed
Add old stack logging support to new stack
Added the following support for easy visualization of training and validation statistics: 1. train_logger callback function which captures the per epoch time, per epoch loss metric and per epoch perplexity 2. This function also captures number of trainable parameters, number of samples in training and eval dataset 3. All these are logged into a log file which can be given as an input by user by setting the flag --log_file_path in the input config .yaml file. Signed-off-by: Anusha Bhamidipati <abhamidi@qti.qualcomm.com>
1 parent 492838e commit 8afcc87

File tree

3 files changed

+155
-4
lines changed

3 files changed

+155
-4
lines changed

QEfficient/cloud/finetune_experimental.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
from pathlib import Path
1515
from typing import Any, Dict, List, Tuple
1616

17-
from QEfficient.finetune.experimental.core.callbacks import replace_progress_callback
17+
from peft import get_peft_model
18+
19+
from QEfficient.finetune.experimental.core.callbacks import TrainingLogger, replace_progress_callback
1820
from QEfficient.finetune.experimental.core.component_registry import ComponentFactory
19-
from QEfficient.finetune.experimental.core.config_manager import (
20-
ConfigManager,
21-
)
21+
from QEfficient.finetune.experimental.core.config_manager import ConfigManager
2222
from QEfficient.finetune.experimental.core.dataset import SFTDataset # noqa: F401
2323
from QEfficient.finetune.experimental.core.logger import Logger
2424
from QEfficient.finetune.experimental.core.model import HFModel # noqa: F401
@@ -29,6 +29,7 @@
2929
from QEfficient.finetune.experimental.core.utils.training_config_utils import prepare_training_config
3030

3131
logger = Logger(__name__)
32+
train_logger = TrainingLogger()
3233

3334
# Try importing QAIC-specific module, proceed without it if it's unavailable
3435
try:
@@ -98,6 +99,7 @@ def get_trainer(self):
9899

99100
def _setup_environment(self) -> None:
100101
"""Set up environment variables for output directories."""
102+
self.rank = int(os.environ.get("RANK", "0"))
101103
os.environ["OUTPUT_DIR"] = str(self.output_dir)
102104
os.environ["TRACKIO_DIR"] = str(self.output_dir / "trackio_logs")
103105
os.environ["TENSORBOARD_LOGGING_DIR"] = str(self.output_dir)
@@ -245,10 +247,21 @@ def _create_trainer(
245247
dependencies = {}
246248
if peft_config is not None:
247249
dependencies["peft_config"] = peft_config
250+
if self.rank == 0:
251+
model_configuration = get_peft_model(model, peft_config)
252+
trainable_params, all_param = model_configuration.get_nb_trainable_parameters()
253+
pct = (trainable_params / all_param) * 100
254+
model_configuration.unload() # Removing the peft adapters
255+
train_logger.write(f"TRAINING INFO: Model has {all_param / 1e6:.4f} Million params.")
256+
train_logger.write(
257+
f"TRAINING INFO: Trainable params: {trainable_params} || "
258+
f"all params: {all_param} || trainable%: {pct:.4f}"
259+
)
248260
trainer_cls, args_cls, additional_kwargs = ComponentFactory.create_trainer_config(trainer_type, **dependencies)
249261

250262
# Clean up training config: remove fields that shouldn't be passed to TrainingArguments
251263
training_config.pop("device", None)
264+
training_config.pop("log_file_name", None)
252265
# Note: torch_dtype was already converted to fp16/bf16 flag in prepare_training_config
253266
training_config.pop("deepspeed_config", None)
254267
training_config.pop("torch_dtype", None)
@@ -271,6 +284,11 @@ def _create_trainer(
271284
subset_eval_indices = list(range(0, int(num_samples - num_samples * split_ratio)))
272285
eval_dataset = eval_dataset.select(subset_eval_indices)
273286
train_dataset = train_dataset.select(subset_train_indices)
287+
# Logging the number of training and evaluation samples
288+
if self.rank == 0:
289+
train_logger.write(f"TRAINING INFO: Length of Training Dataset is {len(train_dataset)}")
290+
train_logger.write(f"TRAINING INFO: Length of Evaluation Dataset is {len(eval_dataset)}")
291+
274292
trainer = trainer_cls(
275293
model=model,
276294
processing_class=tokenizer,

QEfficient/finetune/experimental/core/callbacks.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66
# -----------------------------------------------------------------------------
77

88
import json
9+
import logging
10+
import math
911
import os
12+
import time
13+
from pathlib import Path
1014
from typing import Any, Dict, Optional
1115

16+
import torch.distributed as dist
1217
from transformers import (
1318
DefaultFlowCallback,
1419
EarlyStoppingCallback,
@@ -20,6 +25,8 @@
2025
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
2126

2227
from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry
28+
from QEfficient.finetune.experimental.core.config_manager import ConfigManager
29+
from QEfficient.finetune.experimental.core.logger import Logger
2330
from QEfficient.finetune.experimental.core.utils.profiler_utils import (
2431
get_op_verifier_ctx,
2532
init_qaic_profiling,
@@ -32,6 +39,117 @@
3239
registry.callback("tensorboard")(TensorBoardCallback)
3340

3441

42+
# Making the rank 0 as default process to log information
43+
def is_main_process() -> bool:
44+
return not dist.is_initialized() or dist.get_rank() == 0
45+
46+
47+
logger = Logger(__name__)
48+
# Extracting the user input name for the log file
49+
log_file_name = ConfigManager().config.training["log_file_name"]
50+
51+
52+
@registry.callback("train_logger")
53+
class TrainingLogger(TrainerCallback):
54+
"""
55+
A [`TrainerCallback`] that logs per epoch time, training metric (perplexity),training loss, evaluation metrics and loss etc.
56+
These are only logged for rank = 0.
57+
"""
58+
59+
def __init__(self, log_file: str | None = log_file_name):
60+
# Log file setup
61+
output_dir = Path(ConfigManager().config.training["output_dir"])
62+
self.log_file = os.path.join(output_dir, log_file)
63+
# Ensure directory exists
64+
os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
65+
self.epoch_start_time = None
66+
self.best_eval_loss = float("inf")
67+
68+
# ----------------------------------------------------
69+
# Safe write to log (only rank 0)
70+
# ----------------------------------------------------
71+
def write(self, text):
72+
if not is_main_process():
73+
return
74+
logger.log_rank_zero(text)
75+
try:
76+
with open(self.log_file, "a") as f:
77+
f.write(text + "\n")
78+
f.flush()
79+
os.fsync(f.fileno())
80+
81+
except OSError:
82+
logging.exception("Failed to write to log file: %s", self.log_file)
83+
84+
# ----------------------------------------------------
85+
# EPOCH BEGIN
86+
# ----------------------------------------------------
87+
def on_epoch_begin(self, args, state, control, **kwargs):
88+
if not is_main_process():
89+
return
90+
91+
epoch = int(state.epoch) + 1
92+
self.epoch_start_time = time.time()
93+
if state.is_world_process_zero:
94+
self.write(f"TRAINING INFO: Starting epoch {epoch}/{int(args.num_train_epochs)}")
95+
96+
# ----------------------------------------------------
97+
# EVALUATION
98+
# ----------------------------------------------------
99+
def on_evaluate(self, args, state, control, metrics, **kwargs):
100+
if not is_main_process():
101+
return
102+
103+
epoch = int(state.epoch)
104+
eval_loss = None
105+
eval_metric = None
106+
107+
for entry in reversed(state.log_history):
108+
if "eval_loss" in entry:
109+
eval_loss = entry["eval_loss"]
110+
break
111+
if eval_loss is not None:
112+
eval_metric = math.exp(eval_loss)
113+
# Track best eval loss
114+
if eval_loss is not None and eval_loss < self.best_eval_loss:
115+
self.best_eval_loss = eval_loss
116+
if state.is_world_process_zero:
117+
self.write(f"EVALUATION INFO: Best eval loss on epoch {epoch} is {eval_loss:.4f}")
118+
if state.is_world_process_zero:
119+
self.write(f"EVALUATION INFO: Epoch {epoch}: Eval Loss: {eval_loss:.4f} || Eval metric: {eval_metric:.4f}")
120+
121+
# ----------------------------------------------------
122+
# EPOCH END — TRAIN LOSS + METRIC + TIME
123+
# ----------------------------------------------------
124+
def on_epoch_end(self, args, state, control, **kwargs):
125+
if not is_main_process():
126+
return
127+
128+
epoch = int(state.epoch)
129+
epoch_time = time.time() - self.epoch_start_time
130+
131+
# Extract the last recorded train loss
132+
train_loss = None
133+
for entry in reversed(state.log_history):
134+
if "loss" in entry:
135+
train_loss = entry["loss"]
136+
break
137+
# Compute perplexity safely
138+
train_metric = None
139+
if train_loss is not None:
140+
train_metric = math.exp(train_loss)
141+
142+
if state.is_world_process_zero:
143+
self.write(
144+
f"TRAINING INFO: Epoch {epoch}: "
145+
f" Train epoch loss: {train_loss:.4f} || "
146+
f" Train metric: {train_metric} || "
147+
f" Epoch time {epoch_time:.2f} sec"
148+
)
149+
state.log_history.append({"train/epoch_time_sec": epoch_time, "epoch": state.epoch})
150+
control.should_log = True
151+
152+
35153
@registry.callback("enhanced_progressbar")
36154
class EnhancedProgressCallback(ProgressCallback):
37155
"""
@@ -233,3 +351,14 @@ def replace_progress_callback(trainer: Any, callbacks: list[Any], logger: Any =
233351
import warnings
234352

235353
warnings.warn(f"Could not add enhanced progress callback: {e}")
354+
try:
355+
# Add Train Logger
356+
train_logger = ComponentFactory.create_callback("train_logger")
357+
trainer.add_callback(train_logger)
358+
except Exception as e:
359+
if logger:
360+
logger.log_rank_zero(f"Warning: Could not add train logger callback: {e}", level="warning")
361+
else:
362+
import warnings
363+
364+
warnings.warn(f"Could not add train warning callback: {e}")

QEfficient/finetune/experimental/core/config_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,10 @@ class TrainingConfig:
330330
default="./training_results",
331331
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
332332
)
333+
log_file_name: str = field(
334+
default="training_logs.txt",
335+
metadata={"help": "The log_file output name."},
336+
)
333337
overwrite_output_dir: bool = field(
334338
default=False,
335339
metadata={"help": "Whether to overwrite the output directory."},

0 commit comments

Comments
 (0)