Skip to content

Commit 87d8e30

Browse files
Merge pull request #314 from jwshi21/master
add scheduling language selection
2 parents fb4e6de + 61d136f commit 87d8e30

File tree

1 file changed

+312
-12
lines changed

1 file changed

+312
-12
lines changed

tools/taco.cpp

Lines changed: 312 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
#include "taco/util/env.h"
2828
#include "taco/util/collections.h"
2929
#include "taco/cuda.h"
30-
#include <taco/index_notation/transformations.h>
30+
#include "taco/index_notation/transformations.h"
31+
#include "taco/index_notation/index_notation_visitor.h"
32+
#include "taco/index_notation/index_notation_nodes.h"
3133

3234
using namespace std;
3335
using namespace taco;
@@ -112,6 +114,11 @@ static void printUsageInfo() {
112114
"long, longlong, float, double, complexfloat, complexdouble"
113115
"Examples: A:uint16, b:long and D:complexfloat.");
114116
cout << endl;
117+
printFlag("s=\"<command>(<params>)\"",
118+
"Specify a scheduling command to apply to the generated code. "
119+
"Parameters take the form of a comma-delimited list. "
120+
"Examples: split(i,i0,i1,16), precompute(A(i,j)*x(j),i,i).");
121+
cout << endl;
115122
printFlag("c",
116123
"Generate compute kernel that simultaneously does assembly.");
117124
cout << endl;
@@ -201,6 +208,261 @@ static void printCommandLine(ostream& os, int argc, char* argv[]) {
201208
}
202209
}
203210

211+
static bool setSchedulingCommands(istream& in, parser::Parser& parser, IndexStmt& stmt) {
212+
auto findVar = [&stmt](string name) {
213+
ProvenanceGraph graph(stmt);
214+
for (auto v : graph.getAllIndexVars()) {
215+
if (v.getName() == name) {
216+
return v;
217+
}
218+
}
219+
220+
throw "Index variable not defined in statement.";
221+
};
222+
223+
bool isGPU = false;
224+
225+
while (true) {
226+
string command;
227+
in >> command;
228+
229+
if (command == "pos") {
230+
string i, ipos;
231+
in >> i;
232+
in >> ipos;
233+
234+
string tensor;
235+
in >> tensor;
236+
237+
for (auto a : getArgumentAccesses(stmt)) {
238+
if (a.getTensorVar().getName() == tensor) {
239+
IndexVar derived(ipos);
240+
stmt = stmt.pos(findVar(i), derived, a);
241+
goto end;
242+
}
243+
}
244+
245+
} else if (command == "fuse") {
246+
string i, j, f;
247+
in >> i;
248+
in >> j;
249+
in >> f;
250+
251+
IndexVar fused(f);
252+
stmt = stmt.fuse(findVar(i), findVar(j), fused);
253+
254+
} else if (command == "split") {
255+
string i, i1, i2;
256+
in >> i;
257+
in >> i1;
258+
in >> i2;
259+
260+
size_t splitFactor;
261+
in >> splitFactor;
262+
263+
IndexVar split1(i1);
264+
IndexVar split2(i2);
265+
stmt = stmt.split(findVar(i), split1, split2, splitFactor);
266+
267+
// } else if (command == "divide") {
268+
// string i, i1, i2;
269+
// in >> i;
270+
// in >> i1;
271+
// in >> i2;
272+
273+
// size_t divideFactor;
274+
// in >> divideFactor;
275+
276+
// IndexVar divide1(i1);
277+
// IndexVar divide2(i2);
278+
// stmt = stmt.divide(findVar(i), divide1, divide2, divideFactor);
279+
280+
} else if (command == "precompute") {
281+
string exprStr, i, iw;
282+
in >> exprStr;
283+
in >> i;
284+
in >> iw;
285+
286+
IndexVar orig = findVar(i);
287+
IndexVar pre;
288+
try {
289+
pre = findVar(iw);
290+
} catch (const char* e) {
291+
pre = IndexVar(iw);
292+
}
293+
294+
struct GetExpr : public IndexNotationVisitor {
295+
using IndexNotationVisitor::visit;
296+
297+
string exprStr;
298+
IndexExpr expr;
299+
300+
void setExprStr(string input) {
301+
exprStr = input;
302+
exprStr.erase(remove(exprStr.begin(), exprStr.end(), ' '), exprStr.end());
303+
}
304+
305+
string toString(IndexExpr e) {
306+
stringstream tempStream;
307+
tempStream << e;
308+
string tempStr = tempStream.str();
309+
tempStr.erase(remove(tempStr.begin(), tempStr.end(), ' '), tempStr.end());
310+
return tempStr;
311+
}
312+
313+
void visit(const AccessNode* node) {
314+
IndexExpr currentExpr(node);
315+
if (toString(currentExpr) == exprStr) {
316+
expr = currentExpr;
317+
}
318+
else {
319+
IndexNotationVisitor::visit(node);
320+
}
321+
}
322+
323+
void visit(const UnaryExprNode* node) {
324+
IndexExpr currentExpr(node);
325+
if (toString(currentExpr) == exprStr) {
326+
expr = currentExpr;
327+
}
328+
else {
329+
IndexNotationVisitor::visit(node);
330+
}
331+
}
332+
333+
void visit(const BinaryExprNode* node) {
334+
IndexExpr currentExpr(node);
335+
if (toString(currentExpr) == exprStr) {
336+
expr = currentExpr;
337+
}
338+
else {
339+
IndexNotationVisitor::visit(node);
340+
}
341+
}
342+
};
343+
344+
GetExpr visitor;
345+
visitor.setExprStr(exprStr);
346+
stmt.accept(&visitor);
347+
348+
Dimension dim;
349+
auto domains = stmt.getIndexVarDomains();
350+
auto it = domains.find(orig);
351+
if (it != domains.end()) {
352+
dim = it->second;
353+
} else {
354+
dim = Dimension(orig);
355+
}
356+
357+
TensorVar workspace("workspace", Type(Float64, {dim}), Dense);
358+
stmt = stmt.precompute(visitor.expr, orig, pre, workspace);
359+
360+
} else if (command == "reorder") {
361+
string line;
362+
getline(in, line);
363+
stringstream temp;
364+
temp << line;
365+
366+
vector<IndexVar> reorderedVars;
367+
string var;
368+
while (temp >> var) {
369+
reorderedVars.push_back(findVar(var));
370+
}
371+
372+
stmt = stmt.reorder(reorderedVars);
373+
374+
} else if (command == "bound") {
375+
string i, i1;
376+
in >> i;
377+
in >> i1;
378+
379+
size_t bound;
380+
in >> bound;
381+
382+
string type;
383+
in >> type;
384+
385+
BoundType bound_type;
386+
if (type == "MinExact") {
387+
bound_type = BoundType::MinExact;
388+
} else if (type == "MinConstraint") {
389+
bound_type = BoundType::MinConstraint;
390+
} else if (type == "MaxExact") {
391+
bound_type = BoundType::MaxExact;
392+
} else if (type == "MaxConstraint") {
393+
bound_type = BoundType::MaxConstraint;
394+
} else {
395+
taco_uerror << "Bound type not defined.";
396+
goto end;
397+
}
398+
399+
IndexVar bound1(i1);
400+
stmt = stmt.bound(findVar(i), bound1, bound, bound_type);
401+
402+
} else if (command == "unroll") {
403+
string i;
404+
in >> i;
405+
406+
size_t unrollFactor;
407+
in >> unrollFactor;
408+
409+
stmt = stmt.unroll(findVar(i), unrollFactor);
410+
411+
} else if (command == "parallelize") {
412+
string i, unit, strategy;
413+
in >> i;
414+
in >> unit;
415+
in >> strategy;
416+
417+
ParallelUnit parallel_unit;
418+
if (unit == "NotParallel") {
419+
parallel_unit = ParallelUnit::NotParallel;
420+
} else if (unit == "GPUBlock") {
421+
parallel_unit = ParallelUnit::GPUBlock;
422+
isGPU = true;
423+
} else if (unit == "GPUWarp") {
424+
parallel_unit = ParallelUnit::GPUWarp;
425+
isGPU = true;
426+
} else if (unit == "GPUThread") {
427+
parallel_unit = ParallelUnit::GPUThread;
428+
isGPU = true;
429+
} else if (unit == "CPUThread") {
430+
parallel_unit = ParallelUnit::CPUThread;
431+
} else if (unit == "CPUVector") {
432+
parallel_unit = ParallelUnit::CPUVector;
433+
} else {
434+
taco_uerror << "Parallel hardware not defined.";
435+
goto end;
436+
}
437+
438+
OutputRaceStrategy output_race_strategy;
439+
if (strategy == "IgnoreRaces") {
440+
output_race_strategy = OutputRaceStrategy::IgnoreRaces;
441+
} else if (strategy == "NoRaces") {
442+
output_race_strategy = OutputRaceStrategy::NoRaces;
443+
} else if (strategy == "Atomics") {
444+
output_race_strategy = OutputRaceStrategy::Atomics;
445+
} else if (strategy == "Temporary") {
446+
output_race_strategy = OutputRaceStrategy::Temporary;
447+
} else if (strategy == "ParallelReduction") {
448+
output_race_strategy = OutputRaceStrategy::ParallelReduction;
449+
} else {
450+
taco_uerror << "Race strategy not defined.";
451+
goto end;
452+
}
453+
454+
stmt = stmt.parallelize(findVar(i), parallel_unit, output_race_strategy);
455+
456+
} else {
457+
break;
458+
}
459+
460+
end:;
461+
}
462+
463+
return isGPU;
464+
}
465+
204466
int main(int argc, char* argv[]) {
205467
if (argc < 2) {
206468
printUsageInfo();
@@ -228,6 +490,8 @@ int main(int argc, char* argv[]) {
228490
bool readKernels = false;
229491
bool cuda = false;
230492

493+
bool setSchedule = false;
494+
231495
ParallelSchedule sched = ParallelSchedule::Static;
232496
int chunkSize = 0;
233497
int nthreads = 0;
@@ -256,6 +520,8 @@ int main(int argc, char* argv[]) {
256520

257521
vector<string> kernelFilenames;
258522

523+
vector<string> scheduleCommands;
524+
259525
for (int i = 1; i < argc; i++) {
260526
string arg = argv[i];
261527
vector<string> argparts = util::split(arg, "=");
@@ -543,6 +809,30 @@ int main(int argc, char* argv[]) {
543809
else if ("-print-kernels" == argName) {
544810
printKernels = true;
545811
}
812+
else if ("-s" == argName) {
813+
setSchedule = true;
814+
bool insideCall = false;
815+
bool parsingExpr = false;
816+
817+
std::replace_if(argValue.begin(), argValue.end(), [&insideCall, &parsingExpr](char c) {
818+
if (c == '(') {
819+
if (insideCall) {
820+
parsingExpr = true; // need to handle precompute case specially
821+
} else {
822+
insideCall = true;
823+
return true;
824+
}
825+
} else if (c == ',') {
826+
return !parsingExpr;
827+
} else if (c == ')') {
828+
bool previous = parsingExpr;
829+
parsingExpr = false;
830+
return !previous;
831+
}
832+
return false;
833+
}, ' ');
834+
scheduleCommands.push_back(argValue);
835+
}
546836
else {
547837
if (exprStr.size() != 0) {
548838
printUsageInfo();
@@ -623,16 +913,6 @@ int main(int argc, char* argv[]) {
623913
}
624914
}
625915

626-
if (cuda) {
627-
if (!CUDA_BUILT && benchmark) {
628-
return reportError("TACO must be built for CUDA (cmake -DCUDA=ON ..) to benchmark", 2);
629-
}
630-
set_CUDA_codegen_enabled(true);
631-
}
632-
else {
633-
set_CUDA_codegen_enabled(false);
634-
}
635-
636916
ir::Stmt assemble;
637917
ir::Stmt compute;
638918
ir::Stmt evaluate;
@@ -645,6 +925,26 @@ int main(int argc, char* argv[]) {
645925
stmt = reorderLoopsTopologically(stmt);
646926
stmt = insertTemporaries(stmt);
647927
stmt = parallelizeOuterLoop(stmt);
928+
929+
if (setSchedule) {
930+
stringstream scheduleStream;
931+
for (string command : scheduleCommands) {
932+
scheduleStream << command << endl;
933+
}
934+
935+
cuda |= setSchedulingCommands(scheduleStream, parser, stmt);
936+
}
937+
938+
if (cuda) {
939+
if (!CUDA_BUILT && benchmark) {
940+
return reportError("TACO must be built for CUDA (cmake -DCUDA=ON ..) to benchmark", 2);
941+
}
942+
set_CUDA_codegen_enabled(true);
943+
}
944+
else {
945+
set_CUDA_codegen_enabled(false);
946+
}
947+
648948
stmt = scalarPromote(stmt);
649949
if (printConcrete) {
650950
cout << stmt << endl;
@@ -749,7 +1049,7 @@ int main(int argc, char* argv[]) {
7491049
" * For both, the `_COO_pos` arrays contain two elements, where the first is 0\n"
7501050
" * and the second is the number of nonzeros in the tensor.\n"
7511051
" */";
752-
1052+
7531053
vector<ir::Stmt> packs;
7541054
for (auto a : getArgumentAccesses(stmt)) {
7551055
TensorVar tensor = a.getTensorVar();

0 commit comments

Comments
 (0)