Skip to content

Commit cb4731d

Browse files
Merge pull request #339 from RawnH/workspace_reuse
Workspace reuse
2 parents f051a8f + 8471869 commit cb4731d

File tree

18 files changed

+698
-51
lines changed

18 files changed

+698
-51
lines changed

include/taco/index_notation/index_notation.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,10 @@ std::vector<TensorVar> getArguments(IndexStmt stmt);
953953
/// Returns the temporaries in the index statement, in the order they appear.
954954
std::vector<TensorVar> getTemporaries(IndexStmt stmt);
955955

956+
// [Olivia]
957+
/// Returns the temporaries in the index statement, in the order they appear.
958+
std::map<Forall, Where> getTemporaryLocations(IndexStmt stmt);
959+
956960
/// Returns the tensors in the index statement.
957961
std::vector<TensorVar> getTensorVars(IndexStmt stmt);
958962

include/taco/ir/ir.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ enum class IRNodeType {
6565
BlankLine,
6666
Print,
6767
GetProperty,
68-
Break
68+
Break,
69+
Sort
6970
};
7071

7172
enum class TensorProperty {
@@ -725,6 +726,13 @@ struct Break : public StmtNode<Break> {
725726
static const IRNodeType _type_info = IRNodeType::Break;
726727
};
727728

729+
struct Sort : public StmtNode<Sort> {
730+
std::vector<Expr> args;
731+
static Stmt make(std::vector<Expr> args);
732+
733+
static const IRNodeType _type_info = IRNodeType::Sort;
734+
};
735+
728736
/** A print statement.
729737
* Takes in a printf-style format string and Exprs to pass
730738
* for the values.

include/taco/ir/ir_printer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class IRPrinter : public IRVisitorStrict {
6868
virtual void visit(const Break*);
6969
virtual void visit(const Print*);
7070
virtual void visit(const GetProperty*);
71+
virtual void visit(const Sort*);
7172

7273
std::ostream &stream;
7374
int indent;

include/taco/ir/ir_rewriter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class IRRewriter : public IRVisitorStrict {
6868
virtual void visit(const Break* op);
6969
virtual void visit(const Print* op);
7070
virtual void visit(const GetProperty* op);
71+
virtual void visit(const Sort *op);
7172
};
7273

7374
}}

include/taco/ir/ir_visitor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ struct BlankLine;
4848
struct Break;
4949
struct Print;
5050
struct GetProperty;
51+
struct Sort;
5152

5253
/// Extend this class to visit every node in the IR.
5354
class IRVisitorStrict {
@@ -98,6 +99,7 @@ class IRVisitorStrict {
9899
virtual void visit(const Break*) = 0;
99100
virtual void visit(const Print*) = 0;
100101
virtual void visit(const GetProperty*) = 0;
102+
virtual void visit(const Sort*) = 0;
101103
};
102104

103105

@@ -151,6 +153,7 @@ class IRVisitor : public IRVisitorStrict {
151153
virtual void visit(const Break* op);
152154
virtual void visit(const Print* op);
153155
virtual void visit(const GetProperty* op);
156+
virtual void visit(const Sort* op);
154157
};
155158

156159
}}

include/taco/lower/lowerer_impl.h

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ class LowererImpl : public util::Uncopyable {
8181
std::set<Access> reducedAccesses,
8282
ir::Stmt recoveryStmt);
8383

84+
/// Lower a forall that iterates over all the coordinates in the forall index
85+
/// var's dimension, and locates tensor positions from the locate iterators.
86+
virtual ir::Stmt lowerForallDenseAcceleration(Forall forall,
87+
std::vector<Iterator> locaters,
88+
std::vector<Iterator> inserters,
89+
std::vector<Iterator> appenders,
90+
std::set<Access> reducedAccesses,
91+
ir::Stmt recoveryStmt);
92+
93+
8494
/// Lower a forall that iterates over the coordinates in the iterator, and
8595
/// locates tensor positions from the locate iterators.
8696
virtual ir::Stmt lowerForallCoordinate(Forall forall, Iterator iterator,
@@ -333,17 +343,29 @@ class LowererImpl : public util::Uncopyable {
333343
ir::Stmt codeToInitializeIteratorVars(std::vector<Iterator> iterators, std::vector<Iterator> rangers, std::vector<Iterator> mergers, ir::Expr coord, IndexVar coordinateVar);
334344
ir::Stmt codeToInitializeIteratorVar(Iterator iterator, std::vector<Iterator> iterators, std::vector<Iterator> rangers, std::vector<Iterator> mergers, ir::Expr coordinate, IndexVar coordinateVar);
335345

346+
/// Returns true iff the temporary used in the where statement is dense and sparse iteration over that
347+
/// temporary can be automaticallty supported by the compiler.
348+
bool canAccelerateDenseTemp(Where where);
349+
350+
/// Initializes a temporary workspace
351+
std::vector<ir::Stmt> codeToInitializeTemporary(Where where);
352+
353+
/// Gets the size of a temporary tensorVar in the where statement
354+
ir::Expr getTemporarySize(Where where);
355+
356+
/// Initializes helper arrays to give dense workspaces sparse acceleration
357+
std::vector<ir::Stmt> codeToInitializeDenseAcceleratorArrays(Where where);
336358

337359
/// Recovers a derived indexvar from an underived variable.
338360
ir::Stmt codeToRecoverDerivedIndexVar(IndexVar underived, IndexVar indexVar, bool emitVarDecl);
339361

340-
/// Conditionally increment iterator position variables.
362+
/// Conditionally increment iterator position variables.
341363
ir::Stmt codeToIncIteratorVars(ir::Expr coordinate, IndexVar coordinateVar,
342364
std::vector<Iterator> iterators, std::vector<Iterator> mergers);
343365

344366
ir::Stmt codeToLoadCoordinatesFromPosIterators(std::vector<Iterator> iterators, bool declVars);
345367

346-
/// Create statements to append coordinate to result modes.
368+
/// Create statements to append coordinate to result modes.
347369
ir::Stmt appendCoordinate(std::vector<Iterator> appenders, ir::Expr coord);
348370

349371
/// Create statements to append positions to result modes.
@@ -363,6 +385,9 @@ class LowererImpl : public util::Uncopyable {
363385
int markAssignsAtomicDepth = 0;
364386
ParallelUnit atomicParallelUnit;
365387

388+
/// Map used to hoist temporary workspace initialization
389+
std::map<Forall, Where> temporaryInitialization;
390+
366391
/// Map from tensor variables in index notation to variables in the IR
367392
std::map<TensorVar, ir::Expr> tensorVars;
368393

@@ -371,6 +396,15 @@ class LowererImpl : public util::Uncopyable {
371396
};
372397
std::map<TensorVar, TemporaryArrays> temporaryArrays;
373398

399+
/// Map form temporary to indexList var if accelerating dense workspace
400+
std::map<TensorVar, ir::Expr> tempToIndexList;
401+
402+
/// Map form temporary to indexListSize if accelerating dense workspace
403+
std::map<TensorVar, ir::Expr> tempToIndexListSize;
404+
405+
/// Map form temporary to bitGuard var if accelerating dense workspace
406+
std::map<TensorVar, ir::Expr> tempToBitGuard;
407+
374408
/// Map from result tensors to variables tracking values array capacity.
375409
std::map<ir::Expr, ir::Expr> capacityVars;
376410

src/codegen/codegen_c.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,15 @@ class CodeGen_C::FindVars : public IRVisitor {
182182

183183
virtual void visit(const Var *op) {
184184
if (varMap.count(op) == 0) {
185-
varMap[op] = codeGen->genUniqueName(op->name);
185+
varMap[op] = op->is_ptr? op->name : codeGen->genUniqueName(op->name);
186186
}
187187
}
188188

189189
virtual void visit(const VarDecl *op) {
190190
if (!util::contains(localVars, op->var)) {
191191
localVars.push_back(op->var);
192192
}
193+
op->var.accept(this);
193194
op->rhs.accept(this);
194195
}
195196

src/codegen/codegen_cuda.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ class CodeGen_CUDA::FindVars : public IRVisitor {
240240

241241
virtual void visit(const Var *op) {
242242
if (varMap.count(op) == 0 && !inBlock) {
243-
varMap[op] = codeGen->genUniqueName(op->name);
243+
varMap[op] = op->is_ptr? op->name : codeGen->genUniqueName(op->name);
244244
}
245245
}
246246

src/index_notation/index_notation.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2118,8 +2118,23 @@ bool isConcreteNotation(IndexStmt stmt, std::string* reason) {
21182118
return;
21192119
}
21202120

2121+
// Handles derived vars on RHS with underived vars on LHS.
2122+
Assignment assignPtrWrapper = Assignment(op);
2123+
std::vector<IndexVar> possibleReductionVars = assignPtrWrapper.getReductionVars();
2124+
std::vector<IndexVar> freeVars = assignPtrWrapper.getFreeVars();
2125+
std::set<IndexVar> freeVarsSet(freeVars.begin(), freeVars.end());
2126+
2127+
int numReductionVars = 0;
2128+
for(const auto& reductionVar : possibleReductionVars) {
2129+
std::vector<IndexVar> underivedParents = provGraph.getUnderivedAncestors(reductionVar);
2130+
for(const auto& parent : underivedParents) {
2131+
if(!util::contains(freeVarsSet, parent)) {
2132+
++numReductionVars;
2133+
}
2134+
}
2135+
}
21212136
// allow introducing precompute loops where we set a temporary to values instead of +=
2122-
if (Assignment(op).getReductionVars().size() > 0 &&
2137+
if (numReductionVars > 0 &&
21232138
op->op == IndexExpr() && !inWhereProducer) {
21242139
*reason = "reduction variables in concrete notation must be dominated "
21252140
"by compound assignments (such as +=)";
@@ -2342,6 +2357,22 @@ vector<TensorVar> getArguments(IndexStmt stmt) {
23422357
return result;
23432358
}
23442359

2360+
std::map<Forall, Where> getTemporaryLocations(IndexStmt stmt) {
2361+
map<Forall, Where> temporaryLocs;
2362+
Forall f = Forall();
2363+
match(stmt,
2364+
function<void(const ForallNode*, Matcher*)>([&](const ForallNode* op, Matcher* ctx) {
2365+
f = op;
2366+
ctx->match(op->stmt);
2367+
}),
2368+
function<void(const WhereNode*, Matcher*)>([&](const WhereNode* w, Matcher* ctx) {
2369+
if (!(f == IndexStmt()))
2370+
temporaryLocs.insert({f, Where(w)});
2371+
})
2372+
);
2373+
return temporaryLocs;
2374+
}
2375+
23452376
std::vector<TensorVar> getTemporaries(IndexStmt stmt) {
23462377
vector<TensorVar> temporaries;
23472378
bool firstAssignment = true;

src/index_notation/index_notation_printer.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,12 @@ void IndexNotationPrinter::visit(const NegNode* op) {
8181
Precedence precedence = Precedence::NEG;
8282
bool parenthesize = precedence > parentPrecedence;
8383
parentPrecedence = precedence;
84-
os << "-";
84+
if(op->getDataType().isBool()) {
85+
os << "!";
86+
} else {
87+
os << "-";
88+
}
89+
8590
if (parenthesize) {
8691
os << "(";
8792
}

0 commit comments

Comments
 (0)