Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 3836928

Browse files
committed
use prompt parameter for dist generation
1 parent adcf232 commit 3836928

File tree

4 files changed

+44
-39
lines changed

4 files changed

+44
-39
lines changed

torchchat/cli/builder.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ class BuilderArgs:
5858
precision: torch.dtype = torch.float32
5959
setup_caches: bool = False
6060
distributed: bool = False
61-
num_gpus: int = 1
62-
num_nodes: int = 1
6361
pp: int = 1
6462
tp: int = 1
6563
chpt_from: str = "hf"
@@ -165,8 +163,6 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
165163
dtype = name_to_dtype(args.dtype, args.device)
166164
# distributed args
167165
distributed = getattr(args, "distributed", False)
168-
num_gpus = getattr(args, "num_gpus", 1)
169-
num_nodes = getattr(args, "num_nodes", 1)
170166
pp = getattr(args, "pp", 1)
171167
tp = getattr(args, "tp", 1)
172168
chpt_from = getattr(args, "chpt_from", "hf")
@@ -184,8 +180,6 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
184180
precision=dtype,
185181
setup_caches=(output_dso_path or output_pte_path),
186182
distributed=distributed,
187-
num_gpus=num_gpus,
188-
num_nodes=num_nodes,
189183
pp=pp,
190184
tp=tp,
191185
chpt_from=chpt_from,

torchchat/cli/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,13 +448,13 @@ def _add_custom_model_args(parser) -> None:
448448
"--params-path",
449449
type=Path,
450450
default=None,
451-
help= "Use the specified parameter file, instead of one specified under torchchat.model_params",
451+
help="Use the specified parameter file, instead of one specified under torchchat.model_params",
452452
)
453453
parser.add_argument(
454454
"--tokenizer-path",
455455
type=Path,
456456
default=None,
457-
help= "Use the specified model tokenizer file, instead of the one downloaded from HuggingFace",
457+
help="Use the specified model tokenizer file, instead of the one downloaded from HuggingFace",
458458
)
459459

460460

torchchat/distributed/generate.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,25 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import asyncio
7+
import atexit
8+
import importlib.util
9+
import subprocess
10+
import threading
611
from abc import abstractmethod
712
from collections import deque
813
from dataclasses import dataclass
914
from functools import partial
1015
from os import environ
1116
from pathlib import Path
12-
from torchchat.cli.builder import BuilderArgs, TokenizerArgs
1317
from typing import List, Optional
1418
from uuid import uuid4
1519

16-
import asyncio
17-
import atexit
1820
import torch.multiprocessing as mp
19-
import threading
20-
import importlib.util
21-
import subprocess
21+
from torchchat.cli.builder import BuilderArgs, TokenizerArgs
2222

2323

24-
def _setup_env(world_size:int, rank:int, target: callable, *args, **kwargs):
24+
def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs):
2525
environ["MASTER_ADDR"] = "localhost"
2626
environ["MASTER_PORT"] = "29500"
2727
environ["RDZV_BACKEND"] = "c10d"
@@ -36,10 +36,11 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None:
3636
# create programmatic elastic launch
3737
print("Launching distributed inference ...")
3838

39-
num_processes_per_node = 4 # builder_args.num_gpus + 1
39+
num_processes_per_node = builder_args.pp * builder_args.tp
4040

4141
from torchchat.distributed.dist_run import main
42-
mp.set_start_method('spawn')
42+
43+
mp.set_start_method("spawn")
4344

4445
pipes = []
4546
procs = []
@@ -48,25 +49,24 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None:
4849
pipes.append(server_pipe)
4950
proc = mp.Process(
5051
target=partial(_setup_env, num_processes_per_node, rank, main),
51-
args=(builder_args, client_pipe)
52+
args=(builder_args, client_pipe),
5253
)
5354
proc.start()
5455

55-
5656
for pipe in pipes:
5757
response = pipe.recv()
5858

59-
print(
60-
f"Done launching distributed inference on **4 ** {builder_args.num_gpus} GPUs."
61-
)
59+
print(f"Done launching distributed inference on {num_processes_per_node} GPUs.")
6260
return procs, pipes
6361

62+
6463
@dataclass
6564
class Output:
6665
is_finished: bool = False
6766
text: Optional[str] = None
6867
token: Optional[list] = None
6968

69+
7070
@dataclass
7171
class Request:
7272
request_id: int
@@ -84,7 +84,7 @@ def __init__(
8484
generator_args,
8585
pipes,
8686
loop,
87-
):
87+
):
8888
self.builder_args = builder_args
8989
self.generator_args = generator_args
9090
self.requests = {}
@@ -107,7 +107,7 @@ def process_requests_loop(self):
107107
if req == "stop":
108108
break
109109
self.requests = {req.request_id: req.prompt}
110-
110+
111111
responses = {}
112112
running = True
113113
while running:
@@ -128,17 +128,17 @@ async def wait_for_request(self, req: Request) -> Output:
128128
yield output
129129
del self.req_to_states[req.request_id]
130130
del self.req_to_results[req.request_id]
131-
131+
132132
def step(self) -> List[Output]:
133133
responses = []
134-
#TODO: Implement a scheduler to handle the requests
134+
# TODO: Implement a scheduler to handle the requests
135135
if len(self.in_flight_requests) > 0:
136-
#Receive decoded token
136+
# Receive decoded token
137137
for p in self.pipes:
138138
p.send("step")
139139
for p in self.pipes:
140140
responses.append(p.recv())
141-
141+
142142
else:
143143
# Send requests to backend
144144
self.in_flight_batch_order = list(self.requests.keys())
@@ -148,25 +148,26 @@ def step(self) -> List[Output]:
148148
self.in_flight_requests = self.requests
149149
self.requests = {}
150150
self.current_step = 0
151-
#Receive first token
151+
# Receive first token
152152
for p in self.pipes:
153153
responses.append(p.recv())
154-
responses = responses[0]
154+
# Filter out None responses from in-between stages
155+
responses = [r for r in responses if r is not None][0]
155156
outputs = []
156157
for k, v in zip(self.in_flight_batch_order, zip(responses[0], responses[1])):
157158
text, token_ids = v
158159
outputs.append(
159160
Output(
160-
is_finished=self.current_step>=self.generator_args.max_new_tokens,
161+
is_finished=self.current_step >= self.generator_args.max_new_tokens,
161162
text=text,
162163
token=token_ids,
163-
)
164164
)
165+
)
165166
if self.current_step >= self.generator_args.max_new_tokens:
166167
for p in self.pipes:
167168
p.send("stop")
168169
self.in_flight_requests = []
169-
170+
170171
self.current_step += 1
171172

172173
return outputs
@@ -177,24 +178,28 @@ def __init__(
177178
self,
178179
builder_args: BuilderArgs,
179180
tokenizer_args: TokenizerArgs,
180-
#TODO: move GeneratorArgs into a different module
181+
# TODO: move GeneratorArgs into a different module
181182
generator_args,
182183
profile: Optional[Path],
183184
quantize: bool,
184185
draft_quantize: bool,
185-
):
186+
):
186187
self.builder_args = builder_args
187188
self.generate_args = generator_args
188-
189+
190+
self.check_args()
191+
189192
self.procs, self.pipes = _launch_distributed_inference(builder_args)
190193

191194
self.loop = asyncio.new_event_loop()
192195
asyncio.set_event_loop(self.loop)
193196

194197
self.scheduler = Scheduler(builder_args, generator_args, self.pipes, self.loop)
195198

196-
#TODO: Mode into process and use pipe or queue for comm
197-
self.scheduler_thread = threading.Thread(target=self.scheduler.process_requests_loop)
199+
# TODO: Mode into process and use pipe or queue for comm
200+
self.scheduler_thread = threading.Thread(
201+
target=self.scheduler.process_requests_loop
202+
)
198203
self.scheduler_thread.start()
199204

200205
atexit.register(self.shutdown)
@@ -220,3 +225,9 @@ def generate(self, text):
220225
running &= not output.is_finished
221226

222227
yield output
228+
229+
def check_args(self):
230+
if self.generate_args.chat_mode:
231+
raise NotImplementedError(
232+
"Currently we only support generate with --distributed"
233+
)

torchchat/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1243,7 +1243,7 @@ def main(args):
12431243
)
12441244

12451245
response = ""
1246-
for output in dist_gen.generate("Tell me a joke"):
1246+
for output in dist_gen.generate(generator_args.prompt):
12471247
response += output.text
12481248

12491249
print(f"Model output: {response}")

0 commit comments

Comments
 (0)