Skip to content

Commit b34af77

Browse files
committed
Fixed assemble command with user-defined operators
1 parent cf592ff commit b34af77

File tree

1 file changed

+60
-7
lines changed

1 file changed

+60
-7
lines changed

src/index_notation/transformations.cpp

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -818,8 +818,11 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
818818
const auto resultTensor = resultAccess.getTensorVar();
819819

820820
if (resultTensor != result) {
821+
// TODO: Should check that annihilator of original reduction op equals
822+
// fill value of original result
821823
Access lhs = to<Access>(rewrite(op->lhs));
822-
stmt = (rhs != op->rhs) ? Assignment(lhs, rhs, op->op) : op;
824+
IndexExpr reduceOp = op->op.defined() ? Add() : IndexExpr();
825+
stmt = (rhs != op->rhs) ? Assignment(lhs, rhs, reduceOp) : op;
823826
return;
824827
}
825828

@@ -829,7 +832,7 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
829832
return;
830833
}
831834

832-
queryResults[resultTensor] =
835+
queryResults[resultTensor] =
833836
std::vector<std::vector<TensorVar>>(resultTensor.getOrder());
834837

835838
const auto indices = resultAccess.getIndexVars();
@@ -848,16 +851,16 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
848851
parentCoords.push_back(indices[modeOrdering[i]]);
849852
childCoords.erase(childCoords.begin());
850853

851-
for (const auto& query:
854+
for (const auto& query:
852855
modeFormats[i].getAttrQueries(parentCoords, childCoords)) {
853856
const auto& groupBy = query.getGroupBy();
854857

855858
// TODO: support multiple aggregations in single query
856-
taco_iassert(query.getAttrs().size() == 1);
859+
taco_iassert(query.getAttrs().size() == 1);
857860

858861
std::vector<Dimension> queryDims;
859862
for (const auto& coord : groupBy) {
860-
const auto pos = std::find(groupBy.begin(), groupBy.end(), coord)
863+
const auto pos = std::find(groupBy.begin(), groupBy.end(), coord)
861864
- groupBy.begin();
862865
const auto dim = resultTensor.getType().getShape().getDimension(pos);
863866
queryDims.push_back(dim);
@@ -868,7 +871,7 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
868871
case AttrQuery::COUNT:
869872
{
870873
std::vector<IndexVar> dedupCoords = groupBy;
871-
dedupCoords.insert(dedupCoords.end(), attr.params.begin(),
874+
dedupCoords.insert(dedupCoords.end(), attr.params.begin(),
872875
attr.params.end());
873876
std::vector<Dimension> dedupDims(dedupCoords.size());
874877
TensorVar dedupTmp(modeName + "_dedup", Type(Bool, dedupDims));
@@ -877,7 +880,7 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
877880

878881
const auto resultName = modeName + "_" + attr.label;
879882
TensorVar queryResult(resultName, Type(Int32, queryDims));
880-
epilog = Assignment(queryResult(groupBy),
883+
epilog = Assignment(queryResult(groupBy),
881884
Cast(dedupTmp(dedupCoords), Int()), Add());
882885
for (const auto& coord : util::reverse(dedupCoords)) {
883886
epilog = forall(coord, epilog);
@@ -914,6 +917,56 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
914917

915918
expr = op;
916919
}
920+
921+
void visit(const CallNode* op) {
922+
std::vector<IndexExpr> args;
923+
bool rewritten = false;
924+
for(auto& arg : op->args) {
925+
IndexExpr rewrittenArg = rewrite(arg);
926+
args.push_back(rewrittenArg);
927+
if (arg != rewrittenArg) {
928+
rewritten = true;
929+
}
930+
}
931+
932+
if (rewritten) {
933+
const std::map<IndexExpr, IndexExpr> subs = util::zipToMap(op->args, args);
934+
IterationAlgebra newAlg = replaceAlgIndexExprs(op->iterAlg, subs);
935+
936+
struct InferSymbolic : public IterationAlgebraVisitorStrict {
937+
IndexExpr ret;
938+
939+
IndexExpr infer(IterationAlgebra alg) {
940+
ret = IndexExpr();
941+
alg.accept(this);
942+
return ret;
943+
}
944+
virtual void visit(const RegionNode* op) {
945+
ret = op->expr();
946+
}
947+
948+
virtual void visit(const ComplementNode* op) {
949+
taco_not_supported_yet;
950+
}
951+
952+
virtual void visit(const IntersectNode* op) {
953+
IndexExpr lhs = infer(op->a);
954+
IndexExpr rhs = infer(op->b);
955+
ret = lhs * rhs;
956+
}
957+
958+
virtual void visit(const UnionNode* op) {
959+
IndexExpr lhs = infer(op->a);
960+
IndexExpr rhs = infer(op->b);
961+
ret = lhs + rhs;
962+
}
963+
};
964+
expr = InferSymbolic().infer(newAlg);
965+
}
966+
else {
967+
expr = op;
968+
}
969+
}
917970
};
918971
LowerAttrQuery queryLowerer(getResult(), queryResults, insertedResults);
919972
loweredQueries = queryLowerer.lower(loweredQueries);

0 commit comments

Comments
 (0)