Skip to content

Commit fb25c1d

Browse files
committed
Fix names
1 parent cf4c57f commit fb25c1d

File tree

3 files changed

+30
-28
lines changed

3 files changed

+30
-28
lines changed

include/taco/index_notation/index_notation.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -778,15 +778,17 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
778778
IndexStmt assemble(TensorVar result, AssembleStrategy strategy,
779779
bool separately_schedulable = false) const;
780780

781-
/// The wsaccel primitive specifies the dimensions of a workspace that
782-
/// will be accelerated. Acc controls whether acceleration will be applied.
783-
/// If accels is empty it means all dimensions should be accelerated.
784-
/// Currently, it only supports one-dimension acceleration. Acceleration is used
785-
/// by default.
781+
/// The wsaccel primitive specifies the dimensions of a workspace that will be accelerated.
782+
/// Acceleration means adding compressed acceleration datastructures (bitmap, coordinate list) to a dense work space.
783+
/// shouldAccel controls whether acceleration will be applied.
784+
/// When shouldAccel is true, if accelIndexVars is empty, then all dimensions should be accelerated.
785+
/// When shouldAccel is true, if accelIndexVars is not empty, then dimensions in accelIndexVars will be accelerated.
786+
/// When shouldAccel is false, accelIndexVars is ignored.
787+
/// Currently, it only supports one-dimension acceleration. Acceleration is used by default.
786788
///
787789
/// Precondition:
788-
/// Workspace can be accessed by the IndexVars in the accels.
789-
IndexStmt wsaccel(TensorVar& ws, bool Acc = true,const std::vector<IndexVar>& accels={});
790+
/// Workspace can be accessed by the IndexVars in the accelIndexVars.
791+
IndexStmt wsaccel(TensorVar& ws, bool shouldAccel = true,const std::vector<IndexVar>& accelIndexVars ={});
790792

791793
/// Casts index statement to specified subtype.
792794
template <typename SubType>
@@ -1167,13 +1169,13 @@ class TensorVar : public util::Comparable<TensorVar> {
11671169
const Literal& getFill() const;
11681170

11691171
/// Gets the acceleration dimensions
1170-
const std::vector<IndexVar>& getAccels() const;
1172+
const std::vector<IndexVar>& getaccelIndexVars() const;
11711173

11721174
/// Gets the acceleration flag
1173-
bool getAcc() const;
1175+
bool getshouldAccel() const;
11741176

11751177
/// Set the acceleration dimensions
1176-
void setAccels(const std::vector<IndexVar>& accels, bool Acc);
1178+
void setaccelIndexVars(const std::vector<IndexVar>& accelIndexVars, bool shouldAccel);
11771179

11781180
/// Set the fill value of the tensor variable
11791181
void setFill(const Literal& fill);

src/index_notation/index_notation.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1972,7 +1972,7 @@ IndexStmt IndexStmt::pos(IndexVar i, IndexVar ipos, Access access) const {
19721972

19731973
// Replace all occurrences of i with ipos
19741974
transformed = Transformation(ForAllReplace({i}, {ipos})).apply(transformed, &reason);
1975-
if (!transformed.defined()) {
1975+
if (!transformed.defined()) {
19761976
taco_uerror << reason;
19771977
}
19781978

@@ -2048,9 +2048,9 @@ IndexStmt IndexStmt::assemble(TensorVar result, AssembleStrategy strategy,
20482048
return transformed;
20492049
}
20502050

2051-
IndexStmt IndexStmt::wsaccel(TensorVar& ws, bool Acc, const std::vector<IndexVar>& accels) {
2052-
if (accels.size() == 0) {
2053-
ws.setAccels(accels, Acc);
2051+
IndexStmt IndexStmt::wsaccel(TensorVar& ws, bool shouldAccel, const std::vector<IndexVar>& accelIndexVars) {
2052+
if (accelIndexVars.size() == 0) {
2053+
ws.setaccelIndexVars(accelIndexVars, shouldAccel);
20542054
return *this;
20552055
}
20562056
set<IndexVar> TempVars;
@@ -2065,12 +2065,12 @@ IndexStmt IndexStmt::wsaccel(TensorVar& ws, bool Acc, const std::vector<IndexVar
20652065
ctx->match(where->producer);
20662066
ctx->match(where->consumer);
20672067
}));
2068-
for (auto i : accels) {
2068+
for (auto i : accelIndexVars) {
20692069
if (TempVars.find(i) == TempVars.end()) {
20702070
taco_uerror << "No matching indexVars in the Accel";
20712071
}
20722072
}
2073-
ws.setAccels(accels, Acc);
2073+
ws.setaccelIndexVars(accelIndexVars, shouldAccel);
20742074
return *this;
20752075
}
20762076

@@ -2546,8 +2546,8 @@ struct TensorVar::Content {
25462546
Format format;
25472547
Schedule schedule;
25482548
Literal fill;
2549-
std::vector<IndexVar> accels;
2550-
bool Acc;
2549+
std::vector<IndexVar> accelIndexVars;
2550+
bool shouldAccel;
25512551
};
25522552

25532553
TensorVar::TensorVar() : content(nullptr) {
@@ -2580,8 +2580,8 @@ TensorVar::TensorVar(const int& id, const string& name, const Type& type, const
25802580
content->type = type;
25812581
content->format = format;
25822582
content->fill = fill.defined()? fill : Literal::zero(type.getDataType());
2583-
content->accels = std::vector<IndexVar> {};
2584-
content->Acc = true;
2583+
content->accelIndexVars = std::vector<IndexVar> {};
2584+
content->shouldAccel = true;
25852585
}
25862586

25872587
int TensorVar::getId() const {
@@ -2625,17 +2625,17 @@ const Literal& TensorVar::getFill() const {
26252625
return content->fill;
26262626
}
26272627

2628-
const std::vector<IndexVar>& TensorVar::getAccels() const {
2629-
return content->accels;
2628+
const std::vector<IndexVar>& TensorVar::getaccelIndexVars() const {
2629+
return content->accelIndexVars;
26302630
}
26312631

2632-
bool TensorVar::getAcc() const {
2633-
return content->Acc;
2632+
bool TensorVar::getshouldAccel() const {
2633+
return content->shouldAccel;
26342634
}
26352635

2636-
void TensorVar::setAccels(const std::vector<IndexVar>& accels, bool Acc) {
2637-
content->Acc = Acc;
2638-
content->accels = accels;
2636+
void TensorVar::setaccelIndexVars(const std::vector<IndexVar>& accelIndexVars, bool shouldAccel) {
2637+
content->shouldAccel = shouldAccel;
2638+
content->accelIndexVars = accelIndexVars;
26392639
}
26402640

26412641
void TensorVar::setFill(const Literal &fill) {

src/lower/lowerer_impl_imperative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2276,7 +2276,7 @@ std::pair<bool,bool> LowererImplImperative::canAccelerateDenseTemp(Where where)
22762276
TensorVar temporary = where.getTemporary();
22772277

22782278
// (0) Acceleration flag is true
2279-
if (!temporary.getAcc()) {
2279+
if (!temporary.getshouldAccel()) {
22802280
return std::make_pair(false, false);
22812281
}
22822282

0 commit comments

Comments
 (0)