@@ -81,6 +81,16 @@ class LowererImpl : public util::Uncopyable {
8181 std::set<Access> reducedAccesses,
8282 ir::Stmt recoveryStmt);
8383
84+ // / Lower a forall that iterates over all the coordinates in the forall index
85+ // / var's dimension, and locates tensor positions from the locate iterators.
86+ virtual ir::Stmt lowerForallDenseAcceleration (Forall forall,
87+ std::vector<Iterator> locaters,
88+ std::vector<Iterator> inserters,
89+ std::vector<Iterator> appenders,
90+ std::set<Access> reducedAccesses,
91+ ir::Stmt recoveryStmt);
92+
93+
8494 // / Lower a forall that iterates over the coordinates in the iterator, and
8595 // / locates tensor positions from the locate iterators.
8696 virtual ir::Stmt lowerForallCoordinate (Forall forall, Iterator iterator,
@@ -333,17 +343,29 @@ class LowererImpl : public util::Uncopyable {
333343 ir::Stmt codeToInitializeIteratorVars (std::vector<Iterator> iterators, std::vector<Iterator> rangers, std::vector<Iterator> mergers, ir::Expr coord, IndexVar coordinateVar);
334344 ir::Stmt codeToInitializeIteratorVar (Iterator iterator, std::vector<Iterator> iterators, std::vector<Iterator> rangers, std::vector<Iterator> mergers, ir::Expr coordinate, IndexVar coordinateVar);
335345
346+ // / Returns true iff the temporary used in the where statement is dense and sparse iteration over that
347+ // / temporary can be automaticallty supported by the compiler.
348+ bool canAccelerateDenseTemp (Where where);
349+
350+ // / Initializes a temporary workspace
351+ std::vector<ir::Stmt> codeToInitializeTemporary (Where where);
352+
353+ // / Gets the size of a temporary tensorVar in the where statement
354+ ir::Expr getTemporarySize (Where where);
355+
356+ // / Initializes helper arrays to give dense workspaces sparse acceleration
357+ std::vector<ir::Stmt> codeToInitializeDenseAcceleratorArrays (Where where);
336358
337359 // / Recovers a derived indexvar from an underived variable.
338360 ir::Stmt codeToRecoverDerivedIndexVar (IndexVar underived, IndexVar indexVar, bool emitVarDecl);
339361
340- // / Conditionally increment iterator position variables.
362+ // / Conditionally increment iterator position variables.
341363 ir::Stmt codeToIncIteratorVars (ir::Expr coordinate, IndexVar coordinateVar,
342364 std::vector<Iterator> iterators, std::vector<Iterator> mergers);
343365
344366 ir::Stmt codeToLoadCoordinatesFromPosIterators (std::vector<Iterator> iterators, bool declVars);
345367
346- // / Create statements to append coordinate to result modes.
368+ // / Create statements to append coordinate to result modes.
347369 ir::Stmt appendCoordinate (std::vector<Iterator> appenders, ir::Expr coord);
348370
349371 // / Create statements to append positions to result modes.
@@ -363,6 +385,9 @@ class LowererImpl : public util::Uncopyable {
363385 int markAssignsAtomicDepth = 0 ;
364386 ParallelUnit atomicParallelUnit;
365387
388+ // / Map used to hoist temporary workspace initialization
389+ std::map<Forall, Where> temporaryInitialization;
390+
366391 // / Map from tensor variables in index notation to variables in the IR
367392 std::map<TensorVar, ir::Expr> tensorVars;
368393
@@ -371,6 +396,15 @@ class LowererImpl : public util::Uncopyable {
371396 };
372397 std::map<TensorVar, TemporaryArrays> temporaryArrays;
373398
399+ // / Map form temporary to indexList var if accelerating dense workspace
400+ std::map<TensorVar, ir::Expr> tempToIndexList;
401+
402+ // / Map form temporary to indexListSize if accelerating dense workspace
403+ std::map<TensorVar, ir::Expr> tempToIndexListSize;
404+
405+ // / Map form temporary to bitGuard var if accelerating dense workspace
406+ std::map<TensorVar, ir::Expr> tempToBitGuard;
407+
374408 // / Map from result tensors to variables tracking values array capacity.
375409 std::map<ir::Expr, ir::Expr> capacityVars;
376410
0 commit comments