Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7b4d5c5

Browse files
committed
Integrate distributed inference without introducing abstraction
1 parent d7b681a commit 7b4d5c5

File tree

6 files changed

+622
-58
lines changed

6 files changed

+622
-58
lines changed

torchchat/cli/builder.py

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,17 @@
1414
import torch
1515
import torch._dynamo.config
1616
import torch._inductor.config
17-
import torch.nn as nn
17+
import torch.distributed as dist
1818

19-
from torchchat.model import Model, ModelArgs, ModelType
19+
from torchchat.distributed.utils import(
20+
Color as color,
21+
CUDATrackTime,
22+
init_distributed,
23+
GPUMemoryMonitor,
24+
)
25+
from torchchat.distributed.logging_utils import SingletonLogger
2026

27+
from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs
2128
from torchchat.model_config.model_config import resolve_model_config
2229
from torchchat.utils.build_utils import (
2330
device_sync,
@@ -28,6 +35,7 @@
2835
from torchchat.utils.measure_time import measure_time
2936
from torchchat.utils.quantize import quantize_model
3037

38+
3139
from torchtune.models.convert_weights import meta_to_tune
3240

3341
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
@@ -598,6 +606,117 @@ def do_nothing(max_batch_size, max_seq_length):
598606
model = PTEModel(config, builder_args.pte_path)
599607
except Exception:
600608
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
609+
elif builder_args.distributed:
610+
# Using params_table to identify the model to load, for example "Meta-Llama-3.1-8B".
611+
#TODO This is a hacky way to please the distributed loading api and needs to be replaced
612+
NAME_TO_DISTRIBUTION = {
613+
"Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B-Instruct",
614+
"Meta-Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
615+
"Meta-Llama-3-70B": "meta-llama/Meta-Llama-3-70B-Instruct",
616+
"Meta-Llama-3.1-70B": "meta-llama/Meta-Llama-3.1-70B-Instruct",
617+
618+
}
619+
# TODO: Use information in builder_args directly to build model and load weights
620+
assert builder_args.params_table
621+
try:
622+
distribution = NAME_TO_DISTRIBUTION[builder_args.params_table]
623+
except KeyError as e:
624+
print(f"Unknown params_table: {builder_args.params_table}. Suported model names are: llama3.1, llama3, llama2-7b-chat")
625+
raise e
626+
627+
pp_degree = builder_args.pp
628+
tp_degree = builder_args.tp
629+
630+
init_distributed()
631+
rank = dist.get_rank()
632+
torch.cuda.set_device(rank % torch.cuda.device_count())
633+
634+
logger = SingletonLogger.get_logger()
635+
636+
gpu_memory_monitor = GPUMemoryMonitor("cuda")
637+
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
638+
639+
# Model-level config
640+
if builder_args.params_table:
641+
model_config = ModelArgs.from_table(builder_args.params_table)
642+
else:
643+
raise NotImplementedError()
644+
# Transformer-level config
645+
config = TransformerArgs.from_params(model_config.transformer_args["text"])
646+
logger.info(f"Transformer Config: {config}")
647+
648+
#TODO: Move into head of file after solving circular import
649+
from torchchat.distributed.checkpoint_utils import (
650+
load_model_weights,
651+
)
652+
653+
# Validate pipeline degree
654+
assert config.n_layers % pp_degree == 0
655+
656+
# Create device mesh
657+
device_mesh = dist.init_device_mesh(
658+
"cuda",
659+
(pp_degree, tp_degree),
660+
mesh_dim_names=("pp", "tp")
661+
)
662+
tp_mesh = device_mesh["tp"]
663+
pp_mesh = device_mesh["pp"]
664+
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}")
665+
666+
pp_rank = pp_mesh.get_local_rank()
667+
logger.info(f"{pp_degree=}, {tp_degree=}")
668+
669+
# Assuming same number of GPUs per node
670+
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
671+
672+
# Fill in PP configs
673+
config.stage_idx = pp_rank
674+
config.n_stages = pp_degree
675+
676+
with torch.device("meta"):
677+
# TODO: we should create model instead of Transformer
678+
model = Transformer(config)
679+
680+
# Distribute model on TP mesh
681+
# (Surprisingly, this works even though model is on meta device and mesh is of
682+
# cuda devices)
683+
model.distribute(tp_mesh)
684+
if rank == 0:
685+
logger.info(f"Model: {model}")
686+
687+
# Load weights
688+
logger.info(f"Loading weights for {pp_rank=} on {device=}")
689+
with CUDATrackTime() as timer:
690+
load_model_weights(model, distribution, device, config, builder_args.chpt_from)
691+
692+
logger.info(
693+
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
694+
)
695+
696+
# Setup KV caches (after model distribution)
697+
# The number of cache lanes is the same as the maximum number of
698+
# micro-batches that can be "in flight" in parallel -- imagine each
699+
# micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
700+
# When decoding is done for certain micro-batches, we can reuse the KV cache
701+
# lanes.
702+
# TODO: bump up the lane count
703+
pipeline_lanes = 1
704+
seqlen_prefill=1024
705+
with device:
706+
model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes)
707+
708+
# info on stage size and params
709+
# stage_size = get_module_size(model)
710+
# stage_size_formatted = bytes_to_readable(stage_size)
711+
# stage_num_params = get_num_params(model)
712+
# logger.info(
713+
# f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}"
714+
# )
715+
model.eval()
716+
717+
model.text_transformer_args = None
718+
model.config.model_type = model_config.model_type
719+
model.device_mesh = device_mesh
601720
else:
602721
with measure_time("Time to load model: {time:.02f} seconds"):
603722
model = _load_model(builder_args)

torchchat/distributed/checkpoint_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torch.distributed._tensor import DTensor
1818
from torchchat.distributed.dtensor_utils import convert_to_dtensor
1919
from torchchat.cli.builder import BuilderArgs, _load_checkpoint
20+
from torchchat.model import ModelArgs
2021

2122

2223
_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
@@ -450,3 +451,34 @@ def load_weights_from_torchchat_format(stage_module, distribution, device, model
450451
# Fill state dict into stage module
451452
stage_module.load_state_dict(stage_state_dict, strict=False, assign=True)
452453
logger.info(f"Successfully loaded {len(updated_states)} weights into stage module")
454+
455+
456+
def load_model_weights(
457+
stage_module: torch.nn.Module,
458+
distribution: str,
459+
device: torch.device,
460+
model_config: ModelArgs,
461+
chpt_from: str,
462+
):
463+
"""Load the weights from the safetensor file(s) into the model stage.
464+
Model config is needed b/c we permute wq and wk weights based on attn heads.
465+
466+
Args:
467+
stage_module (torch.nn.Module): The model stage to load the weights into.
468+
distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct".
469+
device (torch.device): The device to load the weights onto.
470+
model_config (ModelArgs): The model config.
471+
chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf".
472+
"""
473+
if chpt_from == "hf":
474+
# This format stands for: index file + multiple binary files
475+
load_weights_from_hf_format(stage_module, distribution, device, model_config)
476+
elif chpt_from == "torchchat":
477+
# This format stands for:
478+
# single binary file, OR
479+
# multiple binary files without index files.
480+
load_weights_from_torchchat_format(
481+
stage_module, distribution, device, model_config
482+
)
483+
else:
484+
raise ValueError(f"Unknown checkpoint format: {chpt_from}")

torchchat/distributed/utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66

77
import itertools
88
import os
9+
import time
910
from dataclasses import dataclass
1011
from datetime import timedelta
11-
import time
12+
from os import environ
1213
from typing import Optional
1314

1415

1516
import torch
1617

17-
1818
from torchchat.distributed.logging_utils import SingletonLogger
1919
logger = SingletonLogger.get_logger()
2020

@@ -257,3 +257,13 @@ def get_device_info(
257257
f"with {self.device_capacity_gib:.2f}GiB memory"
258258
)
259259
return device_info
260+
261+
def setup_env(world_size: int, rank: int, target: callable, *args, **kwargs):
262+
environ["MASTER_ADDR"] = "localhost"
263+
environ["MASTER_PORT"] = "29500"
264+
environ["RDZV_BACKEND"] = "c10d"
265+
environ["WORLD_SIZE"] = str(world_size)
266+
environ["RANK"] = str(rank)
267+
environ["LOCALRANK"] = str(rank)
268+
269+
return target(*args, **kwargs)

0 commit comments

Comments
 (0)