Skip to content

Commit bc8372e

Browse files
authored
[Bugfix] Fix erroneous randomly generated cases in bad word testing (#22170)
Signed-off-by: phantomlei <[email protected]>
1 parent 8d17fa6 commit bc8372e

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

tests/v1/sample/test_sampler.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,27 @@ def _create_bad_words_token_ids(
9090
return bad_words_token_ids
9191

9292

93+
# Returns all last tokens of bad word sequences that share the same prefix
94+
# as `given_prefix` (excluding the last token).
95+
def _collect_suffixes_with_same_prefix(
96+
given_prefix: list[int],
97+
bad_words_token_ids: list[list[int]]) -> list[int]:
98+
return [bwt[-1] for bwt in bad_words_token_ids if bwt[:-1] == given_prefix]
99+
100+
101+
# generate a valid token id that is not in bad_words_token_ids
102+
def _generate_valid_token_id(bad_words_token_ids: list[list[int]],
103+
vocab_size: int) -> int:
104+
forbidden_start_tokens = set()
105+
for bad_word in bad_words_token_ids:
106+
forbidden_start_tokens.add(bad_word[0])
107+
# Get a safe token that's not in forbidden starts
108+
safe_token_candidates = list(
109+
set(range(vocab_size)) - forbidden_start_tokens)
110+
# Pick a random safe token
111+
return np.random.choice(safe_token_candidates)
112+
113+
93114
def _update_output_token_ids_for_bad_words(
94115
metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]:
95116
bad_words_last_tokens = {}
@@ -104,12 +125,17 @@ def _update_output_token_ids_for_bad_words(
104125
prefix_length = len(bad_word_token_ids) - 1
105126
has_bad_words = np.random.choice([True, False])
106127
if has_bad_words:
107-
output_token_ids[-prefix_length:] = bad_word_token_ids[:-1]
108-
bad_words_last_token.append(bad_word_token_ids[-1])
128+
prefix = bad_word_token_ids[:-1]
129+
output_token_ids[-prefix_length:] = prefix
130+
# Collect all last tokens from other bad words
131+
# that share this prefix
132+
bad_words_last_token.extend(
133+
_collect_suffixes_with_same_prefix(
134+
prefix, bad_words_token_ids))
109135
break # Maximum one update to output_token_ids
110136
else: # Make sure no accidental match to bad words
111-
output_token_ids[-1] = (bad_word_token_ids[-2] +
112-
1) % vocab_size
137+
output_token_ids[-1] = _generate_valid_token_id(
138+
bad_words_token_ids, vocab_size)
113139
bad_words_last_tokens[batch_idx] = bad_words_last_token
114140
return bad_words_last_tokens
115141

0 commit comments

Comments
 (0)