Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit e8f7c98

Browse files
committed
Disable torchchat format + constrain possible models for distributed
1 parent fd9f704 commit e8f7c98

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

torchchat/distributed/dist_run.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _build_chat_tokenizer(
105105
tokenizer_args: TokenizerArgs,
106106
) -> SentencePieceProcessor | TiktokenTokenizer:
107107
"""Builds a tokenizer for the given model name"""
108-
108+
109109
tokenizer_args = TokenizerArgs.from_args(tokenizer_args)
110110
tokenizer = tokenizer_args.t
111111
assert tokenizer is not None, f"Failed to get tokenizer using {tokenconfig=}"
@@ -293,12 +293,11 @@ def _cleanup():
293293

294294

295295
def main(
296+
model_name,
296297
builder_args,
297298
tokenizer_args,
298299
pipe,
299300
):
300-
model_name = "llama3" # args.model_name
301-
# print(f"{builder_args.checkpoint_path=}")
302301
pp_degree = builder_args.pp
303302

304303
rank, world_size = _init_distributed()

torchchat/distributed/generate.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import torch.multiprocessing as mp
2121
from torchchat.cli.builder import BuilderArgs, TokenizerArgs
22+
from torchchat.distributed.dist_run import NAME_TO_DISTRIBUTION_AND_DTYPE
2223

2324

2425
def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs):
@@ -33,7 +34,7 @@ def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs):
3334

3435

3536
def _launch_distributed_inference(
36-
builder_args: BuilderArgs, tokenizer_args: TokenizerArgs
37+
model_name: str, builder_args: BuilderArgs, tokenizer_args: TokenizerArgs
3738
) -> tuple[List]:
3839
# create programmatic elastic launch
3940
print("Launching distributed inference ...")
@@ -51,7 +52,7 @@ def _launch_distributed_inference(
5152
pipes.append(server_pipe)
5253
proc = mp.Process(
5354
target=partial(_setup_env, num_processes_per_node, rank, main),
54-
args=(builder_args, tokenizer_args, client_pipe),
55+
args=(model_name, builder_args, tokenizer_args, client_pipe),
5556
)
5657
proc.start()
5758

@@ -178,6 +179,8 @@ def step(self) -> List[Output]:
178179
class DistributedGenerator(object):
179180
def __init__(
180181
self,
182+
# TODO: switch this to torchchat method
183+
model_name: str,
181184
builder_args: BuilderArgs,
182185
tokenizer_args: TokenizerArgs,
183186
# TODO: move GeneratorArgs into a different module
@@ -186,13 +189,14 @@ def __init__(
186189
quantize: bool,
187190
draft_quantize: bool,
188191
):
192+
self.model_name = model_name
189193
self.builder_args = builder_args
190194
self.generate_args = generator_args
191195

192196
self.check_args()
193197

194198
self.procs, self.pipes = _launch_distributed_inference(
195-
builder_args, tokenizer_args
199+
model_name, builder_args, tokenizer_args
196200
)
197201

198202
self.loop = asyncio.new_event_loop()
@@ -236,4 +240,12 @@ def check_args(self):
236240
"Currently we only support generate with --distributed"
237241
)
238242
elif self.builder_args.tp < 2:
239-
raise RuntimeError("TP degree must be at least 2 for distributed inference")
243+
raise ValueError("TP degree must be at least 2 for distributed inference")
244+
elif self.model_name not in NAME_TO_DISTRIBUTION_AND_DTYPE.keys():
245+
raise ValueError(
246+
f"Distributed inference currently only supports then following models: {list(NAME_TO_DISTRIBUTION_AND_DTYPE.keys())}"
247+
)
248+
elif self.builder_args.chpt_from == "torchchat":
249+
raise ValueError(
250+
f"Distributed inference currently only supports HF checkpoints"
251+
)

torchchat/generate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,6 +1233,7 @@ def main(args):
12331233
pass
12341234
else:
12351235
dist_gen = DistributedGenerator(
1236+
args.model,
12361237
builder_args,
12371238
tokenizer_args,
12381239
generator_args,

0 commit comments

Comments
 (0)