@@ -50,62 +50,177 @@ struct CapturedValueInfo {
50
50
return
51
51
}
52
52
53
- // Potentially get the name of the type comprising the current lexical
54
- // context (i.e. whatever `Self` is.)
55
- lazy var lexicalContext = context. lexicalContext
56
- lazy var typeNameOfLexicalContext = {
57
- let lexicalContext = lexicalContext. drop { !$0. isProtocol ( ( any DeclGroupSyntax ) . self) }
58
- return context. type ( ofLexicalContext: lexicalContext)
59
- } ( )
53
+ if let ( expr, type) = Self . _inferExpressionAndType ( of: capture, in: context) {
54
+ self . expression = expr
55
+ self . type = type
56
+ } else {
57
+ // Not enough contextual information to derive the type here.
58
+ context. diagnose ( . typeOfCaptureIsAmbiguous( capture) )
59
+ }
60
+ }
60
61
62
+ /// Infer the captured expression and the type of a closure capture list item.
63
+ ///
64
+ /// - Parameters:
65
+ /// - capture: The closure capture list item to inspect.
66
+ /// - context: The macro context in which the expression is being parsed.
67
+ ///
68
+ /// - Returns: A tuple containing the expression and type of `capture`, or
69
+ /// `nil` if they could not be inferred.
70
+ private static func _inferExpressionAndType( of capture: ClosureCaptureSyntax , in context: some MacroExpansionContext ) -> ( ExprSyntax , TypeSyntax ) ? {
61
71
if let initializer = capture. initializer {
62
72
// Found an initializer clause. Extract the expression it captures.
63
- self . expression = removeParentheses ( from: initializer. value) ?? initializer. value
73
+ let finder = _ExprTypeFinder ( in: context)
74
+ finder. walk ( initializer. value)
75
+ if let inferredType = finder. inferredType {
76
+ return ( initializer. value, inferredType)
77
+ }
78
+ } else if capture. name. tokenKind == . keyword( . self ) ,
79
+ let typeNameOfLexicalContext = Self . _inferSelf ( from: context) {
80
+ // Capturing self.
81
+ return ( ExprSyntax ( DeclReferenceExprSyntax ( baseName: . keyword( . self ) ) ) , typeNameOfLexicalContext)
82
+ } else if let parameterType = Self . _findTypeOfParameter ( named: capture. name, in: context. lexicalContext) {
83
+ return ( ExprSyntax ( DeclReferenceExprSyntax ( baseName: capture. name. trimmed) ) , parameterType)
84
+ }
85
+
86
+ return nil
87
+ }
88
+
89
+ private final class _ExprTypeFinder < C> : SyntaxAnyVisitor where C: MacroExpansionContext {
90
+ var context : C
91
+
92
+ /// The type that was inferred from the visited syntax tree, if any.
93
+ ///
94
+ /// This type has not been fixed up yet. Use ``inferredType`` for the final
95
+ /// derived type.
96
+ private var _inferredType : TypeSyntax ?
97
+
98
+ /// Whether or not the inferred type has been made optional by e.g. `try?`.
99
+ private var _needsOptionalApplied = false
100
+
101
+ /// The type that was inferred from the visited syntax tree, if any.
102
+ var inferredType : TypeSyntax ? {
103
+ _inferredType. flatMap { inferredType in
104
+ if inferredType. isSome || inferredType. isAny {
105
+ // `some` and `any` types are not concrete and cannot be inferred.
106
+ nil
107
+ } else if _needsOptionalApplied {
108
+ TypeSyntax ( OptionalTypeSyntax ( wrappedType: inferredType. trimmed) )
109
+ } else {
110
+ inferredType
111
+ }
112
+ }
113
+ }
114
+
115
+ init ( in context: C ) {
116
+ self . context = context
117
+ super. init ( viewMode: . sourceAccurate)
118
+ }
119
+
120
+ override func visitAny( _ node: Syntax ) -> SyntaxVisitorContinueKind {
121
+ if inferredType != nil {
122
+ // Another part of the syntax tree has already provided a type. Stop.
123
+ return . skipChildren
124
+ }
64
125
65
- // Find the 'as' clause so we can determine the type of the captured value.
66
- if let asExpr = self . expression. as ( AsExprSyntax . self) {
67
- self . type = if asExpr. questionOrExclamationMark? . tokenKind == . postfixQuestionMark {
126
+ switch node. kind {
127
+ case . asExpr:
128
+ let asExpr = node. cast ( AsExprSyntax . self)
129
+ if let type = asExpr. type. as ( IdentifierTypeSyntax . self) , type. name. tokenKind == . keyword( . Self) {
130
+ // `Self` should resolve to the lexical context's type.
131
+ _inferredType = CapturedValueInfo . _inferSelf ( from: context)
132
+ } else if asExpr. questionOrExclamationMark? . tokenKind == . postfixQuestionMark {
68
133
// If the caller is using as?, make the type optional.
69
- TypeSyntax ( OptionalTypeSyntax ( wrappedType: asExpr. type. trimmed) )
134
+ _inferredType = TypeSyntax ( OptionalTypeSyntax ( wrappedType: asExpr. type. trimmed) )
70
135
} else {
71
- asExpr. type
136
+ _inferredType = asExpr. type
72
137
}
73
- } else if let selfExpr = self . expression. as ( DeclReferenceExprSyntax . self) ,
74
- selfExpr. baseName. tokenKind == . keyword( . self ) ,
75
- selfExpr. argumentNames == nil ,
76
- let typeNameOfLexicalContext {
77
- // Copying self.
78
- self . type = typeNameOfLexicalContext
79
- } else {
80
- // Handle literals. Any other types are ambiguous.
81
- switch self . expression. kind {
82
- case . integerLiteralExpr:
83
- self . type = TypeSyntax ( IdentifierTypeSyntax ( name: . identifier( " IntegerLiteralType " ) ) )
84
- case . floatLiteralExpr:
85
- self . type = TypeSyntax ( IdentifierTypeSyntax ( name: . identifier( " FloatLiteralType " ) ) )
86
- case . booleanLiteralExpr:
87
- self . type = TypeSyntax ( IdentifierTypeSyntax ( name: . identifier( " BooleanLiteralType " ) ) )
88
- case . stringLiteralExpr, . simpleStringLiteralExpr:
89
- self . type = TypeSyntax ( IdentifierTypeSyntax ( name: . identifier( " StringLiteralType " ) ) )
90
- default :
91
- context. diagnose ( . typeOfCaptureIsAmbiguous( capture, initializedWith: initializer) )
138
+ return . skipChildren
139
+
140
+ case . awaitExpr, . unsafeExpr:
141
+ // These effect keywords do not affect the type of the expression.
142
+ return . visitChildren
143
+
144
+ case . tryExpr:
145
+ let tryExpr = node. cast ( TryExprSyntax . self)
146
+ if tryExpr. questionOrExclamationMark? . tokenKind == . postfixQuestionMark {
147
+ // The resulting type from the inner expression will be optionalized.
148
+ _needsOptionalApplied = true
92
149
}
93
- }
150
+ return . visitChildren
94
151
95
- } else if capture. name. tokenKind == . keyword( . self ) ,
96
- let typeNameOfLexicalContext {
97
- // Capturing self.
98
- self . expression = " self "
99
- self . type = typeNameOfLexicalContext
100
- } else if let parameterType = Self . _findTypeOfParameter ( named: capture. name, in: lexicalContext) {
101
- self . expression = ExprSyntax ( DeclReferenceExprSyntax ( baseName: capture. name. trimmed) )
102
- self . type = parameterType
103
- } else {
104
- // Not enough contextual information to derive the type here.
105
- context. diagnose ( . typeOfCaptureIsAmbiguous( capture) )
152
+ case . tupleExpr:
153
+ // If the tuple contains exactly one element, it's just parentheses
154
+ // around that expression.
155
+ let tupleExpr = node. cast ( TupleExprSyntax . self)
156
+ if tupleExpr. elements. count == 1 {
157
+ return . visitChildren
158
+ }
159
+
160
+ // Otherwise, we need to try to compose the type as a tuple type from
161
+ // the types of all elements in the tuple expression. Note that tuples
162
+ // do not conform to Sendable or Codable, so our current use of this
163
+ // code in exit tests will still diagnose an error, but the error ("must
164
+ // conform") will be more useful than "couldn't infer".
165
+ let elements = tupleExpr. elements. compactMap { element in
166
+ let finder = Self ( in: context)
167
+ finder. walk ( element. expression)
168
+ return finder. inferredType. map { type in
169
+ TupleTypeElementSyntax ( firstName: element. label? . trimmed, type: type. trimmed)
170
+ }
171
+ }
172
+ if elements. count == tupleExpr. elements. count {
173
+ _inferredType = TypeSyntax (
174
+ TupleTypeSyntax ( elements: TupleTypeElementListSyntax { elements } )
175
+ )
176
+ }
177
+ return . skipChildren
178
+
179
+ case . declReferenceExpr:
180
+ // If the reference is to `self` without any arguments, its type can be
181
+ // inferred from the lexical context.
182
+ let expr = node. cast ( DeclReferenceExprSyntax . self)
183
+ if expr. baseName. tokenKind == . keyword( . self ) , expr. argumentNames == nil {
184
+ _inferredType = CapturedValueInfo . _inferSelf ( from: context)
185
+ }
186
+ return . skipChildren
187
+
188
+ case . integerLiteralExpr:
189
+ _inferredType = TypeSyntax ( IdentifierTypeSyntax ( name: . identifier( " IntegerLiteralType " ) ) )
190
+ return . skipChildren
191
+
192
+ case . floatLiteralExpr:
193
+ _inferredType = TypeSyntax ( IdentifierTypeSyntax ( name: . identifier( " FloatLiteralType " ) ) )
194
+ return . skipChildren
195
+
196
+ case . booleanLiteralExpr:
197
+ _inferredType = TypeSyntax ( IdentifierTypeSyntax ( name: . identifier( " BooleanLiteralType " ) ) )
198
+ return . skipChildren
199
+
200
+ case . stringLiteralExpr, . simpleStringLiteralExpr:
201
+ _inferredType = TypeSyntax ( IdentifierTypeSyntax ( name: . identifier( " StringLiteralType " ) ) )
202
+ return . skipChildren
203
+
204
+ default :
205
+ // We don't know how to infer a type from this syntax node, so do not
206
+ // proceed further.
207
+ return . skipChildren
208
+ }
106
209
}
107
210
}
108
211
212
+ /// Get the type of `self` inferred from the given context.
213
+ ///
214
+ /// - Parameters:
215
+ /// - context: The macro context in which the expression is being parsed.
216
+ ///
217
+ /// - Returns: The type in `lexicalContext` corresponding to `Self`, or `nil`
218
+ /// if it could not be determined.
219
+ private static func _inferSelf( from context: some MacroExpansionContext ) -> TypeSyntax ? {
220
+ let lexicalContext = context. lexicalContext. drop { !$0. isProtocol ( ( any DeclGroupSyntax ) . self) }
221
+ return context. type ( ofLexicalContext: lexicalContext)
222
+ }
223
+
109
224
/// Find a function or closure parameter in the given lexical context with a
110
225
/// given name and return its type.
111
226
///
0 commit comments