Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/float8.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@ USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git

For float8 with tensorwise scaling, launch training job with the following command (or alternatively set configs in toml files)
```
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --training.compile
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --compile.enable
```
* `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.
* `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using.
* **Auto-filter**: add `"auto_filter_small_kn"` as one of the `--float8.filter_fqns=...` to to enable automatic module filtering, which will automatically not convert linear layers are not large enough to benefit from float8 training, since the GEMM has to be big enough that the speedup from using FP8 tensorcores is greater than the overhead of creating dynamically quantized inputs. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs, where (K,N) represents the linear layer weight shape. For best performance, you should still manually filter out layers that are too small to benefit from float8 training.
* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels
* `--compile.enable` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels

For float8 with rowwise scaling, launch training job with the following command (or alternatively set configs in toml files)
```
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.recipe_name rowwise --training.compile
```
* `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
* `--float8.recipe_name="rowwise"`: use the rowwise scaling recipe for higher accuracy compared to tensorwise scaling
* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels
* `--compile.enable` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels

For parallelisms, for float8 with tensorwise scaling we support float8 all-gather for FSDP (optional) and for TP (by default for `Float8Linear`). For float8 with rowwise scaling, all distributed communication is done in high precision.

Expand Down
4 changes: 2 additions & 2 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def estimate_memory(job_config: JobConfig):
# Get the world size
world_size = int(os.environ["WORLD_SIZE"])

if job_config.training.compile or job_config.parallelism.enable_compiled_autograd:
if job_config.compile.enable or job_config.parallelism.enable_compiled_autograd:
logger.info("Compile mode is not supported yet. Switching to eager mode.")
job_config.training.compile = False
job_config.compile.enable = False
job_config.parallelism.enable_compiled_autograd = False

# init fake pg
Expand Down
2 changes: 1 addition & 1 deletion scripts/estimate/run_memory_estimation.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ fi
# Export WORLD_SIZE and LOCAL_RANK
export WORLD_SIZE=$((NGPU * NNODES))
export LOCAL_RANK=0
python -m scripts.estimate.estimation --job.config_file ${CONFIG_FILE} --memory_estimation.enabled $overrides
python -m scripts.estimate.estimation --job.config_file ${CONFIG_FILE} --memory_estimation.enable $overrides
12 changes: 6 additions & 6 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def build_test_list():
OverrideDefinitions(
[
[
"--training.compile",
"--compile.enable",
],
],
"1D compile",
Expand All @@ -66,7 +66,7 @@ def build_test_list():
OverrideDefinitions(
[
[
"--training.compile",
"--compile.enable",
"--activation_checkpoint.mode selective",
"--activation_checkpoint.selective_ac_option op",
],
Expand All @@ -86,7 +86,7 @@ def build_test_list():
OverrideDefinitions(
[
[
"--training.compile",
"--compile.enable",
"--parallelism.tensor_parallel_degree 2",
],
],
Expand All @@ -97,7 +97,7 @@ def build_test_list():
# OverrideDefinitions(
# [
# [
# "--training.compile",
# "--compile.enable",
# "--parallelism.tensor_parallel_degree 2",
# "--parallelism.enable_async_tensor_parallel",
# ],
Expand Down Expand Up @@ -267,7 +267,7 @@ def build_test_list():
"--parallelism.pipeline_parallel_degree 2",
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.tensor_parallel_degree 2",
"--training.compile",
"--compile.enable",
],
],
"PP+DP+TP 3D test with torch.compile",
Expand Down Expand Up @@ -464,7 +464,7 @@ def build_test_list():
# OverrideDefinitions(
# [
# [
# "--memory_estimation.enabled",
# "--memory_estimation.enable",
# ]
# ],
# "FSDP2 Memory Tracking and Estimation",
Expand Down
6 changes: 3 additions & 3 deletions tests/integration_tests_h100.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def build_test_list():
OverrideDefinitions(
[
[
"--training.compile",
"--compile.enable",
"--parallelism.tensor_parallel_degree 2",
"--parallelism.enable_async_tensor_parallel",
],
Expand All @@ -54,7 +54,7 @@ def build_test_list():
OverrideDefinitions(
[
[
"--training.compile",
"--compile.enable",
"--parallelism.data_parallel_shard_degree=2",
"--parallelism.tensor_parallel_degree=2",
"--parallelism.pipeline_parallel_degree=2",
Expand All @@ -71,7 +71,7 @@ def build_test_list():
OverrideDefinitions(
[
[
"--training.compile",
"--compile.enable",
"--parallelism.data_parallel_shard_degree=2",
"--parallelism.data_parallel_replicate_degree=2",
"--parallelism.context_parallel_degree=2",
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/components/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor

def build_cross_entropy_loss(job_config: JobConfig):
loss_fn = cross_entropy_loss
if job_config.training.compile:
if job_config.compile.enable and "loss" in job_config.compile.components:
logger.info("Compiling the loss function with torch.compile")
loss_fn = torch.compile(loss_fn)
return loss_fn
Expand Down
6 changes: 5 additions & 1 deletion torchtitan/components/quantization/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
self.enabled = False

float8_config: Float8 = job_config.float8
compile_config = job_config.compile
model_compile_enabled = (
compile_config.enable and "model" in compile_config.components
)
if has_cuda_capability(8, 9) or (
float8_config.emulate and not job_config.training.compile
float8_config.emulate and not model_compile_enabled
):
pass
else:
Expand Down
7 changes: 5 additions & 2 deletions torchtitan/components/quantization/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
), "MXFP8 is only supported on SM100 or architectures"

# TP not yet supported with torch.compile

model_compile_enabled = (
job_config.compile.enable and "model" in job_config.compile.components
)
assert not (
job_config.training.compile
and job_config.parallelism.tensor_parallel_degree > 1
model_compile_enabled and job_config.parallelism.tensor_parallel_degree > 1
), "TP not yet supported with torch.compile for mxfp8"

# For MoE training with mxfp8, token group sizes must be multiples of 32
Expand Down
17 changes: 13 additions & 4 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,6 @@ class Training:
This feature only takes effect when data_parallel_shard_degree > 1
"""

compile: bool = False
"""Whether to compile the model"""

gc_freq: int = 50
"""Python garbage control scheduling interval, in steps"""

Expand Down Expand Up @@ -550,6 +547,17 @@ class ActivationCheckpoint:
"""


@dataclass
class Compile:
enable: bool = False
"""Whether to apply torch.compile"""

components: list[Literal["model", "loss"]] = field(
default_factory=lambda: ["model", "loss"]
)
"""Which components to compile"""


@dataclass
class Float8:
enable_fsdp_float8_all_gather: bool = False
Expand Down Expand Up @@ -630,7 +638,7 @@ class Comm:

@dataclass
class MemoryEstimation:
enabled: bool = False
enable: bool = False
"""Whether to estimate memory usage for FSDP"""

disable_fake_mode: bool = False
Expand Down Expand Up @@ -747,6 +755,7 @@ class JobConfig:
activation_checkpoint: ActivationCheckpoint = field(
default_factory=ActivationCheckpoint
)
compile: Compile = field(default_factory=Compile)
float8: Float8 = field(default_factory=Float8)
mx: MX = field(default_factory=MX)
comm: Comm = field(default_factory=Comm)
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/flux/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def mse_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:

def build_mse_loss(job_config: JobConfig):
loss_fn = mse_loss
if job_config.training.compile:
if job_config.compile.enable and "loss" in job_config.compile.components:
logger.info("Compiling the loss function with torch.compile")
loss_fn = torch.compile(loss_fn)
return loss_fn
1 change: 0 additions & 1 deletion torchtitan/experiments/flux/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ decay_ratio = 0.0 # no decay, stay stable during training
local_batch_size = 4
max_norm = 2.0 # grad norm clipping
steps = 10
compile = false
dataset = "cc12m-test"
classifier_free_guidance_prob = 0.447
img_size = 256
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ decay_ratio = 0.0 # no decay
local_batch_size = 32
max_norm = 1.0 # grad norm clipping
steps = 30_000
compile = false
dataset = "cc12m-wds"
classifier_free_guidance_prob = 0.447
img_size = 256
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ decay_ratio = 0.0 # no decay
local_batch_size = 64
max_norm = 1.0 # grad norm clipping
steps = 30_000
compile = false
dataset = "cc12m-wds"
classifier_free_guidance_prob = 0.447
img_size = 256
Expand Down
2 changes: 2 additions & 0 deletions torchtitan/experiments/forge/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ActivationCheckpoint,
Checkpoint,
Comm,
Compile,
Float8,
LRScheduler,
Model,
Expand All @@ -31,6 +32,7 @@ class ForgeJobConfig:
activation_checkpoint: ActivationCheckpoint = field(
default_factory=ActivationCheckpoint
)
compile: Compile = field(default_factory=Compile)
float8: Float8 = field(default_factory=Float8)
comm: Comm = field(default_factory=Comm)

Expand Down
9 changes: 6 additions & 3 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,15 @@ def parallelize_llama(
):
raise NotImplementedError("CP support for FlexAttention is still in progress.")

model_compile_enabled = (
job_config.compile.enable and "model" in job_config.compile.components
)
if parallel_dims.tp_enabled:
if (
job_config.parallelism.enable_async_tensor_parallel
and not job_config.training.compile
and not model_compile_enabled
):
raise RuntimeError("Async TP requires --training.compile")
raise RuntimeError("Async TP requires torch.compile")

enable_float8_linear = "float8" in job_config.model.converters
float8_is_rowwise = job_config.float8.recipe_name in (
Expand Down Expand Up @@ -107,7 +110,7 @@ def parallelize_llama(
apply_ac(model, job_config.activation_checkpoint)

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if job_config.training.compile:
if model_compile_enabled:
# NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
torch._dynamo.config.capture_scalar_outputs = True
apply_compile(model)
Expand Down
5 changes: 4 additions & 1 deletion torchtitan/experiments/llama4/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ local_batch_size = 8
seq_len = 2048
max_norm = 1.0 # grad norm clipping
steps = 10
compile = false
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

[parallelism]
Expand All @@ -68,6 +67,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
mode = "selective" # ["none", "selective", "full"]
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[compile]
enable=false
components = ["model", "loss"]

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ local_batch_size = 1
seq_len = 8192
max_norm = 1.0 # grad norm clipping
steps = 3000
compile = false
dataset = "c4"

[parallelism]
Expand All @@ -60,6 +59,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
[activation_checkpoint]
mode = "full" # ["none", "selective", "full"]

[compile]
enable=false
components = ["model", "loss"]

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ local_batch_size = 8
seq_len = 8192
max_norm = 1.0 # grad norm clipping
steps = 3000
compile = false
dataset = "c4"

[parallelism]
Expand All @@ -58,6 +57,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
[activation_checkpoint]
mode = "full" # ["none", "selective", "full"]

[compile]
enable=false
components = ["model", "loss"]

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
Expand Down
5 changes: 4 additions & 1 deletion torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ local_batch_size = 4
seq_len = 4096
max_norm = 1.0 # grad norm clipping
steps = 10
compile = false
dataset = "c4"

[parallelism]
Expand All @@ -53,6 +52,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
mode = "selective" # ["none", "selective", "full"]
selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy

[compile]
enable=false
components = ["model", "loss"]

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/simple_fsdp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ This folder includes an experimental frontend implementation for [SimpleFSDP: Si
### Enable SimpleFSDP Training

```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.name llama3_simple_fsdp --training.compile
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.name llama3_simple_fsdp --compile.enable
```

### Composability Support
Expand Down
9 changes: 6 additions & 3 deletions torchtitan/experiments/simple_fsdp/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@ def parallelize_llama(
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
"""

model_compile_enabled = (
job_config.compile.enable and "model" in job_config.compile.components
)
if parallel_dims.tp_enabled:
if (
job_config.parallelism.enable_async_tensor_parallel
and not job_config.training.compile
and not model_compile_enabled
):
raise RuntimeError("Async TP requires --training.compile")
raise RuntimeError("Async TP requires torch.compile")

enable_float8_linear = "float8" in job_config.model.converters
float8_is_rowwise = job_config.float8.recipe_name in (
Expand Down Expand Up @@ -99,7 +102,7 @@ def parallelize_llama(
)
logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode)

if job_config.training.compile:
if model_compile_enabled:
torch._inductor.config.reorder_for_peak_memory = False
model = torch.compile(model, fullgraph=True)

Expand Down
Loading
Loading