Skip to content

Commit d3a5bd9

Browse files
authored
Fix sampler test (#1379)
1 parent e8ef4c0 commit d3a5bd9

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/samplers/test_sampler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# pylint: disable=protected-access
2-
import pytest
32
import random
43
from typing import Tuple
54
from unittest.mock import patch
65

6+
import pytest
77
import torch
88

99
from vllm.model_executor.layers.sampler import Sampler
@@ -69,7 +69,7 @@ def test_sampler_all_greedy(seed: int):
6969
input_metadata=input_metadata)
7070
expected = torch.argmax(fake_logits, dim=-1)
7171
for i, sequence_output in enumerate(sampler_output):
72-
for nth_output in sequence_output:
72+
for nth_output in sequence_output.samples:
7373
assert nth_output.output_token == expected[i].item()
7474

7575

@@ -101,7 +101,7 @@ def test_sampler_all_random(seed: int):
101101
hidden_states=input_tensor,
102102
input_metadata=input_metadata)
103103
for i, sequence_output in enumerate(sampler_output):
104-
for nth_output in sequence_output:
104+
for nth_output in sequence_output.samples:
105105
assert nth_output.output_token == i
106106

107107

@@ -181,5 +181,5 @@ def test_sampler_mixed(seed: int):
181181
for i, sequence_output in enumerate(sampler_output):
182182
if seq_group_metadata_list[i].sampling_params.use_beam_search:
183183
continue
184-
for nth_output in sequence_output:
184+
for nth_output in sequence_output.samples:
185185
assert nth_output.output_token in expected_tokens

0 commit comments

Comments
 (0)