Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 1faa052

Browse files
committed
[WIP] Move dist inf into its own generator
1 parent e8bb076 commit 1faa052

File tree

5 files changed

+179
-86
lines changed

5 files changed

+179
-86
lines changed

torchchat/cli/builder.py

Lines changed: 1 addition & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -508,72 +508,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
508508

509509
model = model.to(device=builder_args.device, dtype=builder_args.precision)
510510
return model.eval()
511-
512-
513-
import importlib.util
514-
import subprocess
515-
516-
517-
def run_script(script_path, *args):
518-
# Construct the command to run the script
519-
cmd = [sys.executable, script_path] + list(args)
520-
521-
# Run the script as a subprocess
522-
process = subprocess.Popen(
523-
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True
524-
)
525-
526-
# Stream the output in real-time
527-
for line in process.stdout:
528-
print(line, end="")
529-
for line in process.stderr:
530-
print(line, end="", file=sys.stderr)
531-
532-
# Wait for the process to complete and get the return code
533-
return_code = process.wait()
534-
if return_code != 0:
535-
raise subprocess.CalledProcessError(return_code, cmd)
536-
537-
538-
def _launch_distributed_inference(builder_args: BuilderArgs) -> None:
539-
# create programmatic elastic launch
540-
print("Launching distributed inference ...")
541-
542-
num_processes_per_node = 4 # builder_args.num_gpus + 1
543-
544-
lc = launcher.LaunchConfig(
545-
min_nodes=1,
546-
max_nodes=1,
547-
nproc_per_node=num_processes_per_node,
548-
# run_id=str(uuid.uuid4()),
549-
rdzv_backend="c10d",
550-
rdzv_endpoint="localhost:29401",
551-
max_restarts=0,
552-
monitor_interval=1,
553-
)
554-
555-
train_file_path = Path(__file__).parent.parent.parent / "dist_run.py"
556-
print(f"train_file_path: {train_file_path}")
557-
# import argparse
558-
559-
# parser2 = argparse.ArgumentParser()
560-
561-
# args = parser2.parse_args()
562-
args = []
563-
print(f"args: {args}")
564-
565-
elastic_launch(
566-
config=lc,
567-
entrypoint=run_script,
568-
)(train_file_path, *args)
569-
print(
570-
f"Done launching distributed inference on **4 ** {builder_args.num_gpus} GPUs."
571-
)
572-
# role=role, *args, **kwargs)
573-
574-
# assert False, "distributed inference is not supported yet"
575-
# pass
576-
511+
577512

578513
def _initialize_model(
579514
builder_args: BuilderArgs,
@@ -583,11 +518,6 @@ def _initialize_model(
583518
support_tensor_subclass: bool = True,
584519
) -> Model:
585520
print("Loading model...")
586-
if builder_args.distributed:
587-
# we part ways here with torchchat cli and move into dist inference
588-
_launch_distributed_inference(builder_args)
589-
return None
590-
591521
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
592522
print("Setting gguf_kwargs for generate.")
593523
is_dso = builder_args.dso_path is not None

torchchat/cli/cli.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,22 @@ def _add_distributed_args(parser) -> None:
409409
help=argparse.SUPPRESS,
410410
# "Use the specified model checkpoint directory",
411411
)
412+
parser.add_argument(
413+
"--pp",
414+
"--pipeline-parallel",
415+
type=int,
416+
default=1,
417+
help=argparse.SUPPRESS,
418+
# "Pipeline parallel degree",
419+
)
420+
parser.add_argument(
421+
"--tp",
422+
"--tensor-parallel",
423+
type=int,
424+
default=1,
425+
help=argparse.SUPPRESS,
426+
# "Tensor parallel degree",
427+
)
412428

413429

414430
# Add CLI Args related to custom model inputs
File renamed without changes.

torchchat/distributed/generate.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from abc import abstractmethod
7+
from typing import List, Optional
8+
from dataclasses import dataclass
9+
from pathlib import Path
10+
from torchchat.cli.builder import BuilderArgs, TokenizerArgs
11+
12+
13+
import importlib.util
14+
import subprocess
15+
16+
17+
def run_script(script_path, *args):
18+
# Construct the command to run the script
19+
cmd = [sys.executable, script_path] + list(args)
20+
21+
# Run the script as a subprocess
22+
process = subprocess.Popen(
23+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True
24+
)
25+
26+
# Stream the output in real-time
27+
for line in process.stdout:
28+
print(line, end="")
29+
for line in process.stderr:
30+
print(line, end="", file=sys.stderr)
31+
32+
# Wait for the process to complete and get the return code
33+
return_code = process.wait()
34+
if return_code != 0:
35+
raise subprocess.CalledProcessError(return_code, cmd)
36+
37+
38+
def _launch_distributed_inference(builder_args: BuilderArgs) -> None:
39+
# create programmatic elastic launch
40+
print("Launching distributed inference ...")
41+
42+
num_processes_per_node = 4 # builder_args.num_gpus + 1
43+
44+
lc = launcher.LaunchConfig(
45+
min_nodes=1,
46+
max_nodes=1,
47+
nproc_per_node=num_processes_per_node,
48+
# run_id=str(uuid.uuid4()),
49+
rdzv_backend="c10d",
50+
rdzv_endpoint="localhost:29401",
51+
max_restarts=0,
52+
monitor_interval=1,
53+
)
54+
55+
# train_file_path = Path(__file__).parent.parent.parent / "dist_run.py"
56+
# print(f"train_file_path: {train_file_path}")
57+
# import argparse
58+
59+
# parser2 = argparse.ArgumentParser()
60+
61+
# args = parser2.parse_args()
62+
args = []
63+
print(f"args: {args}")
64+
65+
from dist_run import main
66+
67+
elastic_launch(
68+
config=lc,
69+
entrypoint=run_script,
70+
)(main, *args)
71+
print(
72+
f"Done launching distributed inference on **4 ** {builder_args.num_gpus} GPUs."
73+
)
74+
# role=role, *args, **kwargs)
75+
76+
# assert False, "distributed inference is not supported yet"
77+
# pass
78+
79+
@dataclass
80+
class Output:
81+
request_id: int
82+
is_finished: bool = False
83+
output: Optional[str] = None
84+
85+
class Generator(object):
86+
87+
@abstractmethod
88+
def add_request(self, request_id: int, prompt: str):
89+
raise NotImplementedError()
90+
91+
def step(self) -> List[Output]:
92+
raise NotImplementedError()
93+
94+
95+
class DistributedGenerator(Generator):
96+
def __init__(
97+
self,
98+
builder_args: BuilderArgs,
99+
speculative_builder_args: BuilderArgs,
100+
tokenizer_args: TokenizerArgs,
101+
#TODO: move GeneratorArgs into a different module
102+
# generator_args: GeneratorArgs,
103+
profile: Optional[Path],
104+
quantize: bool,
105+
draft_quantize: bool,
106+
):
107+
self.requests = {}
108+
# if builder_args.distributed:
109+
# # we part ways here with torchchat cli and move into dist inference
110+
_launch_distributed_inference(builder_args)
111+
# return None
112+
113+
114+
def add_request(self, request_id: int, prompt: str):
115+
assert request_id not in self.requests
116+
self.requests[request_id] = prompt
117+
118+
119+
def step(self) -> List[Output]:
120+
outputs = []
121+
for request_id, prompt in self.requests.items():
122+
outputs.append(Output(request_id, is_finished=True, output=prompt))
123+
124+
for output in outputs:
125+
if output.is_finished:
126+
del self.requests[output.request_id]
127+
128+
return outputs

torchchat/generate.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
TokenizerArgs,
3232
)
3333
from torchchat.model import Model, ModelType
34+
from torchchat.distributed.generate import DistributedGenerator
3435
from torchchat.utils.build_utils import device_sync, set_precision
3536
from torchchat.utils.device_info import get_device_info
3637

@@ -1215,19 +1216,37 @@ def main(args):
12151216
speculative_builder_args = BuilderArgs.from_speculative_args(args)
12161217
tokenizer_args = TokenizerArgs.from_args(args)
12171218
generator_args = GeneratorArgs.from_args(args)
1218-
gen = Generator(
1219-
builder_args,
1220-
speculative_builder_args,
1221-
tokenizer_args,
1222-
generator_args,
1223-
args.profile,
1224-
args.quantize,
1225-
args.draft_quantize,
1226-
)
1227-
if torch.cuda.is_available():
1228-
torch.cuda.reset_peak_memory_stats()
1229-
if builder_args.distributed:
1219+
if not builder_args.distributed:
1220+
gen = Generator(
1221+
builder_args,
1222+
speculative_builder_args,
1223+
tokenizer_args,
1224+
generator_args,
1225+
args.profile,
1226+
args.quantize,
1227+
args.draft_quantize,
1228+
)
1229+
if torch.cuda.is_available():
1230+
torch.cuda.reset_peak_memory_stats()
1231+
1232+
1233+
for _ in gen.chat(generator_args):
1234+
pass
1235+
else:
1236+
dist_gen = DistributedGenerator(
1237+
builder_args,
1238+
speculative_builder_args,
1239+
tokenizer_args,
1240+
# generator_args,
1241+
args.profile,
1242+
args.quantize,
1243+
args.draft_quantize,
1244+
)
1245+
1246+
dist_gen.add_request(0, "Tell me a joke")
1247+
dist_gen.add_request(1, "Tell me another joke")
12301248

1231-
return
1232-
for _ in gen.chat(generator_args):
1233-
pass
1249+
outputs = dist_gen.step()
1250+
while len(outputs):
1251+
print(outputs)
1252+
outputs = dist_gen.step()

0 commit comments

Comments
 (0)