|
1 | 1 | # pylint: disable=protected-access |
2 | | -import pytest |
3 | 2 | import random |
4 | 3 | from typing import Tuple |
5 | 4 | from unittest.mock import patch |
6 | 5 |
|
| 6 | +import pytest |
7 | 7 | import torch |
8 | 8 |
|
9 | 9 | from vllm.model_executor.layers.sampler import Sampler |
@@ -69,7 +69,7 @@ def test_sampler_all_greedy(seed: int): |
69 | 69 | input_metadata=input_metadata) |
70 | 70 | expected = torch.argmax(fake_logits, dim=-1) |
71 | 71 | for i, sequence_output in enumerate(sampler_output): |
72 | | - for nth_output in sequence_output: |
| 72 | + for nth_output in sequence_output.samples: |
73 | 73 | assert nth_output.output_token == expected[i].item() |
74 | 74 |
|
75 | 75 |
|
@@ -101,7 +101,7 @@ def test_sampler_all_random(seed: int): |
101 | 101 | hidden_states=input_tensor, |
102 | 102 | input_metadata=input_metadata) |
103 | 103 | for i, sequence_output in enumerate(sampler_output): |
104 | | - for nth_output in sequence_output: |
| 104 | + for nth_output in sequence_output.samples: |
105 | 105 | assert nth_output.output_token == i |
106 | 106 |
|
107 | 107 |
|
@@ -181,5 +181,5 @@ def test_sampler_mixed(seed: int): |
181 | 181 | for i, sequence_output in enumerate(sampler_output): |
182 | 182 | if seq_group_metadata_list[i].sampling_params.use_beam_search: |
183 | 183 | continue |
184 | | - for nth_output in sequence_output: |
| 184 | + for nth_output in sequence_output.samples: |
185 | 185 | assert nth_output.output_token in expected_tokens |
0 commit comments