@@ -105,8 +105,9 @@ static bool isFunctionRegion(Region *region) {
105105}
106106
107107// / A utility function to check if a value is defined at the top level of a
108- // / function. A value defined at the top level is always a valid symbol.
109- bool mlir::isTopLevelSymbol (Value *value) {
108+ // / function. A value of index type defined at the top level is always a valid
109+ // / symbol.
110+ bool mlir::isTopLevelValue (Value *value) {
110111 if (auto *arg = dyn_cast<BlockArgument>(value))
111112 return isFunctionRegion (arg->getOwner ()->getParent ());
112113 return isFunctionRegion (value->getDefiningOp ()->getParentRegion ());
@@ -130,13 +131,46 @@ bool mlir::isValidDim(Value *value) {
130131 // The dim op is okay if its operand memref/tensor is defined at the top
131132 // level.
132133 if (auto dimOp = dyn_cast<DimOp>(op))
133- return isTopLevelSymbol (dimOp.getOperand ());
134+ return isTopLevelValue (dimOp.getOperand ());
134135 return false ;
135136 }
136137 // This value is a block argument (which also includes 'affine.for' loop IVs).
137138 return true ;
138139}
139140
141+ // / Returns true if the 'index' dimension of the `memref` defined by
142+ // / `memrefDefOp` is a statically shaped one or defined using a valid symbol.
143+ template <typename AnyMemRefDefOp>
144+ bool isMemRefSizeValidSymbol (AnyMemRefDefOp memrefDefOp, unsigned index) {
145+ auto memRefType = memrefDefOp.getType ();
146+ // Statically shaped.
147+ if (!ShapedType::isDynamic (memRefType.getDimSize (index)))
148+ return true ;
149+ // Get the position of the dimension among dynamic dimensions;
150+ unsigned dynamicDimPos = memRefType.getDynamicDimIndex (index);
151+ return isValidSymbol (
152+ *(memrefDefOp.getDynamicSizes ().begin () + dynamicDimPos));
153+ }
154+
155+ // / Returns true if the result of the dim op is a valid symbol.
156+ static bool isDimOpValidSymbol (DimOp dimOp) {
157+ // The dim op is okay if its operand memref/tensor is defined at the top
158+ // level.
159+ if (isTopLevelValue (dimOp.getOperand ()))
160+ return true ;
161+
162+ // The dim op is also okay if its operand memref/tensor is a view/subview
163+ // whose corresponding size is a valid symbol.
164+ unsigned index = dimOp.getIndex ();
165+ if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand ()->getDefiningOp ()))
166+ return isMemRefSizeValidSymbol<ViewOp>(viewOp, index);
167+ if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand ()->getDefiningOp ()))
168+ return isMemRefSizeValidSymbol<SubViewOp>(subViewOp, index);
169+ if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand ()->getDefiningOp ()))
170+ return isMemRefSizeValidSymbol<AllocOp>(allocOp, index);
171+ return false ;
172+ }
173+
140174// Value can be used as a symbol if it is a constant, or it is defined at
141175// the top level, or it is a result of affine apply operation with symbol
142176// arguments.
@@ -152,14 +186,12 @@ bool mlir::isValidSymbol(Value *value) {
152186 // Affine apply operation is ok if all of its operands are ok.
153187 if (auto applyOp = dyn_cast<AffineApplyOp>(op))
154188 return applyOp.isValidSymbol ();
155- // The dim op is okay if its operand memref/tensor is defined at the top
156- // level.
157- if (auto dimOp = dyn_cast<DimOp>(op))
158- return isTopLevelSymbol (dimOp.getOperand ());
159- return false ;
189+ if (auto dimOp = dyn_cast<DimOp>(op)) {
190+ return isDimOpValidSymbol (dimOp);
191+ }
160192 }
161- // Otherwise, check that the value is a top level symbol .
162- return isTopLevelSymbol (value);
193+ // Otherwise, check that the value is a top level value .
194+ return isTopLevelValue (value);
163195}
164196
165197// Returns true if 'value' is a valid index to an affine operation (e.g.
0 commit comments