@@ -124,21 +124,49 @@ def __init__(self, job_config: JobConfig):
124
124
torch ._inductor .config .allow_buffer_reuse = False
125
125
126
126
# allow configuring inductor comms optimizations from torchtitan commandline
127
- torch ._inductor .config .bucket_all_gathers_fx = (
128
- job_config .experimental .bucket_all_gathers_fx
129
- )
130
- torch ._inductor .config .bucket_reduce_scatters_fx = (
131
- job_config .experimental .bucket_reduce_scatters_fx
132
- )
133
- torch ._inductor .config .reorder_for_compute_comm_overlap = (
134
- job_config .experimental .reorder_for_compute_comm_overlap
135
- )
136
- torch ._inductor .config .reorder_for_compute_comm_overlap_passes = (
137
- job_config .experimental .reorder_for_compute_comm_overlap_passes
138
- )
139
- torch ._inductor .config .reorder_prefetch_limit = (
140
- job_config .experimental .reorder_prefetch_limit
141
- )
127
+ if job_config .experimental .enable_simplefsdp_passes :
128
+ try :
129
+ from torch ._inductor .simple_fsdp .bucket import bucket_fsdp_all_gather_concat_on_scheduler_ir
130
+ except ImportError :
131
+ print ("Must use pytorch from unlanded https://github.com/pytorch/pytorch/pull/160282, e.g. torchtitan_conda_prod:5e4101faa448c2ee6b62ddd76ee08e8c" )
132
+ raise
133
+
134
+ # Configs from Ruisi
135
+
136
+ # set to 0.1 if you want to make bucketing more efficient with mixed dtype collectives
137
+ torch ._inductor .config .simplefsdp .relax_ratio = 0
138
+ torch ._inductor .config .allow_buffer_reuse = False
139
+ torch ._inductor .config .simplefsdp .estimate_ir = False
140
+ torch ._inductor .config .simplefsdp .estimate_verbose = False
141
+ torch ._inductor .config .simplefsdp .save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast_" + job_config .model .flavor + ".pkl"
142
+ # set to True after the first communication estimation results are saved. This would reduce decision making time.
143
+ torch ._inductor .config .simplefsdp .load_cache = False
144
+ torch ._inductor .config .simplefsdp .enable_bucket_ir = True
145
+ torch ._inductor .config .simplefsdp .enable_reorder_ir = True
146
+ torch ._inductor .config .simplefsdp .simplefsdp_only = True # False for 2d True for 1d
147
+ torch ._inductor .config .simplefsdp .peak_memory_offset = 0
148
+ torch ._inductor .config .simplefsdp .bucketing_type = "auto"
149
+
150
+ # Don't use both sets of passes at the same time!
151
+ torch ._inductor .config .bucket_all_gathers_fx = "none"
152
+ torch ._inductor .config .bucket_reduce_scatters_fx = "none"
153
+ torch ._inductor .config .reorder_for_compute_comm_overlap = False
154
+ else :
155
+ torch ._inductor .config .bucket_all_gathers_fx = (
156
+ job_config .experimental .bucket_all_gathers_fx
157
+ )
158
+ torch ._inductor .config .bucket_reduce_scatters_fx = (
159
+ job_config .experimental .bucket_reduce_scatters_fx
160
+ )
161
+ torch ._inductor .config .reorder_for_compute_comm_overlap = (
162
+ job_config .experimental .reorder_for_compute_comm_overlap
163
+ )
164
+ torch ._inductor .config .reorder_for_compute_comm_overlap_passes = (
165
+ job_config .experimental .reorder_for_compute_comm_overlap_passes
166
+ )
167
+ torch ._inductor .config .reorder_prefetch_limit = (
168
+ job_config .experimental .reorder_prefetch_limit
169
+ )
142
170
143
171
# Set random seed, and maybe enable deterministic mode
144
172
# (mainly for debugging, expect perf loss).
0 commit comments