Skip to content

Commit c57b2a0

Browse files
committed
[MLIR][GPU] Make max flat work group size for ROCDL kernels configurable
While the default value for the amdgpu-flat-work-group-size attribute, "1, 256", matches the defaults from Clang, some users of the ROCDL dialect, namely Tensorflow, use larger workgroups, such as 1024. Therefore, instead of hardcoding this value, we add a rocdl.max_flat_work_group_size attribute that can be set on GPU kernels to override the default value. Reviewed By: whchung Differential Revision: https://reviews.llvm.org/D115741
1 parent 100863c commit c57b2a0

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "llvm/IR/IRBuilder.h"
2020
#include "llvm/IR/IntrinsicsAMDGPU.h"
21+
#include "llvm/Support/raw_ostream.h"
2122

2223
using namespace mlir;
2324
using namespace mlir::LLVM;
@@ -71,15 +72,34 @@ class ROCDLDialectLLVMIRTranslationInterface
7172

7273
// For GPU kernels,
7374
// 1. Insert AMDGPU_KERNEL calling convention.
74-
// 2. Insert amdgpu-flat-workgroup-size(1, 256) attribute.
75+
// 2. Insert amdgpu-flat-work-group-size(1, 256) attribute unless the user
76+
// has overriden this value - 256 is the default in clang
7577
// 3. Insert amdgpu-implicitarg-num-bytes=56 (which must be set on OpenCL
7678
// and HIP kernels per Clang)
7779
llvm::Function *llvmFunc =
7880
moduleTranslation.lookupFunction(func.getName());
7981
llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
80-
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", "1, 256");
82+
if (!llvmFunc->hasFnAttribute("amdgpu-flat-work-group-size")) {
83+
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", "1, 256");
84+
}
8185
llvmFunc->addFnAttr("amdgpu-implicitarg-num-bytes", "56");
8286
}
87+
// Override flat-work-group-size
88+
if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
89+
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
90+
if (!func)
91+
return failure();
92+
auto value = attribute.getValue().dyn_cast<IntegerAttr>();
93+
if (!value)
94+
return failure();
95+
96+
llvm::Function *llvmFunc =
97+
moduleTranslation.lookupFunction(func.getName());
98+
llvm::SmallString<8> llvmAttrValue;
99+
llvm::raw_svector_ostream attrValueStream(llvmAttrValue);
100+
attrValueStream << "1, " << value.getInt();
101+
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
102+
}
83103
return success();
84104
}
85105
};

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,15 @@ llvm.func @rocdl_special_regs() -> i32 {
3030
}
3131

3232
llvm.func @kernel_func() attributes {rocdl.kernel} {
33-
// CHECK-LABEL: amdgpu_kernel void @kernel_func
33+
// CHECK-LABEL: amdgpu_kernel void @kernel_func()
34+
// CHECK: #[[$KERNEL_ATTRS:[0-9]+]]
35+
llvm.return
36+
}
37+
38+
llvm.func @kernel_func_workgroups()
39+
attributes {rocdl.kernel, rocdl.max_flat_work_group_size = 1024 : index} {
40+
// CHECK-LABEL: amdgpu_kernel void @kernel_func_workgroups()
41+
// CHECK: #[[$KERNEL_WORKGROUP_ATTRS:[0-9]+]]
3442
llvm.return
3543
}
3644

@@ -177,3 +185,5 @@ llvm.func @rocdl.mubuf(%rsrc : vector<4xi32>, %vindex : i32,
177185
llvm.return
178186
}
179187

188+
// CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1, 256" "amdgpu-implicitarg-num-bytes"="56" }
189+
// CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1, 1024"

0 commit comments

Comments
 (0)