|
| 1 | +import pytest |
| 2 | +import random |
| 3 | +from typing import Tuple |
| 4 | +from unittest.mock import patch |
| 5 | + |
| 6 | +import torch |
| 7 | + |
| 8 | +from vllm.model_executor.layers.sampler import Sampler |
| 9 | +from vllm.model_executor.utils import set_random_seed |
| 10 | +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata |
| 11 | +from vllm.worker.worker import Worker |
| 12 | + |
| 13 | + |
| 14 | +class MockLogitsSampler(Sampler): |
| 15 | + |
| 16 | + def __init__(self, vocab_size: int, fake_logits: torch.Tensor): |
| 17 | + super().__init__(vocab_size=vocab_size) |
| 18 | + self.fake_logits = fake_logits |
| 19 | + |
| 20 | + def forward(self, *args, **kwargs): |
| 21 | + with patch("vllm.model_executor.layers.sampler._prune_hidden_states", |
| 22 | + lambda x, y: x): |
| 23 | + with patch("vllm.model_executor.layers.sampler._get_logits", |
| 24 | + lambda *args, **kwargs: self.fake_logits): |
| 25 | + return super().forward(*args, **kwargs) |
| 26 | + |
| 27 | + |
| 28 | +def _prepare_test( |
| 29 | + batch_size: int |
| 30 | +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]: |
| 31 | + vocab_size = 32000 |
| 32 | + input_tensor = torch.rand((batch_size, 1024), |
| 33 | + device="cuda", |
| 34 | + dtype=torch.float16) |
| 35 | + fake_logits = torch.full((batch_size, vocab_size), |
| 36 | + 1e-2, |
| 37 | + device=input_tensor.device, |
| 38 | + dtype=input_tensor.dtype) |
| 39 | + sampler = MockLogitsSampler(32000, fake_logits) |
| 40 | + worker = Worker(None, None, None) |
| 41 | + worker.block_size = 16 |
| 42 | + return input_tensor, fake_logits, sampler, worker |
| 43 | + |
| 44 | + |
| 45 | +RANDOM_SEEDS = list(range(128)) |
| 46 | + |
| 47 | + |
| 48 | +@pytest.mark.parametrize("seed", RANDOM_SEEDS) |
| 49 | +def test_sampler_all_greedy(seed: int): |
| 50 | + set_random_seed(seed) |
| 51 | + batch_size = random.randint(1, 256) |
| 52 | + input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) |
| 53 | + |
| 54 | + seq_group_metadata_list = [] |
| 55 | + for i in range(batch_size): |
| 56 | + seq_group_metadata_list.append( |
| 57 | + SequenceGroupMetadata( |
| 58 | + request_id=f"test_{i}", |
| 59 | + is_prompt=True, |
| 60 | + seq_data={0: SequenceData([1, 2, 3])}, |
| 61 | + sampling_params=SamplingParams(temperature=0, ), |
| 62 | + block_tables={0: [1]}, |
| 63 | + )) |
| 64 | + |
| 65 | + _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) |
| 66 | + sampler_output = sampler(embedding=None, |
| 67 | + hidden_states=input_tensor, |
| 68 | + input_metadata=input_metadata) |
| 69 | + expected = torch.argmax(fake_logits, dim=-1) |
| 70 | + for i, sequence_output in enumerate(sampler_output): |
| 71 | + for nth_output in sequence_output: |
| 72 | + assert nth_output.output_token == expected[i].item() |
| 73 | + |
| 74 | + |
| 75 | +@pytest.mark.parametrize("seed", RANDOM_SEEDS) |
| 76 | +def test_sampler_all_random(seed: int): |
| 77 | + set_random_seed(seed) |
| 78 | + batch_size = random.randint(1, 256) |
| 79 | + input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) |
| 80 | + |
| 81 | + for i in range(batch_size): |
| 82 | + fake_logits[i, i] = 1e2 |
| 83 | + |
| 84 | + seq_group_metadata_list = [] |
| 85 | + for i in range(batch_size): |
| 86 | + seq_group_metadata_list.append( |
| 87 | + SequenceGroupMetadata( |
| 88 | + request_id=f"test_{i}", |
| 89 | + is_prompt=True, |
| 90 | + seq_data={0: SequenceData([1, 2, 3])}, |
| 91 | + sampling_params=SamplingParams( |
| 92 | + temperature=1.0, |
| 93 | + n=random.randint(1, 10), |
| 94 | + ), |
| 95 | + block_tables={0: [1]}, |
| 96 | + )) |
| 97 | + |
| 98 | + _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) |
| 99 | + sampler_output = sampler(embedding=None, |
| 100 | + hidden_states=input_tensor, |
| 101 | + input_metadata=input_metadata) |
| 102 | + for i, sequence_output in enumerate(sampler_output): |
| 103 | + for nth_output in sequence_output: |
| 104 | + assert nth_output.output_token == i |
| 105 | + |
| 106 | + |
| 107 | +@pytest.mark.parametrize("seed", RANDOM_SEEDS) |
| 108 | +def test_sampler_all_beam(seed: int): |
| 109 | + set_random_seed(seed) |
| 110 | + batch_size = random.randint(1, 256) |
| 111 | + input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) |
| 112 | + |
| 113 | + seq_group_metadata_list = [] |
| 114 | + for i in range(batch_size): |
| 115 | + seq_group_metadata_list.append( |
| 116 | + SequenceGroupMetadata( |
| 117 | + request_id=f"test_{i}", |
| 118 | + is_prompt=True, |
| 119 | + seq_data={0: SequenceData([1, 2, 3])}, |
| 120 | + sampling_params=SamplingParams( |
| 121 | + temperature=0, |
| 122 | + best_of=2, |
| 123 | + use_beam_search=True, |
| 124 | + ), |
| 125 | + block_tables={0: [1]}, |
| 126 | + )) |
| 127 | + |
| 128 | + _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) |
| 129 | + sampler(embedding=None, |
| 130 | + hidden_states=input_tensor, |
| 131 | + input_metadata=input_metadata) |
| 132 | + # no assertion here as I am not sure how to determine whether |
| 133 | + # the outputs are expected - in other words, this just tests |
| 134 | + # whether there are no exceptions in the sampler |
| 135 | + # when handling an all-beam search case. |
| 136 | + |
| 137 | + |
| 138 | +@pytest.mark.parametrize("seed", RANDOM_SEEDS) |
| 139 | +def test_sampler_mixed(seed: int): |
| 140 | + set_random_seed(seed) |
| 141 | + batch_size = random.randint(1, 256) |
| 142 | + input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) |
| 143 | + |
| 144 | + seq_group_metadata_list = [] |
| 145 | + expected_tokens = [] |
| 146 | + for i in range(batch_size): |
| 147 | + n = 1 |
| 148 | + sampling_type = random.randint(0, 2) |
| 149 | + if sampling_type == 0: |
| 150 | + sampling_params = SamplingParams(temperature=0) |
| 151 | + elif sampling_type == 1: |
| 152 | + n = random.randint(1, 10) |
| 153 | + sampling_params = SamplingParams( |
| 154 | + temperature=random.random() + 0.1, |
| 155 | + top_p=min(random.random() + 0.1, 1), |
| 156 | + top_k=random.randint(0, 10) or -1, |
| 157 | + n=n, |
| 158 | + presence_penalty=random.randint(0, 1), |
| 159 | + ) |
| 160 | + else: |
| 161 | + sampling_params = SamplingParams(temperature=0, |
| 162 | + use_beam_search=True, |
| 163 | + best_of=2) |
| 164 | + for idx in range(n): |
| 165 | + fake_logits[i, i + idx] = 1e2 |
| 166 | + expected_tokens.append(i + idx) |
| 167 | + seq_group_metadata_list.append( |
| 168 | + SequenceGroupMetadata( |
| 169 | + request_id=f"test_{i}", |
| 170 | + is_prompt=True, |
| 171 | + seq_data={0: SequenceData([1, 2, 3])}, |
| 172 | + sampling_params=sampling_params, |
| 173 | + block_tables={0: [1]}, |
| 174 | + )) |
| 175 | + |
| 176 | + _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) |
| 177 | + sampler_output = sampler(embedding=None, |
| 178 | + hidden_states=input_tensor, |
| 179 | + input_metadata=input_metadata) |
| 180 | + for i, sequence_output in enumerate(sampler_output): |
| 181 | + if seq_group_metadata_list[i].sampling_params.use_beam_search: |
| 182 | + continue |
| 183 | + for nth_output in sequence_output: |
| 184 | + assert nth_output.output_token in expected_tokens |
0 commit comments