Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 3f6fa2d

Browse files
committed
Enforce tp>=2
1 parent 3836928 commit 3f6fa2d

File tree

3 files changed

+26
-19
lines changed

3 files changed

+26
-19
lines changed

torchchat/cli/cli.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -399,8 +399,7 @@ def _add_distributed_args(parser) -> None:
399399
parser.add_argument(
400400
"--distributed",
401401
action="store_true",
402-
help=argparse.SUPPRESS,
403-
# "Whether to enable distributed inference",
402+
help="Whether to enable distributed inference",
404403
)
405404
parser.add_argument(
406405
"--dcp-dir",
@@ -414,16 +413,14 @@ def _add_distributed_args(parser) -> None:
414413
"--pipeline-parallel",
415414
type=int,
416415
default=1,
417-
help=argparse.SUPPRESS,
418-
# "Pipeline parallel degree",
416+
help="Pipeline parallel degree",
419417
)
420418
parser.add_argument(
421419
"--tp",
422420
"--tensor-parallel",
423421
type=int,
424-
default=1,
425-
help=argparse.SUPPRESS,
426-
# "Tensor parallel degree",
422+
default=2,
423+
help="Tensor parallel degree",
427424
)
428425
parser.add_argument(
429426
"--chpt-from",

torchchat/distributed/dist_run.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import os
1313
from enum import auto, Enum
1414
from pathlib import Path
15-
from types import SimpleNamespace, MethodType
15+
from types import MethodType, SimpleNamespace
1616
from typing import Any, Dict, List, Optional, Tuple
1717

1818
import torch
@@ -71,21 +71,26 @@ def _init_distributed():
7171

7272

7373
def _create_device_mesh(pp_degree, tp_degree):
74-
return dist.init_device_mesh("cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp"))
74+
return dist.init_device_mesh(
75+
"cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp")
76+
)
7577

7678

7779
def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace:
7880
return SimpleNamespace(**dictionary)
7981

82+
8083
def _patch_tokenizer(tokenizer):
8184
"""Patch the tokenizer to support decoding of token ids."""
8285
if isinstance(tokenizer, TiktokenTokenizer):
8386
# Patch tiktokenizer to allow a list of sequences.
84-
#TODO: Upstream to tokenizer modules
87+
# TODO: Upstream to tokenizer modules
8588
old_decode = tokenizer.decode
8689

87-
def decode(self, token_ids: List[int|List[int]], *args, **kwargs) -> str | List[str]:
88-
if len(token_ids)<1:
90+
def decode(
91+
self, token_ids: List[int | List[int]], *args, **kwargs
92+
) -> str | List[str]:
93+
if len(token_ids) < 1:
8994
return ""
9095
if isinstance(token_ids[0], list):
9196
return [old_decode(t, *args, **kwargs) for t in token_ids]
@@ -95,6 +100,7 @@ def decode(self, token_ids: List[int|List[int]], *args, **kwargs) -> str | List[
95100
tokenizer.decode = MethodType(decode, tokenizer)
96101
return tokenizer
97102

103+
98104
def _build_chat_tokenizer(
99105
model_name: str,
100106
model_base_name: Optional[str] = None,
@@ -221,7 +227,7 @@ def _create_padded_prompts(
221227

222228
def _batch_decode_next_tokens(
223229
output: torch.Tensor,
224-
pos: List[int]=None,
230+
pos: List[int] = None,
225231
temperature: float = 1.0,
226232
topk: int = 10,
227233
) -> torch.Tensor:
@@ -388,7 +394,7 @@ def main(args, pipe):
388394
# Batch size. Since we push batches dynamically through the pipeline rather
389395
# than chunking them, this is effectively micro-batch size in pipeline
390396
# sense. Thus it is interchangeable with micro-batch size below.
391-
batch_size = 1# len(prompt)
397+
batch_size = 1 # len(prompt)
392398
seqlen_prefill = 1024 # sequence length
393399
dim = 4096 # embedding dimension
394400

@@ -463,7 +469,9 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
463469
raise ValueError(f"Unknown command: {command}")
464470
else:
465471
prompt = command
466-
assert len(prompt) == batch_size, f"Expecting {batch_size=} prompts but got {len(prompt)=}"
472+
assert (
473+
len(prompt) == batch_size
474+
), f"Expecting {batch_size=} prompts but got {len(prompt)=}"
467475
logger.info(f"{color.green}Prompt: {prompt}{color.reset}")
468476

469477
start_pos = 0
@@ -508,7 +516,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
508516
logger.info(f"{color.green}Decoding...{prompt_lengths=}{color.reset}")
509517
new_token = _batch_decode_next_tokens(output, prompt_lengths)
510518
res.append(new_token)
511-
#TODO: Move to a separate decoding thread
519+
# TODO: Move to a separate decoding thread
512520
resp = _decode_in_flight(new_token, tokenizer, tp_rank)
513521
pipe.send((resp, new_token.tolist()))
514522
else:
@@ -539,7 +547,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
539547
command = pipe.recv()
540548
assert isinstance(command, str)
541549
if command == "stop":
542-
break
550+
break
543551
elif command == "step":
544552
pass
545553
else:
@@ -573,7 +581,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
573581
if pp_rank == last_pp_rank:
574582
new_token = _batch_decode_next_tokens(output)
575583
res.append(new_token)
576-
#TODO: Move to a separate decoding thread
584+
# TODO: Move to a separate decoding thread
577585
resp = _decode_in_flight(new_token, tokenizer, tp_rank)
578586
pipe.send((resp, new_token))
579587
else:
@@ -602,7 +610,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
602610
for prompt_text, response_text in zip(prompt, responses):
603611
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")
604612
logger.info(f"Response: {color.red}{response_text} {color.reset}")
605-
613+
606614
# Cleanup
607615
_cleanup()
608616
logger.info(

torchchat/distributed/generate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,5 @@ def check_args(self):
231231
raise NotImplementedError(
232232
"Currently we only support generate with --distributed"
233233
)
234+
elif self.builder_args.tp < 2:
235+
raise RuntimeError("TP degree must be at least 2 for distributed inference")

0 commit comments

Comments
 (0)