Skip to content

Commit 12a6e45

Browse files
committed
Add comments
1 parent 4817e49 commit 12a6e45

File tree

4 files changed

+34
-16
lines changed

4 files changed

+34
-16
lines changed

include/taco/index_notation/index_notation.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,16 @@ struct SuchThatNode;
7171
class IndexExprVisitorStrict;
7272
class IndexStmtVisitorStrict;
7373

74+
/// Describe the relation between indexVar sets of lhs and rhs in an Assignment node.
75+
/// equal: lhs = rhs
76+
/// none: lhs and rhs are mutually exclusive. And lhs and rhs are not empty sets.
77+
/// lcr: rhs is a proper subset of lhs. (lhs contains rhs)
78+
/// rcl: lhs is a proper subset of rhs. (rhs contains lhs)
79+
/// inter: lhs and rhs share common elements but are not equal or empty.
7480
enum IndexSetRel {
7581
equal, none, lcr, rcl, inter
7682
};
83+
7784
/// Return true if the index statement is of the given subtype. The subtypes
7885
/// are Assignment, Forall, Where, Sequence, and Multi.
7986
template <typename SubType> bool isa(IndexExpr);
@@ -823,7 +830,7 @@ class Assignment : public IndexStmt {
823830
/// Return the reduction index variables i nthe assign
824831
std::vector<IndexVar> getReductionVars() const;
825832

826-
/// Return the set relation of indexVars in lhs and rhd
833+
/// Return the set relation of indexVars in lhs and rhs
827834
IndexSetRel getIndexSetRel() const;
828835

829836
typedef AssignmentNode Node;

src/index_notation/transformations.cpp

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -433,13 +433,16 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
433433
);
434434

435435
IndexSetRel rel = a.getIndexSetRel();
436+
/// The reduceOp depends on the relation between indexVar sets of rhs and lhs. For rcl and inter, reduceOp
437+
/// must be +=. For lcr, reduceOp must be =. For none and equal, reduceOp can't be decided at this stage.
436438
switch (rel) {
437-
case none: a = Assignment(a.getLhs(), a.getRhs());break; // =
438-
case rcl: a = Assignment(a.getLhs(), a.getRhs(), Add());break; // +=
439-
case lcr: a = Assignment(a.getLhs(), a.getRhs());break; // =
440-
case inter: a = Assignment(a.getLhs(), a.getRhs(), Add());break; // +=
441-
case equal: a = Assignment(a.getLhs(), a.getRhs());break;// = OR +=
442-
}return a;
439+
case none: a = Assignment(a.getLhs(), a.getRhs());break;
440+
case rcl: a = Assignment(a.getLhs(), a.getRhs(), Add());break;
441+
case lcr: a = Assignment(a.getLhs(), a.getRhs());break;
442+
case inter: a = Assignment(a.getLhs(), a.getRhs(), Add());break;
443+
case equal: a = Assignment(a.getLhs(), a.getRhs());break;
444+
}
445+
return a;
443446
}
444447

445448
Assignment getProducerAssignment(TensorVar& ws,
@@ -450,14 +453,15 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
450453

451454
auto a = ws(iw_vars) = replace(e, substitutions);
452455
IndexSetRel rel = a.getIndexSetRel();
456+
/// The reduceOp depends on the relation between indexVar sets of rhs and lhs. For rcl and inter, reduceOp
457+
/// must be +=. For lcr, reduceOp must be =. For none and equal, reduceOp can't be decided at this stage.
453458
switch (rel) {
454-
case none: a = Assignment(a.getLhs(), a.getRhs());break; // =
455-
case rcl: a = Assignment(a.getLhs(), a.getRhs(), Add());break; // +=
456-
case lcr: a = Assignment(a.getLhs(), a.getRhs());break; // =
457-
case inter: a = Assignment(a.getLhs(), a.getRhs(), Add());break; // +=
458-
case equal: a = Assignment(a.getLhs(), a.getRhs());break;// = OR +=
459+
case none: a = Assignment(a.getLhs(), a.getRhs());break;
460+
case rcl: a = Assignment(a.getLhs(), a.getRhs(), Add());break;
461+
case lcr: a = Assignment(a.getLhs(), a.getRhs());break;
462+
case inter: a = Assignment(a.getLhs(), a.getRhs(), Add());break;
463+
case equal: a = Assignment(a.getLhs(), a.getRhs());break;
459464
}
460-
461465
return a;
462466
}
463467

@@ -565,6 +569,9 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
565569
}
566570
};
567571

572+
/// RedundantVisitor uses Forall Context to determine reduceOp for none and equal.
573+
/// We assume += is used if a workspace is accessed multiple times, otherwise =.
574+
/// Forall Context describes the related indexVars of the given indexVar at a specific stage. `ctx_stack` implements such concept.
568575
struct RedundantVisitor: public IndexNotationVisitor {
569576
using IndexNotationVisitor::visit;
570577

@@ -608,6 +615,9 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
608615
Assignment a(node->lhs, node->rhs, node->op);
609616
vector<IndexVar> freeVars = a.getLhs().getIndexVars();
610617
set<IndexVar> seen(freeVars.begin(), freeVars.end());
618+
619+
/// For equal, if some indexVar in lhs has sibling in ctx stack, reduceOp will be +=.
620+
bool is_equal = (a.getIndexSetRel() == equal);
611621
bool has_sibling = false;
612622
match(a.getRhs(),
613623
std::function<void(const AccessNode*)>([&](const AccessNode* op) {
@@ -619,11 +629,12 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
619629
}
620630
}
621631
}));
622-
bool is_equal = (a.getIndexSetRel() == equal);
623-
bool is_none = (a.getIndexSetRel() == none);
624632
if (is_equal && has_sibling) {
625633
to_change.push_back(a);
626634
}
635+
636+
/// For none, if ctx_stack except the top contains indexVars in lhs, reduceOp will be +=.
637+
bool is_none = (a.getIndexSetRel() == none);
627638
bool has_outside = true;
628639
for (auto & var : seen) {
629640
for (auto &svar: ctx_stack) {

test/tests-scheduling-eval.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ IndexStmt scheduleSpGEMMCPU(IndexStmt stmt, bool doPrecompute) {
7878
stmt = stmt.precompute(assign.getRhs(), j, j, w);
7979
}
8080
stmt = stmt.assemble(result, AssembleStrategy::Insert, true);
81-
//stmt = stmt.assemble(result, AssembleStrategy::Append, true);
8281
auto qi_stmt = stmt.as<Assemble>().getQueries();
8382
IndexVar qi;
8483
if (isa<Where>(qi_stmt)) {

test/tests-workspaces.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ TEST(workspaces, tile_dotProduct_3) {
594594

595595

596596
stmt = stmt.concretize();
597+
597598
A.compile(stmt);
598599
A.assemble();
599600
A.compute();

0 commit comments

Comments
 (0)