Skip to content

Commit 0bc58c6

Browse files
Merge pull request #369 from rohany/refactor-break
include,src: introduce a true break statement, rename current to continue
2 parents dafe2ba + f35573d commit 0bc58c6

File tree

11 files changed

+47
-15
lines changed

11 files changed

+47
-15
lines changed

include/taco/ir/ir.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ enum class IRNodeType {
6565
BlankLine,
6666
Print,
6767
GetProperty,
68-
Break,
69-
Sort
68+
Continue,
69+
Sort,
70+
Break
7071
};
7172

7273
enum class TensorProperty {
@@ -719,7 +720,14 @@ struct BlankLine : public StmtNode<BlankLine> {
719720
static const IRNodeType _type_info = IRNodeType::BlankLine;
720721
};
721722

722-
/** Breaks current loop */
723+
/** Continues past current iteration of current loop */
724+
struct Continue : public StmtNode<Continue> {
725+
static Stmt make();
726+
727+
static const IRNodeType _type_info = IRNodeType::Continue;
728+
};
729+
730+
/** Breaks out of the current loop */
723731
struct Break : public StmtNode<Break> {
724732
static Stmt make();
725733

include/taco/ir/ir_printer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,11 @@ class IRPrinter : public IRVisitorStrict {
6565
virtual void visit(const Free*);
6666
virtual void visit(const Comment*);
6767
virtual void visit(const BlankLine*);
68-
virtual void visit(const Break*);
68+
virtual void visit(const Continue*);
6969
virtual void visit(const Print*);
7070
virtual void visit(const GetProperty*);
7171
virtual void visit(const Sort*);
72+
virtual void visit(const Break*);
7273

7374
std::ostream &stream;
7475
int indent;

include/taco/ir/ir_rewriter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,11 @@ class IRRewriter : public IRVisitorStrict {
6565
virtual void visit(const Free* op);
6666
virtual void visit(const Comment* op);
6767
virtual void visit(const BlankLine* op);
68-
virtual void visit(const Break* op);
68+
virtual void visit(const Continue* op);
6969
virtual void visit(const Print* op);
7070
virtual void visit(const GetProperty* op);
7171
virtual void visit(const Sort *op);
72+
virtual void visit(const Break *op);
7273
};
7374

7475
}}

include/taco/ir/ir_visitor.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,11 @@ struct Allocate;
4545
struct Free;
4646
struct Comment;
4747
struct BlankLine;
48-
struct Break;
48+
struct Continue;
4949
struct Print;
5050
struct GetProperty;
5151
struct Sort;
52+
struct Break;
5253

5354
/// Extend this class to visit every node in the IR.
5455
class IRVisitorStrict {
@@ -96,10 +97,11 @@ class IRVisitorStrict {
9697
virtual void visit(const Free*) = 0;
9798
virtual void visit(const Comment*) = 0;
9899
virtual void visit(const BlankLine*) = 0;
99-
virtual void visit(const Break*) = 0;
100+
virtual void visit(const Continue*) = 0;
100101
virtual void visit(const Print*) = 0;
101102
virtual void visit(const GetProperty*) = 0;
102103
virtual void visit(const Sort*) = 0;
104+
virtual void visit(const Break*) = 0;
103105
};
104106

105107

@@ -150,10 +152,11 @@ class IRVisitor : public IRVisitorStrict {
150152
virtual void visit(const Free* op);
151153
virtual void visit(const Comment* op);
152154
virtual void visit(const BlankLine* op);
153-
virtual void visit(const Break* op);
155+
virtual void visit(const Continue* op);
154156
virtual void visit(const Print* op);
155157
virtual void visit(const GetProperty* op);
156158
virtual void visit(const Sort* op);
159+
virtual void visit(const Break* op);
157160
};
158161

159162
}}

src/codegen/codegen_cuda.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1142,7 +1142,7 @@ void CodeGen_CUDA::visit(const Sqrt* op) {
11421142
stream << ")";
11431143
}
11441144

1145-
void CodeGen_CUDA::visit(const Break*) {
1145+
void CodeGen_CUDA::visit(const Continue*) {
11461146
doIndent();
11471147
if(!isHostFunction && deviceFunctionLoopDepth == 0) {
11481148
// can't break out of kernel

src/codegen/codegen_cuda.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class CodeGen_CUDA : public CodeGen {
4646
void visit(const Call*);
4747
void visit(const Store*);
4848
void visit(const Assign*);
49-
void visit(const Break*);
49+
void visit(const Continue*);
5050
void visit(const Free* op);
5151
std::string printDeviceFuncName(const std::vector<std::pair<std::string, Expr>> currentParameters, int index);
5252
void printDeviceFuncCall(const std::vector<std::pair<std::string, Expr>> currentParameters, Expr blockSize, int index, Expr gridSize);

src/ir/ir.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,11 @@ Stmt BlankLine::make() {
786786
return new BlankLine;
787787
}
788788

789+
// Continue
790+
Stmt Continue::make() {
791+
return new Continue;
792+
}
793+
789794
// Break
790795
Stmt Break::make() {
791796
return new Break;
@@ -954,14 +959,16 @@ template<> void StmtNode<Comment>::accept(IRVisitorStrict *v)
954959
const { v->visit((const Comment*)this); }
955960
template<> void StmtNode<BlankLine>::accept(IRVisitorStrict *v)
956961
const { v->visit((const BlankLine*)this); }
957-
template<> void StmtNode<Break>::accept(IRVisitorStrict *v)
958-
const { v->visit((const Break*)this); }
962+
template<> void StmtNode<Continue>::accept(IRVisitorStrict *v)
963+
const { v->visit((const Continue*)this); }
959964
template<> void StmtNode<Print>::accept(IRVisitorStrict *v)
960965
const { v->visit((const Print*)this); }
961966
template<> void ExprNode<GetProperty>::accept(IRVisitorStrict *v)
962967
const { v->visit((const GetProperty*)this); }
963968
template<> void StmtNode<Sort>::accept(IRVisitorStrict *v)
964969
const { v->visit((const Sort*)this); }
970+
template<> void StmtNode<Break>::accept(IRVisitorStrict *v)
971+
const { v->visit((const Break*)this); }
965972

966973
// printing methods
967974
std::ostream& operator<<(std::ostream& os, const Stmt& stmt) {

src/ir/ir_printer.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,9 +558,14 @@ void IRPrinter::visit(const BlankLine*) {
558558
stream << endl;
559559
}
560560

561+
void IRPrinter::visit(const Continue*) {
562+
doIndent();
563+
stream << "continue;" << endl;
564+
}
565+
561566
void IRPrinter::visit(const Break*) {
562567
doIndent();
563-
stream << "continue;" << endl; // TODO: add continue statement
568+
stream << "break;" << endl;
564569
}
565570

566571
void IRPrinter::visit(const Print* op) {

src/ir/ir_rewriter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,10 @@ void IRRewriter::visit(const BlankLine* op) {
447447
stmt = op;
448448
}
449449

450+
void IRRewriter::visit(const Continue* op) {
451+
stmt = op;
452+
}
453+
450454
void IRRewriter::visit(const Break* op) {
451455
stmt = op;
452456
}

src/ir/ir_visitor.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,9 @@ void IRVisitor::visit(const Comment*) {
228228
void IRVisitor::visit(const BlankLine*) {
229229
}
230230

231+
void IRVisitor::visit(const Continue*) {
232+
}
233+
231234
void IRVisitor::visit(const Break*) {
232235
}
233236

0 commit comments

Comments
 (0)