Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 13 additions & 20 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import time
from datetime import timedelta
from functools import partial
from typing import Any, Generator, Iterable, Optional

import torch
Expand Down Expand Up @@ -125,32 +126,24 @@ def __init__(self, job_config: JobConfig):

# allow configuring inductor comms optimizations from torchtitan commandline
if job_config.experimental.enable_simplefsdp_passes:
try:
from torch._inductor.simple_fsdp.bucket import bucket_fsdp_all_gather_concat_on_scheduler_ir
except ImportError:
print("Must use pytorch from unlanded https://github.com/pytorch/pytorch/pull/160282, e.g. torchtitan_conda_prod:5e4101faa448c2ee6b62ddd76ee08e8c")
raise

# Configs from Ruisi
# enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282)
from autoparallel.auto_bucketing import (
simple_fsdp_autobucketing_reordering_pass, simplefsdp_autobucketing_config
)

# set to 0.1 if you want to make bucketing more efficient with mixed dtype collectives
torch._inductor.config.simplefsdp.relax_ratio = 0
torch._inductor.config.allow_buffer_reuse = False
torch._inductor.config.simplefsdp.estimate_ir = False
torch._inductor.config.simplefsdp.estimate_verbose = False
torch._inductor.config.simplefsdp.save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast_"+job_config.model.flavor+".pkl"
# set to True after the first communication estimation results are saved. This would reduce decision making time.
torch._inductor.config.simplefsdp.load_cache = False
torch._inductor.config.simplefsdp.enable_bucket_ir = True
torch._inductor.config.simplefsdp.enable_reorder_ir = True
torch._inductor.config.simplefsdp.simplefsdp_only = False # False for 2d True for 1d
torch._inductor.config.simplefsdp.peak_memory_offset = 0
torch._inductor.config.simplefsdp.bucketing_type = "auto"
torch._inductor.config.reorder_for_compute_comm_overlap = True
simple_fsdp_autobucketing_reordering_pass = partial(
simple_fsdp_autobucketing_reordering_pass,
configs=simplefsdp_autobucketing_config,
)
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
simple_fsdp_autobucketing_reordering_pass
]

# Don't use both sets of passes at the same time!
torch._inductor.config.bucket_all_gathers_fx = "none"
torch._inductor.config.bucket_reduce_scatters_fx = "none"
torch._inductor.config.reorder_for_compute_comm_overlap = False
else:
torch._inductor.config.bucket_all_gathers_fx = (
job_config.experimental.bucket_all_gathers_fx
Expand Down
Loading