@@ -119,6 +119,7 @@ def __init__(
119119 def forward (self , hidden_states , cu_seqlens = None , inference_params = None ):
120120 """
121121 hidden_states: (B, L, D)
122+ cu_seqlens: one-dimensional tensor like flash-attn varlen API, only used for variable-length sequences and packing variable-length sequences into one, a.k.a., batch_size B=1
122123 Returns: same shape as hidden_states
123124 """
124125 batch , seqlen , dim = hidden_states .shape
@@ -157,7 +158,7 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
157158 self .D .float (),
158159 delta_bias = self .dt_proj .bias .float (),
159160 delta_softplus = True ,
160- cu_seqlens = cu_seqlens [ 0 ] if cu_seqlens is not None else None ,
161+ cu_seqlens = cu_seqlens ,
161162 )
162163 else :
163164 x , z = xz .chunk (2 , dim = 1 )
@@ -166,12 +167,12 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
166167 if cu_seqlens is not None :
167168 padded_x = x
168169 count = 0
169- for idx in cu_seqlens [0 ][ 1 :- 1 ].tolist ():
170+ for idx in cu_seqlens [1 :- 1 ].tolist ():
170171 padded_idx = idx + count * (self .d_conv - 1 )
171172 padded_x = torch .cat ((padded_x [:, :, :padded_idx ], torch .zeros (1 , x .shape [1 ], self .d_conv - 1 , dtype = x .dtype , device = x .device ), padded_x [:, :, padded_idx :]), dim = 2 )
172173 count = count + 1
173174 x = padded_x
174- assert x .shape [2 ] == (self .d_conv - 1 ) * len (cu_seqlens [ 0 ] [1 :- 1 ]) + z .shape [2 ]
175+ # assert x.shape[2] == (self.d_conv - 1) * len(cu_seqlens[1:-1]) + z.shape[2]
175176
176177 # Compute short convolution
177178 if conv_state is not None :
@@ -192,13 +193,13 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
192193 # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences
193194 if cu_seqlens is not None :
194195 mask = []
195- for seq_len in (cu_seqlens [0 ][ 1 :] - cu_seqlens [ 0 ] [:- 1 ]).tolist ():
196+ for seq_len in (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).tolist ():
196197 mask .extend ([True ] * seq_len )
197198 mask .extend ([False ] * (self .d_conv - 1 ))
198199 mask = mask [:- (self .d_conv - 1 )]
199- assert x .shape [2 ] == len (mask )
200+ # assert x.shape[2] == len(mask)
200201 x = x [:, :, mask ]
201- assert x .shape [2 ] == z .shape [2 ]
202+ # assert x.shape[2] == z.shape[2]
202203
203204 # We're careful here about the layout, to avoid extra transposes.
204205 # We want dt to have d as the slowest moving dimension
@@ -222,7 +223,7 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
222223 delta_bias = self .dt_proj .bias .float (),
223224 delta_softplus = True ,
224225 return_last_state = ssm_state is not None ,
225- cu_seqlens = cu_seqlens [ 0 ] if cu_seqlens is not None else None ,
226+ cu_seqlens = cu_seqlens ,
226227 )
227228 if ssm_state is not None :
228229 y , last_state = y
0 commit comments