Skip to content

Commit b29b033

Browse files
committed
Adds compound operator to assignments
1 parent 8ab0a12 commit b29b033

File tree

6 files changed

+34
-15
lines changed

6 files changed

+34
-15
lines changed

include/taco/index_notation/expr_nodes.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,16 @@ struct FloatImmNode : public ImmExprNode {
177177

178178
// Tensor Index Expressions
179179
struct AssignmentNode : public TensorExprNode {
180-
AssignmentNode(const Access& lhs, const IndexExpr& rhs) : lhs(lhs), rhs(rhs){}
180+
AssignmentNode(const Access& lhs, const IndexExpr& rhs, const IndexExpr& op)
181+
: lhs(lhs), rhs(rhs), op(op) {}
181182

182183
void accept(ExprVisitorStrict* v) const {
183184
v->visit(this);
184185
}
185186

186187
Access lhs;
187188
IndexExpr rhs;
189+
IndexExpr op;
188190
};
189191

190192
struct ForallNode : public TensorExprNode {

include/taco/index_notation/index_notation.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,14 @@ class Access : public IndexExpr {
178178
/// ```
179179
Assignment operator=(const IndexExpr&);
180180

181-
// Must override the default Access operator=, otherwise it is a copy.
181+
/// Must override the default Access operator=, otherwise it is a copy.
182182
Assignment operator=(const Access&);
183183

184184
/// Accumulate the result of an expression to a left-hand-side tensor access.
185185
/// ```
186186
/// a(i) += B(i,j) * c(j);
187187
/// ```
188-
void operator+=(const IndexExpr&);
188+
Assignment operator+=(const IndexExpr&);
189189

190190
private:
191191
const Node* getPtr() const;
@@ -224,7 +224,11 @@ std::ostream& operator<<(std::ostream&, const TensorExpr&);
224224
class Assignment : public TensorExpr {
225225
public:
226226
Assignment(const AssignmentNode*);
227-
Assignment(TensorVar tensor, std::vector<IndexVar> indices, IndexExpr expr);
227+
228+
/// Create an assignment. Can specify an optional operator `op` that turns the
229+
/// assignment into a compound assignment, e.g. `+=`.
230+
Assignment(TensorVar tensor, std::vector<IndexVar> indices, IndexExpr expr,
231+
IndexExpr op = IndexExpr());
228232

229233
Access getLhs() const;
230234
IndexExpr getRhs() const;

src/index_notation/expr_printer.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,20 @@ void ExprPrinter::visit(const ReductionNode* op) {
121121
}
122122

123123
void ExprPrinter::visit(const AssignmentNode* op) {
124+
struct OperatorName : ExprVisitor {
125+
std::string operatorName;
126+
std::string get(IndexExpr expr) {
127+
if (!expr.defined()) return "";
128+
expr.accept(this);
129+
return operatorName;
130+
}
131+
void visit(const BinaryExprNode* node) {
132+
operatorName = node->getOperatorString();
133+
}
134+
};
135+
124136
op->lhs.accept(this);
125-
os << " = ";
137+
os << " " << OperatorName().get(op->op) << "= ";
126138
op->rhs.accept(this);
127139
}
128140

src/index_notation/expr_rewriter.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,16 @@ void ExprRewriter::visit(const ReductionNode* op) {
9797
}
9898

9999
void ExprRewriter::visit(const AssignmentNode* op) {
100+
// A design decission is to not visit the rhs access expressions or the op,
101+
// as these are considered part of the assignment. When visiting access
102+
// expressions, therefore, we only visit read access expressions.
100103
IndexExpr rhs = rewrite(op->rhs);
101104
if (rhs == op->rhs) {
102105
texpr = op;
103106
}
104107
else {
105-
texpr = new AssignmentNode(op->lhs, rhs);
108+
texpr = new AssignmentNode(op->lhs, rhs, op->op);
106109
}
107-
108110
}
109111

110112

src/index_notation/index_notation.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,13 @@ Assignment Access::operator=(const Access& expr) {
242242
return operator=(static_cast<IndexExpr>(expr));
243243
}
244244

245-
void Access::operator+=(const IndexExpr& expr) {
245+
Assignment Access::operator+=(const IndexExpr& expr) {
246246
TensorVar result = getTensorVar();
247247
taco_uassert(!result.getIndexExpr().defined()) << "Cannot reassign " <<result;
248248
// TODO: check that result format is dense. For now only support accumulation
249249
/// into dense. If it's not dense, then we can insert an operator split.
250250
const_cast<AccessNode*>(getPtr())->setIndexExpression(expr, true);
251+
return Assignment(result, result.getFreeVars(), expr, new AddNode);
251252
}
252253

253254

@@ -286,8 +287,8 @@ Assignment::Assignment(const AssignmentNode* n) : TensorExpr(n) {
286287
}
287288

288289
Assignment::Assignment(TensorVar tensor, vector<IndexVar> indices,
289-
IndexExpr expr)
290-
: Assignment(new AssignmentNode(Access(tensor, indices), expr)) {
290+
IndexExpr expr, IndexExpr op)
291+
: Assignment(new AssignmentNode(Access(tensor, indices), expr, op)) {
291292
}
292293

293294
Access Assignment::getLhs() const {

test/concrete-notation-tests.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,13 @@ TEST(concrete, where) {
2323
// std::cout << vecmul << std::endl;
2424
}
2525

26-
TEST(DISABLED_concrete, spmm) {
26+
TEST(concrete, spmm) {
2727
Type t(type<double>(), {3,3});
2828
TensorVar A("A", t, Sparse), B("B", t, Sparse), C("C", t, Sparse);
29-
TensorVar w("w", Type(type<double>(),{3}), Dense);
30-
31-
auto spmm = forall(i,
29+
TensorVar w("w", Type(type<double>(),{3}), Dense); auto spmm = forall(i,
3230
forall(k,
3331
where(forall(j, A(i,j) = w(j)),
34-
forall(j, w(j) = B(i,k)*C(k,j))
32+
forall(j, w(j) += B(i,k)*C(k,j))
3533
)
3634
)
3735
);

0 commit comments

Comments
 (0)