Skip to content

Commit 481f97a

Browse files
committed
add simplefsdp's autobucketing pass entry
1 parent 714cc5b commit 481f97a

File tree

2 files changed

+37
-19
lines changed

2 files changed

+37
-19
lines changed

torchtitan/experiments/auto_parallel/parallelize_llama.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,27 @@
2020
from torchtitan.tools.logging import logger
2121

2222

23+
class simplefsdp_autobucketing_config:
24+
"""
25+
Config for simplefsdp's autobucketing pass, which by default would give good performance.
26+
To make the results tunable, we expose the following parameters:
27+
- relax_ratio: relax comp time to include more comm in one bucket
28+
with this config, comp is updated as comp * (1 + relax_ratio)
29+
- peak_memory_offset: relax peak_memory to include more comm in one bucket
30+
with this config, peak_memory is updated as (peak_memory + peak_memory_offset)
31+
- load_cache: set to True to load cache from save_estimation_path
32+
- enable_bucket_ir: set to True to bucket all_gather/reduce_scatter
33+
- enable_reorder_ir: set to True to reorder all_gather/reduce_satter
34+
"""
35+
36+
relax_ratio = 0
37+
peak_memory_offset = 0
38+
load_cache = False
39+
save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast.pkl"
40+
enable_bucket_ir = True
41+
enable_reorder_ir = True
42+
43+
2344
def parallelize_llama(
2445
model,
2546
world_mesh: DeviceMesh,

torchtitan/train.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import time
1010
from datetime import timedelta
11+
from functools import partial
1112
from typing import Any, Generator, Iterable, Optional
1213

1314
import torch
@@ -125,32 +126,28 @@ def __init__(self, job_config: JobConfig):
125126

126127
# allow configuring inductor comms optimizations from torchtitan commandline
127128
if job_config.experimental.enable_simplefsdp_passes:
128-
try:
129-
from torch._inductor.simple_fsdp.bucket import bucket_fsdp_all_gather_concat_on_scheduler_ir
130-
except ImportError:
131-
print("Must use pytorch from unlanded https://github.com/pytorch/pytorch/pull/160282, e.g. torchtitan_conda_prod:5e4101faa448c2ee6b62ddd76ee08e8c")
132-
raise
129+
# enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282)
130+
from autoparallel.auto_bucketing import (
131+
simple_fsdp_autobucketing_reordering_pass,
132+
)
133133

134-
# Configs from Ruisi
134+
from torchtitan.experiments.auto_parallel.parallelize_llama import (
135+
simplefsdp_autobucketing_config,
136+
)
135137

136-
# set to 0.1 if you want to make bucketing more efficient with mixed dtype collectives
137-
torch._inductor.config.simplefsdp.relax_ratio = 0
138138
torch._inductor.config.allow_buffer_reuse = False
139-
torch._inductor.config.simplefsdp.estimate_ir = False
140-
torch._inductor.config.simplefsdp.estimate_verbose = False
141-
torch._inductor.config.simplefsdp.save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast_"+job_config.model.flavor+".pkl"
142-
# set to True after the first communication estimation results are saved. This would reduce decision making time.
143-
torch._inductor.config.simplefsdp.load_cache = False
144-
torch._inductor.config.simplefsdp.enable_bucket_ir = True
145-
torch._inductor.config.simplefsdp.enable_reorder_ir = True
146-
torch._inductor.config.simplefsdp.simplefsdp_only = False # False for 2d True for 1d
147-
torch._inductor.config.simplefsdp.peak_memory_offset = 0
148-
torch._inductor.config.simplefsdp.bucketing_type = "auto"
139+
torch._inductor.config.reorder_for_compute_comm_overlap = True
140+
simple_fsdp_autobucketing_reordering_pass = partial(
141+
simple_fsdp_autobucketing_reordering_pass,
142+
configs=simplefsdp_autobucketing_config,
143+
)
144+
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
145+
simple_fsdp_autobucketing_reordering_pass
146+
]
149147

150148
# Don't use both sets of passes at the same time!
151149
torch._inductor.config.bucket_all_gathers_fx = "none"
152150
torch._inductor.config.bucket_reduce_scatters_fx = "none"
153-
torch._inductor.config.reorder_for_compute_comm_overlap = False
154151
else:
155152
torch._inductor.config.bucket_all_gathers_fx = (
156153
job_config.experimental.bucket_all_gathers_fx

0 commit comments

Comments
 (0)