@@ -932,6 +932,7 @@ static constexpr IntrinsicHandler handlers[]{
932
932
{{{" count" , asAddr}, {" count_rate" , asAddr}, {" count_max" , asAddr}}},
933
933
/* isElemental=*/ false },
934
934
{" tand" , &I::genTand},
935
+ {" this_grid" , &I::genThisGrid, {}, /* isElemental=*/ false },
935
936
{" threadfence" , &I::genThreadFence, {}, /* isElemental=*/ false },
936
937
{" threadfence_block" , &I::genThreadFenceBlock, {}, /* isElemental=*/ false },
937
938
{" threadfence_system" , &I::genThreadFenceSystem, {}, /* isElemental=*/ false },
@@ -8109,6 +8110,90 @@ mlir::Value IntrinsicLibrary::genTand(mlir::Type resultType,
8109
8110
return getRuntimeCallGenerator (" tan" , ftype)(builder, loc, {arg});
8110
8111
}
8111
8112
8113
+ // THIS_GRID
8114
+ mlir::Value IntrinsicLibrary::genThisGrid (mlir::Type resultType,
8115
+ llvm::ArrayRef<mlir::Value> args) {
8116
+ assert (args.size () == 0 );
8117
+ auto recTy = mlir::cast<fir::RecordType>(resultType);
8118
+ assert (recTy && " RecordType expepected" );
8119
+ mlir::Value res = builder.create <fir::AllocaOp>(loc, resultType);
8120
+ mlir::Type i32Ty = builder.getI32Type ();
8121
+
8122
+ mlir::Value threadIdX = builder.create <mlir::NVVM::ThreadIdXOp>(loc, i32Ty);
8123
+ mlir::Value threadIdY = builder.create <mlir::NVVM::ThreadIdYOp>(loc, i32Ty);
8124
+ mlir::Value threadIdZ = builder.create <mlir::NVVM::ThreadIdZOp>(loc, i32Ty);
8125
+
8126
+ mlir::Value blockIdX = builder.create <mlir::NVVM::BlockIdXOp>(loc, i32Ty);
8127
+ mlir::Value blockIdY = builder.create <mlir::NVVM::BlockIdYOp>(loc, i32Ty);
8128
+ mlir::Value blockIdZ = builder.create <mlir::NVVM::BlockIdZOp>(loc, i32Ty);
8129
+
8130
+ mlir::Value blockDimX = builder.create <mlir::NVVM::BlockDimXOp>(loc, i32Ty);
8131
+ mlir::Value blockDimY = builder.create <mlir::NVVM::BlockDimYOp>(loc, i32Ty);
8132
+ mlir::Value blockDimZ = builder.create <mlir::NVVM::BlockDimZOp>(loc, i32Ty);
8133
+ mlir::Value gridDimX = builder.create <mlir::NVVM::GridDimXOp>(loc, i32Ty);
8134
+ mlir::Value gridDimY = builder.create <mlir::NVVM::GridDimYOp>(loc, i32Ty);
8135
+ mlir::Value gridDimZ = builder.create <mlir::NVVM::GridDimZOp>(loc, i32Ty);
8136
+
8137
+ // this_grid.size = ((blockDim.z * gridDim.z) * (blockDim.y * gridDim.y)) *
8138
+ // (blockDim.x * gridDim.x);
8139
+ mlir::Value resZ =
8140
+ builder.create <mlir::arith::MulIOp>(loc, blockDimZ, gridDimZ);
8141
+ mlir::Value resY =
8142
+ builder.create <mlir::arith::MulIOp>(loc, blockDimY, gridDimY);
8143
+ mlir::Value resX =
8144
+ builder.create <mlir::arith::MulIOp>(loc, blockDimX, gridDimX);
8145
+ mlir::Value resZY = builder.create <mlir::arith::MulIOp>(loc, resZ, resY);
8146
+ mlir::Value size = builder.create <mlir::arith::MulIOp>(loc, resZY, resX);
8147
+
8148
+ // tmp = ((blockIdx.z * gridDim.y * gridDim.x) + (blockIdx.y * gridDim.x)) +
8149
+ // blockIdx.x;
8150
+ // this_group.rank = tmp * ((blockDim.x * blockDim.y) * blockDim.z) +
8151
+ // ((threadIdx.z * blockDim.y) * blockDim.x) +
8152
+ // (threadIdx.y * blockDim.x) + threadIdx.x + 1;
8153
+ mlir::Value r1 = builder.create <mlir::arith::MulIOp>(loc, blockIdZ, gridDimY);
8154
+ mlir::Value r2 = builder.create <mlir::arith::MulIOp>(loc, r1, gridDimX);
8155
+ mlir::Value r3 = builder.create <mlir::arith::MulIOp>(loc, blockIdY, gridDimX);
8156
+ mlir::Value r2r3 = builder.create <mlir::arith::AddIOp>(loc, r2, r3);
8157
+ mlir::Value tmp = builder.create <mlir::arith::AddIOp>(loc, r2r3, blockIdX);
8158
+
8159
+ mlir::Value bXbY =
8160
+ builder.create <mlir::arith::MulIOp>(loc, blockDimX, blockDimY);
8161
+ mlir::Value bXbYbZ =
8162
+ builder.create <mlir::arith::MulIOp>(loc, bXbY, blockDimZ);
8163
+ mlir::Value tZbY =
8164
+ builder.create <mlir::arith::MulIOp>(loc, threadIdZ, blockDimY);
8165
+ mlir::Value tZbYbX =
8166
+ builder.create <mlir::arith::MulIOp>(loc, tZbY, blockDimX);
8167
+ mlir::Value tYbX =
8168
+ builder.create <mlir::arith::MulIOp>(loc, threadIdY, blockDimX);
8169
+ mlir::Value rank = builder.create <mlir::arith::MulIOp>(loc, tmp, bXbYbZ);
8170
+ rank = builder.create <mlir::arith::AddIOp>(loc, rank, tZbYbX);
8171
+ rank = builder.create <mlir::arith::AddIOp>(loc, rank, tYbX);
8172
+ rank = builder.create <mlir::arith::AddIOp>(loc, rank, threadIdX);
8173
+ mlir::Value one = builder.createIntegerConstant (loc, i32Ty, 1 );
8174
+ rank = builder.create <mlir::arith::AddIOp>(loc, rank, one);
8175
+
8176
+ auto sizeFieldName = recTy.getTypeList ()[1 ].first ;
8177
+ mlir::Type sizeFieldTy = recTy.getTypeList ()[1 ].second ;
8178
+ mlir::Type fieldIndexType = fir::FieldType::get (resultType.getContext ());
8179
+ mlir::Value sizeFieldIndex = builder.create <fir::FieldIndexOp>(
8180
+ loc, fieldIndexType, sizeFieldName, recTy,
8181
+ /* typeParams=*/ mlir::ValueRange{});
8182
+ mlir::Value sizeCoord = builder.create <fir::CoordinateOp>(
8183
+ loc, builder.getRefType (sizeFieldTy), res, sizeFieldIndex);
8184
+ builder.create <fir::StoreOp>(loc, size, sizeCoord);
8185
+
8186
+ auto rankFieldName = recTy.getTypeList ()[2 ].first ;
8187
+ mlir::Type rankFieldTy = recTy.getTypeList ()[2 ].second ;
8188
+ mlir::Value rankFieldIndex = builder.create <fir::FieldIndexOp>(
8189
+ loc, fieldIndexType, rankFieldName, recTy,
8190
+ /* typeParams=*/ mlir::ValueRange{});
8191
+ mlir::Value rankCoord = builder.create <fir::CoordinateOp>(
8192
+ loc, builder.getRefType (rankFieldTy), res, rankFieldIndex);
8193
+ builder.create <fir::StoreOp>(loc, rank, rankCoord);
8194
+ return res;
8195
+ }
8196
+
8112
8197
// TRAILZ
8113
8198
mlir::Value IntrinsicLibrary::genTrailz (mlir::Type resultType,
8114
8199
llvm::ArrayRef<mlir::Value> args) {
0 commit comments