Skip to content

Commit 0d7bec5

Browse files
bangtianliuziereis
authored andcommitted
[Codegen] ukernel support for argmax bf16 (iree-org#20768)
This PR mainly add two changes relevant to ukernel support for argmax op. **Change 1: BF16 support** this PR adds the uKernels iree_uk_amdgpu_argmax_bf16i64 and iree_uk_amdgpu_argmax_bf16i32. These implementations are adapted from the existing float versions, as low-level bf16 support typically involves converting to float32 for computation. **Change 2: Support Returning the Maximum Value** Previously, the compiler restricted argmax lowering to pure-index-only cases using the following check: ```C++ // If max value is being used, it is not a pure argmax. if (!genericOp.getResults()[0].use_empty()) { return false; } ``` This check has been removed to enable the microkernel to support both: - Returning only the index (pure argmax) - Returning both the maximum value and its index To support this at the ukernel level, I added a writeValue boolean flag to control whether the value output should be written. Once this LLVM PR is integrated: llvm/llvm-project#140775, I plan to further simplify the implementation by checking whether outputBufferVal is nullptr instead of relying on an explicit flag. The PR also includes a corresponding test. In addition, I performed manual correctness checks to validate the behavior, I put my scripts here: https://github.com/bangtianliu/work-scripts/tree/master/argmax. Issue: iree-org#20650 --------- Signed-off-by: Bangtian Liu <[email protected]> Signed-off-by: Thomas Ziereis <[email protected]>
1 parent 63d6fa9 commit 0d7bec5

File tree

13 files changed

+665
-39
lines changed

13 files changed

+665
-39
lines changed

compiler/plugins/target/ROCM/builtins/ukernel/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ gpu_archs = [
3737

3838
# Element type combinations for the argmax ukernel.
3939
argmax_types = [
40+
"bf16i32",
41+
"bf16i64",
4042
"f16i32",
4143
"f16i64",
4244
"f32i32",

compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,102 @@ if(NOT IREE_TARGET_BACKEND_ROCM)
1414
return()
1515
endif()
1616

17+
iree_amdgpu_bitcode_library(
18+
NAME
19+
iree_uk_amdgpu_argmax_bf16i32_gfx90a
20+
GPU_ARCH
21+
gfx90a
22+
SRCS
23+
"common.h"
24+
"iree_uk_amdgpu_argmax_bf16i32.c"
25+
OUT
26+
"iree_uk_amdgpu_argmax_bf16i32.gfx90a.bc"
27+
)
28+
29+
iree_amdgpu_bitcode_library(
30+
NAME
31+
iree_uk_amdgpu_argmax_bf16i32_gfx942
32+
GPU_ARCH
33+
gfx942
34+
SRCS
35+
"common.h"
36+
"iree_uk_amdgpu_argmax_bf16i32.c"
37+
OUT
38+
"iree_uk_amdgpu_argmax_bf16i32.gfx942.bc"
39+
)
40+
41+
iree_amdgpu_bitcode_library(
42+
NAME
43+
iree_uk_amdgpu_argmax_bf16i32_gfx1030
44+
GPU_ARCH
45+
gfx1030
46+
SRCS
47+
"common.h"
48+
"iree_uk_amdgpu_argmax_bf16i32.c"
49+
OUT
50+
"iree_uk_amdgpu_argmax_bf16i32.gfx1030.bc"
51+
)
52+
53+
iree_amdgpu_bitcode_library(
54+
NAME
55+
iree_uk_amdgpu_argmax_bf16i32_gfx1100
56+
GPU_ARCH
57+
gfx1100
58+
SRCS
59+
"common.h"
60+
"iree_uk_amdgpu_argmax_bf16i32.c"
61+
OUT
62+
"iree_uk_amdgpu_argmax_bf16i32.gfx1100.bc"
63+
)
64+
65+
iree_amdgpu_bitcode_library(
66+
NAME
67+
iree_uk_amdgpu_argmax_bf16i64_gfx90a
68+
GPU_ARCH
69+
gfx90a
70+
SRCS
71+
"common.h"
72+
"iree_uk_amdgpu_argmax_bf16i64.c"
73+
OUT
74+
"iree_uk_amdgpu_argmax_bf16i64.gfx90a.bc"
75+
)
76+
77+
iree_amdgpu_bitcode_library(
78+
NAME
79+
iree_uk_amdgpu_argmax_bf16i64_gfx942
80+
GPU_ARCH
81+
gfx942
82+
SRCS
83+
"common.h"
84+
"iree_uk_amdgpu_argmax_bf16i64.c"
85+
OUT
86+
"iree_uk_amdgpu_argmax_bf16i64.gfx942.bc"
87+
)
88+
89+
iree_amdgpu_bitcode_library(
90+
NAME
91+
iree_uk_amdgpu_argmax_bf16i64_gfx1030
92+
GPU_ARCH
93+
gfx1030
94+
SRCS
95+
"common.h"
96+
"iree_uk_amdgpu_argmax_bf16i64.c"
97+
OUT
98+
"iree_uk_amdgpu_argmax_bf16i64.gfx1030.bc"
99+
)
100+
101+
iree_amdgpu_bitcode_library(
102+
NAME
103+
iree_uk_amdgpu_argmax_bf16i64_gfx1100
104+
GPU_ARCH
105+
gfx1100
106+
SRCS
107+
"common.h"
108+
"iree_uk_amdgpu_argmax_bf16i64.c"
109+
OUT
110+
"iree_uk_amdgpu_argmax_bf16i64.gfx1100.bc"
111+
)
112+
17113
iree_amdgpu_bitcode_library(
18114
NAME
19115
iree_uk_amdgpu_argmax_f16i32_gfx90a
@@ -222,6 +318,14 @@ iree_c_embed_data(
222318
NAME
223319
iree_uk_amdgpu_bitcode
224320
SRCS
321+
"iree_uk_amdgpu_argmax_bf16i32.gfx1030.bc"
322+
"iree_uk_amdgpu_argmax_bf16i32.gfx1100.bc"
323+
"iree_uk_amdgpu_argmax_bf16i32.gfx90a.bc"
324+
"iree_uk_amdgpu_argmax_bf16i32.gfx942.bc"
325+
"iree_uk_amdgpu_argmax_bf16i64.gfx1030.bc"
326+
"iree_uk_amdgpu_argmax_bf16i64.gfx1100.bc"
327+
"iree_uk_amdgpu_argmax_bf16i64.gfx90a.bc"
328+
"iree_uk_amdgpu_argmax_bf16i64.gfx942.bc"
225329
"iree_uk_amdgpu_argmax_f16i32.gfx1030.bc"
226330
"iree_uk_amdgpu_argmax_f16i32.gfx1100.bc"
227331
"iree_uk_amdgpu_argmax_f16i32.gfx90a.bc"
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h"
8+
9+
[[clang::always_inline]] void iree_uk_amdgpu_argmax_bf16i32(
10+
const __bf16 *inputBuffer, int64_t input_offset, __bf16 *outputBufferVal,
11+
int64_t output_val_offset, int32_t *outputBufferIdx,
12+
int64_t output_idx_offset, int64_t reductionSize, bool writeValue) {
13+
// NOTE:
14+
// We convert bf16 inputs to f32 before computation because HIP/OCKL and
15+
// Clang/LLVM do not currently support native arithmetic or comparisons on
16+
// bf16. In practice, these operations are internally performed by first
17+
// converting bf16 to float.
18+
const int warpSize = __builtin_amdgcn_wavefrontsize();
19+
int32_t laneID = __builtin_amdgcn_workitem_id_x();
20+
// Set identity value to handle problem non divisible by subgroupSize.
21+
float laneMax = laneID >= reductionSize
22+
? -FLT_MAX
23+
: (float)(inputBuffer[input_offset + laneID]);
24+
int32_t laneResult = laneID;
25+
26+
// NOTE: On F32 kernels with clang, reductionSize/blockDim.x has numerical
27+
// inaccuracy.
28+
int32_t numBatches = (reductionSize + warpSize - 1) / warpSize;
29+
for (int i = 1; i < numBatches; ++i) {
30+
int32_t idx = warpSize * i + laneID;
31+
float newIn = idx >= reductionSize
32+
? -FLT_MAX
33+
: (float)(inputBuffer[input_offset + idx]);
34+
if (newIn == laneMax)
35+
continue;
36+
laneMax = __builtin_fmaxf(newIn, laneMax);
37+
laneResult = newIn == laneMax ? idx : laneResult;
38+
}
39+
40+
// Final reduction with one subgroup
41+
// NOTE: __ockl_wfred_max_f32 has correctness issue on gfx1100 documented
42+
// on https://github.com/iree-org/iree/issues/16112.
43+
float wgMax = laneMax;
44+
for (int i = 1; i < warpSize; i *= 2) {
45+
wgMax = __builtin_fmaxf(__shfl_xor_f(wgMax, i), wgMax);
46+
}
47+
// Check if there are multiple max value holders.
48+
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
49+
// if there is only one max value holder, write and exit.
50+
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
51+
if (wgMax == laneMax) {
52+
if (writeValue) {
53+
outputBufferVal[output_val_offset] = (__bf16)wgMax;
54+
}
55+
outputBufferIdx[output_idx_offset] = laneResult;
56+
}
57+
} else {
58+
// if there are multiple max value holder, find smallest index (argmax
59+
// semantics).
60+
int32_t indexVal = wgMax == laneMax ? laneResult : __INT32_MAX__;
61+
laneResult = __ockl_wfred_min_i32(indexVal);
62+
if (laneID == 0) {
63+
if (writeValue) {
64+
outputBufferVal[output_val_offset] = (__bf16)wgMax;
65+
}
66+
outputBufferIdx[output_idx_offset] = laneResult;
67+
}
68+
}
69+
// TODO(bjacob): this fence should be on the caller side. Move to TileAndFuse?
70+
__threadfence_block();
71+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h"
8+
9+
[[clang::always_inline]] void iree_uk_amdgpu_argmax_bf16i64(
10+
const __bf16 *inputBuffer, int64_t input_offset, __bf16 *outputBufferVal,
11+
int64_t output_val_offset, int64_t *outputBufferIdx,
12+
int64_t output_idx_offset, int64_t reductionSize, bool writeValue) {
13+
// NOTE:
14+
// We convert bf16 inputs to f32 before computation because HIP/OCKL and
15+
// Clang/LLVM do not currently support native arithmetic or comparisons on
16+
// bf16. In practice, these operations are internally performed by first
17+
// converting bf16 to float.
18+
const int warpSize = __builtin_amdgcn_wavefrontsize();
19+
int32_t laneID = __builtin_amdgcn_workitem_id_x();
20+
// Set identity value to handle problem non divisible by subgroupSize.
21+
float laneMax = laneID >= reductionSize
22+
? -FLT_MAX
23+
: (float)(inputBuffer[input_offset + laneID]);
24+
int64_t laneResult = laneID;
25+
26+
// NOTE: On F32 kernels with clang, reductionSize/blockDim.x has numerical
27+
// inaccuracy.
28+
int32_t numBatches = (reductionSize + warpSize - 1) / warpSize;
29+
for (int i = 1; i < numBatches; ++i) {
30+
int32_t idx = warpSize * i + laneID;
31+
float newIn = idx >= reductionSize
32+
? -FLT_MAX
33+
: (float)(inputBuffer[input_offset + idx]);
34+
if (newIn == laneMax)
35+
continue;
36+
laneMax = __builtin_fmaxf(newIn, laneMax);
37+
laneResult = newIn == laneMax ? idx : laneResult;
38+
}
39+
40+
// Final reduction with one subgroup
41+
// NOTE: __ockl_wfred_max_f32 has correctness issue on gfx1100 documented on
42+
// https://github.com/iree-org/iree/issues/16112.
43+
float wgMax = laneMax;
44+
for (int i = 1; i < warpSize; i *= 2) {
45+
wgMax = __builtin_fmaxf(__shfl_xor_f(wgMax, i), wgMax);
46+
}
47+
// Check if there are multiple max value holders.
48+
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
49+
// if there is only one max value holder, write and exit.
50+
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
51+
if (wgMax == laneMax) {
52+
if (writeValue) {
53+
outputBufferVal[output_val_offset] = (__bf16)wgMax;
54+
}
55+
outputBufferIdx[output_idx_offset] = laneResult;
56+
}
57+
} else {
58+
// if there are multiple max value holder, find smallest index (argmax
59+
// semantics).
60+
int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX;
61+
laneResult = __ockl_wfred_min_i64(indexVal);
62+
if (laneID == 0) {
63+
if (writeValue) {
64+
outputBufferVal[output_val_offset] = (__bf16)wgMax;
65+
}
66+
outputBufferIdx[output_idx_offset] = laneResult;
67+
}
68+
}
69+
// TODO(bjacob): this fence should be on the caller side. Move to TileAndFuse?
70+
__threadfence_block();
71+
}

compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i32.c

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h"
88

9-
[[clang::always_inline]] void
10-
iree_uk_amdgpu_argmax_f16i32(const _Float16 *inputBuffer, int64_t input_offset,
11-
int32_t *outputBuffer, int64_t output_offset,
12-
int64_t reductionSize) {
9+
[[clang::always_inline]] void iree_uk_amdgpu_argmax_f16i32(
10+
const _Float16 *inputBuffer, int64_t input_offset,
11+
_Float16 *outputBufferVal, int64_t output_val_offset,
12+
int32_t *outputBufferIdx, int64_t output_idx_offset, int64_t reductionSize,
13+
bool writeValue) {
1314
const int warpSize = __builtin_amdgcn_wavefrontsize();
1415
_Float16 NEG_F16_MAX = (_Float16)(-65504.0f);
1516
int32_t laneID = __builtin_amdgcn_workitem_id_x();
@@ -36,15 +37,21 @@ iree_uk_amdgpu_argmax_f16i32(const _Float16 *inputBuffer, int64_t input_offset,
3637
// if there is only one max value holder, write and exit.
3738
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
3839
if (wgMax == laneMax) {
39-
outputBuffer[output_offset] = laneResult;
40+
if (writeValue) {
41+
outputBufferVal[output_val_offset] = wgMax;
42+
}
43+
outputBufferIdx[output_idx_offset] = laneResult;
4044
}
4145
} else {
4246
// if there are multiple max value holder, find smallest index (argmax
4347
// semantics).
4448
int32_t indexVal = wgMax == laneMax ? laneResult : __INT32_MAX__;
4549
laneResult = __ockl_wfred_min_i32(indexVal);
4650
if (laneID == 0) {
47-
outputBuffer[output_offset] = laneResult;
51+
if (writeValue) {
52+
outputBufferVal[output_val_offset] = wgMax;
53+
}
54+
outputBufferIdx[output_idx_offset] = laneResult;
4855
}
4956
}
5057
// TODO(bjacob): this fence should be on the caller side. Move to TileAndFuse?

compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i64.c

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h"
88

9-
[[clang::always_inline]] void
10-
iree_uk_amdgpu_argmax_f16i64(const _Float16 *inputBuffer, int64_t input_offset,
11-
int64_t *outputBuffer, int64_t output_offset,
12-
int64_t reductionSize) {
9+
[[clang::always_inline]] void iree_uk_amdgpu_argmax_f16i64(
10+
const _Float16 *inputBuffer, int64_t input_offset,
11+
_Float16 *outputBufferVal, int64_t output_val_offset,
12+
int64_t *outputBufferIdx, int64_t output_idx_offset, int64_t reductionSize,
13+
bool writeValue) {
1314
const int warpSize = __builtin_amdgcn_wavefrontsize();
1415
_Float16 NEG_F16_MAX = (_Float16)(-65504.0f);
1516
int32_t laneID = __builtin_amdgcn_workitem_id_x();
@@ -37,15 +38,21 @@ iree_uk_amdgpu_argmax_f16i64(const _Float16 *inputBuffer, int64_t input_offset,
3738
// if there is only one max value holder, write and exit.
3839
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
3940
if (wgMax == laneMax) {
40-
outputBuffer[output_offset] = laneResult;
41+
if (writeValue) {
42+
outputBufferVal[output_val_offset] = wgMax;
43+
}
44+
outputBufferIdx[output_idx_offset] = laneResult;
4145
}
4246
} else {
4347
// if there are multiple max value holder, find smallest index (argmax
4448
// semantics).
4549
int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX;
4650
laneResult = __ockl_wfred_min_i64(indexVal);
4751
if (laneID == 0) {
48-
outputBuffer[output_offset] = laneResult;
52+
if (writeValue) {
53+
outputBufferVal[output_val_offset] = wgMax;
54+
}
55+
outputBufferIdx[output_idx_offset] = laneResult;
4956
}
5057
}
5158
// TODO(bjacob): this fence should be on the caller side. Move to TileAndFuse?

compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i32.c

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h"
88

9-
[[clang::always_inline]] void
10-
iree_uk_amdgpu_argmax_f32i32(const float *inputBuffer, int64_t input_offset,
11-
int32_t *outputBuffer, int64_t output_offset,
12-
int64_t reductionSize) {
9+
[[clang::always_inline]] void iree_uk_amdgpu_argmax_f32i32(
10+
const float *inputBuffer, int64_t input_offset, float *outputBufferVal,
11+
int64_t output_val_offset, int32_t *outputBufferIdx,
12+
int64_t output_idx_offset, int64_t reductionSize, bool writeValue) {
1313
const int warpSize = __builtin_amdgcn_wavefrontsize();
1414
int32_t laneID = __builtin_amdgcn_workitem_id_x();
1515
// Set identity value to handle problem non divisible by subgroupSize.
@@ -42,15 +42,21 @@ iree_uk_amdgpu_argmax_f32i32(const float *inputBuffer, int64_t input_offset,
4242
// if there is only one max value holder, write and exit.
4343
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
4444
if (wgMax == laneMax) {
45-
outputBuffer[output_offset] = laneResult;
45+
if (writeValue) {
46+
outputBufferVal[output_val_offset] = wgMax;
47+
}
48+
outputBufferIdx[output_idx_offset] = laneResult;
4649
}
4750
} else {
4851
// if there are multiple max value holder, find smallest index (argmax
4952
// semantics).
5053
int32_t indexVal = wgMax == laneMax ? laneResult : __INT32_MAX__;
5154
laneResult = __ockl_wfred_min_i32(indexVal);
5255
if (laneID == 0) {
53-
outputBuffer[output_offset] = laneResult;
56+
if (writeValue) {
57+
outputBufferVal[output_val_offset] = wgMax;
58+
}
59+
outputBufferIdx[output_idx_offset] = laneResult;
5460
}
5561
}
5662
// TODO(bjacob): this fence should be on the caller side. Move to TileAndFuse?

0 commit comments

Comments
 (0)