|
| 1 | +import SwiftDiagnostics |
| 2 | +import SwiftSyntax |
| 3 | +import SwiftSyntaxBuilder |
| 4 | +import SwiftSyntaxMacros |
| 5 | + |
| 6 | +enum OptionSetMacroDiagnostic { |
| 7 | + case requiresStruct |
| 8 | + case requiresStringLiteral(String) |
| 9 | + case requiresOptionsEnum(String) |
| 10 | + case requiresOptionsEnumRawType |
| 11 | +} |
| 12 | + |
| 13 | +extension OptionSetMacroDiagnostic: DiagnosticMessage { |
| 14 | + func diagnose<Node: SyntaxProtocol>(at node: Node) -> Diagnostic { |
| 15 | + Diagnostic(node: Syntax(node), message: self) |
| 16 | + } |
| 17 | + |
| 18 | + var message: String { |
| 19 | + switch self { |
| 20 | + case .requiresStruct: |
| 21 | + return "'OptionSet' macro can only be applied to a struct" |
| 22 | + |
| 23 | + case .requiresStringLiteral(let name): |
| 24 | + return "'OptionSet' macro argument \(name) must be a string literal" |
| 25 | + |
| 26 | + case .requiresOptionsEnum(let name): |
| 27 | + return "'OptionSet' macro requires nested options enum '\(name)'" |
| 28 | + |
| 29 | + case .requiresOptionsEnumRawType: |
| 30 | + return "'OptionSet' macro requires a raw type" |
| 31 | + } |
| 32 | + } |
| 33 | + |
| 34 | + var severity: DiagnosticSeverity { .error } |
| 35 | + |
| 36 | + var diagnosticID: MessageID { |
| 37 | + MessageID(domain: "Swift", id: "OptionSet.\(self)") |
| 38 | + } |
| 39 | +} |
| 40 | + |
| 41 | + |
| 42 | +/// The label used for the OptionSet macro argument that provides the name of |
| 43 | +/// the nested options enum. |
| 44 | +private let optionsEnumNameArgumentLabel = "optionsName" |
| 45 | + |
| 46 | +/// The default name used for the nested "Options" enum. This should |
| 47 | +/// eventually be overridable. |
| 48 | +private let defaultOptionsEnumName = "Options" |
| 49 | + |
| 50 | +extension TupleExprElementListSyntax { |
| 51 | + /// Retrieve the first element with the given label. |
| 52 | + func first(labeled name: String) -> Element? { |
| 53 | + return first { element in |
| 54 | + if let label = element.label, label.text == name { |
| 55 | + return true |
| 56 | + } |
| 57 | + |
| 58 | + return false |
| 59 | + } |
| 60 | + } |
| 61 | +} |
| 62 | + |
| 63 | +public struct OptionSetMacro { |
| 64 | + /// Decodes the arguments to the macro expansion. |
| 65 | + /// |
| 66 | + /// - Returns: the important arguments used by the various roles of this |
| 67 | + /// macro inhabits, or nil if an error occurred. |
| 68 | + static func decodeExpansion< |
| 69 | + Decl: DeclGroupSyntax, |
| 70 | + Context: MacroExpansionContext |
| 71 | + >( |
| 72 | + of attribute: AttributeSyntax, |
| 73 | + attachedTo decl: Decl, |
| 74 | + in context: Context |
| 75 | + ) -> (StructDeclSyntax, EnumDeclSyntax, TypeSyntax)? { |
| 76 | + // Determine the name of the options enum. |
| 77 | + let optionsEnumName: String |
| 78 | + if case let .argumentList(arguments) = attribute.argument, |
| 79 | + let optionEnumNameArg = arguments.first(labeled: optionsEnumNameArgumentLabel) { |
| 80 | + // We have a options name; make sure it is a string literal. |
| 81 | + guard let stringLiteral = optionEnumNameArg.expression.as(StringLiteralExprSyntax.self), |
| 82 | + stringLiteral.segments.count == 1, |
| 83 | + case let .stringSegment(optionsEnumNameString)? = stringLiteral.segments.first else { |
| 84 | + context.diagnose(OptionSetMacroDiagnostic.requiresStringLiteral(optionsEnumNameArgumentLabel).diagnose(at: optionEnumNameArg.expression)) |
| 85 | + return nil |
| 86 | + } |
| 87 | + |
| 88 | + optionsEnumName = optionsEnumNameString.content.text |
| 89 | + } else { |
| 90 | + optionsEnumName = defaultOptionsEnumName |
| 91 | + } |
| 92 | + |
| 93 | + // Only apply to structs. |
| 94 | + guard let structDecl = decl.as(StructDeclSyntax.self) else { |
| 95 | + context.diagnose(OptionSetMacroDiagnostic.requiresStruct.diagnose(at: decl)) |
| 96 | + return nil |
| 97 | + } |
| 98 | + |
| 99 | + // Find the option enum within the struct. |
| 100 | + let optionsEnums: [EnumDeclSyntax] = decl.members.members.compactMap({ member in |
| 101 | + if let enumDecl = member.decl.as(EnumDeclSyntax.self), |
| 102 | + enumDecl.identifier.text == optionsEnumName { |
| 103 | + return enumDecl |
| 104 | + } |
| 105 | + |
| 106 | + return nil |
| 107 | + }) |
| 108 | + |
| 109 | + guard let optionsEnum = optionsEnums.first else { |
| 110 | + context.diagnose(OptionSetMacroDiagnostic.requiresOptionsEnum(optionsEnumName).diagnose(at: decl)) |
| 111 | + return nil |
| 112 | + } |
| 113 | + |
| 114 | + // Retrieve the raw type from the attribute. |
| 115 | + guard let genericArgs = attribute.attributeName.as(SimpleTypeIdentifierSyntax.self)?.genericArgumentClause, |
| 116 | + let rawType = genericArgs.arguments.first?.argumentType else { |
| 117 | + context.diagnose(OptionSetMacroDiagnostic.requiresOptionsEnumRawType.diagnose(at: attribute)) |
| 118 | + return nil |
| 119 | + } |
| 120 | + |
| 121 | + |
| 122 | + return (structDecl, optionsEnum, rawType) |
| 123 | + } |
| 124 | +} |
| 125 | + |
| 126 | +extension OptionSetMacro: ConformanceMacro { |
| 127 | + public static func expansion< |
| 128 | + Decl: DeclGroupSyntax, |
| 129 | + Context: MacroExpansionContext |
| 130 | + >( |
| 131 | + of attribute: AttributeSyntax, |
| 132 | + providingConformancesOf decl: Decl, |
| 133 | + in context: Context |
| 134 | + ) throws -> [(TypeSyntax, GenericWhereClauseSyntax?)] { |
| 135 | + // Decode the expansion arguments. |
| 136 | + guard let (structDecl, _, _) = decodeExpansion(of: attribute, attachedTo: decl, in: context) else { |
| 137 | + return [] |
| 138 | + } |
| 139 | + |
| 140 | + // If there is an explicit conformance to OptionSet already, don't add one. |
| 141 | + if let inheritedTypes = structDecl.inheritanceClause?.inheritedTypeCollection, |
| 142 | + inheritedTypes.contains(where: { inherited in inherited.typeName.trimmedDescription == "OptionSet" }) { |
| 143 | + return [] |
| 144 | + } |
| 145 | + |
| 146 | + return [("OptionSet", nil)] |
| 147 | + } |
| 148 | +} |
| 149 | + |
| 150 | +extension OptionSetMacro: MemberMacro { |
| 151 | + public static func expansion< |
| 152 | + Decl: DeclGroupSyntax, |
| 153 | + Context: MacroExpansionContext |
| 154 | + >( |
| 155 | + of attribute: AttributeSyntax, |
| 156 | + providingMembersOf decl: Decl, |
| 157 | + in context: Context |
| 158 | + ) throws -> [DeclSyntax] { |
| 159 | + // Decode the expansion arguments. |
| 160 | + guard let (_, optionsEnum, rawType) = decodeExpansion(of: attribute, attachedTo: decl, in: context) else { |
| 161 | + return [] |
| 162 | + } |
| 163 | + |
| 164 | + // Find all of the case elements. |
| 165 | + var caseElements: [EnumCaseElementSyntax] = [] |
| 166 | + for member in optionsEnum.members.members { |
| 167 | + if let caseDecl = member.decl.as(EnumCaseDeclSyntax.self) { |
| 168 | + caseElements.append(contentsOf: caseDecl.elements) |
| 169 | + } |
| 170 | + } |
| 171 | + |
| 172 | + // Dig out the access control keyword we need. |
| 173 | + let access = decl.modifiers?.first(where: \.isNeededAccessLevelModifier) |
| 174 | + |
| 175 | + let staticVars = caseElements.map { (element) -> DeclSyntax in |
| 176 | + """ |
| 177 | + \(access) static let \(element.identifier): Self = |
| 178 | + Self(rawValue: 1 << \(optionsEnum.identifier).\(element.identifier).rawValue) |
| 179 | + """ |
| 180 | + } |
| 181 | + |
| 182 | + return [ |
| 183 | + "\(access)typealias RawValue = \(rawType)", |
| 184 | + "\(access)var rawValue: RawValue", |
| 185 | + "\(access)init() { self.rawValue = 0 }", |
| 186 | + "\(access)init(rawValue: RawValue) { self.rawValue = rawValue }", |
| 187 | + ] + staticVars |
| 188 | + } |
| 189 | +} |
| 190 | + |
| 191 | +extension DeclModifierSyntax { |
| 192 | + var isNeededAccessLevelModifier: Bool { |
| 193 | + switch self.name.tokenKind { |
| 194 | + case .keyword(.public): return true |
| 195 | + default: return false |
| 196 | + } |
| 197 | + } |
| 198 | +} |
| 199 | + |
| 200 | +extension SyntaxStringInterpolation { |
| 201 | + // It would be nice for SwiftSyntaxBuilder to provide this out-of-the-box. |
| 202 | + mutating func appendInterpolation<Node: SyntaxProtocol>(_ node: Node?) { |
| 203 | + if let node = node { |
| 204 | + appendInterpolation(node) |
| 205 | + } |
| 206 | + } |
| 207 | +} |
0 commit comments