@@ -90,6 +90,27 @@ def _create_bad_words_token_ids(
90
90
return bad_words_token_ids
91
91
92
92
93
+ # Returns all last tokens of bad word sequences that share the same prefix
94
+ # as `given_prefix` (excluding the last token).
95
+ def _collect_suffixes_with_same_prefix (
96
+ given_prefix : list [int ],
97
+ bad_words_token_ids : list [list [int ]]) -> list [int ]:
98
+ return [bwt [- 1 ] for bwt in bad_words_token_ids if bwt [:- 1 ] == given_prefix ]
99
+
100
+
101
+ # generate a valid token id that is not in bad_words_token_ids
102
+ def _generate_valid_token_id (bad_words_token_ids : list [list [int ]],
103
+ vocab_size : int ) -> int :
104
+ forbidden_start_tokens = set ()
105
+ for bad_word in bad_words_token_ids :
106
+ forbidden_start_tokens .add (bad_word [0 ])
107
+ # Get a safe token that's not in forbidden starts
108
+ safe_token_candidates = list (
109
+ set (range (vocab_size )) - forbidden_start_tokens )
110
+ # Pick a random safe token
111
+ return np .random .choice (safe_token_candidates )
112
+
113
+
93
114
def _update_output_token_ids_for_bad_words (
94
115
metadata : SamplingMetadata , vocab_size : int ) -> dict [int , list [int ]]:
95
116
bad_words_last_tokens = {}
@@ -104,12 +125,17 @@ def _update_output_token_ids_for_bad_words(
104
125
prefix_length = len (bad_word_token_ids ) - 1
105
126
has_bad_words = np .random .choice ([True , False ])
106
127
if has_bad_words :
107
- output_token_ids [- prefix_length :] = bad_word_token_ids [:- 1 ]
108
- bad_words_last_token .append (bad_word_token_ids [- 1 ])
128
+ prefix = bad_word_token_ids [:- 1 ]
129
+ output_token_ids [- prefix_length :] = prefix
130
+ # Collect all last tokens from other bad words
131
+ # that share this prefix
132
+ bad_words_last_token .extend (
133
+ _collect_suffixes_with_same_prefix (
134
+ prefix , bad_words_token_ids ))
109
135
break # Maximum one update to output_token_ids
110
136
else : # Make sure no accidental match to bad words
111
- output_token_ids [- 1 ] = ( bad_word_token_ids [ - 2 ] +
112
- 1 ) % vocab_size
137
+ output_token_ids [- 1 ] = _generate_valid_token_id (
138
+ bad_words_token_ids , vocab_size )
113
139
bad_words_last_tokens [batch_idx ] = bad_words_last_token
114
140
return bad_words_last_tokens
115
141
0 commit comments