Skip to content

Commit c4abc33

Browse files
authored
Merge pull request #61807 from zoecarver/astgen-astnode
[astgen] Introduce ASTNode; update ResultType to be ASTNode; use ASTNode to implement BraceStmt correctly.
2 parents fb7b09e + bec4ff7 commit c4abc33

File tree

11 files changed

+184
-142
lines changed

11 files changed

+184
-142
lines changed

include/swift/AST/CASTBridging.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,13 @@ void *SwiftVarDecl_create(void *ctx, BridgedIdentifier _Nullable name,
128128
void *IfStmt_create(void *ctx, void *ifLoc, void *cond, void *_Nullable then, void *_Nullable elseLoc,
129129
void *_Nullable elseStmt);
130130

131-
void *BraceStmt_createExpr(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc);
132-
void *BraceStmt_createStmt(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc);
131+
struct ASTNodeBridged {
132+
void *ptr;
133+
_Bool isExpr; // Must be expr or stmt.
134+
};
135+
136+
void *BraceStmt_create(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc);
137+
void *BraceStmt_create(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc);
133138

134139
void *BridgedSourceLoc_advanced(void *loc, long len);
135140

lib/AST/CASTBridging.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -160,17 +160,16 @@ void *IfStmt_create(void *ctx, void *ifLoc, void *cond, void *_Nullable then, vo
160160
getSourceLocFromPointer(elseLoc), (Stmt *)elseStmt, None, Context);
161161
}
162162

163-
void *BraceStmt_createExpr(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc) {
164-
ASTContext &Context = *static_cast<ASTContext *>(ctx);
165-
return BraceStmt::create(Context, getSourceLocFromPointer(lbloc),
166-
getArrayRef<ASTNode>(elements),
167-
getSourceLocFromPointer(rbloc));
168-
}
169-
170-
void *BraceStmt_createStmt(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc) {
163+
void *BraceStmt_create(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc) {
171164
llvm::SmallVector<ASTNode, 6> nodes;
172-
for (auto stmt : getArrayRef<Stmt *>(elements)) {
173-
nodes.push_back(stmt);
165+
for (auto node : getArrayRef<ASTNodeBridged>(elements)) {
166+
if (node.isExpr) {
167+
auto expr = (Expr *)node.ptr;
168+
nodes.push_back(expr);
169+
} else {
170+
auto stmt = (Stmt *)node.ptr;
171+
nodes.push_back(stmt);
172+
}
174173
}
175174

176175
ASTContext &Context = *static_cast<ASTContext *>(ctx);

lib/ASTGen/Sources/ASTGen/ASTGen.swift

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,37 @@ extension UnsafePointer {
1717
}
1818
}
1919

20+
enum ASTNode {
21+
case decl(UnsafeMutableRawPointer)
22+
case stmt(UnsafeMutableRawPointer)
23+
case expr(UnsafeMutableRawPointer)
24+
case type(UnsafeMutableRawPointer)
25+
26+
var rawValue: UnsafeMutableRawPointer {
27+
switch self {
28+
case .decl(let ptr):
29+
return ptr
30+
case .stmt(let ptr):
31+
return ptr
32+
case .expr(let ptr):
33+
return ptr
34+
case .type(let ptr):
35+
return ptr
36+
}
37+
}
38+
39+
func bridged() -> ASTNodeBridged {
40+
switch self {
41+
case .expr(let e):
42+
return ASTNodeBridged(ptr: e, isExpr: true)
43+
case .stmt(let s):
44+
return ASTNodeBridged(ptr: s, isExpr: false)
45+
default:
46+
fatalError("Must be expr or stmt.")
47+
}
48+
}
49+
}
50+
2051
/// Little utility wrapper that lets us have some mutable state within
2152
/// immutable structs, and is therefore pretty evil.
2253
@propertyWrapper
@@ -29,6 +60,8 @@ class Boxed<Value> {
2960
}
3061

3162
struct ASTGenVisitor: SyntaxTransformVisitor {
63+
typealias ResultType = ASTNode
64+
3265
let ctx: UnsafeMutableRawPointer
3366
let base: UnsafePointer<UInt8>
3467

@@ -41,11 +74,11 @@ struct ASTGenVisitor: SyntaxTransformVisitor {
4174
// }
4275

4376
@_disfavoredOverload
44-
public func visit(_ node: SourceFileSyntax) -> UnsafeMutableRawPointer {
77+
public func visit(_ node: SourceFileSyntax) -> ASTNode {
4578
fatalError("Use other overload.")
4679
}
4780

48-
public func visitAny(_ node: Syntax) -> UnsafeMutableRawPointer {
81+
public func visitAny(_ node: Syntax) -> ASTNode {
4982
fatalError("Not implemented.")
5083
}
5184

@@ -55,13 +88,15 @@ struct ASTGenVisitor: SyntaxTransformVisitor {
5588

5689
for element in node.statements {
5790
let swiftASTNodes = visit(element)
58-
if element.item.is(StmtSyntax.self) {
59-
out.append(SwiftTopLevelCodeDecl_createStmt(ctx, declContext, loc, swiftASTNodes, loc))
60-
} else if element.item.is(ExprSyntax.self) {
61-
out.append(SwiftTopLevelCodeDecl_createExpr(ctx, declContext, loc, swiftASTNodes, loc))
62-
} else {
63-
assert(element.item.is(DeclSyntax.self))
64-
out.append(swiftASTNodes)
91+
switch swiftASTNodes {
92+
case .decl(let d):
93+
out.append(d)
94+
case .stmt(let s):
95+
out.append(SwiftTopLevelCodeDecl_createStmt(ctx, declContext, loc, s, loc))
96+
case .expr(let e):
97+
out.append(SwiftTopLevelCodeDecl_createExpr(ctx, declContext, loc, e, loc))
98+
case .type(_):
99+
fatalError("Type should not exist at top level.")
65100
}
66101
}
67102

lib/ASTGen/Sources/ASTGen/Decls.swift

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,37 @@ import SwiftSyntax
44
import CASTBridging
55

66
extension ASTGenVisitor {
7-
public func visit(_ node: TypealiasDeclSyntax) -> UnsafeMutableRawPointer {
7+
public func visit(_ node: TypealiasDeclSyntax) -> ASTNode {
88
let aliasLoc = self.base.advanced(by: node.typealiasKeyword.position.utf8Offset).raw
99
let equalLoc = self.base.advanced(by: node.initializer.equal.position.utf8Offset).raw
1010
var nameText = node.identifier.text
1111
let name = nameText.withUTF8 { buf in
1212
return SwiftASTContext_getIdentifier(ctx, buf.baseAddress, buf.count)
1313
}
1414
let nameLoc = self.base.advanced(by: node.identifier.position.utf8Offset).raw
15-
let genericParams = node.genericParameterClause.map(self.visit)
15+
let genericParams = node.genericParameterClause.map(self.visit).map { $0.rawValue }
1616
let out = TypeAliasDecl_create(self.ctx, self.declContext, aliasLoc, equalLoc, name, nameLoc, genericParams)
1717

1818
let oldDeclContext = declContext
1919
declContext = out.declContext
2020
defer { declContext = oldDeclContext }
2121

22-
let underlying = self.visit(node.initializer.value)
22+
let underlying = self.visit(node.initializer.value).rawValue
2323
TypeAliasDecl_setUnderlyingTypeRepr(out.nominalDecl, underlying)
2424

25-
return out.decl
25+
return .decl(out.decl)
2626
}
2727

28-
public func visit(_ node: StructDeclSyntax) -> UnsafeMutableRawPointer {
28+
public func visit(_ node: StructDeclSyntax) -> ASTNode {
2929
let loc = self.base.advanced(by: node.position.utf8Offset).raw
3030
var nameText = node.identifier.text
3131
let name = nameText.withUTF8 { buf in
3232
return SwiftASTContext_getIdentifier(ctx, buf.baseAddress, buf.count)
3333
}
3434

35-
let genericParams = node.genericParameterClause.map(self.visit)
35+
let genericParams = node.genericParameterClause
36+
.map(self.visit)
37+
.map { $0.rawValue }
3638
let out = StructDecl_create(ctx, loc, name, loc, genericParams, declContext)
3739
let oldDeclContext = declContext
3840
declContext = out.declContext
@@ -42,10 +44,10 @@ extension ASTGenVisitor {
4244
NominalTypeDecl_setMembers(out.nominalDecl, ref)
4345
}
4446

45-
return out.decl
47+
return .decl(out.decl)
4648
}
4749

48-
public func visit(_ node: ClassDeclSyntax) -> UnsafeMutableRawPointer {
50+
public func visit(_ node: ClassDeclSyntax) -> ASTNode {
4951
let loc = self.base.advanced(by: node.position.utf8Offset).raw
5052
var nameText = node.identifier.text
5153
let name = nameText.withUTF8 { buf in
@@ -61,31 +63,22 @@ extension ASTGenVisitor {
6163
NominalTypeDecl_setMembers(out.nominalDecl, ref)
6264
}
6365

64-
return out.decl
66+
return .decl(out.decl)
6567
}
6668

67-
public func visit(_ node: VariableDeclSyntax) -> UnsafeMutableRawPointer {
68-
let pattern = visit(node.bindings.first!.pattern)
69-
let initializer = visit(node.bindings.first!.initializer!)
69+
public func visit(_ node: VariableDeclSyntax) -> ASTNode {
70+
let pattern = visit(node.bindings.first!.pattern).rawValue
71+
let initializer = visit(node.bindings.first!.initializer!).rawValue
7072

7173
let loc = self.base.advanced(by: node.position.utf8Offset).raw
7274
let isStateic = false // TODO: compute this
7375
let isLet = node.letOrVarKeyword.tokenKind == .letKeyword
7476

7577
// TODO: don't drop "initializer" on the floor.
76-
return SwiftVarDecl_create(ctx, nil, loc, isStateic, isLet, declContext)
78+
return .decl(SwiftVarDecl_create(ctx, nil, loc, isStateic, isLet, declContext))
7779
}
7880

79-
public func visit(_ node: CodeBlockSyntax) -> UnsafeMutableRawPointer {
80-
let statements = node.statements.map(self.visit)
81-
let loc = self.base.advanced(by: node.position.utf8Offset).raw
82-
83-
return statements.withBridgedArrayRef { ref in
84-
BraceStmt_createStmt(ctx, loc, ref, loc)
85-
}
86-
}
87-
88-
public func visit(_ node: FunctionParameterSyntax) -> UnsafeMutableRawPointer {
81+
public func visit(_ node: FunctionParameterSyntax) -> ASTNode {
8982
let loc = self.base.advanced(by: node.position.utf8Offset).raw
9083

9184
let firstName: UnsafeMutableRawPointer?
@@ -109,34 +102,34 @@ extension ASTGenVisitor {
109102
secondName = nil
110103
}
111104

112-
return ParamDecl_create(ctx, loc, loc, firstName, loc, secondName, declContext)
105+
return .decl(ParamDecl_create(ctx, loc, loc, firstName, loc, secondName, declContext))
113106
}
114107

115-
public func visit(_ node: FunctionDeclSyntax) -> UnsafeMutableRawPointer {
108+
public func visit(_ node: FunctionDeclSyntax) -> ASTNode {
116109
let loc = self.base.advanced(by: node.position.utf8Offset).raw
117110

118111
var nameText = node.identifier.text
119112
let name = nameText.withUTF8 { buf in
120113
return SwiftASTContext_getIdentifier(ctx, buf.baseAddress, buf.count)
121114
}
122115

123-
let body: UnsafeMutableRawPointer?
116+
let body: ASTNode?
124117
if let nodeBody = node.body {
125118
body = visit(nodeBody)
126119
} else {
127120
body = nil
128121
}
129122

130-
let returnType: UnsafeMutableRawPointer?
123+
let returnType: ASTNode?
131124
if let output = node.signature.output {
132125
returnType = visit(output.returnType)
133126
} else {
134127
returnType = nil
135128
}
136129

137130
let params = node.signature.input.parameterList.map { visit($0) }
138-
return params.withBridgedArrayRef { ref in
139-
FuncDecl_create(ctx, loc, false, loc, name, loc, false, nil, false, nil, loc, ref, loc, body, returnType, declContext)
140-
}
131+
return .decl(params.withBridgedArrayRef { ref in
132+
FuncDecl_create(ctx, loc, false, loc, name, loc, false, nil, false, nil, loc, ref, loc, body?.rawValue, returnType?.rawValue, declContext)
133+
})
141134
}
142135
}

lib/ASTGen/Sources/ASTGen/Exprs.swift

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,74 +4,74 @@ import SwiftSyntax
44
import CASTBridging
55

66
extension ASTGenVisitor {
7-
public func visit(_ node: ClosureExprSyntax) -> UnsafeMutableRawPointer {
8-
let statements = node.statements.map(self.visit)
7+
public func visit(_ node: ClosureExprSyntax) -> ASTNode {
8+
let statements = node.statements.map(self.visit).map { $0.bridged() }
99
let loc = self.base.advanced(by: node.position.utf8Offset).raw
1010

1111
let body = statements.withBridgedArrayRef { ref in
12-
BraceStmt_createExpr(ctx, loc, ref, loc)
12+
BraceStmt_create(ctx, loc, ref, loc)
1313
}
1414

15-
return ClosureExpr_create(ctx, body, declContext)
15+
return .expr(ClosureExpr_create(ctx, body, declContext))
1616
}
1717

18-
public func visit(_ node: FunctionCallExprSyntax) -> UnsafeMutableRawPointer {
18+
public func visit(_ node: FunctionCallExprSyntax) -> ASTNode {
1919
// Transform the trailing closure into an argument.
2020
if let trailingClosure = node.trailingClosure {
2121
let tupleElement = TupleExprElementSyntax(label: nil, colon: nil, expression: ExprSyntax(trailingClosure), trailingComma: nil)
2222

2323
return visit(node.addArgument(tupleElement).withTrailingClosure(nil))
2424
}
2525

26-
let args = visit(node.argumentList)
26+
let args = visit(node.argumentList).rawValue
2727
// TODO: hack
28-
let callee = visit(node.calledExpression)
28+
let callee = visit(node.calledExpression).rawValue
2929

30-
return SwiftFunctionCallExpr_create(self.ctx, callee, args)
30+
return .expr(SwiftFunctionCallExpr_create(self.ctx, callee, args))
3131
}
3232

33-
public func visit(_ node: IdentifierExprSyntax) -> UnsafeMutableRawPointer {
33+
public func visit(_ node: IdentifierExprSyntax) -> ASTNode {
3434
let loc = self.base.advanced(by: node.position.utf8Offset).raw
3535

3636
var text = node.identifier.text
3737
let id = text.withUTF8 { buf in
3838
return SwiftASTContext_getIdentifier(ctx, buf.baseAddress, buf.count)
3939
}
4040

41-
return SwiftIdentifierExpr_create(ctx, id, loc)
41+
return .expr(SwiftIdentifierExpr_create(ctx, id, loc))
4242
}
4343

44-
public func visit(_ node: IdentifierPatternSyntax) -> UnsafeMutableRawPointer {
44+
public func visit(_ node: IdentifierPatternSyntax) -> ASTNode {
4545
let loc = self.base.advanced(by: node.position.utf8Offset).raw
4646

4747
var text = node.identifier.text
4848
let id = text.withUTF8 { buf in
4949
return SwiftASTContext_getIdentifier(ctx, buf.baseAddress, buf.count)
5050
}
5151

52-
return SwiftIdentifierExpr_create(ctx, id, loc)
52+
return .expr(SwiftIdentifierExpr_create(ctx, id, loc))
5353
}
5454

55-
public func visit(_ node: MemberAccessExprSyntax) -> UnsafeMutableRawPointer {
55+
public func visit(_ node: MemberAccessExprSyntax) -> ASTNode {
5656
let loc = self.base.advanced(by: node.position.utf8Offset).raw
57-
let base = visit(node.base!)
57+
let base = visit(node.base!).rawValue
5858
var nameText = node.name.text
5959
let name = nameText.withUTF8 { buf in
6060
return SwiftASTContext_getIdentifier(ctx, buf.baseAddress, buf.count)
6161
}
6262

63-
return UnresolvedDotExpr_create(ctx, base, loc, name, loc)
63+
return .expr(UnresolvedDotExpr_create(ctx, base, loc, name, loc))
6464
}
6565

66-
public func visit(_ node: TupleExprElementListSyntax) -> UnsafeMutableRawPointer {
67-
let elements = node.map(self.visit)
66+
public func visit(_ node: TupleExprElementListSyntax) -> ASTNode {
67+
let elements = node.map(self.visit).map { $0.rawValue }
6868

6969
// TODO: find correct paren locs.
7070
let lParenLoc = self.base.advanced(by: node.position.utf8Offset).raw
7171
let rParenLoc = self.base.advanced(by: node.position.utf8Offset).raw
7272

73-
return elements.withBridgedArrayRef { elementsRef in
73+
return .expr(elements.withBridgedArrayRef { elementsRef in
7474
SwiftTupleExpr_create(self.ctx, lParenLoc, elementsRef, rParenLoc)
75-
}
75+
})
7676
}
7777
}

0 commit comments

Comments
 (0)