Skip to content

Commit 68b61db

Browse files
tjohnson31415njhill
authored andcommitted
feat: implement support for prompt prefixes for flash_casual_lm and stantacoder
- add support for prompt embedding injection in flash_causal_lm.py - add inputs_embeds support to the Flash Santacoder custom modeling code - modify the discovery of the embedding layer from the model to make it work for GPTBigCode models Signed-off-by: Travis Johnson <[email protected]>
1 parent 45bfd01 commit 68b61db

File tree

3 files changed

+89
-45
lines changed

3 files changed

+89
-45
lines changed

server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,10 +327,20 @@ def forward(
327327
cu_seqlens,
328328
cu_seqlens_q,
329329
max_s,
330+
inputs_embeds: Optional[torch.Tensor] = None,
330331
past_key_values: Optional[torch.Tensor] = None,
331332
pre_allocate_past_size: Optional[int] = None,
332333
):
333-
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
334+
if input_ids is not None and inputs_embeds is not None:
335+
raise ValueError(
336+
"You cannot specify both input_ids and inputs_embeds at the same time"
337+
)
338+
339+
if inputs_embeds is not None:
340+
hidden_states = inputs_embeds + self.wpe(position_ids)
341+
# TODO: support TP for the position embeddings
342+
else:
343+
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
334344

335345
if self.process_group.size() > 1:
336346
torch.distributed.all_reduce(hidden_states, group=self.process_group)
@@ -396,6 +406,7 @@ def forward(
396406
cu_seqlens,
397407
cu_seqlens_q,
398408
max_s,
409+
inputs_embeds: Optional[torch.Tensor] = None,
399410
past_key_values: Optional[torch.Tensor] = None,
400411
pre_allocate_past_size: Optional[int] = None,
401412
lm_head_indices: Optional[torch.Tensor] = None,
@@ -406,6 +417,7 @@ def forward(
406417
cu_seqlens,
407418
cu_seqlens_q,
408419
max_s,
420+
inputs_embeds,
409421
past_key_values,
410422
pre_allocate_past_size,
411423
)

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 71 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,12 @@ class FlashCausalLMBatch(Batch):
3131
requests: List[generate_pb2.Request]
3232

3333
# Decoder values
34+
# tensors have sequences from the batch concatenated
35+
# shape is [sum(seq_lengths)]
3436
input_ids: torch.Tensor
3537
position_ids: torch.Tensor
38+
# shape is [sum(seq_lengths), embedding_size]
39+
inputs_embeds: torch.Tensor
3640
# cumulative sequence lengths
3741
cu_seqlens: torch.Tensor
3842
# cumulative query sequence lengths, only used in decode
@@ -68,77 +72,97 @@ def from_pb(
6872
) -> Tuple[Optional["FlashCausalLMBatch"], List[GenerateError]]:
6973
errors = []
7074
batch_inputs = []
75+
requests = pb.requests
76+
77+
# track indices of valid requests that have prefixes
78+
i = 0
79+
prefix_ids = {}
80+
# compute sequence lengths in this loop too
81+
# if there is a prefix, input_lengths will include its length
82+
input_lengths = []
7183
max_seqlen = 0
72-
for r in pb.requests:
84+
# Cumulative length
85+
cu_seqlens = [0]
86+
cumulative_length = 0
87+
for r in requests:
88+
input_length = r.input_length
89+
# TODO: Also fail depending on the model type for ones that don't
90+
# have input_embeds implemented?
7391
if r.prefix_id:
74-
message = f"Prompt prefixes not yet supported with flash attention (request #{r.id})"
75-
logging.error(message)
76-
# Exclude this request from the batch, return an error
77-
errors.append(GenerateError(request_id=r.id, message=message))
78-
continue
92+
try:
93+
prefix_embeds = prefix_cache.get(r.prefix_id)
94+
except Exception:
95+
message = f"Prefix lookup error for request #{r.id}, prefix id {r.prefix_id}"
96+
logging.error(message)
97+
# Exclude this request from the batch, return an error
98+
errors.append(GenerateError(request_id=r.id, message=message))
99+
continue
100+
prefix_ids[i] = prefix_embeds
101+
input_length += prefix_embeds.shape[0]
79102
batch_inputs.append(r.inputs)
80-
max_seqlen = max(max_seqlen, r.input_length)
103+
input_lengths.append(input_length)
104+
max_seqlen = max(max_seqlen, input_length)
105+
cumulative_length += input_length
106+
cu_seqlens.append(cumulative_length)
107+
i += 1
81108

109+
# remove errored requests
82110
if errors:
83111
requests = [r for r in pb.requests if not any(r.id == er.request_id for er in errors)]
112+
# early exit if no requests are valid
84113
if not requests:
85114
return None, errors
86115

116+
# return as lists to avoid unnecessary padding;
117+
# sequences will be concatenated across the batch
87118
batch_tokenized_inputs = tokenizer(
88119
batch_inputs, truncation=True, max_length=max_seqlen, return_token_type_ids=False
89120
)["input_ids"]
90121

122+
# Process inputs to generate the needed tensors
91123
input_ids = []
92124
position_ids = []
93-
cu_seqlens = [0]
94-
95-
input_lengths = []
96125
all_input_ids_tensor = []
97-
98126
next_token_choosers = []
99-
100-
# Cumulative length
101-
cumulative_length = 0
102-
103-
# Parse batch
104-
requests = pb.requests
105-
for r, tokenized_input in zip(requests, batch_tokenized_inputs):
106-
input_length = r.input_length
107-
108-
tokenized_input = tokenized_input[-input_length:]
109-
110-
# Fill in bos token in truncation case if needed
111-
if r.truncate and getattr(tokenizer, "add_bos_token", False):
112-
tokenized_input[0] = tokenizer.bos_token_id
113-
114-
input_lengths.append(input_length)
115-
127+
for r, tokenized_input, input_length in zip(requests, batch_tokenized_inputs, input_lengths):
128+
if r.truncate:
129+
tokenized_input = tokenized_input[-r.input_length:]
130+
# Fill in bos token in truncation case if needed
131+
if getattr(tokenizer, "add_bos_token", False):
132+
tokenized_input[0] = tokenizer.bos_token_id
116133
tokenized_input = torch.tensor(tokenized_input, device=device)
117-
input_ids.append(tokenized_input)
118-
119-
# Position ids
120-
position_ids.append(torch.arange(0, input_length, dtype=torch.int32))
121-
122-
# Add cumulative lengths of all previous inputs
123-
cu_seqlens.append(cumulative_length + input_length)
124-
134+
# LHS pad for prefix, if it exists; RHS pad to max output
135+
padded_input_ids = F.pad(tokenized_input, (input_length - r.input_length, r.max_output_length))
136+
all_input_ids_tensor.append(padded_input_ids)
137+
# input_ids needs prefix padding but not output padding
138+
input_ids.append(tokenized_input if input_length == r.input_length else padded_input_ids[:input_length])
125139
next_token_choosers.append(
126140
NextTokenChooser.from_pb(r.parameters, r.details.logprobs, tokenizer, device)
127141
)
128-
all_input_ids_tensor.append(F.pad(tokenized_input, (0, r.max_output_length)))
129-
130-
cumulative_length += input_length
131-
142+
position_ids.append(torch.arange(0, input_length, dtype=torch.int32))
132143
input_ids = torch.cat(input_ids)
133-
position_ids = torch.cat(position_ids).to(device, non_blocking=True)
134-
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
144+
145+
# convert all requests to embeddings if any request has a prefix_id
146+
if prefix_ids:
147+
# TODO: Handle TP distributed embeddings layer
148+
inputs_embeds = embeddings_lookup(input_ids)
149+
input_ids = None
150+
# fill in the prefix embeddings into the space that we already
151+
# allocated due to the padding in input_ids
152+
for i, p in prefix_ids.items():
153+
start = cu_seqlens[i]
154+
prefix_length = p.shape[0]
155+
inputs_embeds[start:start+prefix_length, :] = p
156+
else:
157+
inputs_embeds = None
135158

136159
return cls(
137160
batch_id=pb.id,
138161
requests=requests,
139162
input_ids=input_ids,
140-
position_ids=position_ids,
141-
cu_seqlens=cu_seqlens,
163+
inputs_embeds=inputs_embeds,
164+
position_ids=torch.cat(position_ids).to(device, non_blocking=True),
165+
cu_seqlens=torch.tensor(cu_seqlens, dtype=torch.int32, device=device),
142166
cu_seqlens_q=None,
143167
max_seqlen=max_seqlen,
144168
past_key_values=None,
@@ -195,6 +219,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
195219
batch_id=batches[0].batch_id,
196220
requests=requests,
197221
input_ids=torch.cat(input_ids),
222+
inputs_embeds=None,
198223
position_ids=torch.cat(position_ids),
199224
cu_seqlens=torch.cat(cu_seqlens),
200225
cu_seqlens_q=torch.arange(len(requests) + 1, device=device, dtype=torch.int32),
@@ -345,6 +370,7 @@ def generate_token(
345370
batch.cu_seqlens,
346371
batch.cu_seqlens_q,
347372
batch.max_seqlen,
373+
batch.inputs_embeds,
348374
past_key_values,
349375
prealloc_length,
350376
)
@@ -410,6 +436,7 @@ def _process_prefill(
410436
# Create final next batch tensors
411437
batch.input_ids = torch.cat(next_batch_input_ids) \
412438
if batch_size > 1 else next_batch_input_ids[0].view(1)
439+
batch.inputs_embeds = None
413440

414441
batch.cu_seqlens_q = torch.arange(
415442
batch_size + 1, device=self.device, dtype=torch.int32

server/text_generation_server/models/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ def get_indices_to_keep(
152152
return next_batch_keep_indices
153153

154154
def _setup_prompt_encoder(self) -> bool:
155+
# this is the most common name for the word embedding module for transformers models
156+
if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'wte'):
157+
self.word_embeddings = self.model.transformer.wte
158+
return True
159+
155160
vocab_size = getattr(self.model.config, "vocab_size", None)
156161

157162
if vocab_size is not None and hasattr(self.model, "named_children"):

0 commit comments

Comments
 (0)