Skip to content

Commit 4f4ac27

Browse files
committed
Adds SpMM test
1 parent b7aecb6 commit 4f4ac27

File tree

3 files changed

+52
-20
lines changed

3 files changed

+52
-20
lines changed

include/taco/lower/lowerer_impl.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,13 +266,13 @@ class LowererImpl : public util::Uncopyable {
266266
/**
267267
* Replace scalar tensor pointers with stack scalar for lowering.
268268
*/
269-
ir::Stmt declareScalarVariable(TensorVar var, bool zero);
269+
ir::Stmt defineScalarVariable(TensorVar var, bool zero);
270270

271271
/**
272272
* Creates code to declare temporaries.
273273
*/
274-
ir::Stmt declareTemporaries(std::vector<TensorVar> temporaries,
275-
std::map<TensorVar,ir::Expr> scalars);
274+
ir::Stmt defineTemporaries(std::vector<TensorVar> temporaries,
275+
std::map<TensorVar,ir::Expr> scalars);
276276

277277
ir::Stmt initResultArrays(IndexVar var, std::vector<Access> writes,
278278
std::vector<Access> reads,

src/lower/lowerer_impl.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble,
138138
vector<Stmt> headerStmts;
139139
vector<Stmt> footerStmts;
140140

141-
// Declare and initialize dimension variables
141+
// Define and initialize dimension variables
142142
vector<IndexVar> indexVars = getIndexVars(stmt);
143143
for (auto& indexVar : indexVars) {
144144
Expr dimension;
@@ -168,22 +168,22 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble,
168168
dimensions.insert({indexVar, dimension});
169169
}
170170

171-
// Declare and initialize scalar results and arguments
171+
// Define and initialize scalar results and arguments
172172
if (generateComputeCode()) {
173173
for (auto& result : results) {
174174
if (isScalar(result.getType())) {
175175
taco_iassert(!util::contains(scalars, result));
176176
taco_iassert(util::contains(tensorVars, result));
177177
scalars.insert({result, tensorVars.at(result)});
178-
headerStmts.push_back(declareScalarVariable(result, true));
178+
headerStmts.push_back(defineScalarVariable(result, true));
179179
}
180180
}
181181
for (auto& argument : arguments) {
182182
if (isScalar(argument.getType())) {
183183
taco_iassert(!util::contains(scalars, argument));
184184
taco_iassert(util::contains(tensorVars, argument));
185185
scalars.insert({argument, tensorVars.at(argument)});
186-
headerStmts.push_back(declareScalarVariable(argument, false));
186+
headerStmts.push_back(defineScalarVariable(argument, false));
187187
}
188188
}
189189
}
@@ -204,7 +204,7 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble,
204204
reducedAccesses);
205205

206206
// Declare and initialize non-scalar temporaries
207-
Stmt declTemporaries = declareTemporaries(temporaries, scalars);
207+
Stmt tempDefinitions = defineTemporaries(temporaries, scalars);
208208

209209
// Lower the index statement to compute and/or assemble
210210
Stmt body = lower(stmt);
@@ -231,7 +231,7 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble,
231231
Stmt footer = footerStmts.empty() ? Stmt() : Block::make(footerStmts);
232232
return Function::make(name, resultsIR, argumentsIR,
233233
Block::blanks(header,
234-
declTemporaries,
234+
tempDefinitions,
235235
initializeResults,
236236
body,
237237
finalizeResults,
@@ -661,7 +661,7 @@ Stmt LowererImpl::lowerWhere(Where where) {
661661
TensorVar temporary = where.getTemporary();
662662
Stmt declareTemporary;
663663
if (isScalar(temporary.getType())) {
664-
declareTemporary = declareScalarVariable(temporary, true);
664+
declareTemporary = defineScalarVariable(temporary, true);
665665
}
666666
else {
667667
taco_not_supported_yet;
@@ -1062,21 +1062,29 @@ ir::Stmt LowererImpl::finalizeResultArrays(std::vector<Access> writes) {
10621062
}
10631063

10641064

1065-
Stmt LowererImpl::declareTemporaries(vector<TensorVar> temporaries,
1066-
map<TensorVar, Expr> scalars) {
1065+
Stmt LowererImpl::defineTemporaries(vector<TensorVar> temporaries,
1066+
map<TensorVar, Expr> scalars) {
10671067
vector<Stmt> result;
10681068
if (generateComputeCode()) {
10691069
for (auto& temporary : temporaries) {
1070-
if (!isScalar(temporary.getType())) {
1071-
taco_not_supported_yet;
1070+
if (isScalar(temporary.getType())) {
1071+
// We will define scalar temporaries locally where they are initialized
1072+
continue;
10721073
}
1073-
// We will declare scalar temporaries locally when they are initialized
1074+
1075+
taco_not_supported_yet;
1076+
Expr temporaryPtr = Var::make(temporary.getName(),
1077+
temporary.getType().getDataType(),
1078+
true, true);
1079+
1080+
Stmt definition = VarDecl::make(temporaryPtr, 0);
1081+
result.push_back(definition);
10741082
}
10751083
}
10761084
return result.empty() ? Stmt() : Block::make(result);
10771085
}
10781086

1079-
Stmt LowererImpl::declareScalarVariable(TensorVar var, bool zero) {
1087+
Stmt LowererImpl::defineScalarVariable(TensorVar var, bool zero) {
10801088
Datatype type = var.getType().getDataType();
10811089
Expr varValueIR = Var::make(var.getName() + "_val", type, false, false);
10821090
Expr init = (zero) ? ir::Literal::zero(type)

test/tests-lower.cpp

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,26 @@ TEST_STMT(where_matrix_vector_mul,
726726
}
727727
)
728728

729+
TEST_STMT(DISABLED_where_spmm,
730+
forall(i,
731+
where(forall(j,
732+
A(i,j) = w(j)),
733+
forall(k,
734+
forall(j,
735+
w(j) += B(i,k) * C(k,j))))),
736+
Values(
737+
Formats({{A,Format({dense,dense})},
738+
{B,Format({dense,dense})}, {C,Format({dense,dense})}}),
739+
Formats({{A,Format({dense,sparse})},
740+
{B,Format({dense,sparse})}, {C,Format({dense,sparse})}})
741+
),
742+
{
743+
TestCase({{B, { {{0,1}, 2.0}, {{2,0}, 3.0}, {{2,2}, 4.0}} },
744+
{C, { {{0,0},10.0}, {{0,1}, 20.0}, {{2,1},30.0}} }},
745+
{{A, { {{2,0},30.0}, {{2,1},180.0} }}})
746+
}
747+
)
748+
729749

730750
// Test sequence statements
731751

@@ -782,10 +802,14 @@ TEST_STMT(matrix_transposed_input,
782802
A(i,j) = B(i,j) + C(j,i)
783803
)),
784804
Values(
785-
Formats({{A,Format({ dense, dense})}, {B,Format({ dense, dense})}, {C,Format({dense,dense})}}),
786-
Formats({{A,Format({ dense,sparse})}, {B,Format({ dense,sparse})}, {C,Format({dense,dense})}}),
787-
Formats({{A,Format({sparse, dense})}, {B,Format({sparse, dense})}, {C,Format({dense,dense})}}),
788-
Formats({{A,Format({sparse,sparse})}, {B,Format({sparse,sparse})}, {C,Format({dense,dense})}})
805+
Formats({{A,Format({ dense, dense})}, {B,Format({ dense, dense})},
806+
{C,Format({dense,dense})}}),
807+
Formats({{A,Format({ dense,sparse})}, {B,Format({ dense,sparse})},
808+
{C,Format({dense,dense})}}),
809+
Formats({{A,Format({sparse, dense})}, {B,Format({sparse, dense})},
810+
{C,Format({dense,dense})}}),
811+
Formats({{A,Format({sparse,sparse})}, {B,Format({sparse,sparse})},
812+
{C,Format({dense,dense})}})
789813
),
790814
{
791815
TestCase({{B, {{{0,0}, 42.0}, {{0,2}, 2.0}, {{1,3}, 3.0}, {{3,2}, 4.0}}},

0 commit comments

Comments
 (0)