Skip to content

Commit d77f483

Browse files
committed
[mlir][gpu] Relax restriction on mma load/store op
Those ops can support more complex layout as long as the most inner dimension is contiguous. Differential Revision: https://reviews.llvm.org/D122452
1 parent b62ea9b commit d77f483

File tree

5 files changed

+44
-17
lines changed

5 files changed

+44
-17
lines changed

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ getMemrefConstantHorizontalStride(ShapedType type) {
6565
return 0;
6666
int64_t offset = 0;
6767
SmallVector<int64_t, 2> strides;
68-
if (failed(getStridesAndOffset(memrefType, strides, offset)))
68+
if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
69+
strides.back() != 1)
6970
return llvm::None;
7071
int64_t stride = strides[strides.size() - 2];
7172
if (stride == ShapedType::kDynamicStrideOrOffset)

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,17 @@ static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
10681068
// GPU_SubgroupMmaLoadMatrixOp
10691069
//===----------------------------------------------------------------------===//
10701070

1071+
/// Return true if the last dimension of the MemRefType has unit stride. Also
1072+
/// return true for memrefs with no strides.
1073+
static bool isLastMemrefDimUnitStride(MemRefType type) {
1074+
int64_t offset;
1075+
SmallVector<int64_t> strides;
1076+
if (failed(getStridesAndOffset(type, strides, offset))) {
1077+
return false;
1078+
}
1079+
return strides.back() == 1;
1080+
}
1081+
10711082
LogicalResult SubgroupMmaLoadMatrixOp::verify() {
10721083
auto srcType = srcMemref().getType();
10731084
auto resType = res().getType();
@@ -1076,8 +1087,9 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() {
10761087
auto srcMemrefType = srcType.cast<MemRefType>();
10771088
auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt();
10781089

1079-
if (!srcMemrefType.getLayout().isIdentity())
1080-
return emitError("expected identity layout map for source memref");
1090+
if (!isLastMemrefDimUnitStride(srcMemrefType))
1091+
return emitError(
1092+
"expected source memref most minor dim must have unit stride");
10811093

10821094
if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace &&
10831095
srcMemSpace != kGlobalMemorySpace)
@@ -1102,8 +1114,10 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() {
11021114
auto srcMatrixType = srcType.cast<gpu::MMAMatrixType>();
11031115
auto dstMemrefType = dstType.cast<MemRefType>();
11041116
auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt();
1105-
if (!dstMemrefType.getLayout().isIdentity())
1106-
return emitError("expected identity layout map for destination memref");
1117+
1118+
if (!isLastMemrefDimUnitStride(dstMemrefType))
1119+
return emitError(
1120+
"expected destination memref most minor dim must have unit stride");
11071121

11081122
if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace &&
11091123
dstMemSpace != kGlobalMemorySpace)
@@ -1232,15 +1246,6 @@ void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
12321246
// GPU_DeviceAsyncCopyOp
12331247
//===----------------------------------------------------------------------===//
12341248

1235-
/// Return true if the last dimension of the MemRefType has unit stride. Also
1236-
/// return true for memrefs with no strides.
1237-
static bool isLastMemrefDimUnitStride(MemRefType type) {
1238-
int64_t offset;
1239-
SmallVector<int64_t> strides;
1240-
auto successStrides = getStridesAndOffset(type, strides, offset);
1241-
return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
1242-
}
1243-
12441249
LogicalResult DeviceAsyncCopyOp::verify() {
12451250
auto srcMemref = src().getType().cast<MemRefType>();
12461251
auto dstMemref = dst().getType().cast<MemRefType>();

mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,22 @@ func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2:
151151
vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
152152
return
153153
}
154+
155+
// CHECK-LABEL: func @matmul_memref_strided
156+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
157+
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 32 : index} : memref<2x16x16xf16, #{{.*}}> -> !gpu.mma_matrix<16x16xf16, "AOp">
158+
// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]]] {leadDimension = 0 : index} : memref<16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
159+
// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
160+
// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
161+
// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16>
162+
func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
163+
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
164+
%c0 = arith.constant 0 : index
165+
%cst = arith.constant 0.000000e+00 : f16
166+
%A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16>
167+
%B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
168+
%C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
169+
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
170+
vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
171+
return
172+
}

mlir/test/Dialect/GPU/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ func @mmamatrix_invalid_element_type(){
491491
func @mmaLoadOp_identity_layout(){
492492
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3>
493493
%i = arith.constant 16 : index
494-
// expected-error @+1 {{expected identity layout map for source memref}}
494+
// expected-error @+1 {{expected source memref most minor dim must have unit stride}}
495495
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, #layout_map_col_major, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
496496
return
497497
}
@@ -514,7 +514,7 @@ func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
514514
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3>
515515
%i = arith.constant 16 : index
516516
%j = arith.constant 16 : index
517-
// expected-error @+1 {{expected identity layout map for destination memref}}
517+
// expected-error @+1 {{expected destination memref most minor dim must have unit stride}}
518518
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16,#layout_map_col_major, 3>
519519
return
520520
}

mlir/test/Dialect/GPU/ops.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ module attributes {gpu.container_module} {
227227
return
228228
}
229229

230-
func @mmamatrix_valid_element_type(){
230+
func @mmamatrix_valid_element_type(%src : memref<32x32xf16, affine_map<(d0, d1) -> (d0 * 64 + d1)>>){
231231
// CHECK-LABEL: func @mmamatrix_valid_element_type
232232
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
233233
// CHECK: %[[wg:.*]] = memref.alloca()
@@ -237,6 +237,8 @@ module attributes {gpu.container_module} {
237237
// CHECK: %[[cst:.*]] = arith.constant 1.000000e+00 : f32
238238
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
239239
// CHECK: gpu.subgroup_mma_load_matrix %[[wg]][%[[i]], %[[i]]] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
240+
%s = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 64 : index} : memref<32x32xf16, affine_map<(d0, d1) -> (d0 * 64 + d1)>> -> !gpu.mma_matrix<16x16xf16, "AOp">
241+
// CHECK: gpu.subgroup_mma_load_matrix %{{.*}}[%[[i]], %[[i]]] {leadDimension = 64 : index} : memref<32x32xf16, #{{.*}}> -> !gpu.mma_matrix<16x16xf16, "AOp">
240242
%1 = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf32, "COp">
241243
// CHECK: gpu.subgroup_mma_elementwise addf %{{.*}}, %{{.*}} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
242244
%2 = gpu.subgroup_mma_elementwise addf %1, %1 : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">

0 commit comments

Comments
 (0)