Skip to content

Commit fdb41a2

Browse files
[mlir][tensor] Implement ReifyRankedShapedTypeOpInterface on GenerateOp
Differential Revision: https://reviews.llvm.org/D121520
1 parent 07d5339 commit fdb41a2

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
372372

373373
def Tensor_GenerateOp : Tensor_Op<"generate",
374374
[RecursiveSideEffects,
375+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
375376
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
376377
string summary = "Creates a dynamically sized tensor from elements";
377378
string description = [{

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,21 @@ OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
525525
// GenerateOp
526526
//===----------------------------------------------------------------------===//
527527

528+
LogicalResult GenerateOp::reifyResultShapes(
529+
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
530+
reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
531+
int idx = 0;
532+
for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
533+
if (getType().isDynamicDim(dim)) {
534+
reifiedReturnShapes[0][dim] = getOperand(idx++);
535+
} else {
536+
reifiedReturnShapes[0][dim] = builder.create<arith::ConstantIndexOp>(
537+
getLoc(), getType().getDimSize(dim));
538+
}
539+
}
540+
return success();
541+
}
542+
528543
LogicalResult GenerateOp::verify() {
529544
// Ensure that the tensor type has as many dynamic dimensions as are specified
530545
// by the operands.

0 commit comments

Comments
 (0)