Skip to content

Commit 1505153

Browse files
committed
Merge branch 'new-test-target' into test-target
g especially if it merges an updated upstream into a topic branch. merge
2 parents 8af953b + 58567e2 commit 1505153

29 files changed

+2259
-1885
lines changed

include/taco/format.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ class ModeFormat {
9797
/// Properties of a mode format
9898
enum Property {
9999
FULL, NOT_FULL, ORDERED, NOT_ORDERED, UNIQUE, NOT_UNIQUE, BRANCHLESS,
100-
NOT_BRANCHLESS, COMPACT, NOT_COMPACT, ZEROLESS, NOT_ZEROLESS
100+
NOT_BRANCHLESS, COMPACT, NOT_COMPACT, ZEROLESS, NOT_ZEROLESS, PADDED,
101+
NOT_PADDED
101102
};
102103

103104
/// Instantiates an undefined mode format
@@ -129,6 +130,7 @@ class ModeFormat {
129130
bool isBranchless() const;
130131
bool isCompact() const;
131132
bool isZeroless() const;
133+
bool isPadded() const;
132134

133135
/// Returns true if a mode format has a specific capability, false otherwise
134136
bool hasCoordValIter() const;

include/taco/index_notation/index_notation.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,23 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
637637
/// reorder takes a new ordering for a set of index variables that are directly nested in the iteration order
638638
IndexStmt reorder(std::vector<IndexVar> reorderedvars) const;
639639

640+
/// The mergeby transformation specifies how to merge iterators on
641+
/// the given index variable. By default, if an iterator is used for windowing
642+
/// it will be merged with the "gallop" strategy.
643+
/// All other iterators are merged with the "two finger" strategy.
644+
/// The two finger strategy merges by advancing each iterator one at a time,
645+
/// while the gallop strategy implements the exponential search algorithm.
646+
///
647+
/// Preconditions:
648+
/// This command applies to variables involving sparse iterators only;
649+
/// it is a no-op if the variable invovles any dense iterators.
650+
/// Any variable can be merged with the two finger strategy, whereas gallop
651+
/// only applies to a variable if its merge lattice has a single point
652+
/// (i.e. an intersection). For example, if a variable involves multiplications
653+
/// only, it can be merged with gallop.
654+
/// Furthermore, all iterators must be ordered for gallop to apply.
655+
IndexStmt mergeby(IndexVar i, MergeStrategy strategy) const;
656+
640657
/// The parallelize
641658
/// transformation tags an index variable for parallel execution. The
642659
/// transformation takes as an argument the type of parallel hardware
@@ -835,13 +852,14 @@ class Forall : public IndexStmt {
835852
Forall() = default;
836853
Forall(const ForallNode*);
837854
Forall(IndexVar indexVar, IndexStmt stmt);
838-
Forall(IndexVar indexVar, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0);
855+
Forall(IndexVar indexVar, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0);
839856

840857
IndexVar getIndexVar() const;
841858
IndexStmt getStmt() const;
842859

843860
ParallelUnit getParallelUnit() const;
844861
OutputRaceStrategy getOutputRaceStrategy() const;
862+
MergeStrategy getMergeStrategy() const;
845863

846864
size_t getUnrollFactor() const;
847865

@@ -850,7 +868,7 @@ class Forall : public IndexStmt {
850868

851869
/// Create a forall index statement.
852870
Forall forall(IndexVar i, IndexStmt stmt);
853-
Forall forall(IndexVar i, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0);
871+
Forall forall(IndexVar i, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0);
854872

855873

856874
/// A where statment has a producer statement that binds a tensor variable in

include/taco/index_notation/index_notation_nodes.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,15 +398,16 @@ struct YieldNode : public IndexStmtNode {
398398
};
399399

400400
struct ForallNode : public IndexStmtNode {
401-
ForallNode(IndexVar indexVar, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0)
402-
: indexVar(indexVar), stmt(stmt), parallel_unit(parallel_unit), output_race_strategy(output_race_strategy), unrollFactor(unrollFactor) {}
401+
ForallNode(IndexVar indexVar, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0)
402+
: indexVar(indexVar), stmt(stmt), merge_strategy(merge_strategy), parallel_unit(parallel_unit), output_race_strategy(output_race_strategy), unrollFactor(unrollFactor) {}
403403

404404
void accept(IndexStmtVisitorStrict* v) const {
405405
v->visit(this);
406406
}
407407

408408
IndexVar indexVar;
409409
IndexStmt stmt;
410+
MergeStrategy merge_strategy;
410411
ParallelUnit parallel_unit;
411412
OutputRaceStrategy output_race_strategy;
412413
size_t unrollFactor = 0;

include/taco/index_notation/transformations.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class AddSuchThatPredicates;
2222
class Parallelize;
2323
class TopoReorder;
2424
class SetAssembleStrategy;
25+
class SetMergeStrategy;
2526

2627
/// A transformation is an optimization that transforms a statement in the
2728
/// concrete index notation into a new statement that computes the same result
@@ -36,6 +37,7 @@ class Transformation {
3637
Transformation(TopoReorder);
3738
Transformation(AddSuchThatPredicates);
3839
Transformation(SetAssembleStrategy);
40+
Transformation(SetMergeStrategy);
3941

4042
IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const;
4143

@@ -206,6 +208,25 @@ class SetAssembleStrategy : public TransformationInterface {
206208
/// Print a SetAssembleStrategy command.
207209
std::ostream &operator<<(std::ostream &, const SetAssembleStrategy&);
208210

211+
class SetMergeStrategy : public TransformationInterface {
212+
public:
213+
SetMergeStrategy(IndexVar i, MergeStrategy strategy);
214+
215+
IndexVar geti() const;
216+
MergeStrategy getMergeStrategy() const;
217+
218+
IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const;
219+
220+
void print(std::ostream &os) const;
221+
222+
private:
223+
struct Content;
224+
std::shared_ptr<Content> content;
225+
};
226+
227+
/// Print a SetMergeStrategy command.
228+
std::ostream &operator<<(std::ostream &, const SetMergeStrategy&);
229+
209230
// Autoscheduling functions
210231

211232
/**

include/taco/ir_tags.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ enum class AssembleStrategy {
3333
};
3434
extern const char *AssembleStrategy_NAMES[];
3535

36+
/// MergeStrategy::TwoFinger merges iterators by incrementing one at a time
37+
/// MergeStrategy::Galloping merges iterators by exponential search (galloping)
38+
enum class MergeStrategy {
39+
TwoFinger, Gallop
40+
};
41+
extern const char *MergeStrategy_NAMES[];
42+
3643
}
3744

3845
#endif //TACO_IR_TAGS_H

include/taco/lower/lowerer_impl_imperative.h

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,18 @@ class LowererImplImperative : public LowererImpl {
146146
* \param statement
147147
* A concrete index notation statement to compute at the points in the
148148
* sparse iteration space described by the merge lattice.
149+
* \param mergeStrategy
150+
* A strategy for merging iterators. One of TwoFinger or Gallop.
149151
*
150152
* \return
151153
* IR code to compute the forall loop.
152154
*/
153155
virtual ir::Stmt lowerMergeLattice(MergeLattice lattice, IndexVar coordinateVar,
154156
IndexStmt statement,
155-
const std::set<Access>& reducedAccesses);
157+
const std::set<Access>& reducedAccesses,
158+
MergeStrategy mergeStrategy);
156159

157-
virtual ir::Stmt resolveCoordinate(std::vector<Iterator> mergers, ir::Expr coordinate, bool emitVarDecl);
160+
virtual ir::Stmt resolveCoordinate(std::vector<Iterator> mergers, ir::Expr coordinate, bool emitVarDecl, bool mergeWithMax);
158161

159162
/**
160163
* Lower the merge point at the top of the given lattice to code that iterates
@@ -169,23 +172,29 @@ class LowererImplImperative : public LowererImpl {
169172
* coordinate the merge point is at.
170173
* A concrete index notation statement to compute at the points in the
171174
* sparse iteration space region described by the merge point.
175+
* \param mergeWithMax
176+
* A boolean indicating whether coordinates should be combined with MAX instead of MIN.
177+
* MAX is needed when the iterators are merged with the Gallop strategy.
172178
*/
173179
virtual ir::Stmt lowerMergePoint(MergeLattice pointLattice,
174180
ir::Expr coordinate, IndexVar coordinateVar, IndexStmt statement,
175-
const std::set<Access>& reducedAccesses, bool resolvedCoordDeclared);
181+
const std::set<Access>& reducedAccesses, bool resolvedCoordDeclared,
182+
MergeStrategy mergestrategy);
176183

177184
/// Lower a merge lattice to cases.
178185
virtual ir::Stmt lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt,
179186
MergeLattice lattice,
180-
const std::set<Access>& reducedAccesses);
187+
const std::set<Access>& reducedAccesses,
188+
MergeStrategy mergeStrategy);
181189

182190
/// Lower a forall loop body.
183191
virtual ir::Stmt lowerForallBody(ir::Expr coordinate, IndexStmt stmt,
184192
std::vector<Iterator> locaters,
185193
std::vector<Iterator> inserters,
186194
std::vector<Iterator> appenders,
187195
MergeLattice caseLattice,
188-
const std::set<Access>& reducedAccesses);
196+
const std::set<Access>& reducedAccesses,
197+
MergeStrategy mergeStrategy);
189198

190199

191200
/// Lower a where statement.
@@ -375,7 +384,7 @@ class LowererImplImperative : public LowererImpl {
375384

376385
/// Conditionally increment iterator position variables.
377386
ir::Stmt codeToIncIteratorVars(ir::Expr coordinate, IndexVar coordinateVar,
378-
std::vector<Iterator> iterators, std::vector<Iterator> mergers);
387+
std::vector<Iterator> iterators, std::vector<Iterator> mergers, MergeStrategy strategy);
379388

380389
ir::Stmt codeToLoadCoordinatesFromPosIterators(std::vector<Iterator> iterators, bool declVars);
381390

@@ -410,7 +419,8 @@ class LowererImplImperative : public LowererImpl {
410419
/// Lowers a merge lattice to cases assuming there are no more loops to be emitted in stmt.
411420
/// Will emit checks for explicit zeros for each mode iterator and each locator in the lattice.
412421
ir::Stmt lowerMergeCasesWithExplicitZeroChecks(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt,
413-
MergeLattice lattice, const std::set<Access>& reducedAccesses);
422+
MergeLattice lattice, const std::set<Access>& reducedAccesses,
423+
MergeStrategy mergeStrategy);
414424

415425
/// Constructs cases comparing the coordVar for each iterator to the resolved coordinate.
416426
/// Returns a vector where coordComparisons[i] corresponds to a case for iters[i]
@@ -444,7 +454,7 @@ class LowererImplImperative : public LowererImpl {
444454
/// The map must be of iterators to exprs of boolean types
445455
std::vector<ir::Stmt> lowerCasesFromMap(std::map<Iterator, ir::Expr> iteratorToCondition,
446456
ir::Expr coordinate, IndexStmt stmt, const MergeLattice& lattice,
447-
const std::set<Access>& reducedAccesses);
457+
const std::set<Access>& reducedAccesses, MergeStrategy mergeStrategy);
448458

449459
/// Constructs an expression which checks if this access is "zero"
450460
ir::Expr constructCheckForAccessZero(Access);

include/taco/lower/mode_format_impl.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,10 @@ class ModeFormatImpl {
106106
public:
107107
ModeFormatImpl(std::string name, bool isFull, bool isOrdered, bool isUnique,
108108
bool isBranchless, bool isCompact, bool isZeroless,
109-
bool hasCoordValIter, bool hasCoordPosIter, bool hasLocate,
110-
bool hasInsert, bool hasAppend, bool hasSeqInsertEdge,
111-
bool hasInsertCoord, bool isYieldPosPure);
109+
bool isPadded, bool hasCoordValIter, bool hasCoordPosIter,
110+
bool hasLocate, bool hasInsert, bool hasAppend,
111+
bool hasSeqInsertEdge, bool hasInsertCoord,
112+
bool isYieldPosPure);
112113

113114
virtual ~ModeFormatImpl();
114115

@@ -246,6 +247,7 @@ class ModeFormatImpl {
246247
const bool isBranchless;
247248
const bool isCompact;
248249
const bool isZeroless;
250+
const bool isPadded;
249251

250252
const bool hasCoordValIter;
251253
const bool hasCoordPosIter;

include/taco/lower/mode_format_singleton.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ class SingletonModeFormat : public ModeFormatImpl {
1010
using ModeFormatImpl::getInsertCoord;
1111

1212
SingletonModeFormat();
13-
SingletonModeFormat(bool isFull, bool isOrdered,
14-
bool isUnique, bool isZeroless, long long allocSize = DEFAULT_ALLOC_SIZE);
13+
SingletonModeFormat(bool isFull, bool isOrdered, bool isUnique,
14+
bool isZeroless, bool isPadded,
15+
long long allocSize = DEFAULT_ALLOC_SIZE);
1516

1617
~SingletonModeFormat() override {}
1718

src/codegen/codegen_c.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,28 @@ const string cHeaders =
6262
"int cmp(const void *a, const void *b) {\n"
6363
" return *((const int*)a) - *((const int*)b);\n"
6464
"}\n"
65+
// Increment arrayStart until array[arrayStart] >= target or arrayStart >= arrayEnd
66+
// using an exponential search algorithm: https://en.wikipedia.org/wiki/Exponential_search.
67+
"int taco_gallop(int *array, int arrayStart, int arrayEnd, int target) {\n"
68+
" if (array[arrayStart] >= target || arrayStart >= arrayEnd) {\n"
69+
" return arrayStart;\n"
70+
" }\n"
71+
" int step = 1;\n"
72+
" int curr = arrayStart;\n"
73+
" while (curr + step < arrayEnd && array[curr + step] < target) {\n"
74+
" curr += step;\n"
75+
" step = step * 2;\n"
76+
" }\n"
77+
"\n"
78+
" step = step / 2;\n"
79+
" while (step > 0) {\n"
80+
" if (curr + step < arrayEnd && array[curr + step] < target) {\n"
81+
" curr += step;\n"
82+
" }\n"
83+
" step = step / 2;\n"
84+
" }\n"
85+
" return curr+1;\n"
86+
"}\n"
6587
"int taco_binarySearchAfter(int *array, int arrayStart, int arrayEnd, int target) {\n"
6688
" if (array[arrayStart] >= target) {\n"
6789
" return arrayStart;\n"

src/format.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,11 @@ bool ModeFormat::hasProperties(const std::vector<Property>& properties) const {
187187
return false;
188188
}
189189
break;
190+
case PADDED:
191+
if (!isPadded()) {
192+
return false;
193+
}
194+
break;
190195
case NOT_FULL:
191196
if (isFull()) {
192197
return false;
@@ -217,6 +222,11 @@ bool ModeFormat::hasProperties(const std::vector<Property>& properties) const {
217222
return false;
218223
}
219224
break;
225+
case NOT_PADDED:
226+
if (isPadded()) {
227+
return false;
228+
}
229+
break;
220230
}
221231
}
222232
return true;
@@ -252,6 +262,11 @@ bool ModeFormat::isZeroless() const {
252262
return impl->isZeroless;
253263
}
254264

265+
bool ModeFormat::isPadded() const {
266+
taco_iassert(defined());
267+
return impl->isPadded;
268+
}
269+
255270
bool ModeFormat::hasCoordValIter() const {
256271
taco_iassert(defined());
257272
return impl->hasCoordValIter;

0 commit comments

Comments
 (0)