Skip to content

Commit 1e8ed0c

Browse files
committed
Simplify handling of truncation with add_bos_token
In non-flash causal_lm and seq2seq_lm cases Move truncation / bos_token insertion logic before embeddings lookup so that special handling of bos embedding isn't needed Also update changelog with recent updates Signed-off-by: Nick Hill <[email protected]>
1 parent 68b61db commit 1e8ed0c

File tree

4 files changed

+26
-44
lines changed

4 files changed

+26
-44
lines changed

server/text_generation_server/models/causal_lm.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,18 @@ def from_pb(
140140
# Copy tokenizer attention_mask into fully allocated attention_mask
141141
attention_mask[:, :tokenize_length] = tokenized_inputs["attention_mask"]
142142

143+
# Mask out truncated tokens
144+
# (input_texts aren't truncated, only input_lengths are)
145+
if truncate_indices:
146+
add_bos_token = getattr(tokenizer, "add_bos_token", False)
147+
for i in truncate_indices:
148+
orig_input_length = requests[i].input_length
149+
attention_mask[i, :-orig_input_length-padding_right_offset] = 0
150+
all_input_ids[i, :-orig_input_length] = tokenizer.pad_token_id
151+
if add_bos_token:
152+
# Ensure that first non-virtual token is set to BOS
153+
all_input_ids[i, -orig_input_length] = tokenizer.bos_token_id
154+
143155
if prefix_ids:
144156
# Get input embeddings
145157
inputs_embeds = embeddings_lookup(all_input_ids)
@@ -155,24 +167,6 @@ def from_pb(
155167
input_ids = all_input_ids
156168
inputs_embeds = None
157169

158-
# Mask out truncated tokens
159-
# (input_texts aren't truncated, only input_lengths are)
160-
if truncate_indices:
161-
add_bos_token = getattr(tokenizer, "add_bos_token", False)
162-
for i in truncate_indices:
163-
input_length = input_lengths[i]
164-
attention_mask[i, :-input_length-padding_right_offset] = 0
165-
if inputs_embeds is not None:
166-
inputs_embeds[i, :-input_length, :] = 0
167-
if add_bos_token:
168-
p = prefix_ids.get(i)
169-
orig_length = input_length if p is None else input_length - p.shape[0]
170-
inputs_embeds[i, -orig_length] = prefix_cache.bos_embedding
171-
else:
172-
input_ids[i, :-input_length] = tokenizer.pad_token_id
173-
if add_bos_token:
174-
input_ids[i, -input_length] = tokenizer.bos_token_id
175-
176170
if use_position_ids:
177171
# Fix up position ids
178172
sliced_attention_mask = attention_mask[:, :-padding_right_offset]

server/text_generation_server/models/model.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,9 @@ def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype):
4848

4949
if prompt_prefix_supported:
5050
# Set up prefix cache
51-
bos_token_id = getattr(self.tokenizer, "bos_token_id", None)
5251
decoder_start_token_id = self.model.config.decoder_start_token_id
5352
if decoder_start_token_id is None:
54-
decoder_start_token_id = bos_token_id
53+
decoder_start_token_id = self.tokenizer.bos_token_id
5554
self.prefix_cache = PrefixCache(
5655
device=self.device,
5756
dtype=dtype,
@@ -60,9 +59,6 @@ def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype):
6059
decoder_start_tok_embedding=self.word_embeddings(
6160
torch.tensor([decoder_start_token_id], device=self.device, dtype=torch.long)
6261
) if decoder_start_token_id is not None else None,
63-
bos_embedding=self.word_embeddings(
64-
torch.tensor([bos_token_id], device=self.device, dtype=torch.long)
65-
) if bos_token_id is not None else None,
6662
)
6763
else:
6864
self.prefix_cache = None

server/text_generation_server/models/seq2seq_lm.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,18 @@ def from_pb(
150150
input_ids = tokenized_inputs["input_ids"]
151151
attention_mask = tokenized_inputs["attention_mask"]
152152

153+
# Mask out truncated tokens
154+
# (input_texts aren't truncated, only input_lengths are)
155+
if truncate_indices:
156+
add_bos_token = getattr(tokenizer, "add_bos_token", False)
157+
for i in truncate_indices:
158+
orig_input_length = requests[i].input_length
159+
attention_mask[i, :-orig_input_length] = 0
160+
input_ids[i, :-orig_input_length] = tokenizer.pad_token_id
161+
if add_bos_token:
162+
# Ensure that first non-virtual token is set to BOS
163+
input_ids[i, -orig_input_length] = tokenizer.bos_token_id
164+
153165
if encoder_prefix_ids:
154166
# Get input embeddings
155167
inputs_embeds = embeddings_lookup(input_ids)
@@ -163,23 +175,6 @@ def from_pb(
163175
else:
164176
inputs_embeds = None
165177

166-
# Mask out truncated tokens
167-
# (input_texts aren't truncated, only input_lengths are)
168-
if truncate_indices:
169-
for i in truncate_indices:
170-
add_bos_token = getattr(tokenizer, "add_bos_token", False)
171-
input_length = input_lengths[i]
172-
attention_mask[i, :-input_length] = 0
173-
input_ids[i, :-input_length] = tokenizer.pad_token_id
174-
if add_bos_token:
175-
input_ids[i, -input_length] = tokenizer.bos_token_id
176-
if inputs_embeds is not None:
177-
inputs_embeds[i, :-input_length, :] = 0
178-
if add_bos_token:
179-
p = encoder_prefix_ids.get(i)
180-
orig_length = input_length if p is None else input_length - p.shape[0]
181-
inputs_embeds[i, -orig_length] = prefix_cache.bos_embedding
182-
183178
if decoder_prefix_ids:
184179
# Construct decoder embeddings and attention mask
185180
start_tok_embedding = prefix_cache.decoder_start_tok_embedding

server/text_generation_server/prompt_cache.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,18 +150,15 @@ def __init__(
150150
max_length: int,
151151
encoder_decoder: bool,
152152
decoder_start_tok_embedding: torch.Tensor,
153-
bos_embedding: torch.Tensor,
154153
):
155154
self.max_length = max_length
156155
self.embed_size = decoder_start_tok_embedding.shape[1] \
157-
if decoder_start_tok_embedding is not None else \
158-
(bos_embedding.shape[1] if bos_embedding is not None else None)
156+
if decoder_start_tok_embedding is not None else None
159157
self.device: torch.device = device
160158
self.dtype = dtype
161159

162160
self.is_encoder_decoder = encoder_decoder
163161
self.decoder_start_tok_embedding = decoder_start_tok_embedding
164-
self.bos_embedding = bos_embedding
165162

166163
self.cache_map: Dict[str, PromptCacheNode] = {}
167164
self.cache_dll: DoublyLinkedList = DoublyLinkedList()

0 commit comments

Comments
 (0)