Skip to content

Commit 09e34f2

Browse files
committed
[Macros] Apply generic arguments from attached macro custom attributes when
resolving the macro reference.
1 parent d6e4b70 commit 09e34f2

File tree

3 files changed

+80
-4
lines changed

3 files changed

+80
-4
lines changed

lib/AST/TypeCheckRequests.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1705,8 +1705,16 @@ SourceRange UnresolvedMacroReference::getGenericArgsRange() const {
17051705
return med->getGenericArgsRange();
17061706
if (auto *mee = pointer.dyn_cast<MacroExpansionExpr *>())
17071707
return mee->getGenericArgsRange();
1708-
if (auto *attr = pointer.dyn_cast<CustomAttr *>())
1709-
return SourceRange();
1708+
1709+
if (auto *attr = pointer.dyn_cast<CustomAttr *>()) {
1710+
auto *typeRepr = attr->getTypeRepr();
1711+
auto *genericTypeRepr = dyn_cast_or_null<GenericIdentTypeRepr>(typeRepr);
1712+
if (!genericTypeRepr)
1713+
return SourceRange();
1714+
1715+
return genericTypeRepr->getAngleBrackets();
1716+
}
1717+
17101718
llvm_unreachable("Unhandled case");
17111719
}
17121720

@@ -1715,8 +1723,16 @@ ArrayRef<TypeRepr *> UnresolvedMacroReference::getGenericArgs() const {
17151723
return med->getGenericArgs();
17161724
if (auto *mee = pointer.dyn_cast<MacroExpansionExpr *>())
17171725
return mee->getGenericArgs();
1718-
if (auto *attr = pointer.dyn_cast<CustomAttr *>())
1719-
return {};
1726+
1727+
if (auto *attr = pointer.dyn_cast<CustomAttr *>()) {
1728+
auto *typeRepr = attr->getTypeRepr();
1729+
auto *genericTypeRepr = dyn_cast_or_null<GenericIdentTypeRepr>(typeRepr);
1730+
if (!genericTypeRepr)
1731+
return {};
1732+
1733+
return genericTypeRepr->getGenericArgs();
1734+
}
1735+
17201736
llvm_unreachable("Unhandled case");
17211737
}
17221738

test/Macros/Inputs/syntax_macro_definitions.swift

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,3 +717,48 @@ public struct ObservablePropertyMacro: AccessorMacro {
717717
return [getAccessor, setAccessor]
718718
}
719719
}
720+
721+
extension DeclModifierSyntax {
722+
fileprivate var isNeededAccessLevelModifier: Bool {
723+
switch self.name.tokenKind {
724+
case .keyword(.public): return true
725+
default: return false
726+
}
727+
}
728+
}
729+
730+
extension SyntaxStringInterpolation {
731+
fileprivate mutating func appendInterpolation<Node: SyntaxProtocol>(_ node: Node?) {
732+
if let node {
733+
appendInterpolation(node)
734+
}
735+
}
736+
}
737+
738+
public struct NewTypeMacro: MemberMacro {
739+
public static func expansion(
740+
of node: AttributeSyntax,
741+
providingMembersOf declaration: some DeclGroupSyntax,
742+
in context: some MacroExpansionContext
743+
) throws -> [DeclSyntax] {
744+
guard let type = node.attributeName.as(SimpleTypeIdentifierSyntax.self),
745+
let genericArguments = type.genericArgumentClause?.arguments,
746+
genericArguments.count == 1,
747+
let rawType = genericArguments.first
748+
else {
749+
throw CustomError.message(#"@NewType requires the raw type as an argument, in the form "<RawType>"."#)
750+
}
751+
752+
guard let declaration = declaration.as(StructDeclSyntax.self) else {
753+
throw CustomError.message("@NewType can only be applied to a struct declarations.")
754+
}
755+
756+
let access = declaration.modifiers?.first(where: \.isNeededAccessLevelModifier)
757+
758+
return [
759+
"\(access)typealias RawValue = \(rawType)",
760+
"\(access)var rawValue: RawValue",
761+
"\(access)init(_ rawValue: RawValue) { self.rawValue = rawValue }",
762+
]
763+
}
764+
}

test/Macros/macro_expand_synthesized_members.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,18 @@ let s = S()
2424
// CHECK: synthesized method
2525
// CHECK: Storage
2626
s.useSynthesized()
27+
28+
@attached(
29+
member,
30+
names: named(RawValue), named(rawValue), named(`init`)
31+
)
32+
public macro NewType<T>() = #externalMacro(module: "MacroDefinition", type: "NewTypeMacro")
33+
34+
@NewType<String>
35+
public struct MyString {}
36+
37+
// CHECK: String
38+
// CHECK: hello
39+
let myString = MyString("hello")
40+
print(MyString.RawValue.self)
41+
print(myString.rawValue)

0 commit comments

Comments
 (0)