Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 11f29fc

Browse files
committed
Add initial generator interface to dist inference
1 parent 1faa052 commit 11f29fc

File tree

5 files changed

+247
-183
lines changed

5 files changed

+247
-183
lines changed

torchchat/cli/builder.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,9 @@
1616
import torch._inductor.config
1717
import torch.nn as nn
1818

19-
from torch.distributed import launcher
20-
2119
from torch.distributed.device_mesh import DeviceMesh
2220
from torch.distributed.elastic.multiprocessing.errors import record
2321
from torch.distributed.elastic.utils.distributed import get_free_port
24-
from torch.distributed.launcher.api import elastic_launch
2522

2623
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
2724

@@ -65,6 +62,8 @@ class BuilderArgs:
6562
num_nodes: int = 1
6663
pp: int = 1
6764
tp: int = 1
65+
chpt_from: str = "hf"
66+
ntokens: int = 40
6867
is_chat_model: bool = False
6968
prefill_possible: bool = False
7069
dynamic_shapes: bool = False
@@ -171,6 +170,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
171170
num_nodes = getattr(args, "num_nodes", 1)
172171
pp = getattr(args, "pp", 1)
173172
tp = getattr(args, "tp", 1)
173+
chpt_from = getattr(args, "chpt_from", "hf")
174+
ntokens = getattr(args, "ntokens", 40)
174175
return cls(
175176
checkpoint_dir=checkpoint_dir,
176177
checkpoint_path=checkpoint_path,
@@ -189,6 +190,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
189190
num_nodes=num_nodes,
190191
pp=pp,
191192
tp=tp,
193+
chpt_from=chpt_from,
194+
ntokens=ntokens,
192195
is_chat_model=is_chat_model,
193196
dynamic_shapes=getattr(args, "dynamic_shapes", False),
194197
max_seq_length=getattr(args, "max_seq_length", None),
@@ -508,7 +511,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
508511

509512
model = model.to(device=builder_args.device, dtype=builder_args.precision)
510513
return model.eval()
511-
514+
512515

513516
def _initialize_model(
514517
builder_args: BuilderArgs,

torchchat/cli/cli.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,20 @@ def _add_distributed_args(parser) -> None:
426426
# "Tensor parallel degree",
427427
)
428428

429+
parser.add_argument(
430+
"--ntokens",
431+
type=int,
432+
default=40,
433+
help="Number of tokens to generate",
434+
)
435+
parser.add_argument(
436+
"--chpt-from",
437+
type=str,
438+
default="hf", # TODO: change to torchchat once we support it well
439+
help="Checkpoint format to load from",
440+
choices=["hf", "torchchat"],
441+
)
442+
429443

430444
# Add CLI Args related to custom model inputs
431445
def _add_custom_model_args(parser) -> None:

0 commit comments

Comments
 (0)