|
1 | 1 | # pylint: disable=protected-access
|
2 |
| -import pytest |
3 | 2 | import random
|
4 | 3 | from typing import Tuple
|
5 | 4 | from unittest.mock import patch
|
6 | 5 |
|
| 6 | +import pytest |
7 | 7 | import torch
|
8 | 8 |
|
9 | 9 | from vllm.model_executor.layers.sampler import Sampler
|
@@ -69,7 +69,7 @@ def test_sampler_all_greedy(seed: int):
|
69 | 69 | input_metadata=input_metadata)
|
70 | 70 | expected = torch.argmax(fake_logits, dim=-1)
|
71 | 71 | for i, sequence_output in enumerate(sampler_output):
|
72 |
| - for nth_output in sequence_output: |
| 72 | + for nth_output in sequence_output.samples: |
73 | 73 | assert nth_output.output_token == expected[i].item()
|
74 | 74 |
|
75 | 75 |
|
@@ -101,7 +101,7 @@ def test_sampler_all_random(seed: int):
|
101 | 101 | hidden_states=input_tensor,
|
102 | 102 | input_metadata=input_metadata)
|
103 | 103 | for i, sequence_output in enumerate(sampler_output):
|
104 |
| - for nth_output in sequence_output: |
| 104 | + for nth_output in sequence_output.samples: |
105 | 105 | assert nth_output.output_token == i
|
106 | 106 |
|
107 | 107 |
|
@@ -181,5 +181,5 @@ def test_sampler_mixed(seed: int):
|
181 | 181 | for i, sequence_output in enumerate(sampler_output):
|
182 | 182 | if seq_group_metadata_list[i].sampling_params.use_beam_search:
|
183 | 183 | continue
|
184 |
| - for nth_output in sequence_output: |
| 184 | + for nth_output in sequence_output.samples: |
185 | 185 | assert nth_output.output_token in expected_tokens
|
0 commit comments