Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit 3ce7f94

Browse files
bondhugulatensorflower-gardener
authored andcommitted
Make isValidSymbol more powerful
The check in isValidSymbol, as far as a DimOp result went, checked if the dim op was on a top-level memref. However, any alloc'ed, view, or subview memref would be fine as long as the corresponding dimension of that memref is either a static one or was in turn created using a valid symbol in the case of dynamic dimensions. Reported-by: Jose Gomez Signed-off-by: Uday Bondhugula <[email protected]> Closes #252 COPYBARA_INTEGRATE_REVIEW=#252 from bondhugula:symbol 7b57dc3 PiperOrigin-RevId: 282097114
1 parent 71da5c2 commit 3ce7f94

File tree

7 files changed

+80
-13
lines changed

7 files changed

+80
-13
lines changed

include/mlir/Dialect/AffineOps/AffineOps.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ class FlatAffineConstraints;
3939
class OpBuilder;
4040

4141
/// A utility function to check if a value is defined at the top level of a
42-
/// function. A value defined at the top level is always a valid symbol.
43-
bool isTopLevelSymbol(Value *value);
42+
/// function. A value of index type defined at the top level is always a valid
43+
/// symbol.
44+
bool isTopLevelValue(Value *value);
4445

4546
class AffineOpsDialect : public Dialect {
4647
public:

include/mlir/Dialect/StandardOps/Ops.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ def AllocOp : Std_Op<"alloc"> {
204204
operand_range getSymbolicOperands() {
205205
return {operand_begin() + getType().getNumDynamicDims(), operand_end()};
206206
}
207+
208+
/// Returns the dynamic sizes for this alloc operation if specified.
209+
operand_range getDynamicSizes() { return getOperands(); }
207210
}];
208211

209212
let hasCanonicalizer = 1;

include/mlir/IR/StandardTypes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ class ShapedType : public Type {
226226
/// Otherwise, abort.
227227
int64_t getDimSize(int64_t i) const;
228228

229+
/// Returns the position of the dynamic dimension relative to just the dynamic
230+
/// dimensions, given its `index` within the shape.
231+
unsigned getDynamicDimIndex(unsigned index) const;
232+
229233
/// Get the total amount of bits occupied by a value of this type. This does
230234
/// not take into account any memory layout or widening constraints, e.g. a
231235
/// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice

lib/Analysis/AffineStructures.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) {
822822
return;
823823

824824
// Caller is expected to fully compose map/operands if necessary.
825-
assert((isTopLevelSymbol(id) || isForInductionVar(id)) &&
825+
assert((isTopLevelValue(id) || isForInductionVar(id)) &&
826826
"non-terminal symbol / loop IV expected");
827827
// Outer loop IVs could be used in forOp's bounds.
828828
if (auto loop = getForInductionVarOwner(id)) {

lib/Dialect/AffineOps/AffineOps.cpp

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,9 @@ static bool isFunctionRegion(Region *region) {
105105
}
106106

107107
/// A utility function to check if a value is defined at the top level of a
108-
/// function. A value defined at the top level is always a valid symbol.
109-
bool mlir::isTopLevelSymbol(Value *value) {
108+
/// function. A value of index type defined at the top level is always a valid
109+
/// symbol.
110+
bool mlir::isTopLevelValue(Value *value) {
110111
if (auto *arg = dyn_cast<BlockArgument>(value))
111112
return isFunctionRegion(arg->getOwner()->getParent());
112113
return isFunctionRegion(value->getDefiningOp()->getParentRegion());
@@ -130,13 +131,46 @@ bool mlir::isValidDim(Value *value) {
130131
// The dim op is okay if its operand memref/tensor is defined at the top
131132
// level.
132133
if (auto dimOp = dyn_cast<DimOp>(op))
133-
return isTopLevelSymbol(dimOp.getOperand());
134+
return isTopLevelValue(dimOp.getOperand());
134135
return false;
135136
}
136137
// This value is a block argument (which also includes 'affine.for' loop IVs).
137138
return true;
138139
}
139140

141+
/// Returns true if the 'index' dimension of the `memref` defined by
142+
/// `memrefDefOp` is a statically shaped one or defined using a valid symbol.
143+
template <typename AnyMemRefDefOp>
144+
bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index) {
145+
auto memRefType = memrefDefOp.getType();
146+
// Statically shaped.
147+
if (!ShapedType::isDynamic(memRefType.getDimSize(index)))
148+
return true;
149+
// Get the position of the dimension among dynamic dimensions;
150+
unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
151+
return isValidSymbol(
152+
*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos));
153+
}
154+
155+
/// Returns true if the result of the dim op is a valid symbol.
156+
static bool isDimOpValidSymbol(DimOp dimOp) {
157+
// The dim op is okay if its operand memref/tensor is defined at the top
158+
// level.
159+
if (isTopLevelValue(dimOp.getOperand()))
160+
return true;
161+
162+
// The dim op is also okay if its operand memref/tensor is a view/subview
163+
// whose corresponding size is a valid symbol.
164+
unsigned index = dimOp.getIndex();
165+
if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand()->getDefiningOp()))
166+
return isMemRefSizeValidSymbol<ViewOp>(viewOp, index);
167+
if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand()->getDefiningOp()))
168+
return isMemRefSizeValidSymbol<SubViewOp>(subViewOp, index);
169+
if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand()->getDefiningOp()))
170+
return isMemRefSizeValidSymbol<AllocOp>(allocOp, index);
171+
return false;
172+
}
173+
140174
// Value can be used as a symbol if it is a constant, or it is defined at
141175
// the top level, or it is a result of affine apply operation with symbol
142176
// arguments.
@@ -152,14 +186,12 @@ bool mlir::isValidSymbol(Value *value) {
152186
// Affine apply operation is ok if all of its operands are ok.
153187
if (auto applyOp = dyn_cast<AffineApplyOp>(op))
154188
return applyOp.isValidSymbol();
155-
// The dim op is okay if its operand memref/tensor is defined at the top
156-
// level.
157-
if (auto dimOp = dyn_cast<DimOp>(op))
158-
return isTopLevelSymbol(dimOp.getOperand());
159-
return false;
189+
if (auto dimOp = dyn_cast<DimOp>(op)) {
190+
return isDimOpValidSymbol(dimOp);
191+
}
160192
}
161-
// Otherwise, check that the value is a top level symbol.
162-
return isTopLevelSymbol(value);
193+
// Otherwise, check that the value is a top level value.
194+
return isTopLevelValue(value);
163195
}
164196

165197
// Returns true if 'value' is a valid index to an affine operation (e.g.

lib/IR/StandardTypes.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ int64_t ShapedType::getDimSize(int64_t i) const {
152152
return getShape()[i];
153153
}
154154

155+
unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
156+
assert(index < getRank() && "invalid index");
157+
assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
158+
return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
159+
}
160+
155161
/// Get the number of bits require to store a value of the given shaped type.
156162
/// Compute the value recursively since tensors are allowed to have vectors as
157163
/// elements.

test/AffineOps/ops.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,24 @@ func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
7878
%3 = affine.min ()[] -> (77, 78, 79) ()[]
7979
return
8080
}
81+
82+
// -----
83+
84+
func @valid_symbols(%arg0: index, %arg1: index, %arg2: index) {
85+
%c0 = constant 1 : index
86+
%c1 = constant 0 : index
87+
%0 = alloc(%arg0, %arg1) : memref<?x?xf32>
88+
affine.for %arg3 = 0 to %arg2 step 768 {
89+
%13 = dim %0, 1 : memref<?x?xf32>
90+
affine.for %arg4 = 0 to %13 step 264 {
91+
%18 = dim %0, 0 : memref<?x?xf32>
92+
%20 = std.subview %0[%c0, %c0][%18,%arg4][%c1,%c1] : memref<?x?xf32>
93+
to memref<?x?xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>
94+
%24 = dim %20, 0 : memref<?x?xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>
95+
affine.for %arg5 = 0 to %24 step 768 {
96+
"foo"() : () -> ()
97+
}
98+
}
99+
}
100+
return
101+
}

0 commit comments

Comments
 (0)