Skip to content

Commit 4c24193

Browse files
Merge pull request #491 from tensor-compiler/ell_fix
Fixed support for ELL format
2 parents d0654a8 + b5c2a5e commit 4c24193

17 files changed

+165
-38
lines changed

include/taco/format.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ class ModeFormat {
9797
/// Properties of a mode format
9898
enum Property {
9999
FULL, NOT_FULL, ORDERED, NOT_ORDERED, UNIQUE, NOT_UNIQUE, BRANCHLESS,
100-
NOT_BRANCHLESS, COMPACT, NOT_COMPACT, ZEROLESS, NOT_ZEROLESS
100+
NOT_BRANCHLESS, COMPACT, NOT_COMPACT, ZEROLESS, NOT_ZEROLESS, PADDED,
101+
NOT_PADDED
101102
};
102103

103104
/// Instantiates an undefined mode format
@@ -129,6 +130,7 @@ class ModeFormat {
129130
bool isBranchless() const;
130131
bool isCompact() const;
131132
bool isZeroless() const;
133+
bool isPadded() const;
132134

133135
/// Returns true if a mode format has a specific capability, false otherwise
134136
bool hasCoordValIter() const;

include/taco/lower/mode_format_impl.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,10 @@ class ModeFormatImpl {
106106
public:
107107
ModeFormatImpl(std::string name, bool isFull, bool isOrdered, bool isUnique,
108108
bool isBranchless, bool isCompact, bool isZeroless,
109-
bool hasCoordValIter, bool hasCoordPosIter, bool hasLocate,
110-
bool hasInsert, bool hasAppend, bool hasSeqInsertEdge,
111-
bool hasInsertCoord, bool isYieldPosPure);
109+
bool isPadded, bool hasCoordValIter, bool hasCoordPosIter,
110+
bool hasLocate, bool hasInsert, bool hasAppend,
111+
bool hasSeqInsertEdge, bool hasInsertCoord,
112+
bool isYieldPosPure);
112113

113114
virtual ~ModeFormatImpl();
114115

@@ -246,6 +247,7 @@ class ModeFormatImpl {
246247
const bool isBranchless;
247248
const bool isCompact;
248249
const bool isZeroless;
250+
const bool isPadded;
249251

250252
const bool hasCoordValIter;
251253
const bool hasCoordPosIter;

include/taco/lower/mode_format_singleton.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ class SingletonModeFormat : public ModeFormatImpl {
1010
using ModeFormatImpl::getInsertCoord;
1111

1212
SingletonModeFormat();
13-
SingletonModeFormat(bool isFull, bool isOrdered,
14-
bool isUnique, bool isZeroless, long long allocSize = DEFAULT_ALLOC_SIZE);
13+
SingletonModeFormat(bool isFull, bool isOrdered, bool isUnique,
14+
bool isZeroless, bool isPadded,
15+
long long allocSize = DEFAULT_ALLOC_SIZE);
1516

1617
~SingletonModeFormat() override {}
1718

src/format.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,11 @@ bool ModeFormat::hasProperties(const std::vector<Property>& properties) const {
187187
return false;
188188
}
189189
break;
190+
case PADDED:
191+
if (!isPadded()) {
192+
return false;
193+
}
194+
break;
190195
case NOT_FULL:
191196
if (isFull()) {
192197
return false;
@@ -217,6 +222,11 @@ bool ModeFormat::hasProperties(const std::vector<Property>& properties) const {
217222
return false;
218223
}
219224
break;
225+
case NOT_PADDED:
226+
if (isPadded()) {
227+
return false;
228+
}
229+
break;
220230
}
221231
}
222232
return true;
@@ -252,6 +262,11 @@ bool ModeFormat::isZeroless() const {
252262
return impl->isZeroless;
253263
}
254264

265+
bool ModeFormat::isPadded() const {
266+
taco_iassert(defined());
267+
return impl->isPadded;
268+
}
269+
255270
bool ModeFormat::hasCoordValIter() const {
256271
taco_iassert(defined());
257272
return impl->hasCoordValIter;

src/index_notation/index_notation.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4155,6 +4155,22 @@ IndexStmt generatePackStmt(TensorVar tensor,
41554155
packStmt = forall(indexVars[mode], packStmt);
41564156
}
41574157

4158+
bool doAppend = true;
4159+
const Format lhsFormat = otherIsOnRight ? format : otherFormat;
4160+
for (int i = lhsFormat.getOrder() - 1; i >= 0; --i) {
4161+
const auto modeFormat = lhsFormat.getModeFormats()[i];
4162+
if (modeFormat.isBranchless() && i != 0) {
4163+
const auto parentModeFormat = lhsFormat.getModeFormats()[i - 1];
4164+
if (parentModeFormat.isUnique() || !parentModeFormat.hasAppend()) {
4165+
doAppend = false;
4166+
break;
4167+
}
4168+
}
4169+
}
4170+
if (!doAppend) {
4171+
packStmt = packStmt.assemble(otherIsOnRight ? tensor : other, AssembleStrategy::Insert);
4172+
}
4173+
41584174
return packStmt;
41594175
}
41604176

src/ir/ir_rewriter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ void IRRewriter::visit(const Allocate* op) {
435435
stmt = op;
436436
}
437437
else {
438-
stmt = Allocate::make(var, num_elements, op->is_realloc, op->old_elements);
438+
stmt = Allocate::make(var, num_elements, op->is_realloc, op->old_elements, op->clear);
439439
}
440440
}
441441

src/lower/lowerer_impl_imperative.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,24 +1391,28 @@ Stmt LowererImplImperative::lowerForallPosition(Forall forall, Iterator iterator
13911391
endBound = endBounds[1];
13921392
}
13931393

1394-
LoopKind kind = LoopKind::Serial;
1395-
if (forall.getParallelUnit() == ParallelUnit::CPUVector && !ignoreVectorize) {
1396-
kind = LoopKind::Vectorized;
1397-
}
1398-
else if (forall.getParallelUnit() != ParallelUnit::NotParallel
1399-
&& forall.getOutputRaceStrategy() != OutputRaceStrategy::ParallelReduction && !ignoreVectorize) {
1400-
kind = LoopKind::Runtime;
1394+
Stmt loop = Block::make(strideGuard, declareCoordinate, boundsGuard, body);
1395+
if (iterator.isBranchless() && iterator.isCompact() &&
1396+
(iterator.getParent().isRoot() || iterator.getParent().isUnique())) {
1397+
loop = Block::make(VarDecl::make(iterator.getPosVar(), startBound), loop);
1398+
} else {
1399+
LoopKind kind = LoopKind::Serial;
1400+
if (forall.getParallelUnit() == ParallelUnit::CPUVector && !ignoreVectorize) {
1401+
kind = LoopKind::Vectorized;
1402+
}
1403+
else if (forall.getParallelUnit() != ParallelUnit::NotParallel &&
1404+
forall.getOutputRaceStrategy() != OutputRaceStrategy::ParallelReduction &&
1405+
!ignoreVectorize) {
1406+
kind = LoopKind::Runtime;
1407+
}
1408+
1409+
loop = For::make(iterator.getPosVar(), startBound, endBound, 1, loop, kind,
1410+
ignoreVectorize ? ParallelUnit::NotParallel : forall.getParallelUnit(),
1411+
ignoreVectorize ? 0 : forall.getUnrollFactor());
14011412
}
14021413

14031414
// Loop with preamble and postamble
1404-
return Block::blanks(
1405-
boundsCompute,
1406-
For::make(iterator.getPosVar(), startBound, endBound, 1,
1407-
Block::make(strideGuard, declareCoordinate, boundsGuard, body),
1408-
kind,
1409-
ignoreVectorize ? ParallelUnit::NotParallel : forall.getParallelUnit(), ignoreVectorize ? 0 : forall.getUnrollFactor()),
1410-
posAppend);
1411-
1415+
return Block::blanks(boundsCompute, loop, posAppend);
14121416
}
14131417

14141418
Stmt LowererImplImperative::lowerForallFusedPosition(Forall forall, Iterator iterator,

src/lower/mode_format_compressed.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ CompressedModeFormat::CompressedModeFormat(bool isFull, bool isOrdered,
1717
bool isUnique, bool isZeroless,
1818
long long allocSize) :
1919
ModeFormatImpl("compressed", isFull, isOrdered, isUnique, false, true,
20-
isZeroless, false, true, false, false, true, true, true,
21-
false),
20+
isZeroless, false, false, true, false, false, true, true,
21+
true, false),
2222
allocSize(allocSize) {
2323
}
2424

src/lower/mode_format_dense.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ DenseModeFormat::DenseModeFormat() : DenseModeFormat(true, true, false) {
1111
DenseModeFormat::DenseModeFormat(const bool isOrdered, const bool isUnique,
1212
const bool isZeroless) :
1313
ModeFormatImpl("dense", true, isOrdered, isUnique, false, true, isZeroless,
14-
false, false, true, true, false, false, false, true) {
14+
true, false, false, true, true, false, false, false, true) {
1515
}
1616

1717
ModeFormat DenseModeFormat::copy(

src/lower/mode_format_impl.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,16 @@ std::ostream& operator<<(std::ostream& os, const ModeFunction& modeFunction) {
147147
// class ModeTypeImpl
148148
ModeFormatImpl::ModeFormatImpl(const std::string name, bool isFull,
149149
bool isOrdered, bool isUnique, bool isBranchless,
150-
bool isCompact, bool isZeroless,
150+
bool isCompact, bool isZeroless, bool isPadded,
151151
bool hasCoordValIter, bool hasCoordPosIter,
152152
bool hasLocate, bool hasInsert, bool hasAppend,
153153
bool hasSeqInsertEdge, bool hasInsertCoord,
154154
bool isYieldPosPure) :
155155
name(name), isFull(isFull), isOrdered(isOrdered), isUnique(isUnique),
156156
isBranchless(isBranchless), isCompact(isCompact), isZeroless(isZeroless),
157-
hasCoordValIter(hasCoordValIter), hasCoordPosIter(hasCoordPosIter),
158-
hasLocate(hasLocate), hasInsert(hasInsert), hasAppend(hasAppend),
157+
isPadded(isPadded), hasCoordValIter(hasCoordValIter),
158+
hasCoordPosIter(hasCoordPosIter), hasLocate(hasLocate),
159+
hasInsert(hasInsert), hasAppend(hasAppend),
159160
hasSeqInsertEdge(hasSeqInsertEdge), hasInsertCoord(hasInsertCoord),
160161
isYieldPosPure(isYieldPosPure) {
161162
}

0 commit comments

Comments
 (0)