File tree Expand file tree Collapse file tree 1 file changed +22
-5
lines changed Expand file tree Collapse file tree 1 file changed +22
-5
lines changed Original file line number Diff line number Diff line change @@ -114,7 +114,9 @@ static bool needComputeValues(IndexStmt stmt, TensorVar tensor) {
114114
115115 struct ReturnsTrue : public IndexExprRewriterStrict {
116116 void visit (const AccessNode* op) {
117- if (op->isAccessingStructure ) {
117+ if (op->isAccessingStructure || (
118+ op->tensorVar .getFormat ().getModeFormats ().back ().isZeroless () &&
119+ equals (op->tensorVar .getFill (), Literal (false )))) {
118120 expr = op;
119121 }
120122 }
@@ -146,12 +148,27 @@ static bool needComputeValues(IndexStmt stmt, TensorVar tensor) {
146148 }
147149
148150 void visit (const CallNode* op) {
149- for (const auto & arg : op->args ) {
150- if (!rewrite (arg).defined ()) {
151- return ;
151+ const auto annihilator = findProperty<Annihilator>(op->properties );
152+
153+ if (!annihilator.defined () || !annihilator.positions ().empty ()) {
154+ return ;
155+ }
156+
157+ if (equals (annihilator.annihilator (), Literal (false ))) {
158+ for (const auto & arg : op->args ) {
159+ if (!rewrite (arg).defined ()) {
160+ return ;
161+ }
162+ }
163+ expr = op;
164+ } else {
165+ for (const auto & arg : op->args ) {
166+ if (rewrite (arg).defined ()) {
167+ expr = op;
168+ return ;
169+ }
152170 }
153171 }
154- expr = op;
155172 }
156173
157174 void visit (const SqrtNode* op) {}
You can’t perform that action at this time.
0 commit comments