Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit adcf232

Browse files
committed
Added generate method and placeholder scheduler
1 parent 11f29fc commit adcf232

File tree

5 files changed

+134
-79
lines changed

5 files changed

+134
-79
lines changed

torchchat/cli/builder.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ class BuilderArgs:
6363
pp: int = 1
6464
tp: int = 1
6565
chpt_from: str = "hf"
66-
ntokens: int = 40
6766
is_chat_model: bool = False
6867
prefill_possible: bool = False
6968
dynamic_shapes: bool = False
@@ -171,7 +170,6 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
171170
pp = getattr(args, "pp", 1)
172171
tp = getattr(args, "tp", 1)
173172
chpt_from = getattr(args, "chpt_from", "hf")
174-
ntokens = getattr(args, "ntokens", 40)
175173
return cls(
176174
checkpoint_dir=checkpoint_dir,
177175
checkpoint_path=checkpoint_path,
@@ -191,7 +189,6 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
191189
pp=pp,
192190
tp=tp,
193191
chpt_from=chpt_from,
194-
ntokens=ntokens,
195192
is_chat_model=is_chat_model,
196193
dynamic_shapes=getattr(args, "dynamic_shapes", False),
197194
max_seq_length=getattr(args, "max_seq_length", None),

torchchat/cli/cli.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -425,13 +425,6 @@ def _add_distributed_args(parser) -> None:
425425
help=argparse.SUPPRESS,
426426
# "Tensor parallel degree",
427427
)
428-
429-
parser.add_argument(
430-
"--ntokens",
431-
type=int,
432-
default=40,
433-
help="Number of tokens to generate",
434-
)
435428
parser.add_argument(
436429
"--chpt-from",
437430
type=str,

torchchat/distributed/dist_run.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def main(args, pipe):
388388
# Batch size. Since we push batches dynamically through the pipeline rather
389389
# than chunking them, this is effectively micro-batch size in pipeline
390390
# sense. Thus it is interchangeable with micro-batch size below.
391-
batch_size = 4# len(prompt)
391+
batch_size = 1# len(prompt)
392392
seqlen_prefill = 1024 # sequence length
393393
dim = 4096 # embedding dimension
394394

@@ -410,9 +410,6 @@ def main(args, pipe):
410410
logger.info(
411411
f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}"
412412
)
413-
414-
# Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
415-
input_pos = torch.arange(seqlen_prefill, device=device)
416413
model.eval()
417414

418415
# Helper function to get example inputs and outputs for the stages.
@@ -470,6 +467,8 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
470467
logger.info(f"{color.green}Prompt: {prompt}{color.reset}")
471468

472469
start_pos = 0
470+
# Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
471+
input_pos = torch.arange(seqlen_prefill, device=device)
473472

474473
# encode the prompt
475474
input_ids = _encode_strings(
@@ -511,9 +510,8 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
511510
res.append(new_token)
512511
#TODO: Move to a separate decoding thread
513512
resp = _decode_in_flight(new_token, tokenizer, tp_rank)
514-
pipe.send(resp)
513+
pipe.send((resp, new_token.tolist()))
515514
else:
516-
logger.info(f"sending None {tp_rank=}")
517515
pipe.send(None)
518516

519517
# seqlen = 1 now
@@ -577,7 +575,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
577575
res.append(new_token)
578576
#TODO: Move to a separate decoding thread
579577
resp = _decode_in_flight(new_token, tokenizer, tp_rank)
580-
pipe.send(resp)
578+
pipe.send((resp, new_token))
581579
else:
582580
pipe.send(None)
583581

torchchat/distributed/generate.py

Lines changed: 124 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,19 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
from abc import abstractmethod
7-
from typing import List, Optional
7+
from collections import deque
88
from dataclasses import dataclass
9-
from pathlib import Path
9+
from functools import partial
1010
from os import environ
11+
from pathlib import Path
1112
from torchchat.cli.builder import BuilderArgs, TokenizerArgs
12-
from functools import partial
13+
from typing import List, Optional
14+
from uuid import uuid4
1315

16+
import asyncio
1417
import atexit
1518
import torch.multiprocessing as mp
19+
import threading
1620
import importlib.util
1721
import subprocess
1822

@@ -51,7 +55,6 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None:
5155

5256
for pipe in pipes:
5357
response = pipe.recv()
54-
print(f"Received: {response=}")
5558

5659
print(
5760
f"Done launching distributed inference on **4 ** {builder_args.num_gpus} GPUs."
@@ -60,56 +63,72 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None:
6063

6164
@dataclass
6265
class Output:
63-
request_id: int
6466
is_finished: bool = False
65-
output: Optional[str] = None
66-
67-
class Generator(object):
67+
text: Optional[str] = None
68+
token: Optional[list] = None
6869

69-
@abstractmethod
70-
def add_request(self, request_id: int, prompt: str):
71-
raise NotImplementedError()
70+
@dataclass
71+
class Request:
72+
request_id: int
73+
prompt: str
7274

73-
def step(self) -> List[Output]:
74-
raise NotImplementedError()
75+
@classmethod
76+
def new_request(cls, prompt):
77+
return cls(request_id=uuid4().int, prompt=prompt)
7578

7679

77-
class DistributedGenerator(Generator):
80+
class Scheduler(object):
7881
def __init__(
7982
self,
80-
builder_args: BuilderArgs,
81-
speculative_builder_args: BuilderArgs,
82-
tokenizer_args: TokenizerArgs,
83-
#TODO: move GeneratorArgs into a different module
84-
# generator_args: GeneratorArgs,
85-
profile: Optional[Path],
86-
quantize: bool,
87-
draft_quantize: bool,
83+
builder_args,
84+
generator_args,
85+
pipes,
86+
loop,
8887
):
8988
self.builder_args = builder_args
89+
self.generator_args = generator_args
9090
self.requests = {}
9191
self.in_flight_requests = {}
92-
# For now we have a static batch order we save separately
9392
self.in_flight_batch_order = []
94-
# if builder_args.distributed:
95-
# # we part ways here with torchchat cli and move into dist inference
96-
self.procs, self.pipes = _launch_distributed_inference(builder_args)
97-
self.current_step = 0
98-
99-
atexit.register(self.shutdown)
100-
101-
def shutdown(self):
102-
for p in self.pipes:
103-
p.send("stop")
104-
for p in self.procs:
105-
p.kill()
106-
107-
#TODO: Replace against (async) generate
108-
def add_request(self, request_id: int, prompt: str):
109-
assert request_id not in self.requests
110-
self.requests[request_id] = prompt
111-
112-
93+
self.pipes = pipes
94+
self.req_to_states = {}
95+
self.req_to_results = {}
96+
self.request_queue = mp.Queue()
97+
self.loop = loop
98+
99+
def schedule_request(self, req: Request):
100+
self.req_to_states[req.request_id] = asyncio.Event()
101+
self.req_to_results[req.request_id] = deque()
102+
self.request_queue.put(req)
103+
104+
def process_requests_loop(self):
105+
while True:
106+
req = self.request_queue.get()
107+
if req == "stop":
108+
break
109+
self.requests = {req.request_id: req.prompt}
110+
111+
responses = {}
112+
running = True
113+
while running:
114+
outputs = self.step()
115+
self.req_to_results[req.request_id].append(outputs[0])
116+
117+
self.loop.call_soon_threadsafe(self.req_to_states[req.request_id].set)
118+
119+
running &= not outputs[0].is_finished
120+
121+
async def wait_for_request(self, req: Request) -> Output:
122+
is_finished = False
123+
while not is_finished:
124+
await self.req_to_states[req.request_id].wait()
125+
while len(self.req_to_results[req.request_id]):
126+
output = self.req_to_results[req.request_id].popleft()
127+
is_finished |= output.is_finished
128+
yield output
129+
del self.req_to_states[req.request_id]
130+
del self.req_to_results[req.request_id]
131+
113132
def step(self) -> List[Output]:
114133
responses = []
115134
#TODO: Implement a scheduler to handle the requests
@@ -132,12 +151,72 @@ def step(self) -> List[Output]:
132151
#Receive first token
133152
for p in self.pipes:
134153
responses.append(p.recv())
135-
136154
responses = responses[0]
137155
outputs = []
138-
for k, v in zip(self.in_flight_batch_order, responses):
139-
outputs.append(Output(k, is_finished=self.current_step>=self.builder_args.ntokens, output=v))
156+
for k, v in zip(self.in_flight_batch_order, zip(responses[0], responses[1])):
157+
text, token_ids = v
158+
outputs.append(
159+
Output(
160+
is_finished=self.current_step>=self.generator_args.max_new_tokens,
161+
text=text,
162+
token=token_ids,
163+
)
164+
)
165+
if self.current_step >= self.generator_args.max_new_tokens:
166+
for p in self.pipes:
167+
p.send("stop")
168+
self.in_flight_requests = []
140169

141170
self.current_step += 1
142171

143172
return outputs
173+
174+
175+
class DistributedGenerator(object):
176+
def __init__(
177+
self,
178+
builder_args: BuilderArgs,
179+
tokenizer_args: TokenizerArgs,
180+
#TODO: move GeneratorArgs into a different module
181+
generator_args,
182+
profile: Optional[Path],
183+
quantize: bool,
184+
draft_quantize: bool,
185+
):
186+
self.builder_args = builder_args
187+
self.generate_args = generator_args
188+
189+
self.procs, self.pipes = _launch_distributed_inference(builder_args)
190+
191+
self.loop = asyncio.new_event_loop()
192+
asyncio.set_event_loop(self.loop)
193+
194+
self.scheduler = Scheduler(builder_args, generator_args, self.pipes, self.loop)
195+
196+
#TODO: Mode into process and use pipe or queue for comm
197+
self.scheduler_thread = threading.Thread(target=self.scheduler.process_requests_loop)
198+
self.scheduler_thread.start()
199+
200+
atexit.register(self.shutdown)
201+
202+
def shutdown(self):
203+
self.scheduler.request_queue.put("stop")
204+
self.scheduler_thread.join()
205+
206+
for p in self.pipes:
207+
p.send("stop")
208+
for p in self.procs:
209+
p.kill()
210+
211+
def generate(self, text):
212+
req = Request.new_request(text)
213+
self.scheduler.schedule_request(req)
214+
215+
generator = self.scheduler.wait_for_request(req)
216+
217+
running = True
218+
while running:
219+
output = self.loop.run_until_complete(generator.__anext__())
220+
running &= not output.is_finished
221+
222+
yield output

torchchat/generate.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,30 +1235,18 @@ def main(args):
12351235
else:
12361236
dist_gen = DistributedGenerator(
12371237
builder_args,
1238-
speculative_builder_args,
12391238
tokenizer_args,
1240-
# generator_args,
1239+
generator_args,
12411240
args.profile,
12421241
args.quantize,
12431242
args.draft_quantize,
12441243
)
12451244

1246-
dist_gen.add_request(0, "Tell me a joke")
1247-
dist_gen.add_request(1, "Tell me another joke")
1248-
dist_gen.add_request(2, "Who is this Santa")
1249-
dist_gen.add_request(3, "What did the fish say to the duck")
1250-
1251-
responses = {}
1245+
response = ""
1246+
for output in dist_gen.generate("Tell me a joke"):
1247+
response += output.text
12521248

1253-
running = True
1254-
while running:
1255-
outputs = dist_gen.step()
1256-
for o in outputs:
1257-
responses[o.request_id] = responses.get(o.request_id, "") + o.output
1258-
running &= not o.is_finished
1259-
1260-
print(responses)
1261-
1249+
print(f"Model output: {response}")
12621250
dist_gen.shutdown()
12631251

12641252

0 commit comments

Comments
 (0)