@@ -138,7 +138,7 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble,
138138 vector<Stmt> headerStmts;
139139 vector<Stmt> footerStmts;
140140
141- // Declare and initialize dimension variables
141+ // Define and initialize dimension variables
142142 vector<IndexVar> indexVars = getIndexVars (stmt);
143143 for (auto & indexVar : indexVars) {
144144 Expr dimension;
@@ -168,22 +168,22 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble,
168168 dimensions.insert ({indexVar, dimension});
169169 }
170170
171- // Declare and initialize scalar results and arguments
171+ // Define and initialize scalar results and arguments
172172 if (generateComputeCode ()) {
173173 for (auto & result : results) {
174174 if (isScalar (result.getType ())) {
175175 taco_iassert (!util::contains (scalars, result));
176176 taco_iassert (util::contains (tensorVars, result));
177177 scalars.insert ({result, tensorVars.at (result)});
178- headerStmts.push_back (declareScalarVariable (result, true ));
178+ headerStmts.push_back (defineScalarVariable (result, true ));
179179 }
180180 }
181181 for (auto & argument : arguments) {
182182 if (isScalar (argument.getType ())) {
183183 taco_iassert (!util::contains (scalars, argument));
184184 taco_iassert (util::contains (tensorVars, argument));
185185 scalars.insert ({argument, tensorVars.at (argument)});
186- headerStmts.push_back (declareScalarVariable (argument, false ));
186+ headerStmts.push_back (defineScalarVariable (argument, false ));
187187 }
188188 }
189189 }
@@ -204,7 +204,7 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble,
204204 reducedAccesses);
205205
206206 // Declare and initialize non-scalar temporaries
207- Stmt declTemporaries = declareTemporaries (temporaries, scalars);
207+ Stmt tempDefinitions = defineTemporaries (temporaries, scalars);
208208
209209 // Lower the index statement to compute and/or assemble
210210 Stmt body = lower (stmt);
@@ -231,7 +231,7 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble,
231231 Stmt footer = footerStmts.empty () ? Stmt () : Block::make (footerStmts);
232232 return Function::make (name, resultsIR, argumentsIR,
233233 Block::blanks (header,
234- declTemporaries ,
234+ tempDefinitions ,
235235 initializeResults,
236236 body,
237237 finalizeResults,
@@ -661,7 +661,7 @@ Stmt LowererImpl::lowerWhere(Where where) {
661661 TensorVar temporary = where.getTemporary ();
662662 Stmt declareTemporary;
663663 if (isScalar (temporary.getType ())) {
664- declareTemporary = declareScalarVariable (temporary, true );
664+ declareTemporary = defineScalarVariable (temporary, true );
665665 }
666666 else {
667667 taco_not_supported_yet;
@@ -1062,21 +1062,29 @@ ir::Stmt LowererImpl::finalizeResultArrays(std::vector<Access> writes) {
10621062}
10631063
10641064
1065- Stmt LowererImpl::declareTemporaries (vector<TensorVar> temporaries,
1066- map<TensorVar, Expr> scalars) {
1065+ Stmt LowererImpl::defineTemporaries (vector<TensorVar> temporaries,
1066+ map<TensorVar, Expr> scalars) {
10671067 vector<Stmt> result;
10681068 if (generateComputeCode ()) {
10691069 for (auto & temporary : temporaries) {
1070- if (!isScalar (temporary.getType ())) {
1071- taco_not_supported_yet;
1070+ if (isScalar (temporary.getType ())) {
1071+ // We will define scalar temporaries locally where they are initialized
1072+ continue ;
10721073 }
1073- // We will declare scalar temporaries locally when they are initialized
1074+
1075+ taco_not_supported_yet;
1076+ Expr temporaryPtr = Var::make (temporary.getName (),
1077+ temporary.getType ().getDataType (),
1078+ true , true );
1079+
1080+ Stmt definition = VarDecl::make (temporaryPtr, 0 );
1081+ result.push_back (definition);
10741082 }
10751083 }
10761084 return result.empty () ? Stmt () : Block::make (result);
10771085}
10781086
1079- Stmt LowererImpl::declareScalarVariable (TensorVar var, bool zero) {
1087+ Stmt LowererImpl::defineScalarVariable (TensorVar var, bool zero) {
10801088 Datatype type = var.getType ().getDataType ();
10811089 Expr varValueIR = Var::make (var.getName () + " _val" , type, false , false );
10821090 Expr init = (zero) ? ir::Literal::zero (type)
0 commit comments