Skip to content

[AMD] Fixed make_desc lowering - i.e., findEncodingFromUsers #9585

Merged
antiagainst merged 5 commits intotriton-lang:mainfrom
ravil-mobile:ravil/make-desc-fix
Mar 6, 2026
Merged

[AMD] Fixed make_desc lowering - i.e., findEncodingFromUsers #9585
antiagainst merged 5 commits intotriton-lang:mainfrom
ravil-mobile:ravil/make-desc-fix

Conversation

@ravil-mobile
Copy link
Contributor

@ravil-mobile ravil-mobile commented Feb 26, 2026

The PR fixes findEncodingFromUsers function used in make_desc op lowering by taking into account information about value uses in all basic blocks

Closes https://github.com/ROCm/triton-internal/issues/1598

cc @antiagainst

Comment on lines 90 to 95
if (!sharedEnc) {
// TODO: add an extra pass to assign layout to descriptors
sharedEnc = findEncodingFromUsers(op);
if (!sharedEnc)
return rewriter.notifyMatchFailure(op, "Descriptor has no layout.");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like a very fragile solution. might be worth doing a proper fix?

Copy link
Contributor Author

@ravil-mobile ravil-mobile Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ThomasRaoux Well, the author is @yangshuxin. I believe he is working on a proper solution which may take a while. Meanwhile, this PR fixes the logic of findEncodingFromUsers which doesn't confirm the language syntax at its current implementation.

The kernel which were failing on GFX1250 is

@triton.jit
def batched_gemm_2d_tma_kernel(a_ptr, b_ptr, c_ptr, #
B, M, N, K, #
dtype: tl.constexpr, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
NUM_SMS: tl.constexpr):
start_pid = tl.program_id(axis=0)
num_tiles_m = tl.cdiv(M, BLOCK_M)
num_tiles_n = tl.cdiv(N, BLOCK_N)
k_tiles = tl.cdiv(K, BLOCK_K)
num_tiles_per_batch = num_tiles_m * num_tiles_n
num_tiles = B * num_tiles_per_batch
tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1
tile_id = start_pid - NUM_SMS
ki = -1
tile_m = 0
tile_n = 0
tile_b = 0
offs_m = 0
offs_n = 0
offs_b = 0
a_desc = tl.make_tensor_descriptor(a_ptr + offs_b * (M * K), [M, K], [K, 1], [BLOCK_M, BLOCK_K])
b_desc = tl.make_tensor_descriptor(b_ptr + offs_b * (N * K), [N, K], [K, 1], [BLOCK_N, BLOCK_K])
c_desc = tl.make_tensor_descriptor(c_ptr + offs_b * (M * N), [M, N], [N, 1], [BLOCK_M, BLOCK_N])
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for _ in range(k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
tile_id += NUM_SMS
tile_b = tile_id // num_tiles_per_batch
tile_m = (tile_id // num_tiles_n) % num_tiles_m
tile_n = tile_id % num_tiles_n
offs_b = tile_b
offs_m = tile_m * BLOCK_M
offs_n = tile_n * BLOCK_N
a_desc = tl.make_tensor_descriptor(a_ptr + offs_b * (M * K), [M, K], [K, 1], [BLOCK_M, BLOCK_K])
b_desc = tl.make_tensor_descriptor(b_ptr + offs_b * (N * K), [N, K], [K, 1], [BLOCK_N, BLOCK_K])
c_desc = tl.make_tensor_descriptor(c_ptr + offs_b * (M * N), [M, N], [N, 1], [BLOCK_M, BLOCK_N])
offs_k = ki * BLOCK_K
a = a_desc.load([offs_m, offs_k])
b = b_desc.load([offs_n, offs_k])
accumulator = tl.dot(a, b.T, accumulator)
if ki == k_tiles - 1:
c = accumulator.to(dtype)
c_desc.store([offs_m, offs_n], c)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

The current implementation in the upstream assumes that the tensor descriptor definition and all uses are in the same basic block which is not always true.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is temporary stop gap to avoid crashes; will be replaced with more proper implementation very soon.

Comment on lines 90 to 95
if (!sharedEnc) {
// TODO: add an extra pass to assign layout to descriptors
sharedEnc = findEncodingFromUsers(op);
if (!sharedEnc)
return rewriter.notifyMatchFailure(op, "Descriptor has no layout.");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is temporary stop gap to avoid crashes; will be replaced with more proper implementation very soon.

@antiagainst antiagainst marked this pull request as ready for review March 6, 2026 02:29
@antiagainst antiagainst requested a review from zhanglx13 as a code owner March 6, 2026 02:29
@antiagainst antiagainst merged commit 4b986a0 into triton-lang:main Mar 6, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants