Skip to content

Commit 2a25f4d

Browse files
authored
fix float8 after the HSDP PR (#575)
1 parent f2a1551 commit 2a25f4d

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

torchtitan/float8.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
4949

5050
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
5151
enable_fsdp_float8_all_gather = (
52-
parallel_dims.dp_enabled
53-
and parallel_dims.dp_type == "fsdp"
52+
parallel_dims.dp_shard_enabled
5453
and float8_config.enable_fsdp_float8_all_gather
5554
)
5655
scaling_type_input = ScalingType(float8_config.scaling_type_input)

0 commit comments

Comments
 (0)