@@ -68,14 +68,19 @@ class Mamba2AttentionMetadata:
68
68
query_start_loc : torch .Tensor
69
69
seq_lens : torch .Tensor
70
70
71
- has_initial_states : torch .Tensor
72
71
prep_initial_states : bool
73
72
chunk_size : int
74
- seq_idx : torch .Tensor
75
- chunk_indices : torch .Tensor
76
- chunk_offsets : torch .Tensor
73
+
74
+ # The following tensors only contain prefill requests and will be None if
75
+ # the batch has no prefill request.
76
+ has_initial_states_p : Optional [torch .Tensor ]
77
+ seq_idx_p : Optional [torch .Tensor ]
78
+ chunk_indices_p : Optional [torch .Tensor ]
79
+ chunk_offsets_p : Optional [torch .Tensor ]
77
80
78
81
state_indices_tensor : torch .Tensor # shape: [batch,]
82
+
83
+ # The following attributes are for triton implementation of causal_conv1d
79
84
nums_dict : Optional [dict ] = None
80
85
cu_seqlen : Optional [int ] = None
81
86
batch_ptr : Optional [torch .tensor ] = None
@@ -115,11 +120,11 @@ def build(self,
115
120
query_start_loc = common_attn_metadata .query_start_loc
116
121
seq_lens = common_attn_metadata .seq_lens
117
122
118
- seq_idx = None
119
- chunk_indices , chunk_offsets = None , None
123
+ seq_idx_p = None
124
+ chunk_indices_p , chunk_offsets_p = None , None
120
125
# Need flags to indicate if there are initial states
121
126
# currently we really only support the FlashAttention backend
122
- has_initial_states = None
127
+ has_initial_states_p = None
123
128
prep_initial_states = False
124
129
125
130
state_indices_tensor = common_attn_metadata .block_table_tensor [:, 0 ]
@@ -135,25 +140,25 @@ def build(self,
135
140
common_attn_metadata .
136
141
num_computed_tokens_cpu [num_reqs - num_prefills :num_reqs ] > 0 )
137
142
prep_initial_states = torch .any (has_initial_states_cpu ).item ()
138
- has_initial_states = has_initial_states_cpu .to (
143
+ has_initial_states_p = has_initial_states_cpu .to (
139
144
query_start_loc .device )
140
145
141
146
query_start_loc_p = common_attn_metadata .query_start_loc [
142
147
- num_prefills - 1 :] - num_decode_tokens
143
148
144
- seq_idx = torch .repeat_interleave (torch .arange (
149
+ seq_idx_p = torch .repeat_interleave (torch .arange (
145
150
num_prefills ,
146
151
dtype = torch .int32 ,
147
152
device = query_start_loc_p .device ),
148
- query_start_loc_p .diff (),
149
- output_size = num_prefill_tokens )
150
- seq_idx .unsqueeze_ (0 )
153
+ query_start_loc_p .diff (),
154
+ output_size = num_prefill_tokens )
155
+ seq_idx_p .unsqueeze_ (0 )
151
156
152
157
# We compute metadata for chunked prefill once at the top level
153
158
# model forward and reuse them in mamba layers. If not needed,
154
159
# they will be ignored inside mamba kernels.
155
160
if prep_initial_states :
156
- chunk_indices , chunk_offsets = (
161
+ chunk_indices_p , chunk_offsets_p = (
157
162
_query_start_loc_to_chunk_indices_offsets (
158
163
query_start_loc_p , self .chunk_size ,
159
164
num_prefill_tokens ))
@@ -173,12 +178,12 @@ def build(self,
173
178
num_decode_tokens = num_decode_tokens ,
174
179
query_start_loc = query_start_loc ,
175
180
seq_lens = seq_lens ,
176
- has_initial_states = has_initial_states ,
177
181
prep_initial_states = prep_initial_states ,
178
182
chunk_size = self .chunk_size ,
179
- seq_idx = seq_idx ,
180
- chunk_indices = chunk_indices ,
181
- chunk_offsets = chunk_offsets ,
183
+ has_initial_states_p = has_initial_states_p ,
184
+ seq_idx_p = seq_idx_p ,
185
+ chunk_indices_p = chunk_indices_p ,
186
+ chunk_offsets_p = chunk_offsets_p ,
182
187
state_indices_tensor = state_indices_tensor ,
183
188
)
184
189
return attn_metadata
0 commit comments