Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit d5bca9b

Browse files
committed
Read distribution from model_config
1 parent 08a8e03 commit d5bca9b

File tree

1 file changed

+5
-18
lines changed

1 file changed

+5
-18
lines changed

torchchat/cli/builder.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class BuilderArgs:
6464
pp: int = 1
6565
tp: int = 1
6666
chpt_from: str = "hf"
67+
distribution_path: Optional[str] = None
6768
is_chat_model: bool = False
6869
prefill_possible: bool = False
6970
dynamic_shapes: bool = False
@@ -129,6 +130,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
129130
model_config.transformer_params_key or model_config.name.split("/")[-1]
130131
)
131132

133+
distribution_path = model_config.distribution_path
134+
132135
dso_path = getattr(args, "dso_path", None)
133136
pte_path = getattr(args, "pte_path", None)
134137
aoti_package_path = getattr(args, "aoti_package_path", None)
@@ -194,6 +197,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
194197
pp=pp,
195198
tp=tp,
196199
chpt_from=chpt_from,
200+
distribution_path=distribution_path,
197201
is_chat_model=is_chat_model,
198202
dynamic_shapes=getattr(args, "dynamic_shapes", False),
199203
max_seq_length=getattr(args, "max_seq_length", None),
@@ -607,23 +611,6 @@ def do_nothing(max_batch_size, max_seq_length):
607611
except Exception:
608612
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
609613
elif builder_args.distributed:
610-
# Using params_table to identify the model to load, for example "Meta-Llama-3.1-8B".
611-
#TODO This is a hacky way to please the distributed loading api and needs to be replaced
612-
NAME_TO_DISTRIBUTION = {
613-
"Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B-Instruct",
614-
"Meta-Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
615-
"Meta-Llama-3-70B": "meta-llama/Meta-Llama-3-70B-Instruct",
616-
"Meta-Llama-3.1-70B": "meta-llama/Meta-Llama-3.1-70B-Instruct",
617-
618-
}
619-
# TODO: Use information in builder_args directly to build model and load weights
620-
assert builder_args.params_table
621-
try:
622-
distribution = NAME_TO_DISTRIBUTION[builder_args.params_table]
623-
except KeyError as e:
624-
print(f"Unknown params_table: {builder_args.params_table}. Suported model names are: llama3.1, llama3, llama2-7b-chat")
625-
raise e
626-
627614
pp_degree = builder_args.pp
628615
tp_degree = builder_args.tp
629616

@@ -687,7 +674,7 @@ def do_nothing(max_batch_size, max_seq_length):
687674
# Load weights
688675
logger.info(f"Loading weights for {pp_rank=} on {device=}")
689676
with CUDATrackTime() as timer:
690-
load_model_weights(model, distribution, device, config, builder_args.chpt_from)
677+
load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from)
691678

692679
logger.info(
693680
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"

0 commit comments

Comments
 (0)