Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion QEfficient/cloud/finetune_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from pathlib import Path
from typing import Any, Dict, List, Tuple

from QEfficient.finetune.experimental.core.callbacks import replace_progress_callback
from peft import get_peft_model

from QEfficient.finetune.experimental.core.callbacks import TrainingLogger, replace_progress_callback
from QEfficient.finetune.experimental.core.component_registry import ComponentFactory
from QEfficient.finetune.experimental.core.config_manager import (
ConfigManager,
Expand All @@ -29,6 +31,7 @@
from QEfficient.finetune.experimental.core.utils.training_config_utils import prepare_training_config

logger = Logger(__name__)
train_logger = TrainingLogger(rank=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In DDP case, this will fail I think. Please check. I believe we can't hardcode 0 here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will change this


# Try importing QAIC-specific module, proceed without it if it's unavailable
try:
Expand Down Expand Up @@ -98,6 +101,7 @@ def get_trainer(self):

def _setup_environment(self) -> None:
"""Set up environment variables for output directories."""
self.rank = int(os.environ.get("RANK", "0"))
os.environ["OUTPUT_DIR"] = str(self.output_dir)
os.environ["TRACKIO_DIR"] = str(self.output_dir / "trackio_logs")
os.environ["TENSORBOARD_LOGGING_DIR"] = str(self.output_dir)
Expand Down Expand Up @@ -245,10 +249,21 @@ def _create_trainer(
dependencies = {}
if peft_config is not None:
dependencies["peft_config"] = peft_config
if self.rank == 0:
model_configuration = get_peft_model(model, peft_config)
trainable_params, all_param = model_configuration.get_nb_trainable_parameters()
pct = (trainable_params / all_param) * 100
model_configuration.unload() # Removing the peft adapters
train_logger._write(f"TRAINING INFO: Model has {all_param / 1e6:.4f} Million params.")
train_logger._write(
f"TRAINING INFO: Trainable params: {trainable_params} || "
f"all params: {all_param} || trainable%: {pct:.4f}"
)
trainer_cls, args_cls, additional_kwargs = ComponentFactory.create_trainer_config(trainer_type, **dependencies)

# Clean up training config: remove fields that shouldn't be passed to TrainingArguments
training_config.pop("device", None)
training_config.pop("log_file_name", None)
# Note: torch_dtype was already converted to fp16/bf16 flag in prepare_training_config
training_config.pop("deepspeed_config", None)
training_config.pop("torch_dtype", None)
Expand All @@ -271,6 +286,11 @@ def _create_trainer(
subset_eval_indices = list(range(0, int(num_samples - num_samples * split_ratio)))
eval_dataset = eval_dataset.select(subset_eval_indices)
train_dataset = train_dataset.select(subset_train_indices)
# Logging the number of training and evaluation samples
if self.rank == 0:
train_logger._write(f"TRAINING INFO: Length of Training Dataset is {len(train_dataset)}")
train_logger._write(f"TRAINING INFO: Length of Evaluation Dataset is {len(eval_dataset)}")

trainer = trainer_cls(
model=model,
processing_class=tokenizer,
Expand Down
120 changes: 120 additions & 0 deletions QEfficient/finetune/experimental/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
# -----------------------------------------------------------------------------

import json
import math
import os
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional

from transformers import (
Expand All @@ -20,6 +24,8 @@
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState

from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry
from QEfficient.finetune.experimental.core.config_manager import ConfigManager
from QEfficient.finetune.experimental.core.logger import Logger
from QEfficient.finetune.experimental.core.utils.profiler_utils import (
get_op_verifier_ctx,
init_qaic_profiling,
Expand All @@ -31,6 +37,109 @@
registry.callback("default_flow")(DefaultFlowCallback)
registry.callback("tensorboard")(TensorBoardCallback)

logger = Logger(__name__)

# Setting the path for dumping the log file
output_dir = Path(ConfigManager().config.training["output_dir"])
log_file = os.path.join(output_dir, f"training_logs_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt")


@registry.callback("train_logger")
class TrainingLogger(TrainerCallback):
"""
A [`TrainerCallback`] that logs per epoch time, training metric (perplexity),training loss, evaluation metrics and loss etc.
These are only logged for rank = 0.
"""

def __init__(self, log_file=log_file, rank=0):
self.rank = rank # rank-safe logging (only rank 0)
# Log file setup
self.log_file = log_file
# Ensure directory exists
os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
self.epoch_start_time = None
self.best_eval_loss = float("inf")

# ----------------------------------------------------
# Safe write to log (only rank 0)
# ----------------------------------------------------
def _write(self, text):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually single underscore at the front is for private methods. But _write method is called outside function at finetune_experimental. Please check

if self.rank != 0:
return
logger.log_rank_zero(text)
with open(self.log_file, "a") as f:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to put inside try block, to catch any write errors

f.write(text + "\n")
f.flush()
os.fsync(f.fileno())

# ----------------------------------------------------
# EPOCH BEGIN
# ----------------------------------------------------
def on_epoch_begin(self, args, state, control, **kwargs):
if self.rank != 0:
return

epoch = int(state.epoch) + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it is +1 , but in other methods it is just state.epoch.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we increment by one to make epoch 0 to 1, for better logging

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So logging will start from EPOCH 1, is it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are running it for 5 epochs, instead of representing the logs from epochs 0- 4 we just represent it as epochs 1-5. The information is logged from 1st epoch or epoch 0 itself as usual.

self.epoch_start_time = time.time()
if state.is_world_process_zero:
self._write(f"TRAINING INFO: Starting epoch {epoch}/{int(args.num_train_epochs)}")

# ----------------------------------------------------
# EVALUATION
# ----------------------------------------------------
def on_evaluate(self, args, state, control, metrics, **kwargs):
if self.rank != 0:
return

epoch = int(state.epoch)
eval_loss = None
eval_metric = None

for entry in reversed(state.log_history):
if "eval_loss" in entry:
eval_loss = entry["eval_loss"]
break
if eval_loss is not None:
eval_metric = math.exp(eval_loss)
# Track best eval loss
if eval_loss is not None and eval_loss < self.best_eval_loss:
self.best_eval_loss = eval_loss
if state.is_world_process_zero:
self._write(f"EVALUATION INFO: Best eval loss on epoch {epoch} is {eval_loss:.4f}")
if state.is_world_process_zero:
self._write(f"EVALUATION INFO: Epoch {epoch}: Eval Loss: {eval_loss:.4f} || Eval metric: {eval_metric:.4f}")

# ----------------------------------------------------
# EPOCH END — TRAIN LOSS + METRIC + TIME
# ----------------------------------------------------
def on_epoch_end(self, args, state, control, **kwargs):
if self.rank != 0:
return

epoch = int(state.epoch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHeck this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, since this is at the end of the epoch, the state.epoch is already incremented by 1 i.e. epoch 0 after the training of 0th epoch is done, state.epoch value is changes to one so need not add +1 here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same reason for not incrementing state.epoch by 1 in eval logs as well, as it is done after train epoch is completed and at that point the state.epoch value is incremented.

epoch_time = time.time() - self.epoch_start_time

# Extract the last recorded train loss
train_loss = None
for entry in reversed(state.log_history):
if "loss" in entry:
train_loss = entry["loss"]
break

# Compute perplexity safely
train_metric = None
if train_loss is not None:
train_metric = math.exp(train_loss)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verify the train_metric values, check if there is a step wise match, wrt to the old FT stack. Use the same sdk, and same seed and data_seed on both stacks, for reproducibility

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also use try block and handle in case metric value overflows

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will add this check

if state.is_world_process_zero:
self._write(
f"TRAINING INFO: Epoch {epoch}: "
f" Train epoch loss: {train_loss:.4f} || "
f" Train metric: {train_metric} || "
f" Epoch time {epoch_time:.2f} sec"
)
state.log_history.append({"train/epoch_time_sec": epoch_time, "epoch": state.epoch})
control.should_log = True


@registry.callback("enhanced_progressbar")
class EnhancedProgressCallback(ProgressCallback):
Expand Down Expand Up @@ -233,3 +342,14 @@ def replace_progress_callback(trainer: Any, callbacks: list[Any], logger: Any =
import warnings

warnings.warn(f"Could not add enhanced progress callback: {e}")
try:
# Add Train Logger
train_logger = ComponentFactory.create_callback("train_logger")
trainer.add_callback(train_logger)
except Exception as e:
if logger:
logger.log_rank_zero(f"Warning: Could not add train logger callback: {e}", level="warning")
else:
import warnings

warnings.warn(f"Could not add train warning callback: {e}")
4 changes: 4 additions & 0 deletions QEfficient/finetune/experimental/core/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,10 @@ class TrainingConfig:
default="./training_results",
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
log_file_name: str = field(
default="training_logs.txt",
metadata={"help": "The log_file output name."},
)
overwrite_output_dir: bool = field(
default=False,
metadata={"help": "Whether to overwrite the output directory."},
Expand Down