@@ -52,10 +52,11 @@ class BarrierCounter:
5252 phase : gl .tensor
5353 num_barriers : gl .constexpr
5454
55+ @gluon .constexpr_function
5556 def __init__ (self , index , phase , num_barriers ):
5657 self .index = index
5758 self .phase = phase
58- self .num_barriers = num_barriers
59+ self .num_barriers = gl . constexpr ( num_barriers )
5960
6061 @gluon .must_use_result
6162 @gluon .jit
@@ -79,6 +80,7 @@ class ChannelType:
7980 num_buffers : gl .constexpr
8081 num_consumers : gl .constexpr
8182
83+ @gluon .constexpr_function
8284 def __init__ (self , mem , ready_bars , empty_bars , num_buffers , num_consumers ):
8385 self .mem = mem
8486 self .ready_bars = ready_bars
@@ -143,6 +145,7 @@ class Producer:
143145 channel : ChannelType
144146 counter : BarrierCounter
145147
148+ @gluon .constexpr_function
146149 def __init__ (self , channel , counter ):
147150 self .channel = channel
148151 self .counter = counter
@@ -158,6 +161,7 @@ class Consumer:
158161 channel : ChannelType
159162 counter : BarrierCounter
160163
164+ @gluon .constexpr_function
161165 def __init__ (self , channel , counter ):
162166 self .channel = channel
163167 self .counter = counter
@@ -234,6 +238,7 @@ class AttentionConfig:
234238 num_kv_buffers : gl .constexpr
235239 use_exp2_turnstile : gl .constexpr
236240
241+ @gluon .constexpr_function
237242 def __init__ (self , qk_scale , Z , H , N_CTX , BLOCK_M , BLOCK_N , HEAD_DIM , GROUP_SIZE_N , NUM_SMS , STAGE , dtype ,
238243 num_warps ):
239244 self .qk_scale = qk_scale
@@ -250,7 +255,7 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE
250255 self .num_warps = gl .constexpr (num_warps )
251256
252257 self .SPLIT_D_FACTOR = gl .constexpr (2 )
253- self .SPLIT_EXP_FACTOR = 256 // HEAD_DIM
258+ self .SPLIT_EXP_FACTOR = gl . constexpr ( 256 // HEAD_DIM )
254259 self .SPLIT_QK_LOAD_FACTOR = gl .constexpr (2 if STAGE == 1 else 1 )
255260 self .SPLIT_M = gl .constexpr (self .BLOCK_M // 2 )
256261 self .SPLIT_D = gl .constexpr (self .HEAD_DIM // self .SPLIT_D_FACTOR )
@@ -305,6 +310,7 @@ class ProgramScheduler:
305310 num_pid_in_group : gl .tensor
306311 num_tiles : gl .tensor
307312
313+ @gluon .constexpr_function
308314 def __init__ (self , config , start_pid , num_pid_n , num_pid_in_group , num_tiles ):
309315 self .config = config
310316 self .start_pid = start_pid
@@ -339,6 +345,7 @@ class AttentionProgram:
339345 offset_y : gl .tensor
340346 qo_offset_y : gl .tensor
341347
348+ @gluon .constexpr_function
342349 def __init__ (self , config , start_m , off_hz , offset_y , qo_offset_y ):
343350 self .config = config
344351 self .start_m = start_m
0 commit comments