File tree Expand file tree Collapse file tree 3 files changed +10
-7
lines changed Expand file tree Collapse file tree 3 files changed +10
-7
lines changed Original file line number Diff line number Diff line change @@ -124,8 +124,9 @@ def _construct_expected_sampling_metadata(
124
124
if req .sampling_params .allowed_token_ids :
125
125
allowed_token_ids_mask [index_in_input_batch ][
126
126
req .sampling_params .allowed_token_ids ] = True
127
- bad_words_token_ids [
128
- index_in_input_batch ] = req .sampling_params .bad_words_token_ids
127
+ if req .sampling_params .bad_words_token_ids :
128
+ bad_words_token_ids [
129
+ index_in_input_batch ] = req .sampling_params .bad_words_token_ids
129
130
130
131
return SamplingMetadata (
131
132
temperature = torch .tensor (temperature , dtype = torch .float ,
Original file line number Diff line number Diff line change @@ -235,7 +235,7 @@ class SamplingParams(
235
235
236
236
# Fields used for bad words
237
237
bad_words : Optional [list [str ]] = None
238
- _bad_words_token_ids : list [list [int ]] = msgspec . field ( default_factory = list )
238
+ _bad_words_token_ids : Optional [ list [list [int ]]] = None
239
239
240
240
@staticmethod
241
241
def from_optional (
@@ -464,8 +464,9 @@ def update_from_generation_config(
464
464
self .stop_token_ids = list (eos_ids )
465
465
466
466
def update_from_tokenizer (self , tokenizer : AnyTokenizer ) -> None :
467
- if self .bad_words is None :
467
+ if not self .bad_words :
468
468
return
469
+ self ._bad_words_token_ids = []
469
470
for bad_word in self .bad_words :
470
471
# To prohibit words both at the beginning
471
472
# and in the middle of text
@@ -516,7 +517,7 @@ def all_stop_token_ids(self) -> set[int]:
516
517
return self ._all_stop_token_ids
517
518
518
519
@property
519
- def bad_words_token_ids (self ) -> list [list [int ]]:
520
+ def bad_words_token_ids (self ) -> Optional [ list [list [int ] ]]:
520
521
# For internal use only. Backward compatibility not guaranteed
521
522
return self ._bad_words_token_ids
522
523
Original file line number Diff line number Diff line change @@ -324,8 +324,9 @@ def add_request(
324
324
self .allowed_token_ids_mask_cpu_tensor [req_index ][
325
325
sampling_params .allowed_token_ids ] = False
326
326
327
- self .bad_words_token_ids [
328
- req_index ] = sampling_params .bad_words_token_ids
327
+ if sampling_params .bad_words_token_ids :
328
+ self .bad_words_token_ids [
329
+ req_index ] = sampling_params .bad_words_token_ids
329
330
330
331
# Add request lora ID
331
332
if request .lora_request :
You can’t perform that action at this time.
0 commit comments