Skip to content

Commit 7d6ebcf

Browse files
committed
Fix for truncation option for models whose tokenizer adds bos_token
Notably includes Llama models Signed-off-by: Nick Hill <[email protected]>
1 parent 6488bb4 commit 7d6ebcf

File tree

5 files changed

+43
-15
lines changed

5 files changed

+43
-15
lines changed

server/text_generation_server/models/causal_lm.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,21 @@ def from_pb(
157157

158158
# Mask out truncated tokens
159159
# (input_texts aren't truncated, only input_lengths are)
160-
for i in truncate_indices:
161-
input_length = input_lengths[i]
162-
attention_mask[i, :-input_length-padding_right_offset] = 0
163-
if inputs_embeds is not None:
164-
inputs_embeds[i, :-input_length, :] = 0
165-
else:
166-
input_ids[i, :-input_length] = tokenizer.pad_token_id
160+
if truncate_indices:
161+
add_bos_token = getattr(tokenizer, "add_bos_token", False)
162+
for i in truncate_indices:
163+
input_length = input_lengths[i]
164+
attention_mask[i, :-input_length-padding_right_offset] = 0
165+
if inputs_embeds is not None:
166+
inputs_embeds[i, :-input_length, :] = 0
167+
if add_bos_token:
168+
p = prefix_ids.get(i)
169+
orig_length = input_length if p is None else input_length - p.shape[0]
170+
inputs_embeds[i, -orig_length] = prefix_cache.bos_embedding
171+
else:
172+
input_ids[i, :-input_length] = tokenizer.pad_token_id
173+
if add_bos_token:
174+
input_ids[i, -input_length] = tokenizer.bos_token_id
167175

168176
if use_position_ids:
169177
# Fix up position ids

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ def from_pb(
107107

108108
tokenized_input = tokenized_input[-input_length:]
109109

110+
# Fill in bos token in truncation case if needed
111+
if r.truncate and getattr(tokenizer, "add_bos_token", False):
112+
tokenized_input[0] = tokenizer.bos_token_id
113+
110114
input_lengths.append(input_length)
111115

112116
tokenized_input = torch.tensor(tokenized_input, device=device)

server/text_generation_server/models/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype):
4848

4949
if prompt_prefix_supported:
5050
# Set up prefix cache
51+
bos_token_id = getattr(self.tokenizer, "bos_token_id", None)
5152
decoder_start_token_id = self.model.config.decoder_start_token_id
5253
if decoder_start_token_id is None:
53-
decoder_start_token_id = self.tokenizer.bos_token_id
54+
decoder_start_token_id = bos_token_id
5455
self.prefix_cache = PrefixCache(
5556
device=self.device,
5657
dtype=dtype,
@@ -59,6 +60,9 @@ def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype):
5960
decoder_start_tok_embedding=self.word_embeddings(
6061
torch.tensor([decoder_start_token_id], device=self.device, dtype=torch.long)
6162
) if decoder_start_token_id is not None else None,
63+
bos_embedding=self.word_embeddings(
64+
torch.tensor([bos_token_id], device=self.device, dtype=torch.long)
65+
) if bos_token_id is not None else None,
6266
)
6367
else:
6468
self.prefix_cache = None

server/text_generation_server/models/seq2seq_lm.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,20 @@ def from_pb(
165165

166166
# Mask out truncated tokens
167167
# (input_texts aren't truncated, only input_lengths are)
168-
for i in truncate_indices:
169-
input_length = input_lengths[i]
170-
attention_mask[i, :-input_length] = 0
171-
input_ids[i, :-input_length] = tokenizer.pad_token_id
172-
if inputs_embeds is not None:
173-
inputs_embeds[i, :-input_length, :] = 0
168+
if truncate_indices:
169+
for i in truncate_indices:
170+
add_bos_token = getattr(tokenizer, "add_bos_token", False)
171+
input_length = input_lengths[i]
172+
attention_mask[i, :-input_length] = 0
173+
input_ids[i, :-input_length] = tokenizer.pad_token_id
174+
if add_bos_token:
175+
input_ids[i, -input_length] = tokenizer.bos_token_id
176+
if inputs_embeds is not None:
177+
inputs_embeds[i, :-input_length, :] = 0
178+
if add_bos_token:
179+
p = encoder_prefix_ids.get(i)
180+
orig_length = input_length if p is None else input_length - p.shape[0]
181+
inputs_embeds[i, -orig_length] = prefix_cache.bos_embedding
174182

175183
if decoder_prefix_ids:
176184
# Construct decoder embeddings and attention mask

server/text_generation_server/prompt_cache.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def move_node_to_head(self, cache_node: PromptCacheNode):
139139
cache_node.prev = None
140140
self.head = cache_node
141141

142+
142143
class PrefixCache:
143144
"""Holds the cache of injectable prompts for a single model.
144145
"""
@@ -149,15 +150,18 @@ def __init__(
149150
max_length: int,
150151
encoder_decoder: bool,
151152
decoder_start_tok_embedding: torch.Tensor,
153+
bos_embedding: torch.Tensor,
152154
):
153155
self.max_length = max_length
154156
self.embed_size = decoder_start_tok_embedding.shape[1] \
155-
if decoder_start_tok_embedding is not None else None
157+
if decoder_start_tok_embedding is not None else \
158+
(bos_embedding.shape[1] if bos_embedding is not None else None)
156159
self.device: torch.device = device
157160
self.dtype = dtype
158161

159162
self.is_encoder_decoder = encoder_decoder
160163
self.decoder_start_tok_embedding = decoder_start_tok_embedding
164+
self.bos_embedding = bos_embedding
161165

162166
self.cache_map: Dict[str, PromptCacheNode] = {}
163167
self.cache_dll: DoublyLinkedList = DoublyLinkedList()

0 commit comments

Comments
 (0)