From 42ab5e259365a99f62ff81bb1ee54ca83291e9ad Mon Sep 17 00:00:00 2001 From: ruisizhang123 Date: Thu, 28 Aug 2025 18:02:35 -0700 Subject: [PATCH] add simplefsdp's autobucketing pass entry --- torchtitan/train.py | 33 +++++++++++++-------------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index e4da1b1f6..8269fad24 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -8,6 +8,7 @@ import os import time from datetime import timedelta +from functools import partial from typing import Any, Generator, Iterable, Optional import torch @@ -125,32 +126,24 @@ def __init__(self, job_config: JobConfig): # allow configuring inductor comms optimizations from torchtitan commandline if job_config.experimental.enable_simplefsdp_passes: - try: - from torch._inductor.simple_fsdp.bucket import bucket_fsdp_all_gather_concat_on_scheduler_ir - except ImportError: - print("Must use pytorch from unlanded https://github.com/pytorch/pytorch/pull/160282, e.g. torchtitan_conda_prod:5e4101faa448c2ee6b62ddd76ee08e8c") - raise - - # Configs from Ruisi + # enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282) + from autoparallel.auto_bucketing import ( + simple_fsdp_autobucketing_reordering_pass, simplefsdp_autobucketing_config + ) - # set to 0.1 if you want to make bucketing more efficient with mixed dtype collectives - torch._inductor.config.simplefsdp.relax_ratio = 0 torch._inductor.config.allow_buffer_reuse = False - torch._inductor.config.simplefsdp.estimate_ir = False - torch._inductor.config.simplefsdp.estimate_verbose = False - torch._inductor.config.simplefsdp.save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast_"+job_config.model.flavor+".pkl" - # set to True after the first communication estimation results are saved. This would reduce decision making time. - torch._inductor.config.simplefsdp.load_cache = False - torch._inductor.config.simplefsdp.enable_bucket_ir = True - torch._inductor.config.simplefsdp.enable_reorder_ir = True - torch._inductor.config.simplefsdp.simplefsdp_only = False # False for 2d True for 1d - torch._inductor.config.simplefsdp.peak_memory_offset = 0 - torch._inductor.config.simplefsdp.bucketing_type = "auto" + torch._inductor.config.reorder_for_compute_comm_overlap = True + simple_fsdp_autobucketing_reordering_pass = partial( + simple_fsdp_autobucketing_reordering_pass, + configs=simplefsdp_autobucketing_config, + ) + torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ + simple_fsdp_autobucketing_reordering_pass + ] # Don't use both sets of passes at the same time! torch._inductor.config.bucket_all_gathers_fx = "none" torch._inductor.config.bucket_reduce_scatters_fx = "none" - torch._inductor.config.reorder_for_compute_comm_overlap = False else: torch._inductor.config.bucket_all_gathers_fx = ( job_config.experimental.bucket_all_gathers_fx