2727#include " taco/util/env.h"
2828#include " taco/util/collections.h"
2929#include " taco/cuda.h"
30- #include < taco/index_notation/transformations.h>
30+ #include " taco/index_notation/transformations.h"
31+ #include " taco/index_notation/index_notation_visitor.h"
32+ #include " taco/index_notation/index_notation_nodes.h"
3133
3234using namespace std ;
3335using namespace taco ;
@@ -112,6 +114,11 @@ static void printUsageInfo() {
112114 " long, longlong, float, double, complexfloat, complexdouble"
113115 " Examples: A:uint16, b:long and D:complexfloat." );
114116 cout << endl;
117+ printFlag (" s=\" <command>(<params>)\" " ,
118+ " Specify a scheduling command to apply to the generated code. "
119+ " Parameters take the form of a comma-delimited list. "
120+ " Examples: split(i,i0,i1,16), precompute(A(i,j)*x(j),i,i)." );
121+ cout << endl;
115122 printFlag (" c" ,
116123 " Generate compute kernel that simultaneously does assembly." );
117124 cout << endl;
@@ -201,6 +208,261 @@ static void printCommandLine(ostream& os, int argc, char* argv[]) {
201208 }
202209}
203210
211+ static bool setSchedulingCommands (istream& in, parser::Parser& parser, IndexStmt& stmt) {
212+ auto findVar = [&stmt](string name) {
213+ ProvenanceGraph graph (stmt);
214+ for (auto v : graph.getAllIndexVars ()) {
215+ if (v.getName () == name) {
216+ return v;
217+ }
218+ }
219+
220+ throw " Index variable not defined in statement." ;
221+ };
222+
223+ bool isGPU = false ;
224+
225+ while (true ) {
226+ string command;
227+ in >> command;
228+
229+ if (command == " pos" ) {
230+ string i, ipos;
231+ in >> i;
232+ in >> ipos;
233+
234+ string tensor;
235+ in >> tensor;
236+
237+ for (auto a : getArgumentAccesses (stmt)) {
238+ if (a.getTensorVar ().getName () == tensor) {
239+ IndexVar derived (ipos);
240+ stmt = stmt.pos (findVar (i), derived, a);
241+ goto end;
242+ }
243+ }
244+
245+ } else if (command == " fuse" ) {
246+ string i, j, f;
247+ in >> i;
248+ in >> j;
249+ in >> f;
250+
251+ IndexVar fused (f);
252+ stmt = stmt.fuse (findVar (i), findVar (j), fused);
253+
254+ } else if (command == " split" ) {
255+ string i, i1, i2;
256+ in >> i;
257+ in >> i1;
258+ in >> i2;
259+
260+ size_t splitFactor;
261+ in >> splitFactor;
262+
263+ IndexVar split1 (i1);
264+ IndexVar split2 (i2);
265+ stmt = stmt.split (findVar (i), split1, split2, splitFactor);
266+
267+ // } else if (command == "divide") {
268+ // string i, i1, i2;
269+ // in >> i;
270+ // in >> i1;
271+ // in >> i2;
272+
273+ // size_t divideFactor;
274+ // in >> divideFactor;
275+
276+ // IndexVar divide1(i1);
277+ // IndexVar divide2(i2);
278+ // stmt = stmt.divide(findVar(i), divide1, divide2, divideFactor);
279+
280+ } else if (command == " precompute" ) {
281+ string exprStr, i, iw;
282+ in >> exprStr;
283+ in >> i;
284+ in >> iw;
285+
286+ IndexVar orig = findVar (i);
287+ IndexVar pre ;
288+ try {
289+ pre = findVar (iw);
290+ } catch (const char * e) {
291+ pre = IndexVar (iw);
292+ }
293+
294+ struct GetExpr : public IndexNotationVisitor {
295+ using IndexNotationVisitor::visit;
296+
297+ string exprStr;
298+ IndexExpr expr;
299+
300+ void setExprStr (string input) {
301+ exprStr = input;
302+ exprStr.erase (remove (exprStr.begin (), exprStr.end (), ' ' ), exprStr.end ());
303+ }
304+
305+ string toString (IndexExpr e) {
306+ stringstream tempStream;
307+ tempStream << e;
308+ string tempStr = tempStream.str ();
309+ tempStr.erase (remove (tempStr.begin (), tempStr.end (), ' ' ), tempStr.end ());
310+ return tempStr;
311+ }
312+
313+ void visit (const AccessNode* node) {
314+ IndexExpr currentExpr (node);
315+ if (toString (currentExpr) == exprStr) {
316+ expr = currentExpr;
317+ }
318+ else {
319+ IndexNotationVisitor::visit (node);
320+ }
321+ }
322+
323+ void visit (const UnaryExprNode* node) {
324+ IndexExpr currentExpr (node);
325+ if (toString (currentExpr) == exprStr) {
326+ expr = currentExpr;
327+ }
328+ else {
329+ IndexNotationVisitor::visit (node);
330+ }
331+ }
332+
333+ void visit (const BinaryExprNode* node) {
334+ IndexExpr currentExpr (node);
335+ if (toString (currentExpr) == exprStr) {
336+ expr = currentExpr;
337+ }
338+ else {
339+ IndexNotationVisitor::visit (node);
340+ }
341+ }
342+ };
343+
344+ GetExpr visitor;
345+ visitor.setExprStr (exprStr);
346+ stmt.accept (&visitor);
347+
348+ Dimension dim;
349+ auto domains = stmt.getIndexVarDomains ();
350+ auto it = domains.find (orig);
351+ if (it != domains.end ()) {
352+ dim = it->second ;
353+ } else {
354+ dim = Dimension (orig);
355+ }
356+
357+ TensorVar workspace (" workspace" , Type (Float64, {dim}), Dense);
358+ stmt = stmt.precompute (visitor.expr , orig, pre , workspace);
359+
360+ } else if (command == " reorder" ) {
361+ string line;
362+ getline (in, line);
363+ stringstream temp;
364+ temp << line;
365+
366+ vector<IndexVar> reorderedVars;
367+ string var;
368+ while (temp >> var) {
369+ reorderedVars.push_back (findVar (var));
370+ }
371+
372+ stmt = stmt.reorder (reorderedVars);
373+
374+ } else if (command == " bound" ) {
375+ string i, i1;
376+ in >> i;
377+ in >> i1;
378+
379+ size_t bound;
380+ in >> bound;
381+
382+ string type;
383+ in >> type;
384+
385+ BoundType bound_type;
386+ if (type == " MinExact" ) {
387+ bound_type = BoundType::MinExact;
388+ } else if (type == " MinConstraint" ) {
389+ bound_type = BoundType::MinConstraint;
390+ } else if (type == " MaxExact" ) {
391+ bound_type = BoundType::MaxExact;
392+ } else if (type == " MaxConstraint" ) {
393+ bound_type = BoundType::MaxConstraint;
394+ } else {
395+ taco_uerror << " Bound type not defined." ;
396+ goto end;
397+ }
398+
399+ IndexVar bound1 (i1);
400+ stmt = stmt.bound (findVar (i), bound1, bound, bound_type);
401+
402+ } else if (command == " unroll" ) {
403+ string i;
404+ in >> i;
405+
406+ size_t unrollFactor;
407+ in >> unrollFactor;
408+
409+ stmt = stmt.unroll (findVar (i), unrollFactor);
410+
411+ } else if (command == " parallelize" ) {
412+ string i, unit, strategy;
413+ in >> i;
414+ in >> unit;
415+ in >> strategy;
416+
417+ ParallelUnit parallel_unit;
418+ if (unit == " NotParallel" ) {
419+ parallel_unit = ParallelUnit::NotParallel;
420+ } else if (unit == " GPUBlock" ) {
421+ parallel_unit = ParallelUnit::GPUBlock;
422+ isGPU = true ;
423+ } else if (unit == " GPUWarp" ) {
424+ parallel_unit = ParallelUnit::GPUWarp;
425+ isGPU = true ;
426+ } else if (unit == " GPUThread" ) {
427+ parallel_unit = ParallelUnit::GPUThread;
428+ isGPU = true ;
429+ } else if (unit == " CPUThread" ) {
430+ parallel_unit = ParallelUnit::CPUThread;
431+ } else if (unit == " CPUVector" ) {
432+ parallel_unit = ParallelUnit::CPUVector;
433+ } else {
434+ taco_uerror << " Parallel hardware not defined." ;
435+ goto end;
436+ }
437+
438+ OutputRaceStrategy output_race_strategy;
439+ if (strategy == " IgnoreRaces" ) {
440+ output_race_strategy = OutputRaceStrategy::IgnoreRaces;
441+ } else if (strategy == " NoRaces" ) {
442+ output_race_strategy = OutputRaceStrategy::NoRaces;
443+ } else if (strategy == " Atomics" ) {
444+ output_race_strategy = OutputRaceStrategy::Atomics;
445+ } else if (strategy == " Temporary" ) {
446+ output_race_strategy = OutputRaceStrategy::Temporary;
447+ } else if (strategy == " ParallelReduction" ) {
448+ output_race_strategy = OutputRaceStrategy::ParallelReduction;
449+ } else {
450+ taco_uerror << " Race strategy not defined." ;
451+ goto end;
452+ }
453+
454+ stmt = stmt.parallelize (findVar (i), parallel_unit, output_race_strategy);
455+
456+ } else {
457+ break ;
458+ }
459+
460+ end:;
461+ }
462+
463+ return isGPU;
464+ }
465+
204466int main (int argc, char * argv[]) {
205467 if (argc < 2 ) {
206468 printUsageInfo ();
@@ -228,6 +490,8 @@ int main(int argc, char* argv[]) {
228490 bool readKernels = false ;
229491 bool cuda = false ;
230492
493+ bool setSchedule = false ;
494+
231495 ParallelSchedule sched = ParallelSchedule::Static;
232496 int chunkSize = 0 ;
233497 int nthreads = 0 ;
@@ -256,6 +520,8 @@ int main(int argc, char* argv[]) {
256520
257521 vector<string> kernelFilenames;
258522
523+ vector<string> scheduleCommands;
524+
259525 for (int i = 1 ; i < argc; i++) {
260526 string arg = argv[i];
261527 vector<string> argparts = util::split (arg, " =" );
@@ -543,6 +809,30 @@ int main(int argc, char* argv[]) {
543809 else if (" -print-kernels" == argName) {
544810 printKernels = true ;
545811 }
812+ else if (" -s" == argName) {
813+ setSchedule = true ;
814+ bool insideCall = false ;
815+ bool parsingExpr = false ;
816+
817+ std::replace_if (argValue.begin (), argValue.end (), [&insideCall, &parsingExpr](char c) {
818+ if (c == ' (' ) {
819+ if (insideCall) {
820+ parsingExpr = true ; // need to handle precompute case specially
821+ } else {
822+ insideCall = true ;
823+ return true ;
824+ }
825+ } else if (c == ' ,' ) {
826+ return !parsingExpr;
827+ } else if (c == ' )' ) {
828+ bool previous = parsingExpr;
829+ parsingExpr = false ;
830+ return !previous;
831+ }
832+ return false ;
833+ }, ' ' );
834+ scheduleCommands.push_back (argValue);
835+ }
546836 else {
547837 if (exprStr.size () != 0 ) {
548838 printUsageInfo ();
@@ -623,16 +913,6 @@ int main(int argc, char* argv[]) {
623913 }
624914 }
625915
626- if (cuda) {
627- if (!CUDA_BUILT && benchmark) {
628- return reportError (" TACO must be built for CUDA (cmake -DCUDA=ON ..) to benchmark" , 2 );
629- }
630- set_CUDA_codegen_enabled (true );
631- }
632- else {
633- set_CUDA_codegen_enabled (false );
634- }
635-
636916 ir::Stmt assemble;
637917 ir::Stmt compute;
638918 ir::Stmt evaluate;
@@ -645,6 +925,26 @@ int main(int argc, char* argv[]) {
645925 stmt = reorderLoopsTopologically (stmt);
646926 stmt = insertTemporaries (stmt);
647927 stmt = parallelizeOuterLoop (stmt);
928+
929+ if (setSchedule) {
930+ stringstream scheduleStream;
931+ for (string command : scheduleCommands) {
932+ scheduleStream << command << endl;
933+ }
934+
935+ cuda |= setSchedulingCommands (scheduleStream, parser, stmt);
936+ }
937+
938+ if (cuda) {
939+ if (!CUDA_BUILT && benchmark) {
940+ return reportError (" TACO must be built for CUDA (cmake -DCUDA=ON ..) to benchmark" , 2 );
941+ }
942+ set_CUDA_codegen_enabled (true );
943+ }
944+ else {
945+ set_CUDA_codegen_enabled (false );
946+ }
947+
648948 stmt = scalarPromote (stmt);
649949 if (printConcrete) {
650950 cout << stmt << endl;
@@ -749,7 +1049,7 @@ int main(int argc, char* argv[]) {
7491049 " * For both, the `_COO_pos` arrays contain two elements, where the first is 0\n "
7501050 " * and the second is the number of nonzeros in the tensor.\n "
7511051 " */" ;
752-
1052+
7531053 vector<ir::Stmt> packs;
7541054 for (auto a : getArgumentAccesses (stmt)) {
7551055 TensorVar tensor = a.getTensorVar ();
0 commit comments