Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 4774eaf

Browse files
authored
[Distributed] Follow upstream TransformerArgs changes (#1161)
1 parent 0f2849b commit 4774eaf

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

dist_run.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
)
3434
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
3535
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs
36-
from torchchat.model import ModelArgs, Transformer
36+
from torchchat.model import ModelArgs, Transformer, TransformerArgs
3737
from torchchat.utils.build_utils import set_precision
3838

3939
try:
@@ -239,8 +239,11 @@ def main(args):
239239
distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
240240
logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}")
241241

242-
config = ModelArgs.from_name(distribution).transformer_args["text"]
243-
logger.info(f"Chat Model Config: {config}")
242+
# Model-level config
243+
model_config = ModelArgs.from_name(distribution)
244+
# Transformer-level config
245+
config = TransformerArgs.from_params(model_config.transformer_args["text"])
246+
logger.info(f"Transformer Config: {config}")
244247

245248
tokenizer = _build_chat_tokenizer(model_name)
246249

@@ -282,6 +285,7 @@ def main(args):
282285
config.n_stages = pp_degree
283286

284287
with device:
288+
# TODO: we should create model instead of Transformer
285289
model = Transformer(config)
286290

287291
# Distribute model on TP mesh

0 commit comments

Comments
 (0)