@@ -818,8 +818,11 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
818818 const auto resultTensor = resultAccess.getTensorVar ();
819819
820820 if (resultTensor != result) {
821+ // TODO: Should check that annihilator of original reduction op equals
822+ // fill value of original result
821823 Access lhs = to<Access>(rewrite (op->lhs ));
822- stmt = (rhs != op->rhs ) ? Assignment (lhs, rhs, op->op ) : op;
824+ IndexExpr reduceOp = op->op .defined () ? Add () : IndexExpr ();
825+ stmt = (rhs != op->rhs ) ? Assignment (lhs, rhs, reduceOp) : op;
823826 return ;
824827 }
825828
@@ -829,7 +832,7 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
829832 return ;
830833 }
831834
832- queryResults[resultTensor] =
835+ queryResults[resultTensor] =
833836 std::vector<std::vector<TensorVar>>(resultTensor.getOrder ());
834837
835838 const auto indices = resultAccess.getIndexVars ();
@@ -848,16 +851,16 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
848851 parentCoords.push_back (indices[modeOrdering[i]]);
849852 childCoords.erase (childCoords.begin ());
850853
851- for (const auto & query:
854+ for (const auto & query:
852855 modeFormats[i].getAttrQueries (parentCoords, childCoords)) {
853856 const auto & groupBy = query.getGroupBy ();
854857
855858 // TODO: support multiple aggregations in single query
856- taco_iassert (query.getAttrs ().size () == 1 );
859+ taco_iassert (query.getAttrs ().size () == 1 );
857860
858861 std::vector<Dimension> queryDims;
859862 for (const auto & coord : groupBy) {
860- const auto pos = std::find (groupBy.begin (), groupBy.end (), coord)
863+ const auto pos = std::find (groupBy.begin (), groupBy.end (), coord)
861864 - groupBy.begin ();
862865 const auto dim = resultTensor.getType ().getShape ().getDimension (pos);
863866 queryDims.push_back (dim);
@@ -868,7 +871,7 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
868871 case AttrQuery::COUNT:
869872 {
870873 std::vector<IndexVar> dedupCoords = groupBy;
871- dedupCoords.insert (dedupCoords.end (), attr.params .begin (),
874+ dedupCoords.insert (dedupCoords.end (), attr.params .begin (),
872875 attr.params .end ());
873876 std::vector<Dimension> dedupDims (dedupCoords.size ());
874877 TensorVar dedupTmp (modeName + " _dedup" , Type (Bool, dedupDims));
@@ -877,7 +880,7 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
877880
878881 const auto resultName = modeName + " _" + attr.label ;
879882 TensorVar queryResult (resultName, Type (Int32, queryDims));
880- epilog = Assignment (queryResult (groupBy),
883+ epilog = Assignment (queryResult (groupBy),
881884 Cast (dedupTmp (dedupCoords), Int ()), Add ());
882885 for (const auto & coord : util::reverse (dedupCoords)) {
883886 epilog = forall (coord, epilog);
@@ -914,6 +917,56 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
914917
915918 expr = op;
916919 }
920+
921+ void visit (const CallNode* op) {
922+ std::vector<IndexExpr> args;
923+ bool rewritten = false ;
924+ for (auto & arg : op->args ) {
925+ IndexExpr rewrittenArg = rewrite (arg);
926+ args.push_back (rewrittenArg);
927+ if (arg != rewrittenArg) {
928+ rewritten = true ;
929+ }
930+ }
931+
932+ if (rewritten) {
933+ const std::map<IndexExpr, IndexExpr> subs = util::zipToMap (op->args , args);
934+ IterationAlgebra newAlg = replaceAlgIndexExprs (op->iterAlg , subs);
935+
936+ struct InferSymbolic : public IterationAlgebraVisitorStrict {
937+ IndexExpr ret;
938+
939+ IndexExpr infer (IterationAlgebra alg) {
940+ ret = IndexExpr ();
941+ alg.accept (this );
942+ return ret;
943+ }
944+ virtual void visit (const RegionNode* op) {
945+ ret = op->expr ();
946+ }
947+
948+ virtual void visit (const ComplementNode* op) {
949+ taco_not_supported_yet;
950+ }
951+
952+ virtual void visit (const IntersectNode* op) {
953+ IndexExpr lhs = infer (op->a );
954+ IndexExpr rhs = infer (op->b );
955+ ret = lhs * rhs;
956+ }
957+
958+ virtual void visit (const UnionNode* op) {
959+ IndexExpr lhs = infer (op->a );
960+ IndexExpr rhs = infer (op->b );
961+ ret = lhs + rhs;
962+ }
963+ };
964+ expr = InferSymbolic ().infer (newAlg);
965+ }
966+ else {
967+ expr = op;
968+ }
969+ }
917970 };
918971 LowerAttrQuery queryLowerer (getResult (), queryResults, insertedResults);
919972 loweredQueries = queryLowerer.lower (loweredQueries);
0 commit comments