Skip to content

Commit 66701fb

Browse files
authored
Improve type inference for exit test value captures. (#1163)
This PR refactors the (new) type inference logic for exit test capture lists to use a syntax visitor, which allows for the types of more complex expressions to be inferred. For example, previously the type of this capture would not be inferred: ```swift [x = try await f() as Int] ``` Even though the type (`Int`) is clearly present, because the `AsExprSyntax` is nested in an `AwaitExprSyntax` and then a `TryExprSyntax`. ### Checklist: - [x] Code and documentation should follow the style of the [Style Guide](https://github.com/apple/swift-testing/blob/main/Documentation/StyleGuide.md). - [x] If public symbols are renamed or modified, DocC references should be updated.
1 parent eeeffd4 commit 66701fb

File tree

4 files changed

+207
-45
lines changed

4 files changed

+207
-45
lines changed

Sources/TestingMacros/Support/Additions/TypeSyntaxProtocolAdditions.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ extension TypeSyntaxProtocol {
3636
.contains(.keyword(.some))
3737
}
3838

39+
/// Whether or not this type is `any T` or a type derived from such a type.
40+
var isAny: Bool {
41+
tokens(viewMode: .fixedUp).lazy
42+
.map(\.tokenKind)
43+
.contains(.keyword(.any))
44+
}
45+
3946
/// Check whether or not this type is named with the specified name and
4047
/// module.
4148
///

Sources/TestingMacros/Support/ClosureCaptureListParsing.swift

Lines changed: 159 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -50,62 +50,177 @@ struct CapturedValueInfo {
5050
return
5151
}
5252

53-
// Potentially get the name of the type comprising the current lexical
54-
// context (i.e. whatever `Self` is.)
55-
lazy var lexicalContext = context.lexicalContext
56-
lazy var typeNameOfLexicalContext = {
57-
let lexicalContext = lexicalContext.drop { !$0.isProtocol((any DeclGroupSyntax).self) }
58-
return context.type(ofLexicalContext: lexicalContext)
59-
}()
53+
if let (expr, type) = Self._inferExpressionAndType(of: capture, in: context) {
54+
self.expression = expr
55+
self.type = type
56+
} else {
57+
// Not enough contextual information to derive the type here.
58+
context.diagnose(.typeOfCaptureIsAmbiguous(capture))
59+
}
60+
}
6061

62+
/// Infer the captured expression and the type of a closure capture list item.
63+
///
64+
/// - Parameters:
65+
/// - capture: The closure capture list item to inspect.
66+
/// - context: The macro context in which the expression is being parsed.
67+
///
68+
/// - Returns: A tuple containing the expression and type of `capture`, or
69+
/// `nil` if they could not be inferred.
70+
private static func _inferExpressionAndType(of capture: ClosureCaptureSyntax, in context: some MacroExpansionContext) -> (ExprSyntax, TypeSyntax)? {
6171
if let initializer = capture.initializer {
6272
// Found an initializer clause. Extract the expression it captures.
63-
self.expression = removeParentheses(from: initializer.value) ?? initializer.value
73+
let finder = _ExprTypeFinder(in: context)
74+
finder.walk(initializer.value)
75+
if let inferredType = finder.inferredType {
76+
return (initializer.value, inferredType)
77+
}
78+
} else if capture.name.tokenKind == .keyword(.self),
79+
let typeNameOfLexicalContext = Self._inferSelf(from: context) {
80+
// Capturing self.
81+
return (ExprSyntax(DeclReferenceExprSyntax(baseName: .keyword(.self))), typeNameOfLexicalContext)
82+
} else if let parameterType = Self._findTypeOfParameter(named: capture.name, in: context.lexicalContext) {
83+
return (ExprSyntax(DeclReferenceExprSyntax(baseName: capture.name.trimmed)), parameterType)
84+
}
85+
86+
return nil
87+
}
88+
89+
private final class _ExprTypeFinder<C>: SyntaxAnyVisitor where C: MacroExpansionContext {
90+
var context: C
91+
92+
/// The type that was inferred from the visited syntax tree, if any.
93+
///
94+
/// This type has not been fixed up yet. Use ``inferredType`` for the final
95+
/// derived type.
96+
private var _inferredType: TypeSyntax?
97+
98+
/// Whether or not the inferred type has been made optional by e.g. `try?`.
99+
private var _needsOptionalApplied = false
100+
101+
/// The type that was inferred from the visited syntax tree, if any.
102+
var inferredType: TypeSyntax? {
103+
_inferredType.flatMap { inferredType in
104+
if inferredType.isSome || inferredType.isAny {
105+
// `some` and `any` types are not concrete and cannot be inferred.
106+
nil
107+
} else if _needsOptionalApplied {
108+
TypeSyntax(OptionalTypeSyntax(wrappedType: inferredType.trimmed))
109+
} else {
110+
inferredType
111+
}
112+
}
113+
}
114+
115+
init(in context: C) {
116+
self.context = context
117+
super.init(viewMode: .sourceAccurate)
118+
}
119+
120+
override func visitAny(_ node: Syntax) -> SyntaxVisitorContinueKind {
121+
if inferredType != nil {
122+
// Another part of the syntax tree has already provided a type. Stop.
123+
return .skipChildren
124+
}
64125

65-
// Find the 'as' clause so we can determine the type of the captured value.
66-
if let asExpr = self.expression.as(AsExprSyntax.self) {
67-
self.type = if asExpr.questionOrExclamationMark?.tokenKind == .postfixQuestionMark {
126+
switch node.kind {
127+
case .asExpr:
128+
let asExpr = node.cast(AsExprSyntax.self)
129+
if let type = asExpr.type.as(IdentifierTypeSyntax.self), type.name.tokenKind == .keyword(.Self) {
130+
// `Self` should resolve to the lexical context's type.
131+
_inferredType = CapturedValueInfo._inferSelf(from: context)
132+
} else if asExpr.questionOrExclamationMark?.tokenKind == .postfixQuestionMark {
68133
// If the caller is using as?, make the type optional.
69-
TypeSyntax(OptionalTypeSyntax(wrappedType: asExpr.type.trimmed))
134+
_inferredType = TypeSyntax(OptionalTypeSyntax(wrappedType: asExpr.type.trimmed))
70135
} else {
71-
asExpr.type
136+
_inferredType = asExpr.type
72137
}
73-
} else if let selfExpr = self.expression.as(DeclReferenceExprSyntax.self),
74-
selfExpr.baseName.tokenKind == .keyword(.self),
75-
selfExpr.argumentNames == nil,
76-
let typeNameOfLexicalContext {
77-
// Copying self.
78-
self.type = typeNameOfLexicalContext
79-
} else {
80-
// Handle literals. Any other types are ambiguous.
81-
switch self.expression.kind {
82-
case .integerLiteralExpr:
83-
self.type = TypeSyntax(IdentifierTypeSyntax(name: .identifier("IntegerLiteralType")))
84-
case .floatLiteralExpr:
85-
self.type = TypeSyntax(IdentifierTypeSyntax(name: .identifier("FloatLiteralType")))
86-
case .booleanLiteralExpr:
87-
self.type = TypeSyntax(IdentifierTypeSyntax(name: .identifier("BooleanLiteralType")))
88-
case .stringLiteralExpr, .simpleStringLiteralExpr:
89-
self.type = TypeSyntax(IdentifierTypeSyntax(name: .identifier("StringLiteralType")))
90-
default:
91-
context.diagnose(.typeOfCaptureIsAmbiguous(capture, initializedWith: initializer))
138+
return .skipChildren
139+
140+
case .awaitExpr, .unsafeExpr:
141+
// These effect keywords do not affect the type of the expression.
142+
return .visitChildren
143+
144+
case .tryExpr:
145+
let tryExpr = node.cast(TryExprSyntax.self)
146+
if tryExpr.questionOrExclamationMark?.tokenKind == .postfixQuestionMark {
147+
// The resulting type from the inner expression will be optionalized.
148+
_needsOptionalApplied = true
92149
}
93-
}
150+
return .visitChildren
94151

95-
} else if capture.name.tokenKind == .keyword(.self),
96-
let typeNameOfLexicalContext {
97-
// Capturing self.
98-
self.expression = "self"
99-
self.type = typeNameOfLexicalContext
100-
} else if let parameterType = Self._findTypeOfParameter(named: capture.name, in: lexicalContext) {
101-
self.expression = ExprSyntax(DeclReferenceExprSyntax(baseName: capture.name.trimmed))
102-
self.type = parameterType
103-
} else {
104-
// Not enough contextual information to derive the type here.
105-
context.diagnose(.typeOfCaptureIsAmbiguous(capture))
152+
case .tupleExpr:
153+
// If the tuple contains exactly one element, it's just parentheses
154+
// around that expression.
155+
let tupleExpr = node.cast(TupleExprSyntax.self)
156+
if tupleExpr.elements.count == 1 {
157+
return .visitChildren
158+
}
159+
160+
// Otherwise, we need to try to compose the type as a tuple type from
161+
// the types of all elements in the tuple expression. Note that tuples
162+
// do not conform to Sendable or Codable, so our current use of this
163+
// code in exit tests will still diagnose an error, but the error ("must
164+
// conform") will be more useful than "couldn't infer".
165+
let elements = tupleExpr.elements.compactMap { element in
166+
let finder = Self(in: context)
167+
finder.walk(element.expression)
168+
return finder.inferredType.map { type in
169+
TupleTypeElementSyntax(firstName: element.label?.trimmed, type: type.trimmed)
170+
}
171+
}
172+
if elements.count == tupleExpr.elements.count {
173+
_inferredType = TypeSyntax(
174+
TupleTypeSyntax(elements: TupleTypeElementListSyntax { elements })
175+
)
176+
}
177+
return .skipChildren
178+
179+
case .declReferenceExpr:
180+
// If the reference is to `self` without any arguments, its type can be
181+
// inferred from the lexical context.
182+
let expr = node.cast(DeclReferenceExprSyntax.self)
183+
if expr.baseName.tokenKind == .keyword(.self), expr.argumentNames == nil {
184+
_inferredType = CapturedValueInfo._inferSelf(from: context)
185+
}
186+
return .skipChildren
187+
188+
case .integerLiteralExpr:
189+
_inferredType = TypeSyntax(IdentifierTypeSyntax(name: .identifier("IntegerLiteralType")))
190+
return .skipChildren
191+
192+
case .floatLiteralExpr:
193+
_inferredType = TypeSyntax(IdentifierTypeSyntax(name: .identifier("FloatLiteralType")))
194+
return .skipChildren
195+
196+
case .booleanLiteralExpr:
197+
_inferredType = TypeSyntax(IdentifierTypeSyntax(name: .identifier("BooleanLiteralType")))
198+
return .skipChildren
199+
200+
case .stringLiteralExpr, .simpleStringLiteralExpr:
201+
_inferredType = TypeSyntax(IdentifierTypeSyntax(name: .identifier("StringLiteralType")))
202+
return .skipChildren
203+
204+
default:
205+
// We don't know how to infer a type from this syntax node, so do not
206+
// proceed further.
207+
return .skipChildren
208+
}
106209
}
107210
}
108211

212+
/// Get the type of `self` inferred from the given context.
213+
///
214+
/// - Parameters:
215+
/// - context: The macro context in which the expression is being parsed.
216+
///
217+
/// - Returns: The type in `lexicalContext` corresponding to `Self`, or `nil`
218+
/// if it could not be determined.
219+
private static func _inferSelf(from context: some MacroExpansionContext) -> TypeSyntax? {
220+
let lexicalContext = context.lexicalContext.drop { !$0.isProtocol((any DeclGroupSyntax).self) }
221+
return context.type(ofLexicalContext: lexicalContext)
222+
}
223+
109224
/// Find a function or closure parameter in the given lexical context with a
110225
/// given name and return its type.
111226
///

Tests/TestingMacrosTests/ConditionMacroTests.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,10 @@ struct ConditionMacroTests {
461461
"Type of captured value 'a' is ambiguous",
462462
"#expectExitTest(processExitsWith: x) { [a = b] in }":
463463
"Type of captured value 'a' is ambiguous",
464+
"#expectExitTest(processExitsWith: x) { [a = b as any T] in }":
465+
"Type of captured value 'a' is ambiguous",
466+
"#expectExitTest(processExitsWith: x) { [a = b as some T] in }":
467+
"Type of captured value 'a' is ambiguous",
464468
"struct S<T> { func f() { #expectExitTest(processExitsWith: x) { [a] in } } }":
465469
"Cannot call macro ''#expectExitTest(processExitsWith:_:)'' within generic structure 'S'",
466470
]

Tests/TestingTests/ExitTestTests.swift

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,9 +407,10 @@ private import _TestingInternals
407407

408408
@Test("self in capture list")
409409
func captureListWithSelf() async {
410-
await #expect(processExitsWith: .success) { [self, x = self] in
410+
await #expect(processExitsWith: .success) { [self, x = self, y = self as Self] in
411411
#expect(self.property == 456)
412412
#expect(x.property == 456)
413+
#expect(y.property == 456)
413414
}
414415
}
415416
}
@@ -506,6 +507,41 @@ private import _TestingInternals
506507
}
507508
}
508509

510+
@Test("Capturing an optional value")
511+
func captureListWithOptionalValue() async throws {
512+
await #expect(processExitsWith: .success) { [x = nil as Int?] in
513+
#expect(x != 1)
514+
}
515+
await #expect(processExitsWith: .success) { [x = (0 as Any) as? String] in
516+
#expect(x == nil)
517+
}
518+
}
519+
520+
@Test("Capturing an effectful expression")
521+
func captureListWithEffectfulExpression() async throws {
522+
func f() async throws -> Int { 0 }
523+
try await #require(processExitsWith: .success) { [f = try await f() as Int] in
524+
#expect(f == 0)
525+
}
526+
try await #expect(processExitsWith: .success) { [f = f() as Int] in
527+
#expect(f == 0)
528+
}
529+
}
530+
531+
#if false // intentionally fails to compile
532+
@Test("Capturing a tuple")
533+
func captureListWithTuple() async throws {
534+
// A tuple whose elements conform to Codable does not itself conform to
535+
// Codable, so we cannot actually express this capture list in a way that
536+
// works with #expect().
537+
await #expect(processExitsWith: .success) { [x = (0 as Int, 1 as Double, "2" as String)] in
538+
#expect(x.0 == 0)
539+
#expect(x.1 == 1)
540+
#expect(x.2 == "2")
541+
}
542+
}
543+
#endif
544+
509545
#if false // intentionally fails to compile
510546
struct NonCodableValue {}
511547

0 commit comments

Comments
 (0)