Skip to content

Commit 8e50870

Browse files
committed
Add config for running simple-fsdp bucketing/reordering passes
just add `--experimental.enable_simplefsdp_passes` and do not try to combine it with other `bucket_*` or `reorder_*` options.
1 parent 3f04d22 commit 8e50870

File tree

2 files changed

+45
-15
lines changed

2 files changed

+45
-15
lines changed

torchtitan/config_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,8 @@ class Experimental:
694694

695695
autop_force_bf16: bool = False
696696

697+
enable_simplefsdp_passes: bool = False
698+
697699
@dataclass
698700
class Validation:
699701
enabled: bool = False

torchtitan/train.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -124,21 +124,49 @@ def __init__(self, job_config: JobConfig):
124124
torch._inductor.config.allow_buffer_reuse = False
125125

126126
# allow configuring inductor comms optimizations from torchtitan commandline
127-
torch._inductor.config.bucket_all_gathers_fx = (
128-
job_config.experimental.bucket_all_gathers_fx
129-
)
130-
torch._inductor.config.bucket_reduce_scatters_fx = (
131-
job_config.experimental.bucket_reduce_scatters_fx
132-
)
133-
torch._inductor.config.reorder_for_compute_comm_overlap = (
134-
job_config.experimental.reorder_for_compute_comm_overlap
135-
)
136-
torch._inductor.config.reorder_for_compute_comm_overlap_passes = (
137-
job_config.experimental.reorder_for_compute_comm_overlap_passes
138-
)
139-
torch._inductor.config.reorder_prefetch_limit = (
140-
job_config.experimental.reorder_prefetch_limit
141-
)
127+
if job_config.experimental.enable_simplefsdp_passes:
128+
try:
129+
from torch._inductor.simple_fsdp.bucket import bucket_fsdp_all_gather_concat_on_scheduler_ir
130+
except ImportError:
131+
print("Must use pytorch from unlanded https://github.com/pytorch/pytorch/pull/160282, e.g. torchtitan_conda_prod:5e4101faa448c2ee6b62ddd76ee08e8c")
132+
raise
133+
134+
# Configs from Ruisi
135+
136+
# set to 0.1 if you want to make bucketing more efficient with mixed dtype collectives
137+
torch._inductor.config.simplefsdp.relax_ratio = 0
138+
torch._inductor.config.allow_buffer_reuse = False
139+
torch._inductor.config.simplefsdp.estimate_ir = False
140+
torch._inductor.config.simplefsdp.estimate_verbose = False
141+
torch._inductor.config.simplefsdp.save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast_"+job_config.model.flavor+".pkl"
142+
# set to True after the first communication estimation results are saved. This would reduce decision making time.
143+
torch._inductor.config.simplefsdp.load_cache = False
144+
torch._inductor.config.simplefsdp.enable_bucket_ir = True
145+
torch._inductor.config.simplefsdp.enable_reorder_ir = True
146+
torch._inductor.config.simplefsdp.simplefsdp_only = True # False for 2d True for 1d
147+
torch._inductor.config.simplefsdp.peak_memory_offset = 0
148+
torch._inductor.config.simplefsdp.bucketing_type = "auto"
149+
150+
# Don't use both sets of passes at the same time!
151+
torch._inductor.config.bucket_all_gathers_fx = "none"
152+
torch._inductor.config.bucket_reduce_scatters_fx = "none"
153+
torch._inductor.config.reorder_for_compute_comm_overlap = False
154+
else:
155+
torch._inductor.config.bucket_all_gathers_fx = (
156+
job_config.experimental.bucket_all_gathers_fx
157+
)
158+
torch._inductor.config.bucket_reduce_scatters_fx = (
159+
job_config.experimental.bucket_reduce_scatters_fx
160+
)
161+
torch._inductor.config.reorder_for_compute_comm_overlap = (
162+
job_config.experimental.reorder_for_compute_comm_overlap
163+
)
164+
torch._inductor.config.reorder_for_compute_comm_overlap_passes = (
165+
job_config.experimental.reorder_for_compute_comm_overlap_passes
166+
)
167+
torch._inductor.config.reorder_prefetch_limit = (
168+
job_config.experimental.reorder_prefetch_limit
169+
)
142170

143171
# Set random seed, and maybe enable deterministic mode
144172
# (mainly for debugging, expect perf loss).

0 commit comments

Comments
 (0)