Skip to content

Commit af23d2e

Browse files
committed
[refactor] support compile model and loss separately
1 parent 46a32e7 commit af23d2e

31 files changed

+116
-53
lines changed

docs/float8.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,22 @@ USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
1111

1212
For float8 with tensorwise scaling, launch training job with the following command (or alternatively set configs in toml files)
1313
```
14-
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --training.compile
14+
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --compile.enable
1515
```
1616
* `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
1717
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
1818
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.
1919
* `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using.
2020
* **Auto-filter**: add `"auto_filter_small_kn"` as one of the `--float8.filter_fqns=...` to to enable automatic module filtering, which will automatically not convert linear layers are not large enough to benefit from float8 training, since the GEMM has to be big enough that the speedup from using FP8 tensorcores is greater than the overhead of creating dynamically quantized inputs. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs, where (K,N) represents the linear layer weight shape. For best performance, you should still manually filter out layers that are too small to benefit from float8 training.
21-
* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels
21+
* `--compile.enable` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels
2222

2323
For float8 with rowwise scaling, launch training job with the following command (or alternatively set configs in toml files)
2424
```
2525
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.recipe_name rowwise --training.compile
2626
```
2727
* `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
2828
* `--float8.recipe_name="rowwise"`: use the rowwise scaling recipe for higher accuracy compared to tensorwise scaling
29-
* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels
29+
* `--compile.enable` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels
3030

3131
For parallelisms, for float8 with tensorwise scaling we support float8 all-gather for FSDP (optional) and for TP (by default for `Float8Linear`). For float8 with rowwise scaling, all distributed communication is done in high precision.
3232

scripts/estimate/estimation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def estimate_memory(job_config: JobConfig):
3333
# Get the world size
3434
world_size = int(os.environ["WORLD_SIZE"])
3535

36-
if job_config.training.compile or job_config.parallelism.enable_compiled_autograd:
36+
if job_config.compile.enable or job_config.parallelism.enable_compiled_autograd:
3737
logger.info("Compile mode is not supported yet. Switching to eager mode.")
38-
job_config.training.compile = False
38+
job_config.compile.enable = False
3939
job_config.parallelism.enable_compiled_autograd = False
4040

4141
# init fake pg

scripts/estimate/run_memory_estimation.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ fi
2323
# Export WORLD_SIZE and LOCAL_RANK
2424
export WORLD_SIZE=$((NGPU * NNODES))
2525
export LOCAL_RANK=0
26-
python -m scripts.estimate.estimation --job.config_file ${CONFIG_FILE} --memory_estimation.enabled $overrides
26+
python -m scripts.estimate.estimation --job.config_file ${CONFIG_FILE} --memory_estimation.enable $overrides

tests/integration_tests.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def build_test_list():
5757
OverrideDefinitions(
5858
[
5959
[
60-
"--training.compile",
60+
"--compile.enable",
6161
],
6262
],
6363
"1D compile",
@@ -66,7 +66,7 @@ def build_test_list():
6666
OverrideDefinitions(
6767
[
6868
[
69-
"--training.compile",
69+
"--compile.enable",
7070
"--activation_checkpoint.mode selective",
7171
"--activation_checkpoint.selective_ac_option op",
7272
],
@@ -86,7 +86,7 @@ def build_test_list():
8686
OverrideDefinitions(
8787
[
8888
[
89-
"--training.compile",
89+
"--compile.enable",
9090
"--parallelism.tensor_parallel_degree 2",
9191
],
9292
],
@@ -97,7 +97,7 @@ def build_test_list():
9797
# OverrideDefinitions(
9898
# [
9999
# [
100-
# "--training.compile",
100+
# "--compile.enable",
101101
# "--parallelism.tensor_parallel_degree 2",
102102
# "--parallelism.enable_async_tensor_parallel",
103103
# ],
@@ -267,7 +267,7 @@ def build_test_list():
267267
"--parallelism.pipeline_parallel_degree 2",
268268
"--parallelism.data_parallel_shard_degree 2",
269269
"--parallelism.tensor_parallel_degree 2",
270-
"--training.compile",
270+
"--compile.enable",
271271
],
272272
],
273273
"PP+DP+TP 3D test with torch.compile",
@@ -464,7 +464,7 @@ def build_test_list():
464464
# OverrideDefinitions(
465465
# [
466466
# [
467-
# "--memory_estimation.enabled",
467+
# "--memory_estimation.enable",
468468
# ]
469469
# ],
470470
# "FSDP2 Memory Tracking and Estimation",

tests/integration_tests_h100.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def build_test_list():
3232
OverrideDefinitions(
3333
[
3434
[
35-
"--training.compile",
35+
"--compile.enable",
3636
"--parallelism.tensor_parallel_degree 2",
3737
"--parallelism.enable_async_tensor_parallel",
3838
],
@@ -54,7 +54,7 @@ def build_test_list():
5454
OverrideDefinitions(
5555
[
5656
[
57-
"--training.compile",
57+
"--compile.enable",
5858
"--parallelism.data_parallel_shard_degree=2",
5959
"--parallelism.tensor_parallel_degree=2",
6060
"--parallelism.pipeline_parallel_degree=2",
@@ -71,7 +71,7 @@ def build_test_list():
7171
OverrideDefinitions(
7272
[
7373
[
74-
"--training.compile",
74+
"--compile.enable",
7575
"--parallelism.data_parallel_shard_degree=2",
7676
"--parallelism.data_parallel_replicate_degree=2",
7777
"--parallelism.context_parallel_degree=2",

torchtitan/components/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor
2424

2525
def build_cross_entropy_loss(job_config: JobConfig):
2626
loss_fn = cross_entropy_loss
27-
if job_config.training.compile:
27+
if job_config.compile.enable and "loss" in job_config.compile.components:
2828
logger.info("Compiling the loss function with torch.compile")
2929
loss_fn = torch.compile(loss_fn)
3030
return loss_fn

torchtitan/components/quantization/float8.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
2828
self.enabled = False
2929

3030
float8_config: Float8 = job_config.float8
31+
compile_config = job_config.compile
32+
model_compile_enabled = (
33+
compile_config.enable and "model" in compile_config.components
34+
)
3135
if has_cuda_capability(8, 9) or (
32-
float8_config.emulate and not job_config.training.compile
36+
float8_config.emulate and not model_compile_enabled
3337
):
3438
pass
3539
else:

torchtitan/components/quantization/mx.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5252
), "MXFP8 is only supported on SM100 or architectures"
5353

5454
# TP not yet supported with torch.compile
55+
56+
model_compile_enabled = (
57+
job_config.compile.enable and "model" in job_config.compile.components
58+
)
5559
assert not (
56-
job_config.training.compile
57-
and job_config.parallelism.tensor_parallel_degree > 1
60+
model_compile_enabled and job_config.parallelism.tensor_parallel_degree > 1
5861
), "TP not yet supported with torch.compile for mxfp8"
5962

6063
# For MoE training with mxfp8, token group sizes must be multiples of 32

torchtitan/config/job_config.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,6 @@ class Training:
218218
This feature only takes effect when data_parallel_shard_degree > 1
219219
"""
220220

221-
compile: bool = False
222-
"""Whether to compile the model"""
223-
224221
gc_freq: int = 50
225222
"""Python garbage control scheduling interval, in steps"""
226223

@@ -550,6 +547,17 @@ class ActivationCheckpoint:
550547
"""
551548

552549

550+
@dataclass
551+
class Compile:
552+
enable: bool = False
553+
"""Whether to apply torch.compile"""
554+
555+
components: list[Literal["model", "loss"]] = field(
556+
default_factory=lambda: ["model", "loss"]
557+
)
558+
"""Which components to compile"""
559+
560+
553561
@dataclass
554562
class Float8:
555563
enable_fsdp_float8_all_gather: bool = False
@@ -630,7 +638,7 @@ class Comm:
630638

631639
@dataclass
632640
class MemoryEstimation:
633-
enabled: bool = False
641+
enable: bool = False
634642
"""Whether to estimate memory usage for FSDP"""
635643

636644
disable_fake_mode: bool = False
@@ -747,6 +755,7 @@ class JobConfig:
747755
activation_checkpoint: ActivationCheckpoint = field(
748756
default_factory=ActivationCheckpoint
749757
)
758+
compile: Compile = field(default_factory=Compile)
750759
float8: Float8 = field(default_factory=Float8)
751760
mx: MX = field(default_factory=MX)
752761
comm: Comm = field(default_factory=Comm)

torchtitan/experiments/flux/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def mse_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
2121

2222
def build_mse_loss(job_config: JobConfig):
2323
loss_fn = mse_loss
24-
if job_config.training.compile:
24+
if job_config.compile.enable and "loss" in job_config.compile.components:
2525
logger.info("Compiling the loss function with torch.compile")
2626
loss_fn = torch.compile(loss_fn)
2727
return loss_fn

0 commit comments

Comments
 (0)