@@ -31,8 +31,12 @@ class FlashCausalLMBatch(Batch):
31
31
requests : List [generate_pb2 .Request ]
32
32
33
33
# Decoder values
34
+ # tensors have sequences from the batch concatenated
35
+ # shape is [sum(seq_lengths)]
34
36
input_ids : torch .Tensor
35
37
position_ids : torch .Tensor
38
+ # shape is [sum(seq_lengths), embedding_size]
39
+ inputs_embeds : torch .Tensor
36
40
# cumulative sequence lengths
37
41
cu_seqlens : torch .Tensor
38
42
# cumulative query sequence lengths, only used in decode
@@ -68,77 +72,97 @@ def from_pb(
68
72
) -> Tuple [Optional ["FlashCausalLMBatch" ], List [GenerateError ]]:
69
73
errors = []
70
74
batch_inputs = []
75
+ requests = pb .requests
76
+
77
+ # track indices of valid requests that have prefixes
78
+ i = 0
79
+ prefix_ids = {}
80
+ # compute sequence lengths in this loop too
81
+ # if there is a prefix, input_lengths will include its length
82
+ input_lengths = []
71
83
max_seqlen = 0
72
- for r in pb .requests :
84
+ # Cumulative length
85
+ cu_seqlens = [0 ]
86
+ cumulative_length = 0
87
+ for r in requests :
88
+ input_length = r .input_length
89
+ # TODO: Also fail depending on the model type for ones that don't
90
+ # have input_embeds implemented?
73
91
if r .prefix_id :
74
- message = f"Prompt prefixes not yet supported with flash attention (request #{ r .id } )"
75
- logging .error (message )
76
- # Exclude this request from the batch, return an error
77
- errors .append (GenerateError (request_id = r .id , message = message ))
78
- continue
92
+ try :
93
+ prefix_embeds = prefix_cache .get (r .prefix_id )
94
+ except Exception :
95
+ message = f"Prefix lookup error for request #{ r .id } , prefix id { r .prefix_id } "
96
+ logging .error (message )
97
+ # Exclude this request from the batch, return an error
98
+ errors .append (GenerateError (request_id = r .id , message = message ))
99
+ continue
100
+ prefix_ids [i ] = prefix_embeds
101
+ input_length += prefix_embeds .shape [0 ]
79
102
batch_inputs .append (r .inputs )
80
- max_seqlen = max (max_seqlen , r .input_length )
103
+ input_lengths .append (input_length )
104
+ max_seqlen = max (max_seqlen , input_length )
105
+ cumulative_length += input_length
106
+ cu_seqlens .append (cumulative_length )
107
+ i += 1
81
108
109
+ # remove errored requests
82
110
if errors :
83
111
requests = [r for r in pb .requests if not any (r .id == er .request_id for er in errors )]
112
+ # early exit if no requests are valid
84
113
if not requests :
85
114
return None , errors
86
115
116
+ # return as lists to avoid unnecessary padding;
117
+ # sequences will be concatenated across the batch
87
118
batch_tokenized_inputs = tokenizer (
88
119
batch_inputs , truncation = True , max_length = max_seqlen , return_token_type_ids = False
89
120
)["input_ids" ]
90
121
122
+ # Process inputs to generate the needed tensors
91
123
input_ids = []
92
124
position_ids = []
93
- cu_seqlens = [0 ]
94
-
95
- input_lengths = []
96
125
all_input_ids_tensor = []
97
-
98
126
next_token_choosers = []
99
-
100
- # Cumulative length
101
- cumulative_length = 0
102
-
103
- # Parse batch
104
- requests = pb .requests
105
- for r , tokenized_input in zip (requests , batch_tokenized_inputs ):
106
- input_length = r .input_length
107
-
108
- tokenized_input = tokenized_input [- input_length :]
109
-
110
- # Fill in bos token in truncation case if needed
111
- if r .truncate and getattr (tokenizer , "add_bos_token" , False ):
112
- tokenized_input [0 ] = tokenizer .bos_token_id
113
-
114
- input_lengths .append (input_length )
115
-
127
+ for r , tokenized_input , input_length in zip (requests , batch_tokenized_inputs , input_lengths ):
128
+ if r .truncate :
129
+ tokenized_input = tokenized_input [- r .input_length :]
130
+ # Fill in bos token in truncation case if needed
131
+ if getattr (tokenizer , "add_bos_token" , False ):
132
+ tokenized_input [0 ] = tokenizer .bos_token_id
116
133
tokenized_input = torch .tensor (tokenized_input , device = device )
117
- input_ids .append (tokenized_input )
118
-
119
- # Position ids
120
- position_ids .append (torch .arange (0 , input_length , dtype = torch .int32 ))
121
-
122
- # Add cumulative lengths of all previous inputs
123
- cu_seqlens .append (cumulative_length + input_length )
124
-
134
+ # LHS pad for prefix, if it exists; RHS pad to max output
135
+ padded_input_ids = F .pad (tokenized_input , (input_length - r .input_length , r .max_output_length ))
136
+ all_input_ids_tensor .append (padded_input_ids )
137
+ # input_ids needs prefix padding but not output padding
138
+ input_ids .append (tokenized_input if input_length == r .input_length else padded_input_ids [:input_length ])
125
139
next_token_choosers .append (
126
140
NextTokenChooser .from_pb (r .parameters , r .details .logprobs , tokenizer , device )
127
141
)
128
- all_input_ids_tensor .append (F .pad (tokenized_input , (0 , r .max_output_length )))
129
-
130
- cumulative_length += input_length
131
-
142
+ position_ids .append (torch .arange (0 , input_length , dtype = torch .int32 ))
132
143
input_ids = torch .cat (input_ids )
133
- position_ids = torch .cat (position_ids ).to (device , non_blocking = True )
134
- cu_seqlens = torch .tensor (cu_seqlens , dtype = torch .int32 , device = device )
144
+
145
+ # convert all requests to embeddings if any request has a prefix_id
146
+ if prefix_ids :
147
+ # TODO: Handle TP distributed embeddings layer
148
+ inputs_embeds = embeddings_lookup (input_ids )
149
+ input_ids = None
150
+ # fill in the prefix embeddings into the space that we already
151
+ # allocated due to the padding in input_ids
152
+ for i , p in prefix_ids .items ():
153
+ start = cu_seqlens [i ]
154
+ prefix_length = p .shape [0 ]
155
+ inputs_embeds [start :start + prefix_length , :] = p
156
+ else :
157
+ inputs_embeds = None
135
158
136
159
return cls (
137
160
batch_id = pb .id ,
138
161
requests = requests ,
139
162
input_ids = input_ids ,
140
- position_ids = position_ids ,
141
- cu_seqlens = cu_seqlens ,
163
+ inputs_embeds = inputs_embeds ,
164
+ position_ids = torch .cat (position_ids ).to (device , non_blocking = True ),
165
+ cu_seqlens = torch .tensor (cu_seqlens , dtype = torch .int32 , device = device ),
142
166
cu_seqlens_q = None ,
143
167
max_seqlen = max_seqlen ,
144
168
past_key_values = None ,
@@ -195,6 +219,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
195
219
batch_id = batches [0 ].batch_id ,
196
220
requests = requests ,
197
221
input_ids = torch .cat (input_ids ),
222
+ inputs_embeds = None ,
198
223
position_ids = torch .cat (position_ids ),
199
224
cu_seqlens = torch .cat (cu_seqlens ),
200
225
cu_seqlens_q = torch .arange (len (requests ) + 1 , device = device , dtype = torch .int32 ),
@@ -345,6 +370,7 @@ def generate_token(
345
370
batch .cu_seqlens ,
346
371
batch .cu_seqlens_q ,
347
372
batch .max_seqlen ,
373
+ batch .inputs_embeds ,
348
374
past_key_values ,
349
375
prealloc_length ,
350
376
)
@@ -410,6 +436,7 @@ def _process_prefill(
410
436
# Create final next batch tensors
411
437
batch .input_ids = torch .cat (next_batch_input_ids ) \
412
438
if batch_size > 1 else next_batch_input_ids [0 ].view (1 )
439
+ batch .inputs_embeds = None
413
440
414
441
batch .cu_seqlens_q = torch .arange (
415
442
batch_size + 1 , device = self .device , dtype = torch .int32
0 commit comments