diff --git a/tests/unit_tests/sample/test_sampler.py b/tests/unit_tests/sample/test_sampler.py new file mode 100644 index 00000000..7eee1aba --- /dev/null +++ b/tests/unit_tests/sample/test_sampler.py @@ -0,0 +1,422 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import numpy as np +import pytest +import torch +import habana_frameworks.torch # noqa: F401 + +from vllm_gaudi.utils import make_tensor_with_pad +from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler + +LOGITS_DTYPES = [torch.float, torch.bfloat16, torch.float16] +MAX_NUM_REQS = 256 +VOCAB_SIZE = 1024 +NUM_OUTPUT_TOKENS = 20 +DEVICES = ['hpu'] +MAX_NUM_PROMPT_TOKENS = 64 + + +def _create_fake_logits(batch_size: int, vocab_size: int, + dtype: torch.dtype) -> torch.Tensor: + fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=dtype) + return fake_logits + + +def _create_penalty_tensor(batch_size: int, penalty_value: float, + device: torch.device) -> torch.Tensor: + return torch.full((batch_size, ), + fill_value=penalty_value, + dtype=torch.float, + device=device) + + +def _create_prompt_tokens_tensor( + prompt_token_ids: list[list[int]], + vocab_size: int, + device: torch.device, +) -> torch.Tensor: + return make_tensor_with_pad( + prompt_token_ids, + pad=vocab_size, + device=device, + dtype=torch.int64, + pin_memory=False, + ) + + +def _create_allowed_token_ids( + batch_size: int, + vocab_size: int, + num_allowed_token_ids: int, + device: torch.device, +) -> Optional[torch.Tensor]: + mask: Optional[torch.Tensor] = None + for i in range(batch_size): + if i % 2 == 1: + continue + if mask is None: + mask = torch.zeros((batch_size, vocab_size), + dtype=torch.bool, + device=device) + start = min(i, vocab_size - 1) + end = min(i + num_allowed_token_ids, vocab_size - 1) + mask[i, start:end] = True + return mask + + +def _create_bad_words_token_ids( + batch_size: int, vocab_size: int, + bad_words_lengths: list[tuple[int]]) -> dict[int, list[list[int]]]: + bad_words_token_ids = {} + for batch_idx in range(batch_size): + token_ids_single_batch = [] + for bad_words_length in bad_words_lengths: + token_ids = np.random.choice(vocab_size, + size=bad_words_length, + replace=True).tolist() + token_ids_single_batch.append(token_ids) + bad_words_token_ids[batch_idx] = token_ids_single_batch + if batch_size >= 2: + # Test no bad_words for some batch + no_bad_words_batch_idx = np.random.choice(batch_size) + bad_words_token_ids.pop(no_bad_words_batch_idx, None) + return bad_words_token_ids + + +def _update_output_token_ids_for_bad_words( + metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]: + bad_words_last_tokens = {} + for batch_idx, bad_words_token_ids in metadata.bad_words_token_ids.items(): + output_token_ids = metadata.output_token_ids[batch_idx] + bad_words_last_token: list[int] = [] + for i, bad_word_token_ids in enumerate(bad_words_token_ids): + if len(bad_word_token_ids) == 1: + # Single token id always affects logits + bad_words_last_token.append(bad_word_token_ids[0]) + else: + prefix_length = len(bad_word_token_ids) - 1 + has_bad_words = np.random.choice([True, False]) + if has_bad_words: + output_token_ids[-prefix_length:] = bad_word_token_ids[:-1] + bad_words_last_token.append(bad_word_token_ids[-1]) + break # Maximum one update to output_token_ids + else: # Make sure no accidental match to bad words + output_token_ids[-1] = (bad_word_token_ids[-2] + + 1) % vocab_size + bad_words_last_tokens[batch_idx] = bad_words_last_token + return bad_words_last_tokens + + +def _create_default_sampling_metadata( + num_output_tokens: int, + batch_size: int, + vocab_size: int, + device: torch.device, +) -> SamplingMetadata: + output_token_ids: list[list[int]] = [] + prompt_token_ids: list[list[int]] = [] + for _ in range(batch_size): + output_token_ids.append( + np.random.randint(0, vocab_size, size=num_output_tokens).tolist()) + prompt_token_ids.append( + np.random.randint(0, + vocab_size, + size=np.random.randint( + 1, MAX_NUM_PROMPT_TOKENS)).tolist()) + fake_sampling_metadata = SamplingMetadata( + temperature=torch.full((batch_size, ), 0.0), + all_greedy=True, + all_random=False, + top_p=None, + top_k=None, + generators={}, + max_num_logprobs=0, + prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, + vocab_size, device), + output_token_ids=output_token_ids, + frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device), + presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), + repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), + no_penalties=True, + allowed_token_ids_mask=None, + bad_words_token_ids={}, + logitsprocs=LogitsProcessorManager(), + ) + return fake_sampling_metadata + + +def _create_weighted_output_token_list( + batch_size: int, + vocab_size: int) -> tuple[list[list[int]], list[list[int]]]: + """ + Creates an output token list where each token occurs a distinct + number of times. + + For each batch, a random subset of token IDs is selected from the + vocabulary. The selected tokens are then added to the output token + list, each with a different frequency. + + Returns: + tuple[list[list[int]], list[list[int]]]: + - The first element is the output token list, where each sublist + corresponds to a batch and contains tokens with weighted + frequencies. + - The second element is a list of distinct token IDs for each + batch, ordered by their frequency in the corresponding output + list. + """ + output_token_ids: list[list[int]] = [] + sorted_token_ids_in_output: list[list[int]] = [] + for _ in range(batch_size): + distinct_token_ids = np.random.choice(vocab_size, + size=np.random.randint(1, 10), + replace=False).tolist() + sorted_token_ids_in_output.append(distinct_token_ids) + output_token_ids_for_batch = [] + for index, token_id in enumerate(distinct_token_ids): + output_token_ids_for_batch.extend( + [token_id for _ in range(index + 1)]) + output_token_ids.append(output_token_ids_for_batch) + return output_token_ids, sorted_token_ids_in_output + + +@pytest.mark.parametrize("logits_dtype", LOGITS_DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0]) +def test_sampler_presence_penalty(device: str, batch_size: int, + presence_penalty: float, + logits_dtype: torch.dtype): + """ + Test to verify that if presence penalty is enabled then tokens + are penalized as per their presence in the existing output. + """ + torch.set_default_device(device) + # Create fake logits where each token is assigned the same + # logit value. + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE, logits_dtype) + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + output_token_ids = sampling_metadata.output_token_ids + sampling_metadata.presence_penalties = _create_penalty_tensor( + batch_size, presence_penalty, torch.device(device)) + sampling_metadata.no_penalties = False + sampler = Sampler() + logits = sampler.apply_penalties(fake_logits, sampling_metadata) + assert logits.dtype == logits_dtype + assert logits.device.type == device + logits = logits.cpu() + for batch_idx in range(batch_size): + # Since all tokens initially have the same logits, the non-penalized + # token ID will be the one with the highest logit value, while the + # penalized token ID will be the one with the lowest logit value. + non_penalized_token_id = logits[batch_idx].argmax().item() + penalized_token_id = logits[batch_idx].argmin().item() + if presence_penalty > 0: + # If `presence_penalty` is set to a value greater than 0, it + # indicates a preference for new tokens over those already + # present in the output. + # Verify that the penalized token ID exists in the output, while the + # non-penalized token ID does not. + assert penalized_token_id in output_token_ids[batch_idx] + assert non_penalized_token_id not in output_token_ids[batch_idx] + elif presence_penalty < 0: + # If `presence_penalty` is set to a value less than 0, it indicates + # a preference for existing tokens over new ones. Verify that the + # non-penalized token ID exists in the output, while the penalized + # token ID does not. + assert non_penalized_token_id in output_token_ids[batch_idx] + assert penalized_token_id not in output_token_ids[batch_idx] + + +@pytest.mark.parametrize("logits_dtype", LOGITS_DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +@pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0]) +def test_sampler_frequency_penalty(device: str, batch_size: int, + frequency_penalty: float, + logits_dtype: torch.dtype): + """ + Test to verify that if frequency penalty is enabled then tokens are + penalized as per their frequency of occurrence. + """ + torch.set_default_device(device) + # Create fake logits where each token is assigned the same + # logit value. + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE, logits_dtype) + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + sampling_metadata.frequency_penalties = _create_penalty_tensor( + batch_size, frequency_penalty, torch.device(device)) + output_token_ids, sorted_token_ids_in_output = \ + _create_weighted_output_token_list( + batch_size, + VOCAB_SIZE, + ) + sampling_metadata.output_token_ids = output_token_ids + sampling_metadata.no_penalties = False + sampler = Sampler() + logits = sampler.apply_penalties(fake_logits, sampling_metadata) + assert logits.dtype == logits_dtype + assert logits.device.type == device + logits = logits.cpu() + for batch_idx in range(batch_size): + non_penalized_token_id = logits[batch_idx].argmax().item() + penalized_token_id = logits[batch_idx].argmin().item() + distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[ + batch_idx] + most_frequent_token_id = distinct_sorted_token_ids_in_output[ + len(distinct_sorted_token_ids_in_output) - 1] + if frequency_penalty > 0: + # If `frequency_penalty` is set to > 0, it indicates + # a preference for new tokens over existing ones. Verify that the + # non-penalized token ID is not present in the output, while the + # most penalized token is the one that occurs most frequently in + # the output. + assert (non_penalized_token_id + not in distinct_sorted_token_ids_in_output) + assert penalized_token_id == most_frequent_token_id + elif frequency_penalty < 0: + # If `frequency_penalty` is set to < 0, it indicates + # a preference for existing tokens over new ones. Verify that the + # non-penalized token ID is the one that occurs most frequently + # in the output, while the penalized token ID is one that has not + # yet appeared. + assert non_penalized_token_id == most_frequent_token_id + assert penalized_token_id not in distinct_sorted_token_ids_in_output + + +@pytest.mark.parametrize("logits_dtype", LOGITS_DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +@pytest.mark.parametrize("repetition_penalty", [0.1, 1.9]) +def test_sampler_repetition_penalty(device: str, batch_size: int, + repetition_penalty: float, + logits_dtype: torch.dtype): + """ + Test to verify that when the repetition penalty is enabled, tokens + are penalized based on their presence in the prompt or the existing + output. + """ + torch.set_default_device(device) + # Create fake logits where each token is assigned the same + # logit value. + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE, logits_dtype) + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + sampling_metadata.repetition_penalties = _create_penalty_tensor( + batch_size, repetition_penalty, torch.device(device)) + sampling_metadata.no_penalties = False + sampler = Sampler() + logits = sampler.apply_penalties(fake_logits, sampling_metadata) + assert logits.dtype == logits_dtype + assert logits.device.type == device + logits = logits.cpu() + for batch_idx in range(batch_size): + non_penalized_token_id = logits[batch_idx].argmax().item() + penalized_token_id = logits[batch_idx].argmin().item() + prompt_tokens = sampling_metadata.prompt_token_ids[ + batch_idx][:].tolist() + output_tokens = sampling_metadata.output_token_ids[batch_idx] + if repetition_penalty > 1.0: + # If `repetition_penalty` > 1.0, verify that the non-penalized + # token ID has not been seen before, while the penalized token ID + # exists either in the prompt or the output. + assert (non_penalized_token_id not in prompt_tokens + and non_penalized_token_id not in output_tokens) + assert (penalized_token_id in prompt_tokens + or penalized_token_id in output_tokens) + elif repetition_penalty < 1.0: + # If `repetition_penalty` < 1.0, verify that the penalized + # token ID has not been seen before, while the non-penalized + # token ID exists either in the prompt or the output. + assert (penalized_token_id not in prompt_tokens + and penalized_token_id not in output_tokens) + assert (non_penalized_token_id in prompt_tokens + or non_penalized_token_id in output_tokens) + + +@pytest.mark.parametrize("logits_dtype", LOGITS_DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2]) +def test_sampler_allowed_token_ids(device: str, batch_size: int, + num_allowed_token_ids: int, + logits_dtype: torch.dtype): + """ + Test to verify that when the repetition penalty is enabled, tokens + are penalized based on their presence in the prompt or the existing + output. + """ + torch.set_default_device(device) + # Create fake logits where each token is assigned the same + # logit value. + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE, logits_dtype) + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + mask = _create_allowed_token_ids( + batch_size=batch_size, + vocab_size=VOCAB_SIZE, + num_allowed_token_ids=num_allowed_token_ids, + device=device, + ) + sampling_metadata.allowed_token_ids_mask = mask + sampler = Sampler() + logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata) + assert logits.dtype == logits_dtype + assert logits.device.type == device + logits = logits.cpu() + for batch_idx in range(batch_size): + logits_for_req = logits[batch_idx] + if batch_idx % 2 == 1: + assert torch.all(logits_for_req != -float("inf")) + continue + for token_id in range(VOCAB_SIZE): + start = min(batch_idx, VOCAB_SIZE - 1) + end = min(batch_idx + num_allowed_token_ids, VOCAB_SIZE - 1) + if token_id >= start and token_id < end: + assert logits_for_req[token_id] == -float( + "inf"), f"{batch_idx}, {token_id}" + else: + assert logits_for_req[token_id] != -float("inf") + + +@pytest.mark.parametrize("logits_dtype", LOGITS_DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)]) +def test_sampler_bad_words(device: str, batch_size: int, + bad_words_lengths: list[tuple[int]], + logits_dtype: torch.dtype): + """ + Test to verify that when the bad words restriction is present, tokens + are penalized based on their match with the bad words. + """ + torch.set_default_device(device) + # Create fake logits where each token is assigned the same + # logit value. + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE, logits_dtype) + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + sampling_metadata.bad_words_token_ids = _create_bad_words_token_ids( + batch_size, VOCAB_SIZE, bad_words_lengths) + bad_words_last_tokens = _update_output_token_ids_for_bad_words( + sampling_metadata, VOCAB_SIZE) + sampler = Sampler() + logits = sampler.apply_bad_words(fake_logits, sampling_metadata) + assert logits.dtype == logits_dtype + assert logits.device.type == device + logits = logits.cpu() + for batch_idx in range(batch_size): + logits_for_req = logits[batch_idx] + for token_id in range(VOCAB_SIZE): + if (batch_idx in bad_words_last_tokens + and token_id in bad_words_last_tokens[batch_idx]): + assert logits_for_req[token_id] == -float("inf") + else: + assert logits_for_req[token_id] != -float("inf")