Skip to content

Commit 9b8a96b

Browse files
committed
Add acceleration flag
1 parent b18d2a5 commit 9b8a96b

File tree

4 files changed

+67
-3
lines changed

4 files changed

+67
-3
lines changed

include/taco/index_notation/index_notation.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,16 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
778778
IndexStmt assemble(TensorVar result, AssembleStrategy strategy,
779779
bool separately_schedulable = false) const;
780780

781+
/// The wsaccel primitive specifies the dimensions of a workspace that
782+
/// will be accelerated. Acc controls whether acceleration will be applied.
783+
/// If accels is empty it means all dimensions should be accelerated.
784+
/// Currently, it only supports one-dimension acceleration. Acceleration is used
785+
/// by default.
786+
///
787+
/// Precondition:
788+
/// Workspace can be accessed by the IndexVars in the accels.
789+
IndexStmt wsaccel(TensorVar& ws, const std::vector<IndexVar>& accels, bool Acc = true);
790+
781791
/// Casts index statement to specified subtype.
782792
template <typename SubType>
783793
SubType as() {
@@ -1156,6 +1166,15 @@ class TensorVar : public util::Comparable<TensorVar> {
11561166
/// Gets the fill value of the tensor variable. May be left undefined.
11571167
const Literal& getFill() const;
11581168

1169+
/// Gets the acceleration dimensions
1170+
const std::vector<IndexVar>& getAccels() const;
1171+
1172+
/// Gets the acceleration flag
1173+
bool getAcc() const;
1174+
1175+
/// Set the acceleration dimensions
1176+
void setAccels(const std::vector<IndexVar>& accels, bool Acc);
1177+
11591178
/// Set the fill value of the tensor variable
11601179
void setFill(const Literal& fill);
11611180

src/index_notation/index_notation.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2048,6 +2048,30 @@ IndexStmt IndexStmt::assemble(TensorVar result, AssembleStrategy strategy,
20482048
return transformed;
20492049
}
20502050

2051+
IndexStmt IndexStmt::wsaccel(TensorVar& ws, const std::vector<IndexVar>& accels, bool Acc) {
2052+
if (accels.size() == 0) {
2053+
ws.setAccels(accels, Acc);
2054+
return *this;
2055+
}
2056+
set<IndexVar> TempVars;
2057+
match(*this,
2058+
std::function<void(const WhereNode*)>([&](const WhereNode* where) {
2059+
auto Temp = getResultAccesses(where->producer).first[0];
2060+
if (Temp.getTensorVar() == ws) {
2061+
for (auto i :getIndexVars()){
2062+
TempVars.insert(i);
2063+
}
2064+
}
2065+
}));
2066+
for (auto i : accels) {
2067+
if (TempVars.find(i) == TempVars.end()) {
2068+
taco_uerror << "No matching indexVars in the Accel";
2069+
}
2070+
}
2071+
ws.setAccels(accels, Acc);
2072+
return *this;
2073+
}
2074+
20512075
std::ostream& operator<<(std::ostream& os, const IndexStmt& expr) {
20522076
if (!expr.defined()) return os << "IndexStmt()";
20532077
IndexNotationPrinter printer(os);
@@ -2520,6 +2544,8 @@ struct TensorVar::Content {
25202544
Format format;
25212545
Schedule schedule;
25222546
Literal fill;
2547+
std::vector<IndexVar> accels;
2548+
bool Acc;
25232549
};
25242550

25252551
TensorVar::TensorVar() : content(nullptr) {
@@ -2552,6 +2578,8 @@ TensorVar::TensorVar(const int& id, const string& name, const Type& type, const
25522578
content->type = type;
25532579
content->format = format;
25542580
content->fill = fill.defined()? fill : Literal::zero(type.getDataType());
2581+
content->accels = std::vector<IndexVar> {};
2582+
content->Acc = true;
25552583
}
25562584

25572585
int TensorVar::getId() const {
@@ -2595,6 +2623,19 @@ const Literal& TensorVar::getFill() const {
25952623
return content->fill;
25962624
}
25972625

2626+
const std::vector<IndexVar>& TensorVar::getAccels() const {
2627+
return content->accels;
2628+
}
2629+
2630+
bool TensorVar::getAcc() const {
2631+
return content->Acc;
2632+
}
2633+
2634+
void TensorVar::setAccels(const std::vector<IndexVar>& accels, bool Acc) {
2635+
content->Acc = Acc;
2636+
content->accels = accels;
2637+
}
2638+
25982639
void TensorVar::setFill(const Literal &fill) {
25992640
content->fill = fill;
26002641
}

src/lower/lowerer_impl_imperative.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2274,6 +2274,12 @@ std::pair<bool,bool> LowererImplImperative::canAccelerateDenseTemp(Where where)
22742274
}
22752275

22762276
TensorVar temporary = where.getTemporary();
2277+
2278+
// (0) Acceleration flag is true
2279+
if (!temporary.getAcc()) {
2280+
return std::make_pair(false, false);
2281+
}
2282+
22772283
// (1) Temporary is dense vector
22782284
if(!isDense(temporary.getFormat()) || temporary.getOrder() != 1) {
22792285
return std::make_pair(false, false);
@@ -2302,9 +2308,6 @@ std::pair<bool,bool> LowererImplImperative::canAccelerateDenseTemp(Where where)
23022308
return resultVar == tempVar[0] ||
23032309
provGraph.isDerivedFrom(tempVar[0], resultVar);
23042310
});
2305-
if (resultVars.size() == 0){
2306-
return std::make_pair(false, false);
2307-
}
23082311
if (it == resultVars.end()) {
23092312
return std::make_pair(true, false);
23102313
}

test/tests-workspaces.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ TEST(workspaces, tile_dotProduct_2) {
546546

547547
stmt = stmt.concretize();
548548

549+
stmt = stmt.wsaccel(precomputed, {}, false);
549550
A.compile(stmt);
550551
A.assemble();
551552
A.compute();

0 commit comments

Comments
 (0)