Skip to content

Commit f7ec976

Browse files
committed
Relaxed condition for omitting computation of values for workspaces
1 parent 14567d6 commit f7ec976

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

src/lower/lowerer_impl.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ static bool needComputeValues(IndexStmt stmt, TensorVar tensor) {
114114

115115
struct ReturnsTrue : public IndexExprRewriterStrict {
116116
void visit(const AccessNode* op) {
117-
if (op->isAccessingStructure) {
117+
if (op->isAccessingStructure || (
118+
op->tensorVar.getFormat().getModeFormats().back().isZeroless() &&
119+
equals(op->tensorVar.getFill(), Literal(false)))) {
118120
expr = op;
119121
}
120122
}
@@ -146,12 +148,27 @@ static bool needComputeValues(IndexStmt stmt, TensorVar tensor) {
146148
}
147149

148150
void visit(const CallNode* op) {
149-
for (const auto& arg : op->args) {
150-
if (!rewrite(arg).defined()) {
151-
return;
151+
const auto annihilator = findProperty<Annihilator>(op->properties);
152+
153+
if (!annihilator.defined() || !annihilator.positions().empty()) {
154+
return;
155+
}
156+
157+
if (equals(annihilator.annihilator(), Literal(false))) {
158+
for (const auto& arg : op->args) {
159+
if (!rewrite(arg).defined()) {
160+
return;
161+
}
162+
}
163+
expr = op;
164+
} else {
165+
for (const auto& arg : op->args) {
166+
if (rewrite(arg).defined()) {
167+
expr = op;
168+
return;
169+
}
152170
}
153171
}
154-
expr = op;
155172
}
156173

157174
void visit(const SqrtNode* op) {}

0 commit comments

Comments
 (0)