Skip to content

Commit 1d7c3ab

Browse files
helunwencserfacebook-github-bot
authored andcommitted
fix llama runner (pytorch#6256)
Summary: Pull Request resolved: pytorch#6256 imported-using-ghimport Test Plan: Imported from OSS Run the following command and make sure it generates the right result: ``` python -m examples.models.llama2.runner.eager \ --checkpoint /home/lunwenh/models/1B/consolidated.00.pth \ --params /home/lunwenh/models/1B/params.json \ --max_len 128 \ --tokenizer /home/lunwenh/models/1B/tokenizer.model \ --prompt "<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a good assistant<|eot_id|><|start_header_id|>user<|end_header_id|> What is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>" ``` ``` Response: The capital of France is Paris. Tokens: [791, 6864, 315, 9822, 374, 12366, 13] ``` Reviewed By: mergennachin Differential Revision: D64442223 Pulled By: helunwencser fbshipit-source-id: 99ee56f73e472a7243b8896a35ee092f287edf6b
1 parent 2c8b14c commit 1d7c3ab

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/models/llama2/runner/generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ def sample_top_p(probs, p):
4545

4646
def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
4747
if temperature > 0:
48-
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
48+
probs = torch.softmax(logits / temperature, dim=-1)
4949
return sample_top_p(probs, top_p).item()
50-
return torch.argmax(logits[:, -1], dim=-1).item()
50+
return torch.argmax(logits, dim=-1).item()
5151

5252

5353
class LlamaRunner(ABC):

0 commit comments

Comments
 (0)