Skip to content

Commit 6301574

Browse files
committed
[mlir][SparseTensor] Enable VLA ops in index value generation
Current index value generation uses fixed-length vector ops, this patch adds an alterantive codegen path compatible with scalable vectors by using `LLVM::StepVectorOp`. Differential Revision: https://reviews.llvm.org/D124454
1 parent 515f890 commit 6301574

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1818
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1919
#include "mlir/Dialect/Func/IR/FuncOps.h"
20+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2021
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2122
#include "mlir/Dialect/Linalg/Utils/Utils.h"
2223
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -889,11 +890,18 @@ static Value genIndexValue(Merger &merger, CodeGen &codegen,
889890
VectorType vtp = vectorType(codegen, itype);
890891
ival = rewriter.create<vector::BroadcastOp>(loc, vtp, ival);
891892
if (idx == ldx) {
892-
SmallVector<APInt, 4> integers;
893-
for (unsigned i = 0; i < vl; i++)
894-
integers.push_back(APInt(/*width=*/64, i));
895-
auto values = DenseElementsAttr::get(vtp, integers);
896-
Value incr = rewriter.create<arith::ConstantOp>(loc, vtp, values);
893+
Value incr;
894+
if (vtp.isScalable()) {
895+
Type stepvty = vectorType(codegen, rewriter.getI64Type());
896+
Value stepv = rewriter.create<LLVM::StepVectorOp>(loc, stepvty);
897+
incr = rewriter.create<arith::IndexCastOp>(loc, vtp, stepv);
898+
} else {
899+
SmallVector<APInt, 4> integers;
900+
for (unsigned i = 0; i < vl; i++)
901+
integers.push_back(APInt(/*width=*/64, i));
902+
auto values = DenseElementsAttr::get(vtp, integers);
903+
incr = rewriter.create<arith::ConstantOp>(loc, vtp, values);
904+
}
897905
ival = rewriter.create<arith::AddIOp>(loc, ival, incr);
898906
}
899907
}

0 commit comments

Comments
 (0)