Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 4b8cdcb

Browse files
committed
Reenable arg.model_name in dist_run.py
1 parent 99606ab commit 4b8cdcb

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

dist_run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def _cleanup():
306306

307307

308308
def main(args):
309-
model_name = "llama3" # args.model_name
309+
model_name = args.model_name
310310
pp_degree = args.pp
311311

312312
rank, world_size = _init_distributed()
@@ -592,14 +592,14 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
592592

593593
if __name__ == "__main__":
594594
parser = argparse.ArgumentParser()
595-
"""parser.add_argument(
595+
parser.add_argument(
596596
"model_name",
597597
type=str,
598598
default="llama3",
599599
help="Name of the model to load",
600-
# choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(),
600+
choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(),
601601
)
602-
"""
602+
603603
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
604604
parser.add_argument(
605605
"--ntokens",

0 commit comments

Comments
 (0)