|
8 | 8 | import os
|
9 | 9 | import time
|
10 | 10 | from datetime import timedelta
|
| 11 | +from functools import partial |
11 | 12 | from typing import Any, Generator, Iterable, Optional
|
12 | 13 |
|
13 | 14 | import torch
|
@@ -125,32 +126,24 @@ def __init__(self, job_config: JobConfig):
|
125 | 126 |
|
126 | 127 | # allow configuring inductor comms optimizations from torchtitan commandline
|
127 | 128 | if job_config.experimental.enable_simplefsdp_passes:
|
128 |
| - try: |
129 |
| - from torch._inductor.simple_fsdp.bucket import bucket_fsdp_all_gather_concat_on_scheduler_ir |
130 |
| - except ImportError: |
131 |
| - print("Must use pytorch from unlanded https://github.com/pytorch/pytorch/pull/160282, e.g. torchtitan_conda_prod:5e4101faa448c2ee6b62ddd76ee08e8c") |
132 |
| - raise |
133 |
| - |
134 |
| - # Configs from Ruisi |
| 129 | + # enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282) |
| 130 | + from autoparallel.auto_bucketing import ( |
| 131 | + simple_fsdp_autobucketing_reordering_pass, simplefsdp_autobucketing_config |
| 132 | + ) |
135 | 133 |
|
136 |
| - # set to 0.1 if you want to make bucketing more efficient with mixed dtype collectives |
137 |
| - torch._inductor.config.simplefsdp.relax_ratio = 0 |
138 | 134 | torch._inductor.config.allow_buffer_reuse = False
|
139 |
| - torch._inductor.config.simplefsdp.estimate_ir = False |
140 |
| - torch._inductor.config.simplefsdp.estimate_verbose = False |
141 |
| - torch._inductor.config.simplefsdp.save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast_"+job_config.model.flavor+".pkl" |
142 |
| - # set to True after the first communication estimation results are saved. This would reduce decision making time. |
143 |
| - torch._inductor.config.simplefsdp.load_cache = False |
144 |
| - torch._inductor.config.simplefsdp.enable_bucket_ir = True |
145 |
| - torch._inductor.config.simplefsdp.enable_reorder_ir = True |
146 |
| - torch._inductor.config.simplefsdp.simplefsdp_only = False # False for 2d True for 1d |
147 |
| - torch._inductor.config.simplefsdp.peak_memory_offset = 0 |
148 |
| - torch._inductor.config.simplefsdp.bucketing_type = "auto" |
| 135 | + torch._inductor.config.reorder_for_compute_comm_overlap = True |
| 136 | + simple_fsdp_autobucketing_reordering_pass = partial( |
| 137 | + simple_fsdp_autobucketing_reordering_pass, |
| 138 | + configs=simplefsdp_autobucketing_config, |
| 139 | + ) |
| 140 | + torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ |
| 141 | + simple_fsdp_autobucketing_reordering_pass |
| 142 | + ] |
149 | 143 |
|
150 | 144 | # Don't use both sets of passes at the same time!
|
151 | 145 | torch._inductor.config.bucket_all_gathers_fx = "none"
|
152 | 146 | torch._inductor.config.bucket_reduce_scatters_fx = "none"
|
153 |
| - torch._inductor.config.reorder_for_compute_comm_overlap = False |
154 | 147 | else:
|
155 | 148 | torch._inductor.config.bucket_all_gathers_fx = (
|
156 | 149 | job_config.experimental.bucket_all_gathers_fx
|
|
0 commit comments