-
Notifications
You must be signed in to change notification settings - Fork 248
[Distributed] Fix new token's shape #1254
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1254
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1f8ff93 with merge base 8fcb3ba ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
dist_run.py
Outdated
| # Make a 2D tensor with ids on row dimension | ||
| unsqueezed = torch.unsqueeze(token, 1) | ||
| token_str = tokenizer.decode(unsqueezed.tolist()) | ||
| token_str = tokenizer.decode(token.tolist()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find that this does not work as-is.
However, adding in a one liner before the tokenizer.decode line (421):
token = token.squeeze(1)
And now it works on both llama2 and llama3.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, makes sense. Tokenizer difference :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems I cannot unconditionally squeeze the tensor.
If I do, some tokenizers will output:
responses ====>>>> is Christmasiving
istead of
responses ====>>>> ['is', 'Christmas', 'iving']
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am:
using tokenizer = sentencepiece.SentencePieceProcessor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I am adding an if there:
# `token` is a tensor of shape (batch_size, 1).
# For TiktokenTokenizer, we need to squeeze it to 1D.
# For SentencePieceProcessor, we don't.
if isinstance(tokenizer, TiktokenTokenizer):
token = torch.squeeze(token, dim=1)
token_str = tokenizer.decode(token.tolist())
lessw2020
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for adding!
I find that there is a missing line for this PR to work but with that line, verified working on llama2 and llama3.
Stamping to land, please add in the squeeze line.
08b3f09 to
4cdf355
Compare
4cdf355 to
1f8ff93
Compare
Issue
TP-only case is broken due to the following error:
It suggests that in the decoding phase, our
input_ids(i.e.new_tokens) is flattened rather than being 2D (batch_size, 1).The flattening happens here:
Fix
The fix is simple, we just add a
keepdim=Trueflag to torch.argmax.With that, the
unsqueezeop indecode_in_flightcan be also saved.