1- # Copyright (c) 2023 , Tri Dao.
1+ # Copyright (c) 2024 , Tri Dao, Albert Gu .
22
3- """We want triton==2.1.0 for this
3+ """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
44"""
55
66import math
@@ -22,20 +22,21 @@ def _selective_scan_update_kernel(
2222 # Pointers to matrices
2323 state_ptr , x_ptr , dt_ptr , dt_bias_ptr , A_ptr , B_ptr , C_ptr , D_ptr , z_ptr , out_ptr ,
2424 # Matrix dimensions
25- batch , dim , dstate ,
25+ batch , nheads , dim , dstate , nheads_ngroups_ratio ,
2626 # Strides
27- stride_state_batch , stride_state_dim , stride_state_dstate ,
28- stride_x_batch , stride_x_dim ,
29- stride_dt_batch , stride_dt_dim ,
30- stride_dt_bias_dim ,
31- stride_A_dim , stride_A_dstate ,
32- stride_B_batch , stride_B_dstate ,
33- stride_C_batch , stride_C_dstate ,
34- stride_D_dim ,
35- stride_z_batch , stride_z_dim ,
36- stride_out_batch , stride_out_dim ,
27+ stride_state_batch , stride_state_head , stride_state_dim , stride_state_dstate ,
28+ stride_x_batch , stride_x_head , stride_x_dim ,
29+ stride_dt_batch , stride_dt_head , stride_dt_dim ,
30+ stride_dt_bias_head , stride_dt_bias_dim ,
31+ stride_A_head , stride_A_dim , stride_A_dstate ,
32+ stride_B_batch , stride_B_group , stride_B_dstate ,
33+ stride_C_batch , stride_C_group , stride_C_dstate ,
34+ stride_D_head , stride_D_dim ,
35+ stride_z_batch , stride_z_head , stride_z_dim ,
36+ stride_out_batch , stride_out_head , stride_out_dim ,
3737 # Meta-parameters
3838 DT_SOFTPLUS : tl .constexpr ,
39+ TIE_HDIM : tl .constexpr ,
3940 BLOCK_SIZE_M : tl .constexpr ,
4041 HAS_DT_BIAS : tl .constexpr ,
4142 HAS_D : tl .constexpr ,
@@ -44,14 +45,18 @@ def _selective_scan_update_kernel(
4445):
4546 pid_m = tl .program_id (axis = 0 )
4647 pid_b = tl .program_id (axis = 1 )
47- state_ptr += pid_b * stride_state_batch
48- x_ptr += pid_b * stride_x_batch
49- dt_ptr += pid_b * stride_dt_batch
50- B_ptr += pid_b * stride_B_batch
51- C_ptr += pid_b * stride_C_batch
48+ pid_h = tl .program_id (axis = 2 )
49+ state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
50+ x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
51+ dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
52+ if HAS_DT_BIAS :
53+ dt_bias_ptr += pid_h * stride_dt_bias_head
54+ A_ptr += pid_h * stride_A_head
55+ B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio ) * stride_B_group
56+ C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio ) * stride_C_group
5257 if HAS_Z :
53- z_ptr += pid_b * stride_z_batch
54- out_ptr += pid_b * stride_out_batch
58+ z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
59+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
5560
5661 offs_m = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
5762 offs_n = tl .arange (0 , BLOCK_SIZE_DSTATE )
@@ -60,6 +65,8 @@ def _selective_scan_update_kernel(
6065 dt_ptrs = dt_ptr + offs_m * stride_dt_dim
6166 if HAS_DT_BIAS :
6267 dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
68+ if HAS_D :
69+ D_ptr += pid_h * stride_D_head
6370 A_ptrs = A_ptr + (offs_m [:, None ] * stride_A_dim + offs_n [None , :] * stride_A_dstate )
6471 B_ptrs = B_ptr + offs_n * stride_B_dstate
6572 C_ptrs = C_ptr + offs_n * stride_C_dstate
@@ -71,21 +78,34 @@ def _selective_scan_update_kernel(
7178
7279 state = tl .load (state_ptrs , mask = (offs_m [:, None ] < dim ) & (offs_n [None , :] < dstate ), other = 0.0 )
7380 x = tl .load (x_ptrs , mask = offs_m < dim , other = 0.0 ).to (tl .float32 )
74- dt = tl .load (dt_ptrs , mask = offs_m < dim , other = 0.0 ).to (tl .float32 )
75- if HAS_DT_BIAS :
76- dt += tl .load (dt_bias_ptrs , mask = offs_m < dim , other = 0.0 ).to (tl .float32 )
77- if DT_SOFTPLUS :
78- dt = tl .where (dt <= 20.0 , tl .math .log1p (tl .exp (dt )), dt )
79- A = tl .load (A_ptrs , mask = (offs_m [:, None ] < dim ) & (offs_n [None , :] < dstate ), other = 0.0 ).to (tl .float32 )
80- dA = tl .exp (A * dt [:, None ])
81+ if not TIE_HDIM :
82+ dt = tl .load (dt_ptrs , mask = offs_m < dim , other = 0.0 ).to (tl .float32 )
83+ if HAS_DT_BIAS :
84+ dt += tl .load (dt_bias_ptrs , mask = offs_m < dim , other = 0.0 ).to (tl .float32 )
85+ if DT_SOFTPLUS :
86+ dt = tl .where (dt <= 20.0 , tl .math .log1p (tl .exp (dt )), dt )
87+ A = tl .load (A_ptrs , mask = (offs_m [:, None ] < dim ) & (offs_n [None , :] < dstate ), other = 0.0 ).to (tl .float32 )
88+ dA = tl .exp (A * dt [:, None ])
89+ else :
90+ dt = tl .load (dt_ptr ).to (tl .float32 )
91+ if HAS_DT_BIAS :
92+ dt += tl .load (dt_bias_ptr ).to (tl .float32 )
93+ if DT_SOFTPLUS :
94+ dt = tl .where (dt <= 20.0 , tl .math .log1p (tl .exp (dt )), dt )
95+ A = tl .load (A_ptr ).to (tl .float32 )
96+ dA = tl .exp (A * dt ) # scalar, not a matrix
97+
8198 B = tl .load (B_ptrs , mask = offs_n < dstate , other = 0.0 ).to (tl .float32 )
8299 C = tl .load (C_ptrs , mask = offs_n < dstate , other = 0.0 ).to (tl .float32 )
83100 if HAS_D :
84101 D = tl .load (D_ptrs , mask = offs_m < dim , other = 0.0 ).to (tl .float32 )
85102 if HAS_Z :
86103 z = tl .load (z_ptrs , mask = offs_m < dim , other = 0.0 ).to (tl .float32 )
87104
88- dB = B [None , :] * dt [:, None ]
105+ if not TIE_HDIM :
106+ dB = B [None , :] * dt [:, None ]
107+ else :
108+ dB = B * dt # vector of size (dstate,)
89109 state = state * dA + dB * x [:, None ]
90110 tl .store (state_ptrs , state , mask = (offs_m [:, None ] < dim ) & (offs_n [None , :] < dstate ))
91111 out = tl .sum (state * C [None , :], axis = 1 )
@@ -99,94 +119,145 @@ def _selective_scan_update_kernel(
99119def selective_state_update (state , x , dt , A , B , C , D = None , z = None , dt_bias = None , dt_softplus = False ):
100120 """
101121 Argument:
102- state: (batch, dim, dstate)
103- x: (batch, dim)
104- dt: (batch, dim)
105- A: (dim, dstate)
106- B: (batch, dstate)
107- C: (batch, dstate)
108- D: (dim,)
109- z: (batch, dim)
110- dt_bias: (dim,)
122+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
123+ x: (batch, dim) or (batch, nheads, dim)
124+ dt: (batch, dim) or (batch, nheads, dim)
125+ A: (dim, dstate) or (nheads, dim, dstate)
126+ B: (batch, dstate) or (batch, ngroups, dstate)
127+ C: (batch, dstate) or (batch, ngroups, dstate)
128+ D: (dim,) or (nheads, dim)
129+ z: (batch, dim) or (batch, nheads, dim)
130+ dt_bias: (dim,) or (nheads, dim)
111131 Return:
112- out: (batch, dim)
132+ out: (batch, dim) or (batch, nheads, dim)
113133 """
114- batch , dim , dstate = state .shape
115- assert x .shape == (batch , dim )
134+ has_heads = state .dim () > 3
135+ if state .dim () == 3 :
136+ state = state .unsqueeze (1 )
137+ if x .dim () == 2 :
138+ x = x .unsqueeze (1 )
139+ if dt .dim () == 2 :
140+ dt = dt .unsqueeze (1 )
141+ if A .dim () == 2 :
142+ A = A .unsqueeze (0 )
143+ if B .dim () == 2 :
144+ B = B .unsqueeze (1 )
145+ if C .dim () == 2 :
146+ C = C .unsqueeze (1 )
147+ if D is not None and D .dim () == 1 :
148+ D = D .unsqueeze (0 )
149+ if z is not None and z .dim () == 2 :
150+ z = z .unsqueeze (1 )
151+ if dt_bias is not None and dt_bias .dim () == 1 :
152+ dt_bias = dt_bias .unsqueeze (0 )
153+ batch , nheads , dim , dstate = state .shape
154+ assert x .shape == (batch , nheads , dim )
116155 assert dt .shape == x .shape
117- assert A .shape == (dim , dstate )
118- assert B .shape == (batch , dstate )
156+ assert A .shape == (nheads , dim , dstate )
157+ ngroups = B .shape [1 ]
158+ assert nheads % ngroups == 0 , "nheads must be divisible by ngroups"
159+ assert B .shape == (batch , ngroups , dstate )
119160 assert C .shape == B .shape
120161 if D is not None :
121- assert D .shape == (dim , )
162+ assert D .shape == (nheads , dim )
122163 if z is not None :
123164 assert z .shape == x .shape
124165 if dt_bias is not None :
125- assert dt_bias .shape == (dim , )
166+ assert dt_bias .shape == (nheads , dim )
126167 out = torch .empty_like (x )
127- grid = lambda META : (triton .cdiv (dim , META ['BLOCK_SIZE_M' ]), batch )
128- z_strides = ((z .stride (0 ), z .stride (1 )) if z is not None else (0 , 0 ))
168+ grid = lambda META : (triton .cdiv (dim , META ['BLOCK_SIZE_M' ]), batch , nheads )
169+ z_strides = ((z .stride (0 ), z .stride (1 ), z . stride ( 2 )) if z is not None else (0 , 0 , 0 ))
129170 # We don't want autotune since it will overwrite the state
130171 # We instead tune by hand.
131172 BLOCK_SIZE_M , num_warps = ((32 , 4 ) if dstate <= 16
132173 else ((16 , 4 ) if dstate <= 32 else
133174 ((8 , 4 ) if dstate <= 64 else
134175 ((4 , 4 ) if dstate <= 128 else
135176 ((4 , 8 ))))))
177+ tie_hdim = A .stride (- 1 ) == 0 and A .stride (- 2 ) == 0 and dt .stride (- 1 ) == 0 and dt_bias .stride (- 1 ) == 0
136178 with torch .cuda .device (x .device .index ):
137179 _selective_scan_update_kernel [grid ](
138180 state , x , dt , dt_bias , A , B , C , D , z , out ,
139- batch , dim , dstate ,
140- state .stride (0 ), state .stride (1 ), state .stride (2 ),
141- x .stride (0 ), x .stride (1 ),
142- dt .stride (0 ), dt .stride (1 ),
143- dt_bias .stride (0 ) if dt_bias is not None else 0 ,
144- A .stride (0 ), A .stride (1 ),
145- B .stride (0 ), B .stride (1 ),
146- C .stride (0 ), C .stride (1 ),
147- D .stride (0 ) if D is not None else 0 ,
148- z_strides [0 ], z_strides [1 ],
149- out .stride (0 ), out .stride (1 ),
181+ batch , nheads , dim , dstate , nheads // ngroups ,
182+ state .stride (0 ), state .stride (1 ), state .stride (2 ), state . stride ( 3 ),
183+ x .stride (0 ), x .stride (1 ), x . stride ( 2 ),
184+ dt .stride (0 ), dt .stride (1 ), dt . stride ( 2 ),
185+ * ( dt_bias .stride (0 ), dt_bias . stride ( 1 ) ) if dt_bias is not None else 0 ,
186+ A .stride (0 ), A .stride (1 ), A . stride ( 2 ),
187+ B .stride (0 ), B .stride (1 ), B . stride ( 2 ),
188+ C .stride (0 ), C .stride (1 ), C . stride ( 2 ),
189+ * ( D .stride (0 ), D . stride ( 1 ) ) if D is not None else 0 ,
190+ z_strides [0 ], z_strides [1 ], z_strides [ 2 ],
191+ out .stride (0 ), out .stride (1 ), out . stride ( 2 ),
150192 dt_softplus ,
193+ tie_hdim ,
151194 BLOCK_SIZE_M ,
152195 num_warps = num_warps ,
153196 )
197+ if not has_heads :
198+ out = out .squeeze (1 )
154199 return out
155200
156201
157202def selective_state_update_ref (state , x , dt , A , B , C , D = None , z = None , dt_bias = None , dt_softplus = False ):
158203 """
159204 Argument:
160- state: (batch, dim, dstate)
161- x: (batch, dim)
162- dt: (batch, dim)
163- A: (dim, dstate)
164- B: (batch, dstate)
165- C: (batch, dstate)
166- D: (dim,)
167- z: (batch, dim)
168- dt_bias: (dim,)
205+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
206+ x: (batch, dim) or (batch, nheads, dim)
207+ dt: (batch, dim) or (batch, nheads, dim)
208+ A: (dim, dstate) or (nheads, dim, dstate)
209+ B: (batch, dstate) or (batch, ngroups, dstate)
210+ C: (batch, dstate) or (batch, ngroups, dstate)
211+ D: (dim,) or (nheads, dim)
212+ z: (batch, dim) or (batch, nheads, dim)
213+ dt_bias: (dim,) or (nheads, dim)
169214 Return:
170- out: (batch, dim)
215+ out: (batch, dim) or (batch, nheads, dim)
171216 """
172- batch , dim , dstate = state .shape
173- assert x .shape == (batch , dim )
217+ has_heads = state .dim () > 3
218+ if state .dim () == 3 :
219+ state = state .unsqueeze (1 )
220+ if x .dim () == 2 :
221+ x = x .unsqueeze (1 )
222+ if dt .dim () == 2 :
223+ dt = dt .unsqueeze (1 )
224+ if A .dim () == 2 :
225+ A = A .unsqueeze (0 )
226+ if B .dim () == 2 :
227+ B = B .unsqueeze (1 )
228+ if C .dim () == 2 :
229+ C = C .unsqueeze (1 )
230+ if D is not None and D .dim () == 1 :
231+ D = D .unsqueeze (0 )
232+ if z is not None and z .dim () == 2 :
233+ z = z .unsqueeze (1 )
234+ if dt_bias is not None and dt_bias .dim () == 1 :
235+ dt_bias = dt_bias .unsqueeze (0 )
236+ batch , nheads , dim , dstate = state .shape
237+ assert x .shape == (batch , nheads , dim )
174238 assert dt .shape == x .shape
175- assert A .shape == (dim , dstate )
176- assert B .shape == (batch , dstate )
239+ assert A .shape == (nheads , dim , dstate )
240+ ngroups = B .shape [1 ]
241+ assert nheads % ngroups == 0 , "nheads must be divisible by ngroups"
242+ assert B .shape == (batch , ngroups , dstate )
177243 assert C .shape == B .shape
178244 if D is not None :
179- assert D .shape == (dim , )
245+ assert D .shape == (nheads , dim )
180246 if z is not None :
181247 assert z .shape == x .shape
182248 if dt_bias is not None :
183- assert dt_bias .shape == (dim , )
249+ assert dt_bias .shape == (nheads , dim )
184250 dt = dt + dt_bias
185251 dt = F .softplus (dt ) if dt_softplus else dt
186- dA = torch .exp (rearrange (dt , "b d -> b d 1" ) * A ) # (batch, dim, dstate)
187- dB = rearrange (dt , "b d -> b d 1" ) * rearrange (B , "b n -> b 1 n" ) # (batch, dim, dstate)
188- state .copy_ (state * dA + dB * rearrange (x , "b d -> b d 1" )) # (batch, dim, dstate
189- out = torch .einsum ("bdn,bn->bd" , state .to (C .dtype ), C )
252+ dA = torch .exp (rearrange (dt , "b h d -> b h d 1" ) * A ) # (batch, nheads, dim, dstate)
253+ B = repeat (B , "b g n -> b (g h) n" , h = nheads // ngroups ) # (batch, nheads, dstate)
254+ C = repeat (C , "b g n -> b (g h) n" , h = nheads // ngroups ) # (batch, nheads, dstate)
255+ dB = rearrange (dt , "b h d -> b h d 1" ) * rearrange (B , "b h n -> b h 1 n" ) # (batch, nheads, dim, dstate)
256+ state .copy_ (state * dA + dB * rearrange (x , "b h d -> b h d 1" )) # (batch, dim, dstate
257+ out = torch .einsum ("bhdn,bhn->bhd" , state .to (C .dtype ), C )
190258 if D is not None :
191259 out += (x * D ).to (out .dtype )
192- return (out if z is None else out * F .silu (z )).to (x .dtype )
260+ out = (out if z is None else out * F .silu (z )).to (x .dtype )
261+ if not has_heads :
262+ out = out .squeeze (1 )
263+ return out
0 commit comments