Skip to content

Commit cf45095

Browse files
committed
Fixed support for ELL format
1 parent 0ede002 commit cf45095

File tree

8 files changed

+111
-17
lines changed

8 files changed

+111
-17
lines changed

src/index_notation/index_notation.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3538,6 +3538,22 @@ IndexStmt generatePackStmt(TensorVar tensor,
35383538
packStmt = forall(indexVars[mode], packStmt);
35393539
}
35403540

3541+
bool doAppend = true;
3542+
const Format lhsFormat = otherIsOnRight ? format : otherFormat;
3543+
for (int i = lhsFormat.getOrder() - 1; i >= 0; --i) {
3544+
const auto modeFormat = lhsFormat.getModeFormats()[i];
3545+
if (modeFormat.isBranchless() && i != 0) {
3546+
const auto parentModeFormat = lhsFormat.getModeFormats()[i - 1];
3547+
if (parentModeFormat.isUnique() || !parentModeFormat.hasAppend()) {
3548+
doAppend = false;
3549+
break;
3550+
}
3551+
}
3552+
}
3553+
if (!doAppend) {
3554+
packStmt = packStmt.assemble(otherIsOnRight ? tensor : other, AssembleStrategy::Insert);
3555+
}
3556+
35413557
return packStmt;
35423558
}
35433559

src/ir/ir_rewriter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ void IRRewriter::visit(const Allocate* op) {
425425
stmt = op;
426426
}
427427
else {
428-
stmt = Allocate::make(var, num_elements, op->is_realloc, op->old_elements);
428+
stmt = Allocate::make(var, num_elements, op->is_realloc, op->old_elements, op->clear);
429429
}
430430
}
431431

src/lower/lowerer_impl_imperative.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,24 +1335,28 @@ Stmt LowererImplImperative::lowerForallPosition(Forall forall, Iterator iterator
13351335
endBound = endBounds[1];
13361336
}
13371337

1338-
LoopKind kind = LoopKind::Serial;
1339-
if (forall.getParallelUnit() == ParallelUnit::CPUVector && !ignoreVectorize) {
1340-
kind = LoopKind::Vectorized;
1341-
}
1342-
else if (forall.getParallelUnit() != ParallelUnit::NotParallel
1343-
&& forall.getOutputRaceStrategy() != OutputRaceStrategy::ParallelReduction && !ignoreVectorize) {
1344-
kind = LoopKind::Runtime;
1338+
Stmt loop = Block::make(strideGuard, declareCoordinate, boundsGuard, body);
1339+
if (iterator.isBranchless() && iterator.isCompact() &&
1340+
(iterator.getParent().isRoot() || iterator.getParent().isUnique())) {
1341+
loop = Block::make(VarDecl::make(iterator.getPosVar(), startBound), loop);
1342+
} else {
1343+
LoopKind kind = LoopKind::Serial;
1344+
if (forall.getParallelUnit() == ParallelUnit::CPUVector && !ignoreVectorize) {
1345+
kind = LoopKind::Vectorized;
1346+
}
1347+
else if (forall.getParallelUnit() != ParallelUnit::NotParallel &&
1348+
forall.getOutputRaceStrategy() != OutputRaceStrategy::ParallelReduction &&
1349+
!ignoreVectorize) {
1350+
kind = LoopKind::Runtime;
1351+
}
1352+
1353+
loop = For::make(iterator.getPosVar(), startBound, endBound, 1, loop, kind,
1354+
ignoreVectorize ? ParallelUnit::NotParallel : forall.getParallelUnit(),
1355+
ignoreVectorize ? 0 : forall.getUnrollFactor());
13451356
}
13461357

13471358
// Loop with preamble and postamble
1348-
return Block::blanks(
1349-
boundsCompute,
1350-
For::make(iterator.getPosVar(), startBound, endBound, 1,
1351-
Block::make(strideGuard, declareCoordinate, boundsGuard, body),
1352-
kind,
1353-
ignoreVectorize ? ParallelUnit::NotParallel : forall.getParallelUnit(), ignoreVectorize ? 0 : forall.getUnrollFactor()),
1354-
posAppend);
1355-
1359+
return Block::blanks(boundsCompute, loop, posAppend);
13561360
}
13571361

13581362
Stmt LowererImplImperative::lowerForallFusedPosition(Forall forall, Iterator iterator,

src/lower/mode_format_singleton.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ Expr SingletonModeFormat::getAssembledSize(Expr prevSize, Mode mode) const {
128128
Stmt SingletonModeFormat::getInitCoords(Expr prevSize,
129129
std::vector<AttrQueryResult> queries, Mode mode) const {
130130
Expr crdArray = getCoordArray(mode.getModePack());
131-
return Allocate::make(crdArray, prevSize, false, Expr());
131+
return Allocate::make(crdArray, prevSize, false, Expr(), true);
132132
}
133133

134134
ModeFunction SingletonModeFormat::getYieldPos(Expr parentPos,

src/tensor.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,7 @@ TensorBase::getHelperFunctions(const Format& format, Datatype ctype,
941941
TensorVar packedTensor(Type(ctype, Shape(dims)), format);
942942

943943
// Define packing and iterator routines in index notation.
944+
// TODO: Use `generatePackCOOStmt` function to generate pack routine.
944945
std::vector<IndexVar> indexVars(format.getOrder());
945946
IndexStmt packStmt = (packedTensor(indexVars) = bufferTensor(indexVars));
946947
IndexStmt iterateStmt = Yield(indexVars, packedTensor(indexVars));
@@ -950,6 +951,21 @@ TensorBase::getHelperFunctions(const Format& format, Datatype ctype,
950951
iterateStmt = forall(indexVars[mode], iterateStmt);
951952
}
952953

954+
bool doAppend = true;
955+
for (int i = format.getOrder() - 1; i >= 0; --i) {
956+
const auto modeFormat = format.getModeFormats()[i];
957+
if (modeFormat.isBranchless() && i != 0) {
958+
const auto parentModeFormat = format.getModeFormats()[i - 1];
959+
if (parentModeFormat.isUnique() || !parentModeFormat.hasAppend()) {
960+
doAppend = false;
961+
break;
962+
}
963+
}
964+
}
965+
if (!doAppend) {
966+
packStmt = packStmt.assemble(packedTensor, AssembleStrategy::Insert);
967+
}
968+
953969
// Lower packing and iterator code.
954970
helperModule->addFunction(lower(packStmt, "pack", true, true));
955971
helperModule->addFunction(lower(iterateStmt, "iterate", false, true));

test/test_tensors.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,16 @@ TensorData<double> d5d_data() {
137137
});
138138
}
139139

140+
TensorData<double> d5e_data() {
141+
return TensorData<double>({5}, {
142+
{{0}, 1},
143+
{{1}, 2},
144+
{{2}, 3},
145+
{{3}, 4},
146+
{{4}, 5}
147+
});
148+
}
149+
140150
TensorData<double> d8a_data() {
141151
return TensorData<double>({8}, {
142152
{{0}, 1},
@@ -328,6 +338,23 @@ TensorData<double> d333a_data() {
328338
});
329339
}
330340

341+
TensorData<double> d355a_data() {
342+
return TensorData<double>({3,5,5}, {
343+
{{0,0,0}, 1},
344+
{{0,1,1}, 2},
345+
{{0,2,1}, 3},
346+
{{0,3,1}, 4},
347+
{{0,4,1}, 5},
348+
{{1,0,1}, 6},
349+
{{1,1,0}, 7},
350+
{{1,2,0}, 8},
351+
{{1,4,2}, 9},
352+
{{2,1,2}, 10},
353+
{{2,2,3}, 11},
354+
{{2,4,4}, 12},
355+
});
356+
}
357+
331358
TensorData<double> d32b_data() {
332359
return TensorData<double>({3,2}, {
333360
{{0,0}, 10},
@@ -406,6 +433,10 @@ Tensor<double> d5d(std::string name, Format format) {
406433
return d5d_data().makeTensor(name, format);
407434
}
408435

436+
Tensor<double> d5e(std::string name, Format format) {
437+
return d5e_data().makeTensor(name, format);
438+
}
439+
409440
Tensor<double> d8a(std::string name, Format format) {
410441
return d8a_data().makeTensor(name, format);
411442
}
@@ -486,6 +517,10 @@ Tensor<double> d333a(std::string name, Format format) {
486517
return d333a_data().makeTensor(name, format);
487518
}
488519

520+
Tensor<double> d355a(std::string name, Format format) {
521+
return d355a_data().makeTensor(name, format);
522+
}
523+
489524
Tensor<double> d32b(std::string name, Format format) {
490525
return d32b_data().makeTensor(name, format);
491526
}

test/test_tensors.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ TensorData<double> d5a_data();
101101
TensorData<double> d5b_data();
102102
TensorData<double> d5c_data();
103103
TensorData<double> d5d_data();
104+
TensorData<double> d5e_data();
104105

105106
TensorData<double> d8a_data();
106107
TensorData<double> d8b_data();
@@ -127,6 +128,8 @@ TensorData<double> d233c_data();
127128

128129
TensorData<double> d333a_data();
129130

131+
TensorData<double> d355a_data();
132+
130133
TensorData<double> d32b_data();
131134
TensorData<double> d3322a_data();
132135

@@ -146,6 +149,7 @@ Tensor<double> d5a(std::string name, Format format);
146149
Tensor<double> d5b(std::string name, Format format);
147150
Tensor<double> d5c(std::string name, Format format);
148151
Tensor<double> d5d(std::string name, Format format);
152+
Tensor<double> d5e(std::string name, Format format);
149153

150154
Tensor<double> d8a(std::string name, Format format);
151155
Tensor<double> d8b(std::string name, Format format);
@@ -175,6 +179,8 @@ Tensor<double> d233c(std::string name, Format format);
175179

176180
Tensor<double> d333a(std::string name, Format format);
177181

182+
Tensor<double> d355a(std::string name, Format format);
183+
178184
Tensor<double> d32b(std::string name, Format format);
179185
Tensor<double> d3322a(std::string name, Format format);
180186

test/tests-expr_storage.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,23 @@ INSTANTIATE_TEST_CASE_P(bspmv, expr,
957957
)
958958
);
959959

960+
INSTANTIATE_TEST_CASE_P(espmv, expr,
961+
Values(
962+
TestData(Tensor<double>("a",{5},Format({Dense})),
963+
{i},
964+
d355a("B",Format({Dense, Dense, Singleton(ModeFormat::UNIQUE)}))(j,i,k) *
965+
d5e("c",Format({Dense}))(k),
966+
{
967+
{
968+
// Dense index
969+
{5}
970+
},
971+
},
972+
{13,41,58,8,97}
973+
)
974+
)
975+
);
976+
960977
INSTANTIATE_TEST_CASE_P(matrix_sum, expr,
961978
Values(
962979
TestData(Tensor<double>("a",{},Format()),

0 commit comments

Comments
 (0)