Skip to content

Commit 40dd0c4

Browse files
authored
[Frontend] Make sure aggregate members are added to the cache key (#8528)
Each aggregate class tracks its callable members and when the aggregate is referenced by name, the cache keys of all its members are computed. This does require `def __init__` to be marked as `@constexpr_function`
1 parent 7578e3e commit 40dd0c4

File tree

7 files changed

+38
-12
lines changed

7 files changed

+38
-12
lines changed

python/examples/gluon/01-attention-forward.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,11 @@ class BarrierCounter:
5252
phase: gl.tensor
5353
num_barriers: gl.constexpr
5454

55+
@gluon.constexpr_function
5556
def __init__(self, index, phase, num_barriers):
5657
self.index = index
5758
self.phase = phase
58-
self.num_barriers = num_barriers
59+
self.num_barriers = gl.constexpr(num_barriers)
5960

6061
@gluon.must_use_result
6162
@gluon.jit
@@ -79,6 +80,7 @@ class ChannelType:
7980
num_buffers: gl.constexpr
8081
num_consumers: gl.constexpr
8182

83+
@gluon.constexpr_function
8284
def __init__(self, mem, ready_bars, empty_bars, num_buffers, num_consumers):
8385
self.mem = mem
8486
self.ready_bars = ready_bars
@@ -143,6 +145,7 @@ class Producer:
143145
channel: ChannelType
144146
counter: BarrierCounter
145147

148+
@gluon.constexpr_function
146149
def __init__(self, channel, counter):
147150
self.channel = channel
148151
self.counter = counter
@@ -158,6 +161,7 @@ class Consumer:
158161
channel: ChannelType
159162
counter: BarrierCounter
160163

164+
@gluon.constexpr_function
161165
def __init__(self, channel, counter):
162166
self.channel = channel
163167
self.counter = counter
@@ -234,6 +238,7 @@ class AttentionConfig:
234238
num_kv_buffers: gl.constexpr
235239
use_exp2_turnstile: gl.constexpr
236240

241+
@gluon.constexpr_function
237242
def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE_N, NUM_SMS, STAGE, dtype,
238243
num_warps):
239244
self.qk_scale = qk_scale
@@ -250,7 +255,7 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE
250255
self.num_warps = gl.constexpr(num_warps)
251256

252257
self.SPLIT_D_FACTOR = gl.constexpr(2)
253-
self.SPLIT_EXP_FACTOR = 256 // HEAD_DIM
258+
self.SPLIT_EXP_FACTOR = gl.constexpr(256 // HEAD_DIM)
254259
self.SPLIT_QK_LOAD_FACTOR = gl.constexpr(2 if STAGE == 1 else 1)
255260
self.SPLIT_M = gl.constexpr(self.BLOCK_M // 2)
256261
self.SPLIT_D = gl.constexpr(self.HEAD_DIM // self.SPLIT_D_FACTOR)
@@ -305,6 +310,7 @@ class ProgramScheduler:
305310
num_pid_in_group: gl.tensor
306311
num_tiles: gl.tensor
307312

313+
@gluon.constexpr_function
308314
def __init__(self, config, start_pid, num_pid_n, num_pid_in_group, num_tiles):
309315
self.config = config
310316
self.start_pid = start_pid
@@ -339,6 +345,7 @@ class AttentionProgram:
339345
offset_y: gl.tensor
340346
qo_offset_y: gl.tensor
341347

348+
@gluon.constexpr_function
342349
def __init__(self, config, start_m, off_hz, offset_y, qo_offset_y):
343350
self.config = config
344351
self.start_m = start_m

python/triton/experimental/gluon/language/nvidia/blackwell/float2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def _fma_f32x2(a, b, c):
7272
class Float2Tensor:
7373
value: ttgl.tensor
7474

75+
@constexpr_function
7576
def __init__(self, value: ttgl.tensor):
7677
self.value = value
7778

python/triton/language/core.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,13 +1544,15 @@ def _get_instance(this_cls):
15441544
def __new__(this_cls, *args, _semantic=None, _generator=None, **kwargs):
15451545
# Call into the user-defined constructor.
15461546
instance = this_cls._get_instance()
1547-
if isinstance(cls.__init__, JITCallable):
1548-
raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function")
15491547
extra_kwargs = {}
1550-
if "_semantic" in inspect.signature(cls.__init__).parameters:
1551-
extra_kwargs["_semantic"] = _semantic
1552-
if "_generator" in inspect.signature(cls.__init__).parameters:
1553-
extra_kwargs["_generator"] = _generator
1548+
if isinstance(cls.__init__, JITCallable):
1549+
# raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function")
1550+
pass
1551+
else:
1552+
if "_semantic" in inspect.signature(cls.__init__).parameters:
1553+
extra_kwargs["_semantic"] = _semantic
1554+
if "_generator" in inspect.signature(cls.__init__).parameters:
1555+
extra_kwargs["_generator"] = _generator
15541556
cls.__init__(instance, *args, **extra_kwargs, **kwargs)
15551557

15561558
# Require that the user-defined constructor initialized all fields.
@@ -1577,11 +1579,15 @@ def type(self):
15771579
return _aggregate_type(aggregate_value,
15781580
[(name, getattr(self, name).type) for name in cls.__annotations__.keys()])
15791581

1582+
hash_attrs = [cls.__init__]
1583+
15801584
for (name, member) in inspect.getmembers(cls):
15811585
if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member, JITCallable):
15821586
if name != "__init__":
15831587
setattr(aggregate_value, name, member)
1588+
hash_attrs.append(member)
15841589

1590+
aggregate_value.hash_attrs = hash_attrs
15851591
aggregate_value.__name__ = cls.__name__
15861592
aggregate_value.__module__ = cls.__module__
15871593
aggregate_value.__qualname__ = cls.__qualname__

python/triton/runtime/jit.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ def record_reference(self, val, var_dict=None, name=None):
122122
if val is None or type(val) is ModuleType:
123123
return
124124

125+
if getattr(val, "__triton_aggregate__", False):
126+
for attr in val.hash_attrs:
127+
self.record_reference(attr)
128+
return
129+
125130
if getattr(val, "__triton_builtin__", False):
126131
return
127132

python/tutorials/gluon/07-persistence.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class WGMMA:
9797
acc: Union[warpgroup_mma_accumulator, gl.tensor]
9898
use_acc: gl.tensor
9999

100+
@gluon.constexpr_function
100101
def __init__(self, acc, use_acc):
101102
self.acc = acc
102103
self.use_acc = use_acc
@@ -136,12 +137,13 @@ class MMAv5:
136137
counter: gl.tensor
137138
reg_layout: gl.constexpr
138139

140+
@gluon.constexpr_function
139141
def __init__(self, use_acc, acc_tmem, bar, counter, reg_layout):
140142
self.use_acc = use_acc
141143
self.acc_tmem = acc_tmem
142144
self.bar = bar
143145
self.counter = counter
144-
self.reg_layout = reg_layout
146+
self.reg_layout = gl.constexpr(reg_layout)
145147

146148
@gluon.jit
147149
def initialize(dtype: gl.constexpr, BLOCK_M: gl.constexpr, BLOCK_N: gl.constexpr, num_warps: gl.constexpr):
@@ -342,6 +344,7 @@ class PersistentTileScheduler:
342344
pid_end: gl.tensor
343345
num_pid_m: gl.tensor
344346

347+
@gluon.constexpr_function
345348
def __init__(self, pid_start, pid_end, num_pid_m):
346349
self.pid_start = pid_start
347350
self.pid_end = pid_end
@@ -523,6 +526,7 @@ class GroupedPersistentTileSchedulerImpl:
523526
num_pid_in_group: gl.tensor
524527
num_pid: gl.tensor
525528

529+
@gluon.constexpr_function
526530
def __init__(self, start_pid, num_pid_m, num_pid_in_group, num_pid):
527531
self.start_pid = start_pid
528532
self.num_pid_m = num_pid_m

python/tutorials/gluon/08-warp-specialization.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ class PartitionArgs:
400400
SUBTILE_FACTOR: gl.constexpr
401401
num_warps: gl.constexpr
402402

403+
@gluon.constexpr_function
403404
def __init__(self, a_desc, b_desc, c_desc, a_bufs, b_bufs, load_empty_bars, load_ready_bars, acc_bufs,
404405
acc_empty_bars, acc_ready_bars, SUBTILE_FACTOR, num_warps):
405406
self.a_desc = a_desc
@@ -412,8 +413,8 @@ def __init__(self, a_desc, b_desc, c_desc, a_bufs, b_bufs, load_empty_bars, load
412413
self.acc_bufs = acc_bufs
413414
self.acc_empty_bars = acc_empty_bars
414415
self.acc_ready_bars = acc_ready_bars
415-
self.SUBTILE_FACTOR = SUBTILE_FACTOR
416-
self.num_warps = num_warps
416+
self.SUBTILE_FACTOR = gl.constexpr(SUBTILE_FACTOR)
417+
self.num_warps = gl.constexpr(num_warps)
417418

418419

419420
# Counter abstraction for tracking barrier index and phase.
@@ -423,10 +424,11 @@ class Counter:
423424
phase: gl.tensor
424425
num_barriers: gl.constexpr
425426

427+
@gluon.constexpr_function
426428
def __init__(self, index, phase, num_barriers):
427429
self.index = index
428430
self.phase = phase
429-
self.num_barriers = num_barriers
431+
self.num_barriers = gl.constexpr(num_barriers)
430432

431433
@gluon.jit
432434
def create(phase, num_barriers: gl.constexpr):

third_party/amd/python/examples/gluon/f16_gemm_gfx1250.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class PersistentTileScheduler:
1919
pid_end: ttgl.tensor
2020
num_pid_m: ttgl.tensor
2121

22+
@gluon.constexpr_function
2223
def __init__(self, pid_start, pid_end, num_pid_m):
2324
self.pid_start = pid_start
2425
self.pid_end = pid_end

0 commit comments

Comments
 (0)