Skip to content

Commit 6b772b0

Browse files
committed
fix: Add input_embeds arg to all flash model impls
Means that recent changes to support input_embeds for santacoder (bigcode_gpt) don't break usage of the other flash model impls.
1 parent 1e8ed0c commit 6b772b0

File tree

4 files changed

+13
-1
lines changed

4 files changed

+13
-1
lines changed

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,10 +489,14 @@ def forward(
489489
cu_seqlens,
490490
cu_seqlens_q,
491491
max_s,
492+
inputs_embeds: Optional[torch.Tensor] = None,
492493
past_key_values: Optional[torch.Tensor] = None,
493494
pre_allocate_past_size: Optional[int] = None,
494495
lm_head_indices: Optional[torch.Tensor] = None,
495496
):
497+
if inputs_embeds is not None:
498+
raise ValueError("input_embeds not yet supported for flash llama")
499+
496500
hidden_states, present = self.model(
497501
input_ids,
498502
position_ids,

server/text_generation_server/models/custom_modeling/flash_neox_modeling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,10 +398,14 @@ def forward(
398398
cu_seqlens,
399399
cu_seqlens_q,
400400
max_s,
401+
inputs_embeds: Optional[torch.Tensor] = None,
401402
past_key_values: Optional[torch.Tensor] = None,
402403
pre_allocate_past_size: Optional[int] = None,
403404
lm_head_indices: Optional[torch.Tensor] = None,
404405
):
406+
if inputs_embeds is not None:
407+
raise ValueError("input_embeds not yet supported for flash neox")
408+
405409
hidden_states, present = self.gpt_neox(
406410
input_ids,
407411
position_ids,

server/text_generation_server/models/custom_modeling/flash_rw_modeling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,10 +634,14 @@ def forward(
634634
cu_seqlens,
635635
cu_seqlens_q,
636636
max_s,
637+
inputs_embeds: Optional[torch.Tensor] = None,
637638
past_key_values: Optional[torch.Tensor] = None,
638639
pre_allocate_past_size: Optional[int] = None,
639640
lm_head_indices: Optional[torch.Tensor] = None,
640641
):
642+
if inputs_embeds is not None:
643+
raise ValueError("input_embeds not yet supported for flash rw (falcon)")
644+
641645
hidden_states, present = self.transformer(
642646
input_ids,
643647
position_ids,

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class FlashCausalLMBatch(Batch):
3636
input_ids: torch.Tensor
3737
position_ids: torch.Tensor
3838
# shape is [sum(seq_lengths), embedding_size]
39-
inputs_embeds: torch.Tensor
39+
inputs_embeds: Optional[torch.Tensor]
4040
# cumulative sequence lengths
4141
cu_seqlens: torch.Tensor
4242
# cumulative query sequence lengths, only used in decode

0 commit comments

Comments
 (0)