Skip to content

Commit f1f194b

Browse files
yangtetrisYang Bai
andauthored
[mlir][vector] fix: unroll vector.from_elements in gpu pipelines (llvm#154774)
### Problem PR llvm#142944 introduced a new canonicalization pattern which caused failures in the following GPU-related integration tests: - mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir - mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir The issue occurs because the new canonicalization pattern can generate multi-dimensional `vector.from_elements` operations (rank > 1), but the GPU lowering pipelines were not equipped to handle these during the conversion to LLVM. ### Fix This PR adds `vector::populateVectorFromElementsLoweringPatterns` to the GPU lowering passes that are integrated in `gpu-lower-to-nvvm-pipeline`: - `GpuToLLVMConversionPass`: the general GPU-to-LLVM conversion pass. - `LowerGpuOpsToNVVMOpsPass`: the NVVM-specific lowering pass. Co-authored-by: Yang Bai <[email protected]>
1 parent 418fb50 commit f1f194b

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,9 @@ void GpuToLLVMConversionPass::runOnOperation() {
532532
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
533533
vector::populateVectorTransferLoweringPatterns(patterns,
534534
/*maxTransferRank=*/1);
535+
// Transform N-D vector.from_elements to 1-D vector.from_elements before
536+
// conversion.
537+
vector::populateVectorFromElementsLoweringPatterns(patterns);
535538
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
536539
return signalPassFailure();
537540
}

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/Dialect/Math/IR/Math.h"
2828
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2929
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
30+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
3031
#include "mlir/Transforms/DialectConversion.h"
3132
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3233

@@ -369,6 +370,9 @@ struct LowerGpuOpsToNVVMOpsPass final
369370
{
370371
RewritePatternSet patterns(m.getContext());
371372
populateGpuRewritePatterns(patterns);
373+
// Transform N-D vector.from_elements to 1-D vector.from_elements before
374+
// conversion.
375+
vector::populateVectorFromElementsLoweringPatterns(patterns);
372376
if (failed(applyPatternsGreedily(m, std::move(patterns))))
373377
return signalPassFailure();
374378
}

0 commit comments

Comments
 (0)