-
Notifications
You must be signed in to change notification settings - Fork 76
Add old stack logging support to new stack #889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: ft_experimental
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,11 @@ | |
| # ----------------------------------------------------------------------------- | ||
|
|
||
| import json | ||
| import math | ||
| import os | ||
| import time | ||
| from datetime import datetime | ||
| from pathlib import Path | ||
| from typing import Any, Dict, Optional | ||
|
|
||
| from transformers import ( | ||
|
|
@@ -20,6 +24,8 @@ | |
| from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState | ||
|
|
||
| from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry | ||
| from QEfficient.finetune.experimental.core.config_manager import ConfigManager | ||
| from QEfficient.finetune.experimental.core.logger import Logger | ||
| from QEfficient.finetune.experimental.core.utils.profiler_utils import ( | ||
| get_op_verifier_ctx, | ||
| init_qaic_profiling, | ||
|
|
@@ -31,6 +37,109 @@ | |
| registry.callback("default_flow")(DefaultFlowCallback) | ||
| registry.callback("tensorboard")(TensorBoardCallback) | ||
|
|
||
| logger = Logger(__name__) | ||
|
|
||
| # Setting the path for dumping the log file | ||
| output_dir = Path(ConfigManager().config.training["output_dir"]) | ||
| log_file = os.path.join(output_dir, f"training_logs_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt") | ||
|
|
||
|
|
||
| @registry.callback("train_logger") | ||
| class TrainingLogger(TrainerCallback): | ||
| """ | ||
| A [`TrainerCallback`] that logs per epoch time, training metric (perplexity),training loss, evaluation metrics and loss etc. | ||
| These are only logged for rank = 0. | ||
| """ | ||
|
|
||
| def __init__(self, log_file=log_file, rank=0): | ||
| self.rank = rank # rank-safe logging (only rank 0) | ||
| # Log file setup | ||
| self.log_file = log_file | ||
| # Ensure directory exists | ||
| os.makedirs(os.path.dirname(self.log_file), exist_ok=True) | ||
| self.epoch_start_time = None | ||
| self.best_eval_loss = float("inf") | ||
|
|
||
| # ---------------------------------------------------- | ||
| # Safe write to log (only rank 0) | ||
| # ---------------------------------------------------- | ||
| def _write(self, text): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Usually single underscore at the front is for private methods. But _write method is called outside function at |
||
| if self.rank != 0: | ||
| return | ||
| logger.log_rank_zero(text) | ||
| with open(self.log_file, "a") as f: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be better to put inside try block, to catch any write errors |
||
| f.write(text + "\n") | ||
| f.flush() | ||
| os.fsync(f.fileno()) | ||
|
|
||
| # ---------------------------------------------------- | ||
| # EPOCH BEGIN | ||
| # ---------------------------------------------------- | ||
| def on_epoch_begin(self, args, state, control, **kwargs): | ||
| if self.rank != 0: | ||
| return | ||
|
|
||
| epoch = int(state.epoch) + 1 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here it is +1 , but in other methods it is just
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, we increment by one to make epoch 0 to 1, for better logging
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So logging will start from EPOCH 1, is it?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we are running it for 5 epochs, instead of representing the logs from epochs 0- 4 we just represent it as epochs 1-5. The information is logged from 1st epoch or epoch 0 itself as usual. |
||
| self.epoch_start_time = time.time() | ||
| if state.is_world_process_zero: | ||
| self._write(f"TRAINING INFO: Starting epoch {epoch}/{int(args.num_train_epochs)}") | ||
|
|
||
| # ---------------------------------------------------- | ||
| # EVALUATION | ||
| # ---------------------------------------------------- | ||
| def on_evaluate(self, args, state, control, metrics, **kwargs): | ||
| if self.rank != 0: | ||
| return | ||
|
|
||
| epoch = int(state.epoch) | ||
| eval_loss = None | ||
| eval_metric = None | ||
|
|
||
| for entry in reversed(state.log_history): | ||
| if "eval_loss" in entry: | ||
| eval_loss = entry["eval_loss"] | ||
| break | ||
| if eval_loss is not None: | ||
| eval_metric = math.exp(eval_loss) | ||
| # Track best eval loss | ||
| if eval_loss is not None and eval_loss < self.best_eval_loss: | ||
| self.best_eval_loss = eval_loss | ||
| if state.is_world_process_zero: | ||
| self._write(f"EVALUATION INFO: Best eval loss on epoch {epoch} is {eval_loss:.4f}") | ||
| if state.is_world_process_zero: | ||
| self._write(f"EVALUATION INFO: Epoch {epoch}: Eval Loss: {eval_loss:.4f} || Eval metric: {eval_metric:.4f}") | ||
|
|
||
| # ---------------------------------------------------- | ||
| # EPOCH END — TRAIN LOSS + METRIC + TIME | ||
| # ---------------------------------------------------- | ||
| def on_epoch_end(self, args, state, control, **kwargs): | ||
| if self.rank != 0: | ||
| return | ||
|
|
||
| epoch = int(state.epoch) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CHeck this
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, since this is at the end of the epoch, the state.epoch is already incremented by 1 i.e. epoch 0 after the training of 0th epoch is done, state.epoch value is changes to one so need not add +1 here
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the same reason for not incrementing state.epoch by 1 in eval logs as well, as it is done after train epoch is completed and at that point the state.epoch value is incremented. |
||
| epoch_time = time.time() - self.epoch_start_time | ||
|
|
||
| # Extract the last recorded train loss | ||
| train_loss = None | ||
| for entry in reversed(state.log_history): | ||
| if "loss" in entry: | ||
| train_loss = entry["loss"] | ||
| break | ||
|
|
||
| # Compute perplexity safely | ||
| train_metric = None | ||
| if train_loss is not None: | ||
| train_metric = math.exp(train_loss) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Verify the train_metric values, check if there is a step wise match, wrt to the old FT stack. Use the same sdk, and same seed and data_seed on both stacks, for reproducibility
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also use try block and handle in case metric value overflows
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, will add this check |
||
| if state.is_world_process_zero: | ||
| self._write( | ||
| f"TRAINING INFO: Epoch {epoch}: " | ||
| f" Train epoch loss: {train_loss:.4f} || " | ||
| f" Train metric: {train_metric} || " | ||
| f" Epoch time {epoch_time:.2f} sec" | ||
| ) | ||
| state.log_history.append({"train/epoch_time_sec": epoch_time, "epoch": state.epoch}) | ||
| control.should_log = True | ||
|
|
||
|
|
||
| @registry.callback("enhanced_progressbar") | ||
| class EnhancedProgressCallback(ProgressCallback): | ||
|
|
@@ -233,3 +342,14 @@ def replace_progress_callback(trainer: Any, callbacks: list[Any], logger: Any = | |
| import warnings | ||
|
|
||
| warnings.warn(f"Could not add enhanced progress callback: {e}") | ||
| try: | ||
| # Add Train Logger | ||
| train_logger = ComponentFactory.create_callback("train_logger") | ||
| trainer.add_callback(train_logger) | ||
| except Exception as e: | ||
| if logger: | ||
| logger.log_rank_zero(f"Warning: Could not add train logger callback: {e}", level="warning") | ||
| else: | ||
| import warnings | ||
|
|
||
| warnings.warn(f"Could not add train warning callback: {e}") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In DDP case, this will fail I think. Please check. I believe we can't hardcode 0 here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will change this