Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 28d7836

Browse files
committed
Make sure speculative decoding is disable for pp >1 and remark this in the comments as well
1 parent 10fb55a commit 28d7836

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

torchchat/generate.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,6 +1237,9 @@ def __init__(
12371237
quantize: bool,
12381238
draft_quantize: bool,
12391239
):
1240+
1241+
is_speculative = speculative_builder_args.checkpoint_path is not None
1242+
assert is_speculative == False, "Distributed inference with pp > 1 does not support speculative inference yet."
12401243
super().__init__(
12411244
builder_args,
12421245
speculative_builder_args,
@@ -1449,8 +1452,9 @@ def decode_one_token(
14491452
"""
14501453
Decodes a single token.
14511454
1455+
# TODO: implement speculative decoding with pp>1
14521456
Returns:
1453-
Tuple[torch.Tensor, Optional[torch.Tensor]]: A tuple containing the decoded token and its probability.
1457+
Tuple[torch.Tensor, None]: A tuple containing the decoded token and None.
14541458
"""
14551459
if self.builder_args.pp == 1:
14561460
return super().decode_one_token(
@@ -1511,9 +1515,7 @@ def sample(
15111515
return (idx_next, None)
15121516
probs = self.logits_to_probs(logits[0, -1], temperature, top_k)
15131517
idx_next = self.multinomial_sample_one_no_sync(probs)
1514-
if self.builder_args.pp == 1:
1515-
dist.broadcast(idx_next, src=0)
1516-
dist.broadcast(probs, src=0)
1518+
15171519
return idx_next, probs
15181520

15191521

0 commit comments

Comments
 (0)