Skip to content

Commit 7a05d63

Browse files
committed
Add mergeby scheduling directive.
This commit addes a new scheduling directive called mergeby. This directive specifies if the iterators of a given variable is merged by Two Finger merge or Galloping. The default strategy is Two Finger merge which is the same as the old behavior. Galloping merges the iterators with exponential search, which can be more efficient if the iterator sizes are skewed.
1 parent 4c24193 commit 7a05d63

15 files changed

+436
-83
lines changed

include/taco/index_notation/index_notation.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,23 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
634634
/// reorder takes a new ordering for a set of index variables that are directly nested in the iteration order
635635
IndexStmt reorder(std::vector<IndexVar> reorderedvars) const;
636636

637+
/// The mergeby transformation specifies how to merge iterators on
638+
/// the given index variable. By default, if an iterator is used for windowing
639+
/// it will be merged with the "gallop" strategy.
640+
/// All other iterators are merged with the "two finger" strategy.
641+
/// The two finger strategy merges by advancing each iterator one at a time,
642+
/// while the gallop strategy implements the exponential search algorithm.
643+
///
644+
/// Preconditions:
645+
/// This command applies to variables involving sparse iterators only;
646+
/// it is a no-op if the variable invovles any dense iterators.
647+
/// Any variable can be merged with the two finger strategy, whereas gallop
648+
/// only applies to a variable if its merge lattice has a single point
649+
/// (i.e. an intersection). For example, if a variable involves multiplications
650+
/// only, it can be merged with gallop.
651+
/// Furthermore, all iterators must be ordered for gallop to apply.
652+
IndexStmt mergeby(IndexVar i, MergeStrategy strategy) const;
653+
637654
/// The parallelize
638655
/// transformation tags an index variable for parallel execution. The
639656
/// transformation takes as an argument the type of parallel hardware
@@ -829,13 +846,14 @@ class Forall : public IndexStmt {
829846
Forall() = default;
830847
Forall(const ForallNode*);
831848
Forall(IndexVar indexVar, IndexStmt stmt);
832-
Forall(IndexVar indexVar, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0);
849+
Forall(IndexVar indexVar, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0);
833850

834851
IndexVar getIndexVar() const;
835852
IndexStmt getStmt() const;
836853

837854
ParallelUnit getParallelUnit() const;
838855
OutputRaceStrategy getOutputRaceStrategy() const;
856+
MergeStrategy getMergeStrategy() const;
839857

840858
size_t getUnrollFactor() const;
841859

@@ -844,7 +862,7 @@ class Forall : public IndexStmt {
844862

845863
/// Create a forall index statement.
846864
Forall forall(IndexVar i, IndexStmt stmt);
847-
Forall forall(IndexVar i, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0);
865+
Forall forall(IndexVar i, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0);
848866

849867

850868
/// A where statment has a producer statement that binds a tensor variable in

include/taco/index_notation/index_notation_nodes.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,15 +398,16 @@ struct YieldNode : public IndexStmtNode {
398398
};
399399

400400
struct ForallNode : public IndexStmtNode {
401-
ForallNode(IndexVar indexVar, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0)
402-
: indexVar(indexVar), stmt(stmt), parallel_unit(parallel_unit), output_race_strategy(output_race_strategy), unrollFactor(unrollFactor) {}
401+
ForallNode(IndexVar indexVar, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0)
402+
: indexVar(indexVar), stmt(stmt), merge_strategy(merge_strategy), parallel_unit(parallel_unit), output_race_strategy(output_race_strategy), unrollFactor(unrollFactor) {}
403403

404404
void accept(IndexStmtVisitorStrict* v) const {
405405
v->visit(this);
406406
}
407407

408408
IndexVar indexVar;
409409
IndexStmt stmt;
410+
MergeStrategy merge_strategy;
410411
ParallelUnit parallel_unit;
411412
OutputRaceStrategy output_race_strategy;
412413
size_t unrollFactor = 0;

include/taco/index_notation/transformations.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class AddSuchThatPredicates;
2222
class Parallelize;
2323
class TopoReorder;
2424
class SetAssembleStrategy;
25+
class SetMergeStrategy;
2526

2627
/// A transformation is an optimization that transforms a statement in the
2728
/// concrete index notation into a new statement that computes the same result
@@ -36,6 +37,7 @@ class Transformation {
3637
Transformation(TopoReorder);
3738
Transformation(AddSuchThatPredicates);
3839
Transformation(SetAssembleStrategy);
40+
Transformation(SetMergeStrategy);
3941

4042
IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const;
4143

@@ -206,6 +208,25 @@ class SetAssembleStrategy : public TransformationInterface {
206208
/// Print a SetAssembleStrategy command.
207209
std::ostream &operator<<(std::ostream &, const SetAssembleStrategy&);
208210

211+
class SetMergeStrategy : public TransformationInterface {
212+
public:
213+
SetMergeStrategy(IndexVar i, MergeStrategy strategy);
214+
215+
IndexVar geti() const;
216+
MergeStrategy getMergeStrategy() const;
217+
218+
IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const;
219+
220+
void print(std::ostream &os) const;
221+
222+
private:
223+
struct Content;
224+
std::shared_ptr<Content> content;
225+
};
226+
227+
/// Print a SetMergeStrategy command.
228+
std::ostream &operator<<(std::ostream &, const SetMergeStrategy&);
229+
209230
// Autoscheduling functions
210231

211232
/**

include/taco/ir_tags.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ enum class AssembleStrategy {
3333
};
3434
extern const char *AssembleStrategy_NAMES[];
3535

36+
/// MergeStrategy::TwoFinger merges iterators by incrementing one at a time
37+
/// MergeStrategy::Galloping merges iterators by exponential search (galloping)
38+
enum class MergeStrategy {
39+
TwoFinger, Gallop
40+
};
41+
extern const char *MergeStrategy_NAMES[];
42+
3643
}
3744

3845
#endif //TACO_IR_TAGS_H

include/taco/lower/lowerer_impl_imperative.h

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,18 @@ class LowererImplImperative : public LowererImpl {
146146
* \param statement
147147
* A concrete index notation statement to compute at the points in the
148148
* sparse iteration space described by the merge lattice.
149+
* \param mergeStrategy
150+
* A strategy for merging iterators. One of TwoFinger or Gallop.
149151
*
150152
* \return
151153
* IR code to compute the forall loop.
152154
*/
153155
virtual ir::Stmt lowerMergeLattice(MergeLattice lattice, IndexVar coordinateVar,
154156
IndexStmt statement,
155-
const std::set<Access>& reducedAccesses);
157+
const std::set<Access>& reducedAccesses,
158+
MergeStrategy mergeStrategy);
156159

157-
virtual ir::Stmt resolveCoordinate(std::vector<Iterator> mergers, ir::Expr coordinate, bool emitVarDecl);
160+
virtual ir::Stmt resolveCoordinate(std::vector<Iterator> mergers, ir::Expr coordinate, bool emitVarDecl, bool mergeWithMax);
158161

159162
/**
160163
* Lower the merge point at the top of the given lattice to code that iterates
@@ -169,23 +172,29 @@ class LowererImplImperative : public LowererImpl {
169172
* coordinate the merge point is at.
170173
* A concrete index notation statement to compute at the points in the
171174
* sparse iteration space region described by the merge point.
175+
* \param mergeWithMax
176+
* A boolean indicating whether coordinates should be combined with MAX instead of MIN.
177+
* MAX is needed when the iterators are merged with the Gallop strategy.
172178
*/
173179
virtual ir::Stmt lowerMergePoint(MergeLattice pointLattice,
174180
ir::Expr coordinate, IndexVar coordinateVar, IndexStmt statement,
175-
const std::set<Access>& reducedAccesses, bool resolvedCoordDeclared);
181+
const std::set<Access>& reducedAccesses, bool resolvedCoordDeclared,
182+
MergeStrategy mergestrategy);
176183

177184
/// Lower a merge lattice to cases.
178185
virtual ir::Stmt lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt,
179186
MergeLattice lattice,
180-
const std::set<Access>& reducedAccesses);
187+
const std::set<Access>& reducedAccesses,
188+
MergeStrategy mergeStrategy);
181189

182190
/// Lower a forall loop body.
183191
virtual ir::Stmt lowerForallBody(ir::Expr coordinate, IndexStmt stmt,
184192
std::vector<Iterator> locaters,
185193
std::vector<Iterator> inserters,
186194
std::vector<Iterator> appenders,
187195
MergeLattice caseLattice,
188-
const std::set<Access>& reducedAccesses);
196+
const std::set<Access>& reducedAccesses,
197+
MergeStrategy mergeStrategy);
189198

190199

191200
/// Lower a where statement.
@@ -375,7 +384,7 @@ class LowererImplImperative : public LowererImpl {
375384

376385
/// Conditionally increment iterator position variables.
377386
ir::Stmt codeToIncIteratorVars(ir::Expr coordinate, IndexVar coordinateVar,
378-
std::vector<Iterator> iterators, std::vector<Iterator> mergers);
387+
std::vector<Iterator> iterators, std::vector<Iterator> mergers, MergeStrategy strategy);
379388

380389
ir::Stmt codeToLoadCoordinatesFromPosIterators(std::vector<Iterator> iterators, bool declVars);
381390

@@ -410,7 +419,8 @@ class LowererImplImperative : public LowererImpl {
410419
/// Lowers a merge lattice to cases assuming there are no more loops to be emitted in stmt.
411420
/// Will emit checks for explicit zeros for each mode iterator and each locator in the lattice.
412421
ir::Stmt lowerMergeCasesWithExplicitZeroChecks(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt,
413-
MergeLattice lattice, const std::set<Access>& reducedAccesses);
422+
MergeLattice lattice, const std::set<Access>& reducedAccesses,
423+
MergeStrategy mergeStrategy);
414424

415425
/// Constructs cases comparing the coordVar for each iterator to the resolved coordinate.
416426
/// Returns a vector where coordComparisons[i] corresponds to a case for iters[i]
@@ -444,7 +454,7 @@ class LowererImplImperative : public LowererImpl {
444454
/// The map must be of iterators to exprs of boolean types
445455
std::vector<ir::Stmt> lowerCasesFromMap(std::map<Iterator, ir::Expr> iteratorToCondition,
446456
ir::Expr coordinate, IndexStmt stmt, const MergeLattice& lattice,
447-
const std::set<Access>& reducedAccesses);
457+
const std::set<Access>& reducedAccesses, MergeStrategy mergeStrategy);
448458

449459
/// Constructs an expression which checks if this access is "zero"
450460
ir::Expr constructCheckForAccessZero(Access);

src/codegen/codegen_c.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,28 @@ const string cHeaders =
6262
"int cmp(const void *a, const void *b) {\n"
6363
" return *((const int*)a) - *((const int*)b);\n"
6464
"}\n"
65+
// Increment arrayStart until array[arrayStart] >= target or arrayStart >= arrayEnd
66+
// using an exponential search algorithm: https://en.wikipedia.org/wiki/Exponential_search.
67+
"int taco_gallop(int *array, int arrayStart, int arrayEnd, int target) {\n"
68+
" if (array[arrayStart] >= target || arrayStart >= arrayEnd) {\n"
69+
" return arrayStart;\n"
70+
" }\n"
71+
" int step = 1;\n"
72+
" int curr = arrayStart;\n"
73+
" while (curr + step < arrayEnd && array[curr + step] < target) {\n"
74+
" curr += step;\n"
75+
" step = step * 2;\n"
76+
" }\n"
77+
"\n"
78+
" step = step / 2;\n"
79+
" while (step > 0) {\n"
80+
" if (curr + step < arrayEnd && array[curr + step] < target) {\n"
81+
" curr += step;\n"
82+
" }\n"
83+
" step = step / 2;\n"
84+
" }\n"
85+
" return curr+1;\n"
86+
"}\n"
6587
"int taco_binarySearchAfter(int *array, int arrayStart, int arrayEnd, int target) {\n"
6688
" if (array[arrayStart] >= target) {\n"
6789
" return arrayStart;\n"

src/index_notation/index_notation.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,6 +1907,15 @@ IndexStmt IndexStmt::reorder(std::vector<IndexVar> reorderedvars) const {
19071907
return transformed;
19081908
}
19091909

1910+
IndexStmt IndexStmt::mergeby(IndexVar i, MergeStrategy strategy) const {
1911+
string reason;
1912+
IndexStmt transformed = SetMergeStrategy(i, strategy).apply(*this, &reason);
1913+
if (!transformed.defined()) {
1914+
taco_uerror << reason;
1915+
}
1916+
return transformed;
1917+
}
1918+
19101919
IndexStmt IndexStmt::parallelize(IndexVar i, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy) const {
19111920
string reason;
19121921
IndexStmt transformed = Parallelize(i, parallel_unit, output_race_strategy).apply(*this, &reason);
@@ -2017,7 +2026,7 @@ IndexStmt IndexStmt::unroll(IndexVar i, size_t unrollFactor) const {
20172026

20182027
void visit(const ForallNode* node) {
20192028
if (node->indexVar == i) {
2020-
stmt = Forall(i, rewrite(node->stmt), node->parallel_unit, node->output_race_strategy, unrollFactor);
2029+
stmt = Forall(i, rewrite(node->stmt), node->merge_strategy, node->parallel_unit, node->output_race_strategy, unrollFactor);
20212030
}
20222031
else {
20232032
IndexNotationRewriter::visit(node);
@@ -2125,11 +2134,11 @@ Forall::Forall(const ForallNode* n) : IndexStmt(n) {
21252134
}
21262135

21272136
Forall::Forall(IndexVar indexVar, IndexStmt stmt)
2128-
: Forall(indexVar, stmt, ParallelUnit::NotParallel, OutputRaceStrategy::IgnoreRaces) {
2137+
: Forall(indexVar, stmt, MergeStrategy::TwoFinger, ParallelUnit::NotParallel, OutputRaceStrategy::IgnoreRaces) {
21292138
}
21302139

2131-
Forall::Forall(IndexVar indexVar, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor)
2132-
: Forall(new ForallNode(indexVar, stmt, parallel_unit, output_race_strategy, unrollFactor)) {
2140+
Forall::Forall(IndexVar indexVar, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor)
2141+
: Forall(new ForallNode(indexVar, stmt, merge_strategy, parallel_unit, output_race_strategy, unrollFactor)) {
21332142
}
21342143

21352144
IndexVar Forall::getIndexVar() const {
@@ -2148,6 +2157,10 @@ OutputRaceStrategy Forall::getOutputRaceStrategy() const {
21482157
return getNode(*this)->output_race_strategy;
21492158
}
21502159

2160+
MergeStrategy Forall::getMergeStrategy() const {
2161+
return getNode(*this)->merge_strategy;
2162+
}
2163+
21512164
size_t Forall::getUnrollFactor() const {
21522165
return getNode(*this)->unrollFactor;
21532166
}
@@ -2156,8 +2169,8 @@ Forall forall(IndexVar i, IndexStmt stmt) {
21562169
return Forall(i, stmt);
21572170
}
21582171

2159-
Forall forall(IndexVar i, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor) {
2160-
return Forall(i, stmt, parallel_unit, output_race_strategy, unrollFactor);
2172+
Forall forall(IndexVar i, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor) {
2173+
return Forall(i, stmt, merge_strategy, parallel_unit, output_race_strategy, unrollFactor);
21612174
}
21622175

21632176
template <> bool isa<Forall>(IndexStmt s) {
@@ -3938,7 +3951,7 @@ struct Zero : public IndexNotationRewriterStrict {
39383951
stmt = op;
39393952
}
39403953
else {
3941-
stmt = new ForallNode(op->indexVar, body, op->parallel_unit, op->output_race_strategy, op->unrollFactor);
3954+
stmt = new ForallNode(op->indexVar, body, op->merge_strategy, op->parallel_unit, op->output_race_strategy, op->unrollFactor);
39423955
}
39433956
}
39443957

src/index_notation/index_notation_rewriter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ void IndexNotationRewriter::visit(const ForallNode* op) {
185185
stmt = op;
186186
}
187187
else {
188-
stmt = new ForallNode(op->indexVar, s, op->parallel_unit, op->output_race_strategy, op->unrollFactor);
188+
stmt = new ForallNode(op->indexVar, s, op->merge_strategy, op->parallel_unit, op->output_race_strategy, op->unrollFactor);
189189
}
190190
}
191191

@@ -406,7 +406,7 @@ struct ReplaceIndexVars : public IndexNotationRewriter {
406406
stmt = op;
407407
}
408408
else {
409-
stmt = new ForallNode(iv, s, op->parallel_unit, op->output_race_strategy,
409+
stmt = new ForallNode(iv, s, op->merge_strategy, op->parallel_unit, op->output_race_strategy,
410410
op->unrollFactor);
411411
}
412412
}

0 commit comments

Comments
 (0)