Skip to content

Commit 755ce8f

Browse files
authored
allow saving config to file (#1904)
Here's an example output json file ``` { "job": { "config_file": "./torchtitan/models/llama3/train_configs/debug_model.toml", "dump_folder": "./outputs", "description": "Llama 3 debug training", "print_config": false, "save_config_folder": "config", "custom_config_module": "" }, "profiling": { "enable_profiling": false, "save_traces_folder": "profile_trace", "profile_freq": 10, "profiler_active": 1, "profiler_warmup": 3, "enable_memory_snapshot": false, "save_memory_snapshot_folder": "memory_snapshot" }, "metrics": { "log_freq": 1, "enable_tensorboard": false, "disable_color_printing": false, "save_tb_folder": "tb", "save_for_all_ranks": false, "enable_wandb": false }, "model": { "name": "llama3", "flavor": "debugmodel", "hf_assets_path": "./tests/assets/tokenizer", "tokenizer_path": null, "converters": [], "print_after_conversion": false }, "optimizer": { "name": "AdamW", "lr": 0.0008, "beta1": 0.9, "beta2": 0.95, "eps": 1e-08, "weight_decay": 0.1, "implementation": "fused", "early_step_in_backward": false }, "lr_scheduler": { "warmup_steps": 2, "decay_ratio": 0.8, "decay_type": "linear", "min_lr_factor": 0.0 }, "training": { "dataset": "c4_test", "dataset_path": null, "local_batch_size": 8, "global_batch_size": -1, "seq_len": 2048, "max_norm": 1.0, "steps": 10, "enable_cpu_offload": false, "dtype": "float32", "mixed_precision_param": "bfloat16", "mixed_precision_reduce": "float32", "gc_freq": 50, "gc_debug": false, "seed": null, "deterministic": false, "debug_moe_force_load_balance": false }, "parallelism": { "data_parallel_replicate_degree": 1, "enable_compiled_autograd": false, "data_parallel_shard_degree": -1, "fsdp_reshard_after_forward": "default", "tensor_parallel_degree": 1, "disable_loss_parallel": false, "enable_async_tensor_parallel": false, "pipeline_parallel_degree": 1, "module_fqns_per_model_part": null, "pipeline_parallel_first_stage_less_layers": 1, "pipeline_parallel_last_stage_less_layers": 1, "pipeline_parallel_layers_per_stage": null, "pipeline_parallel_schedule": "Interleaved1F1B", "pipeline_parallel_schedule_csv": "", "pipeline_parallel_microbatch_size": 1, "context_parallel_degree": 1, "context_parallel_rotate_method": "allgather", "expert_parallel_degree": 1, "expert_tensor_parallel_degree": 1 }, "checkpoint": { "enable": false, "enable_ft_dataloader_checkpoints": true, "folder": "checkpoint", "interval": 10, "initial_load_path": null, "initial_load_model_only": true, "initial_load_in_hf": false, "initial_load_in_hf_quantized": false, "last_save_model_only": false, "last_save_in_hf": false, "export_dtype": "float32", "async_mode": "disabled", "keep_latest_k": 10, "load_step": -1, "exclude_from_loading": [], "enable_first_step_checkpoint": false, "create_seed_checkpoint": false, "load_only": false }, "activation_checkpoint": { "mode": "selective", "selective_ac_option": "2", "per_op_sac_force_recompute_mm_shapes_by_fqns": [ "moe.router.gate" ], "early_stop": false, "memory_budget": 0.5, "visualize_memory_budget_pareto": false }, "compile": { "enable": false, "components": [ "model", "loss" ], "backend": "inductor" }, "quantize": { "linear": { "float8": { "enable_fsdp_float8_all_gather": false, "precompute_float8_dynamic_scale_for_fsdp": false, "recipe_name": null, "filter_fqns": [ "output" ], "emulate": false }, "mx": { "mxfp8_dim1_cast_kernel_choice": "triton", "recipe_name": "mxfp8_cublas", "filter_fqns": [ "output" ] } }, "grouped_mm": { "float8": { "fqns": [] }, "mx": { "recipe_name": "mxfp8", "fqns": [] } } }, "comm": { "init_timeout_seconds": 300, "train_timeout_seconds": 100, "trace_buf_size": 20000, "save_traces_folder": "comm_traces", "save_traces_file_prefix": "rank_" }, "memory_estimation": { "enable": false, "disable_fake_mode": false }, "fault_tolerance": { "enable": false, "process_group": "gloo", "process_group_timeout_ms": 10000, "replica_id": 0, "group_size": 0, "min_replica_size": 1, "semi_sync_method": null }, "experimental": { "custom_import": "", "custom_args_module": "" }, "validation": { "enable": false, "dataset": "c4_validation", "dataset_path": null, "local_batch_size": 8, "seq_len": 2048, "freq": 5, "steps": 10 } } ```
1 parent 75d4e4d commit 755ce8f

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

torchtitan/config/job_config.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,33 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import json
8+
9+
import os
710
from dataclasses import asdict, dataclass, field
811
from typing import Any, Literal
912

13+
import torch
14+
15+
from torchtitan.tools.logging import logger
16+
1017

1118
@dataclass
1219
class Job:
1320
config_file: str | None = None
14-
"""Job config file"""
21+
"""File to read job configs from"""
1522

16-
dump_folder: str = "./torchtitan/outputs"
23+
dump_folder: str = "./outputs"
1724
"""Folder to dump job outputs"""
1825

1926
description: str = "default job"
2027
"""Description of the job"""
2128

2229
print_config: bool = False
23-
"""Print the configs to terminal"""
30+
"""Print the job configs to terminal"""
31+
32+
save_config_folder: str | None = None
33+
"""Folder to save a job_config.json file"""
2434

2535
custom_config_module: str = ""
2636
"""
@@ -908,3 +918,22 @@ class JobConfig:
908918

909919
def to_dict(self) -> dict[str, Any]:
910920
return asdict(self)
921+
922+
def maybe_log(self) -> None:
923+
if self.job.print_config:
924+
logger.info(f"Running with configs: {self.to_dict()}")
925+
926+
if self.job.save_config_folder is not None:
927+
config_file = os.path.join(
928+
self.job.dump_folder, self.job.save_config_folder, "job_config.json"
929+
)
930+
if torch.distributed.is_initialized():
931+
if torch.distributed.get_rank() == 0:
932+
os.makedirs(os.path.dirname(config_file), exist_ok=True)
933+
with open(config_file, "w") as f:
934+
json.dump(self.to_dict(), f, indent=2)
935+
logger.info(f"Saved job configs to {config_file}")
936+
else:
937+
logger.warning(
938+
"Job configs logging is disabled due to torch.distributed not initialized."
939+
)

torchtitan/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,6 @@ def __init__(self, job_config: JobConfig):
7878
if job_config.experimental.custom_import:
7979
importlib.import_module(job_config.experimental.custom_import)
8080

81-
if job_config.job.print_config:
82-
logger.info(f"Running with args: {job_config.to_dict()}")
83-
8481
device_module, device_type = utils.device_module, utils.device_type
8582
self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
8683
# Device has to be set before creating TorchFT manager.
@@ -92,6 +89,9 @@ def __init__(self, job_config: JobConfig):
9289
enable_cpu_backend=job_config.training.enable_cpu_offload,
9390
base_folder=job_config.job.dump_folder,
9491
)
92+
93+
job_config.maybe_log()
94+
9595
world_size = int(os.environ["WORLD_SIZE"])
9696
parallelism_config = job_config.parallelism
9797
self.parallel_dims = parallel_dims = self._create_parallel_dims(

0 commit comments

Comments
 (0)