@@ -40,11 +40,11 @@ protocol ParamInfo: CustomStringConvertible {
4040 var dependencies : [ LifetimeDependence ] { get set }
4141
4242 func getBoundsCheckedThunkBuilder(
43- _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionDeclSyntax
43+ _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionParts
4444 ) -> BoundsCheckedThunkBuilder
4545}
4646
47- func tryGetParamName( _ funcDecl: FunctionDeclSyntax , _ expr: SwiftifyExpr ) -> TokenSyntax ? {
47+ func tryGetParamName( _ funcDecl: FunctionParts , _ expr: SwiftifyExpr ) -> TokenSyntax ? {
4848 switch expr {
4949 case . param( let i) :
5050 let funcParam = getParam ( funcDecl, i - 1 )
@@ -55,7 +55,7 @@ func tryGetParamName(_ funcDecl: FunctionDeclSyntax, _ expr: SwiftifyExpr) -> To
5555 }
5656}
5757
58- func getSwiftifyExprType( _ funcDecl: FunctionDeclSyntax , _ expr: SwiftifyExpr ) -> TypeSyntax {
58+ func getSwiftifyExprType( _ funcDecl: FunctionParts , _ expr: SwiftifyExpr ) -> TypeSyntax {
5959 switch expr {
6060 case . param( let i) :
6161 let funcParam = getParam ( funcDecl, i - 1 )
@@ -79,7 +79,7 @@ struct CxxSpan: ParamInfo {
7979 }
8080
8181 func getBoundsCheckedThunkBuilder(
82- _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionDeclSyntax
82+ _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionParts
8383 ) -> BoundsCheckedThunkBuilder {
8484 switch pointerIndex {
8585 case . param( let i) :
@@ -115,7 +115,7 @@ struct CountedBy: ParamInfo {
115115 }
116116
117117 func getBoundsCheckedThunkBuilder(
118- _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionDeclSyntax
118+ _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionParts
119119 ) -> BoundsCheckedThunkBuilder {
120120 switch pointerIndex {
121121 case . param( let i) :
@@ -424,14 +424,14 @@ func getParam(_ signature: FunctionSignatureSyntax, _ paramIndex: Int) -> Functi
424424 }
425425}
426426
427- func getParam( _ funcDecl: FunctionDeclSyntax , _ paramIndex: Int ) -> FunctionParameterSyntax {
427+ func getParam( _ funcDecl: FunctionParts , _ paramIndex: Int ) -> FunctionParameterSyntax {
428428 return getParam ( funcDecl. signature, paramIndex)
429429}
430430
431431struct FunctionCallBuilder : BoundsCheckedThunkBuilder {
432- let base : FunctionDeclSyntax
432+ let base : FunctionParts
433433
434- init ( _ function: FunctionDeclSyntax ) {
434+ init ( _ function: FunctionParts ) {
435435 base = function
436436 }
437437
@@ -491,14 +491,18 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
491491 FunctionCallExprSyntax (
492492 calledExpression: functionRef, leftParen: . leftParenToken( ) ,
493493 arguments: LabeledExprListSyntax ( labeledArgs) , rightParen: . rightParenToken( ) ) )
494- return " unsafe \( call) "
494+ if base. name. tokenKind == . keyword( . `init`) {
495+ return " unsafe self. \( call) "
496+ } else {
497+ return " unsafe \( call) "
498+ }
495499 }
496500}
497501
498502struct CxxSpanThunkBuilder : SpanBoundsThunkBuilder , ParamBoundsThunkBuilder {
499503 public let base : BoundsCheckedThunkBuilder
500504 public let index : Int
501- public let funcDecl : FunctionDeclSyntax
505+ public let funcDecl : FunctionParts
502506 public let typeMappings : [ String : String ]
503507 public let node : SyntaxProtocol
504508 public let nonescaping : Bool
@@ -549,7 +553,7 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
549553
550554struct CxxSpanReturnThunkBuilder : SpanBoundsThunkBuilder {
551555 public let base : BoundsCheckedThunkBuilder
552- public let funcDecl : FunctionDeclSyntax
556+ public let funcDecl : FunctionParts
553557 public let typeMappings : [ String : String ]
554558 public let node : SyntaxProtocol
555559 let isParameter : Bool = false
@@ -588,7 +592,7 @@ struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder {
588592protocol BoundsThunkBuilder : BoundsCheckedThunkBuilder {
589593 var oldType : TypeSyntax { get }
590594 var newType : TypeSyntax { get throws }
591- var funcDecl : FunctionDeclSyntax { get }
595+ var funcDecl : FunctionParts { get }
592596}
593597
594598extension BoundsThunkBuilder {
@@ -700,7 +704,7 @@ extension ParamBoundsThunkBuilder {
700704struct CountedOrSizedReturnPointerThunkBuilder : PointerBoundsThunkBuilder {
701705 public let base : BoundsCheckedThunkBuilder
702706 public let countExpr : ExprSyntax
703- public let funcDecl : FunctionDeclSyntax
707+ public let funcDecl : FunctionParts
704708 public let nonescaping : Bool
705709 public let isSizedBy : Bool
706710 public let dependencies : [ LifetimeDependence ]
@@ -768,7 +772,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
768772 public let base : BoundsCheckedThunkBuilder
769773 public let index : Int
770774 public let countExpr : ExprSyntax
771- public let funcDecl : FunctionDeclSyntax
775+ public let funcDecl : FunctionParts
772776 public let nonescaping : Bool
773777 public let isSizedBy : Bool
774778 let isParameter : Bool = true
@@ -1267,22 +1271,22 @@ func parseMacroParam(
12671271 }
12681272}
12691273
1270- func checkArgs( _ args: [ ParamInfo ] , _ funcDecl : FunctionDeclSyntax ) throws {
1274+ func checkArgs( _ args: [ ParamInfo ] , _ funcComponents : FunctionParts ) throws {
12711275 var argByIndex : [ Int : ParamInfo ] = [ : ]
12721276 var ret : ParamInfo ? = nil
1273- let paramCount = funcDecl . signature. parameterClause. parameters. count
1277+ let paramCount = funcComponents . signature. parameterClause. parameters. count
12741278 try args. forEach { pointerInfo in
12751279 switch pointerInfo. pointerIndex {
12761280 case . param( let i) :
12771281 if i < 1 || i > paramCount {
12781282 let noteMessage =
12791283 paramCount > 0
1280- ? " function \( funcDecl . name) has parameter indices 1.. \( paramCount) "
1281- : " function \( funcDecl . name) has no parameters "
1284+ ? " function \( funcComponents . name) has parameter indices 1.. \( paramCount) "
1285+ : " function \( funcComponents . name) has no parameters "
12821286 throw DiagnosticError (
12831287 " pointer index out of bounds " , node: pointerInfo. original,
12841288 notes: [
1285- Note ( node: Syntax ( funcDecl . name) , message: MacroExpansionNoteMessage ( noteMessage) )
1289+ Note ( node: Syntax ( funcComponents . name) , message: MacroExpansionNoteMessage ( noteMessage) )
12861290 ] )
12871291 }
12881292 if argByIndex [ i] != nil {
@@ -1346,7 +1350,7 @@ func isInout(_ type: TypeSyntax) -> Bool {
13461350}
13471351
13481352func getReturnLifetimeAttribute(
1349- _ funcDecl: FunctionDeclSyntax ,
1353+ _ funcDecl: FunctionParts ,
13501354 _ dependencies: [ SwiftifyExpr : [ LifetimeDependence ] ]
13511355) -> [ AttributeListSyntax . Element ] {
13521356 let returnDependencies = dependencies [ . `return`, default: [ ] ]
@@ -1503,9 +1507,9 @@ class CountExprRewriter: SyntaxRewriter {
15031507 }
15041508}
15051509
1506- func renameParameterNamesIfNeeded( _ funcDecl : FunctionDeclSyntax ) -> ( FunctionDeclSyntax , CountExprRewriter ) {
1507- let params = funcDecl . signature. parameterClause. parameters
1508- let funcName = funcDecl . name. withoutBackticks. trimmed. text
1510+ func renameParameterNamesIfNeeded( _ funcComponents : FunctionParts ) -> ( FunctionParts , CountExprRewriter ) {
1511+ let params = funcComponents . signature. parameterClause. parameters
1512+ let funcName = funcComponents . name. withoutBackticks. trimmed. text
15091513 let shouldRename = params. contains ( where: { param in
15101514 let paramName = param. name. trimmed. text
15111515 return paramName == " _ " || paramName == funcName || " ` \( paramName) ` " == funcName
@@ -1529,13 +1533,32 @@ func renameParameterNamesIfNeeded(_ funcDecl: FunctionDeclSyntax) -> (FunctionDe
15291533 }
15301534 return newParam
15311535 }
1532- let newDecl = if renamedParams. count > 0 {
1533- funcDecl . with ( \. signature . parameterClause. parameters, FunctionParameterListSyntax ( newParams) )
1536+ let newSig = if renamedParams. count > 0 {
1537+ funcComponents . signature . with ( \. parameterClause. parameters, FunctionParameterListSyntax ( newParams) )
15341538 } else {
15351539 // Keeps source locations for diagnostics, in the common case where nothing was renamed
1536- funcDecl
1540+ funcComponents. signature
1541+ }
1542+ return ( FunctionParts ( signature: newSig, name: funcComponents. name, attributes: funcComponents. attributes) ,
1543+ CountExprRewriter ( renamedParams) )
1544+ }
1545+
1546+ struct FunctionParts {
1547+ let signature : FunctionSignatureSyntax
1548+ let name : TokenSyntax
1549+ let attributes : AttributeListSyntax
1550+ }
1551+
1552+ func deconstructFunction( _ declaration: some DeclSyntaxProtocol ) throws -> FunctionParts {
1553+ if let origFuncDecl = declaration. as ( FunctionDeclSyntax . self) {
1554+ return FunctionParts ( signature: origFuncDecl. signature, name: origFuncDecl. name,
1555+ attributes: origFuncDecl. attributes)
1556+ }
1557+ if let origInitDecl = declaration. as ( InitializerDeclSyntax . self) {
1558+ return FunctionParts ( signature: origInitDecl. signature, name: origInitDecl. initKeyword,
1559+ attributes: origInitDecl. attributes)
15371560 }
1538- return ( newDecl , CountExprRewriter ( renamedParams ) )
1561+ throw DiagnosticError ( " @_SwiftifyImport only works on functions and initializers " , node : declaration )
15391562}
15401563
15411564/// A macro that adds safe(r) wrappers for functions with unsafe pointer types.
@@ -1551,10 +1574,8 @@ public struct SwiftifyImportMacro: PeerMacro {
15511574 in context: some MacroExpansionContext
15521575 ) throws -> [ DeclSyntax ] {
15531576 do {
1554- guard let origFuncDecl = declaration. as ( FunctionDeclSyntax . self) else {
1555- throw DiagnosticError ( " @_SwiftifyImport only works on functions " , node: declaration)
1556- }
1557- let ( funcDecl, rewriter) = renameParameterNamesIfNeeded ( origFuncDecl)
1577+ let origFuncComponents = try deconstructFunction ( declaration)
1578+ let ( funcComponents, rewriter) = renameParameterNamesIfNeeded ( origFuncComponents)
15581579
15591580 let argumentList = node. arguments!. as ( LabeledExprListSyntax . self) !
15601581 var arguments = [ LabeledExprSyntax] ( argumentList)
@@ -1570,10 +1591,10 @@ public struct SwiftifyImportMacro: PeerMacro {
15701591 var lifetimeDependencies : [ SwiftifyExpr : [ LifetimeDependence ] ] = [ : ]
15711592 var parsedArgs = try arguments. compactMap {
15721593 try parseMacroParam (
1573- $0, funcDecl . signature, rewriter, nonescapingPointers: & nonescapingPointers,
1594+ $0, funcComponents . signature, rewriter, nonescapingPointers: & nonescapingPointers,
15741595 lifetimeDependencies: & lifetimeDependencies)
15751596 }
1576- parsedArgs. append ( contentsOf: try parseCxxSpansInSignature ( funcDecl . signature, typeMappings) )
1597+ parsedArgs. append ( contentsOf: try parseCxxSpansInSignature ( funcComponents . signature, typeMappings) )
15771598 setNonescapingPointers ( & parsedArgs, nonescapingPointers)
15781599 setLifetimeDependencies ( & parsedArgs, lifetimeDependencies)
15791600 // We only transform non-escaping spans.
@@ -1584,7 +1605,7 @@ public struct SwiftifyImportMacro: PeerMacro {
15841605 return true
15851606 }
15861607 }
1587- try checkArgs ( parsedArgs, funcDecl )
1608+ try checkArgs ( parsedArgs, funcComponents )
15881609 parsedArgs. sort { a, b in
15891610 // make sure return value cast to Span happens last so that withUnsafeBufferPointer
15901611 // doesn't return a ~Escapable type
@@ -1596,12 +1617,12 @@ public struct SwiftifyImportMacro: PeerMacro {
15961617 }
15971618 return paramOrReturnIndex ( a. pointerIndex) < paramOrReturnIndex ( b. pointerIndex)
15981619 }
1599- let baseBuilder = FunctionCallBuilder ( funcDecl )
1620+ let baseBuilder = FunctionCallBuilder ( funcComponents )
16001621
16011622 let builder : BoundsCheckedThunkBuilder = parsedArgs. reduce (
16021623 baseBuilder,
16031624 { ( prev, parsedArg) in
1604- parsedArg. getBoundsCheckedThunkBuilder ( prev, funcDecl )
1625+ parsedArg. getBoundsCheckedThunkBuilder ( prev, funcComponents )
16051626 } )
16061627 let newSignature = try builder. buildFunctionSignature ( [ : ] , nil )
16071628 var eliminatedArgs = Set < Int > ( )
@@ -1610,15 +1631,22 @@ public struct SwiftifyImportMacro: PeerMacro {
16101631 let checks = ( basicChecks + compoundChecks) . map { e in
16111632 CodeBlockItemSyntax ( leadingTrivia: " \n " , item: e)
16121633 }
1613- let call = CodeBlockItemSyntax (
1614- item: CodeBlockItemSyntax . Item (
1615- ReturnStmtSyntax (
1616- returnKeyword: . keyword( . return, trailingTrivia: " " ) ,
1617- expression: try builder. buildFunctionCall ( [ : ] ) ) ) )
1634+ var call : CodeBlockItemSyntax
1635+ if declaration. is ( InitializerDeclSyntax . self) {
1636+ call = CodeBlockItemSyntax (
1637+ item: CodeBlockItemSyntax . Item (
1638+ try builder. buildFunctionCall ( [ : ] ) ) )
1639+ } else {
1640+ call = CodeBlockItemSyntax (
1641+ item: CodeBlockItemSyntax . Item (
1642+ ReturnStmtSyntax (
1643+ returnKeyword: . keyword( . return, trailingTrivia: " " ) ,
1644+ expression: try builder. buildFunctionCall ( [ : ] ) ) ) )
1645+ }
16181646 let body = CodeBlockSyntax ( statements: CodeBlockItemListSyntax ( checks + [ call] ) )
1619- let returnLifetimeAttribute = getReturnLifetimeAttribute ( funcDecl , lifetimeDependencies)
1647+ let returnLifetimeAttribute = getReturnLifetimeAttribute ( funcComponents , lifetimeDependencies)
16201648 let lifetimeAttrs =
1621- returnLifetimeAttribute + paramLifetimeAttributes( newSignature, funcDecl . attributes)
1649+ returnLifetimeAttribute + paramLifetimeAttributes( newSignature, funcComponents . attributes)
16221650 let availabilityAttr = try getAvailability ( newSignature, spanAvailability)
16231651 let disfavoredOverload : [ AttributeListSyntax . Element ] =
16241652 [
@@ -1627,13 +1655,7 @@ public struct SwiftifyImportMacro: PeerMacro {
16271655 atSign: . atSignToken( ) ,
16281656 attributeName: IdentifierTypeSyntax ( name: " _disfavoredOverload " ) ) )
16291657 ]
1630- let newFunc =
1631- funcDecl
1632- . with ( \. signature, newSignature)
1633- . with ( \. body, body)
1634- . with (
1635- \. attributes,
1636- funcDecl. attributes. filter { e in
1658+ let attributes = funcComponents. attributes. filter { e in
16371659 switch e {
16381660 case . attribute( let attr) :
16391661 // don't apply this macro recursively, and avoid dupe _alwaysEmitIntoClient
@@ -1649,9 +1671,23 @@ public struct SwiftifyImportMacro: PeerMacro {
16491671 ]
16501672 + availabilityAttr
16511673 + lifetimeAttrs
1652- + disfavoredOverload)
1653- . with ( \. leadingTrivia, node. leadingTrivia + . docLineComment( " /// This is an auto-generated wrapper for safer interop \n " ) )
1654- return [ DeclSyntax ( newFunc) ]
1674+ + disfavoredOverload
1675+ let trivia = node. leadingTrivia + . docLineComment( " /// This is an auto-generated wrapper for safer interop \n " )
1676+ if let origFuncDecl = declaration. as ( FunctionDeclSyntax . self) {
1677+ return [ DeclSyntax ( origFuncDecl
1678+ . with ( \. signature, newSignature)
1679+ . with ( \. body, body)
1680+ . with ( \. attributes, AttributeListSyntax ( attributes) )
1681+ . with ( \. leadingTrivia, trivia) ) ]
1682+ }
1683+ if let origInitDecl = declaration. as ( InitializerDeclSyntax . self) {
1684+ return [ DeclSyntax ( origInitDecl
1685+ . with ( \. signature, newSignature)
1686+ . with ( \. body, body)
1687+ . with ( \. attributes, AttributeListSyntax ( attributes) )
1688+ . with ( \. leadingTrivia, trivia) ) ]
1689+ }
1690+ return [ ]
16551691 } catch let error as DiagnosticError {
16561692 context. diagnose (
16571693 Diagnostic (
@@ -1716,6 +1752,9 @@ extension FunctionParameterSyntax {
17161752
17171753extension TokenSyntax {
17181754 public var withoutBackticks : TokenSyntax {
1755+ if self . identifier == nil {
1756+ return self
1757+ }
17191758 return . identifier( self . identifier!. name)
17201759 }
17211760
0 commit comments