Skip to content

Commit bec4ff7

Browse files
committed
[astgen] Use ASTNode to implement BraceStmt correctly.
1 parent 6b7123e commit bec4ff7

File tree

5 files changed

+31
-16
lines changed

5 files changed

+31
-16
lines changed

include/swift/AST/CASTBridging.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,13 @@ void *SwiftVarDecl_create(void *ctx, BridgedIdentifier _Nullable name,
128128
void *IfStmt_create(void *ctx, void *ifLoc, void *cond, void *_Nullable then, void *_Nullable elseLoc,
129129
void *_Nullable elseStmt);
130130

131-
void *BraceStmt_createExpr(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc);
132-
void *BraceStmt_createStmt(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc);
131+
struct ASTNodeBridged {
132+
void *ptr;
133+
_Bool isExpr; // Must be expr or stmt.
134+
};
135+
136+
void *BraceStmt_create(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc);
137+
void *BraceStmt_create(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc);
133138

134139
void *BridgedSourceLoc_advanced(void *loc, long len);
135140

lib/AST/CASTBridging.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -160,17 +160,16 @@ void *IfStmt_create(void *ctx, void *ifLoc, void *cond, void *_Nullable then, vo
160160
getSourceLocFromPointer(elseLoc), (Stmt *)elseStmt, None, Context);
161161
}
162162

163-
void *BraceStmt_createExpr(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc) {
164-
ASTContext &Context = *static_cast<ASTContext *>(ctx);
165-
return BraceStmt::create(Context, getSourceLocFromPointer(lbloc),
166-
getArrayRef<ASTNode>(elements),
167-
getSourceLocFromPointer(rbloc));
168-
}
169-
170-
void *BraceStmt_createStmt(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc) {
163+
void *BraceStmt_create(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc) {
171164
llvm::SmallVector<ASTNode, 6> nodes;
172-
for (auto stmt : getArrayRef<Stmt *>(elements)) {
173-
nodes.push_back(stmt);
165+
for (auto node : getArrayRef<ASTNodeBridged>(elements)) {
166+
if (node.isExpr) {
167+
auto expr = (Expr *)node.ptr;
168+
nodes.push_back(expr);
169+
} else {
170+
auto stmt = (Stmt *)node.ptr;
171+
nodes.push_back(stmt);
172+
}
174173
}
175174

176175
ASTContext &Context = *static_cast<ASTContext *>(ctx);

lib/ASTGen/Sources/ASTGen/ASTGen.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ enum ASTNode {
3535
return ptr
3636
}
3737
}
38+
39+
func bridged() -> ASTNodeBridged {
40+
switch self {
41+
case .expr(let e):
42+
return ASTNodeBridged(ptr: e, isExpr: true)
43+
case .stmt(let s):
44+
return ASTNodeBridged(ptr: s, isExpr: false)
45+
default:
46+
fatalError("Must be expr or stmt.")
47+
}
48+
}
3849
}
3950

4051
/// Little utility wrapper that lets us have some mutable state within

lib/ASTGen/Sources/ASTGen/Exprs.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ import CASTBridging
55

66
extension ASTGenVisitor {
77
public func visit(_ node: ClosureExprSyntax) -> ASTNode {
8-
let statements = node.statements.map(self.visit)
8+
let statements = node.statements.map(self.visit).map { $0.bridged() }
99
let loc = self.base.advanced(by: node.position.utf8Offset).raw
1010

1111
let body = statements.withBridgedArrayRef { ref in
12-
BraceStmt_createExpr(ctx, loc, ref, loc)
12+
BraceStmt_create(ctx, loc, ref, loc)
1313
}
1414

1515
return .expr(ClosureExpr_create(ctx, body, declContext))

lib/ASTGen/Sources/ASTGen/Stmts.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ import CASTBridging
55

66
extension ASTGenVisitor {
77
public func visit(_ node: CodeBlockSyntax) -> ASTNode {
8-
let statements = node.statements.map(self.visit)
8+
let statements = node.statements.map(self.visit).map { $0.bridged() }
99
let loc = self.base.advanced(by: node.position.utf8Offset).raw
1010

1111
return .stmt(statements.withBridgedArrayRef { ref in
12-
BraceStmt_createStmt(ctx, loc, ref, loc)
12+
BraceStmt_create(ctx, loc, ref, loc)
1313
})
1414
}
1515

0 commit comments

Comments
 (0)