diff --git a/docs/float8.md b/docs/float8.md index 5d90e0617..1a7277ff3 100644 --- a/docs/float8.md +++ b/docs/float8.md @@ -11,14 +11,14 @@ USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git For float8 with tensorwise scaling, launch training job with the following command (or alternatively set configs in toml files) ``` -CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --training.compile +CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --compile.enable ``` * `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul. * `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth. * `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter. * `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using. * **Auto-filter**: add `"auto_filter_small_kn"` as one of the `--float8.filter_fqns=...` to to enable automatic module filtering, which will automatically not convert linear layers are not large enough to benefit from float8 training, since the GEMM has to be big enough that the speedup from using FP8 tensorcores is greater than the overhead of creating dynamically quantized inputs. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs, where (K,N) represents the linear layer weight shape. For best performance, you should still manually filter out layers that are too small to benefit from float8 training. -* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels +* `--compile.enable` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels For float8 with rowwise scaling, launch training job with the following command (or alternatively set configs in toml files) ``` @@ -26,7 +26,7 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_trai ``` * `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul. * `--float8.recipe_name="rowwise"`: use the rowwise scaling recipe for higher accuracy compared to tensorwise scaling -* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels +* `--compile.enable` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels For parallelisms, for float8 with tensorwise scaling we support float8 all-gather for FSDP (optional) and for TP (by default for `Float8Linear`). For float8 with rowwise scaling, all distributed communication is done in high precision. diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 510cc394f..8103ae0b5 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -33,9 +33,9 @@ def estimate_memory(job_config: JobConfig): # Get the world size world_size = int(os.environ["WORLD_SIZE"]) - if job_config.training.compile or job_config.parallelism.enable_compiled_autograd: + if job_config.compile.enable or job_config.parallelism.enable_compiled_autograd: logger.info("Compile mode is not supported yet. Switching to eager mode.") - job_config.training.compile = False + job_config.compile.enable = False job_config.parallelism.enable_compiled_autograd = False # init fake pg diff --git a/scripts/estimate/run_memory_estimation.sh b/scripts/estimate/run_memory_estimation.sh index e8f9ecc88..9d766a07a 100755 --- a/scripts/estimate/run_memory_estimation.sh +++ b/scripts/estimate/run_memory_estimation.sh @@ -23,4 +23,4 @@ fi # Export WORLD_SIZE and LOCAL_RANK export WORLD_SIZE=$((NGPU * NNODES)) export LOCAL_RANK=0 -python -m scripts.estimate.estimation --job.config_file ${CONFIG_FILE} --memory_estimation.enabled $overrides +python -m scripts.estimate.estimation --job.config_file ${CONFIG_FILE} --memory_estimation.enable $overrides diff --git a/tests/integration_tests.py b/tests/integration_tests.py index f7512836c..73ded4548 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -57,7 +57,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.compile", + "--compile.enable", ], ], "1D compile", @@ -66,7 +66,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.compile", + "--compile.enable", "--activation_checkpoint.mode selective", "--activation_checkpoint.selective_ac_option op", ], @@ -86,7 +86,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.compile", + "--compile.enable", "--parallelism.tensor_parallel_degree 2", ], ], @@ -97,7 +97,7 @@ def build_test_list(): # OverrideDefinitions( # [ # [ - # "--training.compile", + # "--compile.enable", # "--parallelism.tensor_parallel_degree 2", # "--parallelism.enable_async_tensor_parallel", # ], @@ -267,7 +267,7 @@ def build_test_list(): "--parallelism.pipeline_parallel_degree 2", "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", - "--training.compile", + "--compile.enable", ], ], "PP+DP+TP 3D test with torch.compile", @@ -464,7 +464,7 @@ def build_test_list(): # OverrideDefinitions( # [ # [ - # "--memory_estimation.enabled", + # "--memory_estimation.enable", # ] # ], # "FSDP2 Memory Tracking and Estimation", diff --git a/tests/integration_tests_h100.py b/tests/integration_tests_h100.py index 29c11476b..b45d6f315 100755 --- a/tests/integration_tests_h100.py +++ b/tests/integration_tests_h100.py @@ -32,7 +32,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.compile", + "--compile.enable", "--parallelism.tensor_parallel_degree 2", "--parallelism.enable_async_tensor_parallel", ], @@ -54,7 +54,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.compile", + "--compile.enable", "--parallelism.data_parallel_shard_degree=2", "--parallelism.tensor_parallel_degree=2", "--parallelism.pipeline_parallel_degree=2", @@ -71,7 +71,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.compile", + "--compile.enable", "--parallelism.data_parallel_shard_degree=2", "--parallelism.data_parallel_replicate_degree=2", "--parallelism.context_parallel_degree=2", diff --git a/torchtitan/components/loss.py b/torchtitan/components/loss.py index 6aa1dd569..84ae78683 100644 --- a/torchtitan/components/loss.py +++ b/torchtitan/components/loss.py @@ -24,7 +24,7 @@ def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor def build_cross_entropy_loss(job_config: JobConfig): loss_fn = cross_entropy_loss - if job_config.training.compile: + if job_config.compile.enable and "loss" in job_config.compile.components: logger.info("Compiling the loss function with torch.compile") loss_fn = torch.compile(loss_fn) return loss_fn diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 362925815..22134c65f 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -28,8 +28,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.enabled = False float8_config: Float8 = job_config.float8 + compile_config = job_config.compile + model_compile_enabled = ( + compile_config.enable and "model" in compile_config.components + ) if has_cuda_capability(8, 9) or ( - float8_config.emulate and not job_config.training.compile + float8_config.emulate and not model_compile_enabled ): pass else: diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index 15c74b7fd..84216dadb 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -52,9 +52,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ), "MXFP8 is only supported on SM100 or architectures" # TP not yet supported with torch.compile + + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) assert not ( - job_config.training.compile - and job_config.parallelism.tensor_parallel_degree > 1 + model_compile_enabled and job_config.parallelism.tensor_parallel_degree > 1 ), "TP not yet supported with torch.compile for mxfp8" # For MoE training with mxfp8, token group sizes must be multiples of 32 diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 9a78451fc..a688cdada 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -218,9 +218,6 @@ class Training: This feature only takes effect when data_parallel_shard_degree > 1 """ - compile: bool = False - """Whether to compile the model""" - gc_freq: int = 50 """Python garbage control scheduling interval, in steps""" @@ -550,6 +547,17 @@ class ActivationCheckpoint: """ +@dataclass +class Compile: + enable: bool = False + """Whether to apply torch.compile""" + + components: list[Literal["model", "loss"]] = field( + default_factory=lambda: ["model", "loss"] + ) + """Which components to compile""" + + @dataclass class Float8: enable_fsdp_float8_all_gather: bool = False @@ -630,7 +638,7 @@ class Comm: @dataclass class MemoryEstimation: - enabled: bool = False + enable: bool = False """Whether to estimate memory usage for FSDP""" disable_fake_mode: bool = False @@ -747,6 +755,7 @@ class JobConfig: activation_checkpoint: ActivationCheckpoint = field( default_factory=ActivationCheckpoint ) + compile: Compile = field(default_factory=Compile) float8: Float8 = field(default_factory=Float8) mx: MX = field(default_factory=MX) comm: Comm = field(default_factory=Comm) diff --git a/torchtitan/experiments/flux/loss.py b/torchtitan/experiments/flux/loss.py index 9159b40b8..6bf93f9d0 100644 --- a/torchtitan/experiments/flux/loss.py +++ b/torchtitan/experiments/flux/loss.py @@ -21,7 +21,7 @@ def mse_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: def build_mse_loss(job_config: JobConfig): loss_fn = mse_loss - if job_config.training.compile: + if job_config.compile.enable and "loss" in job_config.compile.components: logger.info("Compiling the loss function with torch.compile") loss_fn = torch.compile(loss_fn) return loss_fn diff --git a/torchtitan/experiments/flux/train_configs/debug_model.toml b/torchtitan/experiments/flux/train_configs/debug_model.toml index e565b23bd..9be99b042 100644 --- a/torchtitan/experiments/flux/train_configs/debug_model.toml +++ b/torchtitan/experiments/flux/train_configs/debug_model.toml @@ -36,7 +36,6 @@ decay_ratio = 0.0 # no decay, stay stable during training local_batch_size = 4 max_norm = 2.0 # grad norm clipping steps = 10 -compile = false dataset = "cc12m-test" classifier_free_guidance_prob = 0.447 img_size = 256 diff --git a/torchtitan/experiments/flux/train_configs/flux_dev_model.toml b/torchtitan/experiments/flux/train_configs/flux_dev_model.toml index 5fbdcb6fc..083ad7977 100644 --- a/torchtitan/experiments/flux/train_configs/flux_dev_model.toml +++ b/torchtitan/experiments/flux/train_configs/flux_dev_model.toml @@ -35,7 +35,6 @@ decay_ratio = 0.0 # no decay local_batch_size = 32 max_norm = 1.0 # grad norm clipping steps = 30_000 -compile = false dataset = "cc12m-wds" classifier_free_guidance_prob = 0.447 img_size = 256 diff --git a/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml b/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml index d479710e6..0a9cce71c 100644 --- a/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml +++ b/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml @@ -35,7 +35,6 @@ decay_ratio = 0.0 # no decay local_batch_size = 64 max_norm = 1.0 # grad norm clipping steps = 30_000 -compile = false dataset = "cc12m-wds" classifier_free_guidance_prob = 0.447 img_size = 256 diff --git a/torchtitan/experiments/forge/job_config.py b/torchtitan/experiments/forge/job_config.py index 56602e352..f65488b01 100644 --- a/torchtitan/experiments/forge/job_config.py +++ b/torchtitan/experiments/forge/job_config.py @@ -11,6 +11,7 @@ ActivationCheckpoint, Checkpoint, Comm, + Compile, Float8, LRScheduler, Model, @@ -31,6 +32,7 @@ class ForgeJobConfig: activation_checkpoint: ActivationCheckpoint = field( default_factory=ActivationCheckpoint ) + compile: Compile = field(default_factory=Compile) float8: Float8 = field(default_factory=Float8) comm: Comm = field(default_factory=Comm) diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 35a72167d..a716c7890 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -62,12 +62,15 @@ def parallelize_llama( ): raise NotImplementedError("CP support for FlexAttention is still in progress.") + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel - and not job_config.training.compile + and not model_compile_enabled ): - raise RuntimeError("Async TP requires --training.compile") + raise RuntimeError("Async TP requires torch.compile") enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( @@ -107,7 +110,7 @@ def parallelize_llama( apply_ac(model, job_config.activation_checkpoint) # turn on per-TransformerBlock compile after AC wrapping and before FSDP - if job_config.training.compile: + if model_compile_enabled: # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE torch._dynamo.config.capture_scalar_outputs = True apply_compile(model) diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index f445b2ad7..0bdb16ecb 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -41,7 +41,6 @@ local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping steps = 10 -compile = false dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -68,6 +67,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] mode = "selective" # ["none", "selective", "full"] selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy +[compile] +enable=false +components = ["model", "loss"] + [float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml index cb69e63e2..c40437b37 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml @@ -34,7 +34,6 @@ local_batch_size = 1 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 3000 -compile = false dataset = "c4" [parallelism] @@ -60,6 +59,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = "full" # ["none", "selective", "full"] +[compile] +enable=false +components = ["model", "loss"] + [float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml index 4e7416fd2..ab718cf6f 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml @@ -34,7 +34,6 @@ local_batch_size = 8 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 3000 -compile = false dataset = "c4" [parallelism] @@ -58,6 +57,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = "full" # ["none", "selective", "full"] +[compile] +enable=false +components = ["model", "loss"] + [float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false diff --git a/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml b/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml index 5c73423af..38dc25949 100644 --- a/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml +++ b/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml @@ -31,7 +31,6 @@ local_batch_size = 4 seq_len = 4096 max_norm = 1.0 # grad norm clipping steps = 10 -compile = false dataset = "c4" [parallelism] @@ -53,6 +52,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] mode = "selective" # ["none", "selective", "full"] selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy +[compile] +enable=false +components = ["model", "loss"] + [float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false diff --git a/torchtitan/experiments/simple_fsdp/README.md b/torchtitan/experiments/simple_fsdp/README.md index 82eeec877..43edc6d80 100644 --- a/torchtitan/experiments/simple_fsdp/README.md +++ b/torchtitan/experiments/simple_fsdp/README.md @@ -13,7 +13,7 @@ This folder includes an experimental frontend implementation for [SimpleFSDP: Si ### Enable SimpleFSDP Training ```bash -CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.name llama3_simple_fsdp --training.compile +CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.name llama3_simple_fsdp --compile.enable ``` ### Composability Support diff --git a/torchtitan/experiments/simple_fsdp/parallelize.py b/torchtitan/experiments/simple_fsdp/parallelize.py index ef02a4bf6..4d909e4fe 100644 --- a/torchtitan/experiments/simple_fsdp/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/parallelize.py @@ -37,12 +37,15 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel - and not job_config.training.compile + and not model_compile_enabled ): - raise RuntimeError("Async TP requires --training.compile") + raise RuntimeError("Async TP requires torch.compile") enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( @@ -99,7 +102,7 @@ def parallelize_llama( ) logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode) - if job_config.training.compile: + if model_compile_enabled: torch._inductor.config.reorder_for_peak_memory = False model = torch.compile(model, fullgraph=True) diff --git a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py index d0579adcd..45e99aa6c 100755 --- a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py +++ b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py @@ -220,7 +220,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): for idx, override_arg in enumerate(test_flavor.override_args): cmd = ( f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_train.sh " - f"--model.name llama3_simple_fsdp --training.compile " + f"--model.name llama3_simple_fsdp --compile.enable " ) # dump compile trace for debugging purpose cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 1a64d34a1..1aedd73ad 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -51,10 +51,13 @@ def parallelize_deepseekv3( ): raise NotImplementedError("CP support for FlexAttention is still in progress.") + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel - and not job_config.training.compile + and not model_compile_enabled ): raise RuntimeError("Async TP requires --training.compile") @@ -97,7 +100,7 @@ def parallelize_deepseekv3( if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) - if job_config.training.compile: + if model_compile_enabled: # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE torch._dynamo.config.capture_scalar_outputs = True apply_compile(model) diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index bb564bd38..79a15bd2e 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -41,7 +41,6 @@ local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping steps = 10 -compile = false dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -66,7 +65,11 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] -selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] [float8] enable_fsdp_float8_all_gather = false diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 15ce11bd0..3ef6e67fc 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -39,7 +39,6 @@ local_batch_size = 8 seq_len = 4096 max_norm = 1.0 # grad norm clipping steps = 1000 -compile = false dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -50,7 +49,7 @@ tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 pipeline_parallel_schedule = "Interleaved1F1B" -expert_parallel_degree = 1 +expert_parallel_degree = 8 expert_tensor_parallel_degree = 1 [checkpoint] @@ -62,7 +61,12 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" [activation_checkpoint] -mode = "full" # ["none", "selective", "full"] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=true +components = ["loss"] # ["model", "loss"] [float8] enable_fsdp_float8_all_gather = false diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 614719dd2..23dc315d0 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -46,7 +46,7 @@ dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 8 +tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 pipeline_parallel_schedule = "Interleaved1F1B" @@ -62,7 +62,12 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" [activation_checkpoint] -mode = "full" # ["none", "selective", "full"] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=true +components = ["loss"] # ["model", "loss"] [float8] enable_fsdp_float8_all_gather = false diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 6d9bf60c1..2e2e81302 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -63,12 +63,15 @@ def parallelize_llama( ): raise NotImplementedError("CP support for FlexAttention is still in progress.") + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel - and not job_config.training.compile + and not model_compile_enabled ): - raise RuntimeError("Async TP requires --training.compile") + raise RuntimeError("Async TP requires torch.compile") enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( @@ -93,7 +96,7 @@ def parallelize_llama( apply_ac(model, job_config.activation_checkpoint) # turn on per-TransformerBlock compile after AC wrapping and before FSDP - if job_config.training.compile: + if model_compile_enabled: apply_compile(model) if parallel_dims.fsdp_enabled: @@ -129,7 +132,7 @@ def parallelize_llama( apply_ddp( model, world_mesh, - enable_compile=job_config.training.compile, + enable_compile=model_compile_enabled, enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index 0607268a7..d446027f4 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -43,7 +43,6 @@ local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping steps = 10 -compile = false dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -67,6 +66,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] mode = "selective" # ["none", "selective", "full"] selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy +[compile] +enable=false +components = ["model", "loss"] + [float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false diff --git a/torchtitan/models/llama3/train_configs/llama3_405b.toml b/torchtitan/models/llama3/train_configs/llama3_405b.toml index 471ed981b..5895f7f25 100644 --- a/torchtitan/models/llama3/train_configs/llama3_405b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_405b.toml @@ -34,7 +34,6 @@ local_batch_size = 2 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 3000 -compile = true dataset = "c4" [parallelism] @@ -56,6 +55,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = "full" # ["none", "selective", "full"] +[compile] +enable=true +components = ["model", "loss"] + [float8] enable_fsdp_float8_all_gather = true precompute_float8_dynamic_scale_for_fsdp = true diff --git a/torchtitan/models/llama3/train_configs/llama3_70b.toml b/torchtitan/models/llama3/train_configs/llama3_70b.toml index 8a3f2018e..9a2eddd09 100644 --- a/torchtitan/models/llama3/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_70b.toml @@ -34,7 +34,6 @@ local_batch_size = 8 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 1000 -compile = false dataset = "c4" [parallelism] @@ -55,6 +54,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = "full" +[compile] +enable=false +components = ["model", "loss"] + [float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index 038f9b33f..d9a9c331f 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -34,7 +34,6 @@ local_batch_size = 1 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 1000 -compile = false dataset = "c4" [parallelism] @@ -52,6 +51,10 @@ last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] +[compile] +enable=false +components = ["model", "loss"] + [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy