Skip to content

Commit fc1f677

Browse files
authored
[BugFix][V1] Fix overhead related to bad_words sampling when not in use (#14894)
Signed-off-by: Nick Hill <[email protected]>
1 parent f6137ad commit fc1f677

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,9 @@ def _construct_expected_sampling_metadata(
124124
if req.sampling_params.allowed_token_ids:
125125
allowed_token_ids_mask[index_in_input_batch][
126126
req.sampling_params.allowed_token_ids] = True
127-
bad_words_token_ids[
128-
index_in_input_batch] = req.sampling_params.bad_words_token_ids
127+
if req.sampling_params.bad_words_token_ids:
128+
bad_words_token_ids[
129+
index_in_input_batch] = req.sampling_params.bad_words_token_ids
129130

130131
return SamplingMetadata(
131132
temperature=torch.tensor(temperature, dtype=torch.float,

vllm/sampling_params.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ class SamplingParams(
235235

236236
# Fields used for bad words
237237
bad_words: Optional[list[str]] = None
238-
_bad_words_token_ids: list[list[int]] = msgspec.field(default_factory=list)
238+
_bad_words_token_ids: Optional[list[list[int]]] = None
239239

240240
@staticmethod
241241
def from_optional(
@@ -464,8 +464,9 @@ def update_from_generation_config(
464464
self.stop_token_ids = list(eos_ids)
465465

466466
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
467-
if self.bad_words is None:
467+
if not self.bad_words:
468468
return
469+
self._bad_words_token_ids = []
469470
for bad_word in self.bad_words:
470471
# To prohibit words both at the beginning
471472
# and in the middle of text
@@ -516,7 +517,7 @@ def all_stop_token_ids(self) -> set[int]:
516517
return self._all_stop_token_ids
517518

518519
@property
519-
def bad_words_token_ids(self) -> list[list[int]]:
520+
def bad_words_token_ids(self) -> Optional[list[list[int]]]:
520521
# For internal use only. Backward compatibility not guaranteed
521522
return self._bad_words_token_ids
522523

vllm/v1/worker/gpu_input_batch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,9 @@ def add_request(
324324
self.allowed_token_ids_mask_cpu_tensor[req_index][
325325
sampling_params.allowed_token_ids] = False
326326

327-
self.bad_words_token_ids[
328-
req_index] = sampling_params.bad_words_token_ids
327+
if sampling_params.bad_words_token_ids:
328+
self.bad_words_token_ids[
329+
req_index] = sampling_params.bad_words_token_ids
329330

330331
# Add request lora ID
331332
if request.lora_request:

0 commit comments

Comments
 (0)