Skip to content

Commit 2b2bd51

Browse files
authored
[flang][cuda] Inline this_grid call for cooperative groups (llvm#145796)
1 parent f63bc84 commit 2b2bd51

File tree

5 files changed

+179
-5
lines changed

5 files changed

+179
-5
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ struct IntrinsicLibrary {
442442
llvm::ArrayRef<fir::ExtendedValue>);
443443
fir::ExtendedValue genTranspose(mlir::Type,
444444
llvm::ArrayRef<fir::ExtendedValue>);
445+
mlir::Value genThisGrid(mlir::Type, llvm::ArrayRef<mlir::Value>);
445446
void genThreadFence(llvm::ArrayRef<fir::ExtendedValue>);
446447
void genThreadFenceBlock(llvm::ArrayRef<fir::ExtendedValue>);
447448
void genThreadFenceSystem(llvm::ArrayRef<fir::ExtendedValue>);

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,7 @@ static constexpr IntrinsicHandler handlers[]{
932932
{{{"count", asAddr}, {"count_rate", asAddr}, {"count_max", asAddr}}},
933933
/*isElemental=*/false},
934934
{"tand", &I::genTand},
935+
{"this_grid", &I::genThisGrid, {}, /*isElemental=*/false},
935936
{"threadfence", &I::genThreadFence, {}, /*isElemental=*/false},
936937
{"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false},
937938
{"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false},
@@ -8109,6 +8110,90 @@ mlir::Value IntrinsicLibrary::genTand(mlir::Type resultType,
81098110
return getRuntimeCallGenerator("tan", ftype)(builder, loc, {arg});
81108111
}
81118112

8113+
// THIS_GRID
8114+
mlir::Value IntrinsicLibrary::genThisGrid(mlir::Type resultType,
8115+
llvm::ArrayRef<mlir::Value> args) {
8116+
assert(args.size() == 0);
8117+
auto recTy = mlir::cast<fir::RecordType>(resultType);
8118+
assert(recTy && "RecordType expepected");
8119+
mlir::Value res = builder.create<fir::AllocaOp>(loc, resultType);
8120+
mlir::Type i32Ty = builder.getI32Type();
8121+
8122+
mlir::Value threadIdX = builder.create<mlir::NVVM::ThreadIdXOp>(loc, i32Ty);
8123+
mlir::Value threadIdY = builder.create<mlir::NVVM::ThreadIdYOp>(loc, i32Ty);
8124+
mlir::Value threadIdZ = builder.create<mlir::NVVM::ThreadIdZOp>(loc, i32Ty);
8125+
8126+
mlir::Value blockIdX = builder.create<mlir::NVVM::BlockIdXOp>(loc, i32Ty);
8127+
mlir::Value blockIdY = builder.create<mlir::NVVM::BlockIdYOp>(loc, i32Ty);
8128+
mlir::Value blockIdZ = builder.create<mlir::NVVM::BlockIdZOp>(loc, i32Ty);
8129+
8130+
mlir::Value blockDimX = builder.create<mlir::NVVM::BlockDimXOp>(loc, i32Ty);
8131+
mlir::Value blockDimY = builder.create<mlir::NVVM::BlockDimYOp>(loc, i32Ty);
8132+
mlir::Value blockDimZ = builder.create<mlir::NVVM::BlockDimZOp>(loc, i32Ty);
8133+
mlir::Value gridDimX = builder.create<mlir::NVVM::GridDimXOp>(loc, i32Ty);
8134+
mlir::Value gridDimY = builder.create<mlir::NVVM::GridDimYOp>(loc, i32Ty);
8135+
mlir::Value gridDimZ = builder.create<mlir::NVVM::GridDimZOp>(loc, i32Ty);
8136+
8137+
// this_grid.size = ((blockDim.z * gridDim.z) * (blockDim.y * gridDim.y)) *
8138+
// (blockDim.x * gridDim.x);
8139+
mlir::Value resZ =
8140+
builder.create<mlir::arith::MulIOp>(loc, blockDimZ, gridDimZ);
8141+
mlir::Value resY =
8142+
builder.create<mlir::arith::MulIOp>(loc, blockDimY, gridDimY);
8143+
mlir::Value resX =
8144+
builder.create<mlir::arith::MulIOp>(loc, blockDimX, gridDimX);
8145+
mlir::Value resZY = builder.create<mlir::arith::MulIOp>(loc, resZ, resY);
8146+
mlir::Value size = builder.create<mlir::arith::MulIOp>(loc, resZY, resX);
8147+
8148+
// tmp = ((blockIdx.z * gridDim.y * gridDim.x) + (blockIdx.y * gridDim.x)) +
8149+
// blockIdx.x;
8150+
// this_group.rank = tmp * ((blockDim.x * blockDim.y) * blockDim.z) +
8151+
// ((threadIdx.z * blockDim.y) * blockDim.x) +
8152+
// (threadIdx.y * blockDim.x) + threadIdx.x + 1;
8153+
mlir::Value r1 = builder.create<mlir::arith::MulIOp>(loc, blockIdZ, gridDimY);
8154+
mlir::Value r2 = builder.create<mlir::arith::MulIOp>(loc, r1, gridDimX);
8155+
mlir::Value r3 = builder.create<mlir::arith::MulIOp>(loc, blockIdY, gridDimX);
8156+
mlir::Value r2r3 = builder.create<mlir::arith::AddIOp>(loc, r2, r3);
8157+
mlir::Value tmp = builder.create<mlir::arith::AddIOp>(loc, r2r3, blockIdX);
8158+
8159+
mlir::Value bXbY =
8160+
builder.create<mlir::arith::MulIOp>(loc, blockDimX, blockDimY);
8161+
mlir::Value bXbYbZ =
8162+
builder.create<mlir::arith::MulIOp>(loc, bXbY, blockDimZ);
8163+
mlir::Value tZbY =
8164+
builder.create<mlir::arith::MulIOp>(loc, threadIdZ, blockDimY);
8165+
mlir::Value tZbYbX =
8166+
builder.create<mlir::arith::MulIOp>(loc, tZbY, blockDimX);
8167+
mlir::Value tYbX =
8168+
builder.create<mlir::arith::MulIOp>(loc, threadIdY, blockDimX);
8169+
mlir::Value rank = builder.create<mlir::arith::MulIOp>(loc, tmp, bXbYbZ);
8170+
rank = builder.create<mlir::arith::AddIOp>(loc, rank, tZbYbX);
8171+
rank = builder.create<mlir::arith::AddIOp>(loc, rank, tYbX);
8172+
rank = builder.create<mlir::arith::AddIOp>(loc, rank, threadIdX);
8173+
mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
8174+
rank = builder.create<mlir::arith::AddIOp>(loc, rank, one);
8175+
8176+
auto sizeFieldName = recTy.getTypeList()[1].first;
8177+
mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
8178+
mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
8179+
mlir::Value sizeFieldIndex = builder.create<fir::FieldIndexOp>(
8180+
loc, fieldIndexType, sizeFieldName, recTy,
8181+
/*typeParams=*/mlir::ValueRange{});
8182+
mlir::Value sizeCoord = builder.create<fir::CoordinateOp>(
8183+
loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
8184+
builder.create<fir::StoreOp>(loc, size, sizeCoord);
8185+
8186+
auto rankFieldName = recTy.getTypeList()[2].first;
8187+
mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
8188+
mlir::Value rankFieldIndex = builder.create<fir::FieldIndexOp>(
8189+
loc, fieldIndexType, rankFieldName, recTy,
8190+
/*typeParams=*/mlir::ValueRange{});
8191+
mlir::Value rankCoord = builder.create<fir::CoordinateOp>(
8192+
loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
8193+
builder.create<fir::StoreOp>(loc, rank, rankCoord);
8194+
return res;
8195+
}
8196+
81128197
// TRAILZ
81138198
mlir::Value IntrinsicLibrary::genTrailz(mlir::Type resultType,
81148199
llvm::ArrayRef<mlir::Value> args) {

flang/module/cooperative_groups.f90

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
!===-- module/cooperative_groups.f90 ---------------------------------------===!
2+
!
3+
! Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
! See https://llvm.org/LICENSE.txt for license information.
5+
! SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
!
7+
!===------------------------------------------------------------------------===!
8+
9+
! CUDA Fortran cooperative groups
10+
11+
module cooperative_groups
12+
13+
use, intrinsic :: __fortran_builtins, only: c_devptr => __builtin_c_devptr
14+
15+
implicit none
16+
17+
type :: grid_group
18+
type(c_devptr), private :: handle
19+
integer(4) :: size
20+
integer(4) :: rank
21+
end type grid_group
22+
23+
interface
24+
attributes(device) function this_grid()
25+
import
26+
type(grid_group) :: this_grid
27+
end function
28+
end interface
29+
30+
end module
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
2+
3+
! Test CUDA Fortran procedures available in cooperative_groups module.
4+
5+
attributes(grid_global) subroutine g1()
6+
use cooperative_groups
7+
type(grid_group) :: gg
8+
gg = this_grid()
9+
end subroutine
10+
11+
! CHECK: %{{.*}} = fir.alloca !fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
12+
! CHECK: %[[RES:.*]] = fir.alloca !fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
13+
! CHECK: %[[THREAD_ID_X:.*]] = nvvm.read.ptx.sreg.tid.x : i32
14+
! CHECK: %[[THREAD_ID_Y:.*]] = nvvm.read.ptx.sreg.tid.y : i32
15+
! CHECK: %[[THREAD_ID_Z:.*]] = nvvm.read.ptx.sreg.tid.z : i32
16+
! CHECK: %[[BLOCK_ID_X:.*]] = nvvm.read.ptx.sreg.ctaid.x : i32
17+
! CHECK: %[[BLOCK_ID_Y:.*]] = nvvm.read.ptx.sreg.ctaid.y : i32
18+
! CHECK: %[[BLOCK_ID_Z:.*]] = nvvm.read.ptx.sreg.ctaid.z : i32
19+
! CHECK: %[[BLOCK_DIM_X:.*]] = nvvm.read.ptx.sreg.ntid.x : i32
20+
! CHECK: %[[BLOCK_DIM_Y:.*]] = nvvm.read.ptx.sreg.ntid.y : i32
21+
! CHECK: %[[BLOCK_DIM_Z:.*]] = nvvm.read.ptx.sreg.ntid.z : i32
22+
! CHECK: %[[GRID_DIM_X:.*]] = nvvm.read.ptx.sreg.nctaid.x : i32
23+
! CHECK: %[[GRID_DIM_Y:.*]] = nvvm.read.ptx.sreg.nctaid.y : i32
24+
! CHECK: %[[GRID_DIM_Z:.*]] = nvvm.read.ptx.sreg.nctaid.z : i32
25+
26+
! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_DIM_Z]], %[[GRID_DIM_Z]] : i32
27+
! CHECK: %[[R2:.*]] = arith.muli %[[BLOCK_DIM_Y]], %[[GRID_DIM_Y]] : i32
28+
! CHECK: %[[R3:.*]] = arith.muli %[[BLOCK_DIM_X]], %[[GRID_DIM_X]] : i32
29+
! CHECK: %[[R4:.*]] = arith.muli %[[R1]], %[[R2]] : i32
30+
! CHECK: %[[SIZE:.*]] = arith.muli %[[R4]], %[[R3]] : i32
31+
32+
! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_ID_Z]], %[[GRID_DIM_Y]] : i32
33+
! CHECK: %[[R2:.*]] = arith.muli %[[R1]], %[[GRID_DIM_X]] : i32
34+
! CHECK: %[[R3:.*]] = arith.muli %[[BLOCK_ID_Y]], %[[GRID_DIM_X]] : i32
35+
! CHECK: %[[R4:.*]] = arith.addi %[[R2]], %[[R3]] : i32
36+
! CHECK: %[[TMP:.*]] = arith.addi %[[R4]], %[[BLOCK_ID_X]] : i32
37+
38+
! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_DIM_X]], %[[BLOCK_DIM_Y]] : i32
39+
! CHECK: %[[R2:.*]] = arith.muli %[[R1]], %[[BLOCK_DIM_Z]] : i32
40+
! CHECK: %[[R3:.*]] = arith.muli %[[THREAD_ID_Z]], %[[BLOCK_DIM_Y]] : i32
41+
! CHECK: %[[R4:.*]] = arith.muli %[[R3]], %[[BLOCK_DIM_X]] : i32
42+
! CHECK: %[[R5:.*]] = arith.muli %[[THREAD_ID_Y]], %[[BLOCK_DIM_X]] : i32
43+
! CHECK: %[[RES0:.*]] = arith.muli %[[TMP]], %[[R2]] : i32
44+
! CHECK: %[[RES1:.*]] = arith.addi %[[RES0]], %[[R4]] : i32
45+
! CHECK: %[[RES2:.*]] = arith.addi %[[RES1]], %[[R5]] : i32
46+
! CHECK: %[[RES3:.*]] = arith.addi %[[RES2]], %[[THREAD_ID_X]] : i32
47+
! CHECK: %[[ONE:.*]] = arith.constant 1 : i32
48+
! CHECK: %[[RANK:.*]] = arith.addi %[[RES3]], %[[ONE]] : i32
49+
! CHECK: %[[COORD_SIZE:.*]] = fir.coordinate_of %[[RES]], size : (!fir.ref<!fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
50+
! CHECK: fir.store %[[SIZE]] to %[[COORD_SIZE]] : !fir.ref<i32>
51+
! CHECK: %[[COORD_RANK:.*]] = fir.coordinate_of %[[RES]], rank : (!fir.ref<!fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
52+
! CHECK: fir.store %[[RANK]] to %[[COORD_RANK]] : !fir.ref<i32>

flang/tools/f18/CMakeLists.txt

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ set(MODULES
1515
"mma"
1616
"__cuda_builtins"
1717
"__cuda_device"
18+
"cooperative_groups"
1819
"cudadevice"
1920
"ieee_arithmetic"
2021
"ieee_exceptions"
@@ -60,12 +61,17 @@ if (NOT CMAKE_CROSSCOMPILING)
6061
elseif(${filename} STREQUAL "__ppc_intrinsics" OR
6162
${filename} STREQUAL "mma")
6263
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__ppc_types.mod)
63-
elseif(${filename} STREQUAL "__cuda_device")
64+
elseif(${filename} STREQUAL "__cuda_device" OR
65+
${filename} STREQUAL "cudadevice" OR
66+
${filename} STREQUAL "cooperative_groups")
6467
set(opts -fc1 -xcuda)
65-
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_builtins.mod)
66-
elseif(${filename} STREQUAL "cudadevice")
67-
set(opts -fc1 -xcuda)
68-
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_device.mod)
68+
if(${filename} STREQUAL "__cuda_device")
69+
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_builtins.mod)
70+
elseif(${filename} STREQUAL "cudadevice")
71+
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_device.mod)
72+
elseif(${filename} STREQUAL "cooperative_groups")
73+
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/cudadevice.mod)
74+
endif()
6975
else()
7076
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__fortran_builtins.mod)
7177
if(${filename} STREQUAL "iso_fortran_env")

0 commit comments

Comments
 (0)