File tree Expand file tree Collapse file tree 4 files changed +13
-1
lines changed
server/text_generation_server/models Expand file tree Collapse file tree 4 files changed +13
-1
lines changed Original file line number Diff line number Diff line change @@ -489,10 +489,14 @@ def forward(
489
489
cu_seqlens ,
490
490
cu_seqlens_q ,
491
491
max_s ,
492
+ inputs_embeds : Optional [torch .Tensor ] = None ,
492
493
past_key_values : Optional [torch .Tensor ] = None ,
493
494
pre_allocate_past_size : Optional [int ] = None ,
494
495
lm_head_indices : Optional [torch .Tensor ] = None ,
495
496
):
497
+ if inputs_embeds is not None :
498
+ raise ValueError ("input_embeds not yet supported for flash llama" )
499
+
496
500
hidden_states , present = self .model (
497
501
input_ids ,
498
502
position_ids ,
Original file line number Diff line number Diff line change @@ -398,10 +398,14 @@ def forward(
398
398
cu_seqlens ,
399
399
cu_seqlens_q ,
400
400
max_s ,
401
+ inputs_embeds : Optional [torch .Tensor ] = None ,
401
402
past_key_values : Optional [torch .Tensor ] = None ,
402
403
pre_allocate_past_size : Optional [int ] = None ,
403
404
lm_head_indices : Optional [torch .Tensor ] = None ,
404
405
):
406
+ if inputs_embeds is not None :
407
+ raise ValueError ("input_embeds not yet supported for flash neox" )
408
+
405
409
hidden_states , present = self .gpt_neox (
406
410
input_ids ,
407
411
position_ids ,
Original file line number Diff line number Diff line change @@ -634,10 +634,14 @@ def forward(
634
634
cu_seqlens ,
635
635
cu_seqlens_q ,
636
636
max_s ,
637
+ inputs_embeds : Optional [torch .Tensor ] = None ,
637
638
past_key_values : Optional [torch .Tensor ] = None ,
638
639
pre_allocate_past_size : Optional [int ] = None ,
639
640
lm_head_indices : Optional [torch .Tensor ] = None ,
640
641
):
642
+ if inputs_embeds is not None :
643
+ raise ValueError ("input_embeds not yet supported for flash rw (falcon)" )
644
+
641
645
hidden_states , present = self .transformer (
642
646
input_ids ,
643
647
position_ids ,
Original file line number Diff line number Diff line change @@ -36,7 +36,7 @@ class FlashCausalLMBatch(Batch):
36
36
input_ids : torch .Tensor
37
37
position_ids : torch .Tensor
38
38
# shape is [sum(seq_lengths), embedding_size]
39
- inputs_embeds : torch .Tensor
39
+ inputs_embeds : Optional [ torch .Tensor ]
40
40
# cumulative sequence lengths
41
41
cu_seqlens : torch .Tensor
42
42
# cumulative query sequence lengths, only used in decode
You can’t perform that action at this time.
0 commit comments