@@ -40,11 +40,11 @@ protocol ParamInfo: CustomStringConvertible {
4040 var dependencies : [ LifetimeDependence ] { get set }
4141
4242 func getBoundsCheckedThunkBuilder(
43- _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionDeclSyntax
43+ _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionParts
4444 ) -> BoundsCheckedThunkBuilder
4545}
4646
47- func tryGetParamName( _ funcDecl: FunctionDeclSyntax , _ expr: SwiftifyExpr ) -> TokenSyntax ? {
47+ func tryGetParamName( _ funcDecl: FunctionParts , _ expr: SwiftifyExpr ) -> TokenSyntax ? {
4848 switch expr {
4949 case . param( let i) :
5050 let funcParam = getParam ( funcDecl, i - 1 )
@@ -55,7 +55,7 @@ func tryGetParamName(_ funcDecl: FunctionDeclSyntax, _ expr: SwiftifyExpr) -> To
5555 }
5656}
5757
58- func getSwiftifyExprType( _ funcDecl: FunctionDeclSyntax , _ expr: SwiftifyExpr ) -> TypeSyntax {
58+ func getSwiftifyExprType( _ funcDecl: FunctionParts , _ expr: SwiftifyExpr ) -> TypeSyntax {
5959 switch expr {
6060 case . param( let i) :
6161 let funcParam = getParam ( funcDecl, i - 1 )
@@ -79,7 +79,7 @@ struct CxxSpan: ParamInfo {
7979 }
8080
8181 func getBoundsCheckedThunkBuilder(
82- _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionDeclSyntax
82+ _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionParts
8383 ) -> BoundsCheckedThunkBuilder {
8484 switch pointerIndex {
8585 case . param( let i) :
@@ -115,7 +115,7 @@ struct CountedBy: ParamInfo {
115115 }
116116
117117 func getBoundsCheckedThunkBuilder(
118- _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionDeclSyntax
118+ _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionParts
119119 ) -> BoundsCheckedThunkBuilder {
120120 switch pointerIndex {
121121 case . param( let i) :
@@ -400,14 +400,14 @@ func getParam(_ signature: FunctionSignatureSyntax, _ paramIndex: Int) -> Functi
400400 }
401401}
402402
403- func getParam( _ funcDecl: FunctionDeclSyntax , _ paramIndex: Int ) -> FunctionParameterSyntax {
403+ func getParam( _ funcDecl: FunctionParts , _ paramIndex: Int ) -> FunctionParameterSyntax {
404404 return getParam ( funcDecl. signature, paramIndex)
405405}
406406
407407struct FunctionCallBuilder : BoundsCheckedThunkBuilder {
408- let base : FunctionDeclSyntax
408+ let base : FunctionParts
409409
410- init ( _ function: FunctionDeclSyntax ) {
410+ init ( _ function: FunctionParts ) {
411411 base = function
412412 }
413413
@@ -467,14 +467,18 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
467467 FunctionCallExprSyntax (
468468 calledExpression: functionRef, leftParen: . leftParenToken( ) ,
469469 arguments: LabeledExprListSyntax ( labeledArgs) , rightParen: . rightParenToken( ) ) )
470- return " unsafe \( call) "
470+ if base. name. tokenKind == . keyword( . `init`) {
471+ return " unsafe self. \( call) "
472+ } else {
473+ return " unsafe \( call) "
474+ }
471475 }
472476}
473477
474478struct CxxSpanThunkBuilder : SpanBoundsThunkBuilder , ParamBoundsThunkBuilder {
475479 public let base : BoundsCheckedThunkBuilder
476480 public let index : Int
477- public let funcDecl : FunctionDeclSyntax
481+ public let funcDecl : FunctionParts
478482 public let typeMappings : [ String : String ]
479483 public let node : SyntaxProtocol
480484 public let nonescaping : Bool
@@ -525,7 +529,7 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
525529
526530struct CxxSpanReturnThunkBuilder : SpanBoundsThunkBuilder {
527531 public let base : BoundsCheckedThunkBuilder
528- public let funcDecl : FunctionDeclSyntax
532+ public let funcDecl : FunctionParts
529533 public let typeMappings : [ String : String ]
530534 public let node : SyntaxProtocol
531535 let isParameter : Bool = false
@@ -564,7 +568,7 @@ struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder {
564568protocol BoundsThunkBuilder : BoundsCheckedThunkBuilder {
565569 var oldType : TypeSyntax { get }
566570 var newType : TypeSyntax { get throws }
567- var funcDecl : FunctionDeclSyntax { get }
571+ var funcDecl : FunctionParts { get }
568572}
569573
570574extension BoundsThunkBuilder {
@@ -675,7 +679,7 @@ extension ParamBoundsThunkBuilder {
675679struct CountedOrSizedReturnPointerThunkBuilder : PointerBoundsThunkBuilder {
676680 public let base : BoundsCheckedThunkBuilder
677681 public let countExpr : ExprSyntax
678- public let funcDecl : FunctionDeclSyntax
682+ public let funcDecl : FunctionParts
679683 public let nonescaping : Bool
680684 public let isSizedBy : Bool
681685 public let dependencies : [ LifetimeDependence ]
@@ -743,7 +747,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
743747 public let base : BoundsCheckedThunkBuilder
744748 public let index : Int
745749 public let countExpr : ExprSyntax
746- public let funcDecl : FunctionDeclSyntax
750+ public let funcDecl : FunctionParts
747751 public let nonescaping : Bool
748752 public let isSizedBy : Bool
749753 let isParameter : Bool = true
@@ -1237,22 +1241,22 @@ func parseMacroParam(
12371241 }
12381242}
12391243
1240- func checkArgs( _ args: [ ParamInfo ] , _ funcDecl : FunctionDeclSyntax ) throws {
1244+ func checkArgs( _ args: [ ParamInfo ] , _ funcComponents : FunctionParts ) throws {
12411245 var argByIndex : [ Int : ParamInfo ] = [ : ]
12421246 var ret : ParamInfo ? = nil
1243- let paramCount = funcDecl . signature. parameterClause. parameters. count
1247+ let paramCount = funcComponents . signature. parameterClause. parameters. count
12441248 try args. forEach { pointerInfo in
12451249 switch pointerInfo. pointerIndex {
12461250 case . param( let i) :
12471251 if i < 1 || i > paramCount {
12481252 let noteMessage =
12491253 paramCount > 0
1250- ? " function \( funcDecl . name) has parameter indices 1.. \( paramCount) "
1251- : " function \( funcDecl . name) has no parameters "
1254+ ? " function \( funcComponents . name) has parameter indices 1.. \( paramCount) "
1255+ : " function \( funcComponents . name) has no parameters "
12521256 throw DiagnosticError (
12531257 " pointer index out of bounds " , node: pointerInfo. original,
12541258 notes: [
1255- Note ( node: Syntax ( funcDecl . name) , message: MacroExpansionNoteMessage ( noteMessage) )
1259+ Note ( node: Syntax ( funcComponents . name) , message: MacroExpansionNoteMessage ( noteMessage) )
12561260 ] )
12571261 }
12581262 if argByIndex [ i] != nil {
@@ -1316,7 +1320,7 @@ func isInout(_ type: TypeSyntax) -> Bool {
13161320}
13171321
13181322func getReturnLifetimeAttribute(
1319- _ funcDecl: FunctionDeclSyntax ,
1323+ _ funcDecl: FunctionParts ,
13201324 _ dependencies: [ SwiftifyExpr : [ LifetimeDependence ] ]
13211325) -> [ AttributeListSyntax . Element ] {
13221326 let returnDependencies = dependencies [ . `return`, default: [ ] ]
@@ -1473,9 +1477,9 @@ class CountExprRewriter: SyntaxRewriter {
14731477 }
14741478}
14751479
1476- func renameParameterNamesIfNeeded( _ funcDecl : FunctionDeclSyntax ) -> ( FunctionDeclSyntax , CountExprRewriter ) {
1477- let params = funcDecl . signature. parameterClause. parameters
1478- let funcName = funcDecl . name. withoutBackticks. trimmed. text
1480+ func renameParameterNamesIfNeeded( _ funcComponents : FunctionParts ) -> ( FunctionParts , CountExprRewriter ) {
1481+ let params = funcComponents . signature. parameterClause. parameters
1482+ let funcName = funcComponents . name. withoutBackticks. trimmed. text
14791483 let shouldRename = params. contains ( where: { param in
14801484 let paramName = param. name. trimmed. text
14811485 return paramName == " _ " || paramName == funcName || " ` \( paramName) ` " == funcName
@@ -1499,13 +1503,32 @@ func renameParameterNamesIfNeeded(_ funcDecl: FunctionDeclSyntax) -> (FunctionDe
14991503 }
15001504 return newParam
15011505 }
1502- let newDecl = if renamedParams. count > 0 {
1503- funcDecl . with ( \. signature . parameterClause. parameters, FunctionParameterListSyntax ( newParams) )
1506+ let newSig = if renamedParams. count > 0 {
1507+ funcComponents . signature . with ( \. parameterClause. parameters, FunctionParameterListSyntax ( newParams) )
15041508 } else {
15051509 // Keeps source locations for diagnostics, in the common case where nothing was renamed
1506- funcDecl
1510+ funcComponents. signature
1511+ }
1512+ return ( FunctionParts ( signature: newSig, name: funcComponents. name, attributes: funcComponents. attributes) ,
1513+ CountExprRewriter ( renamedParams) )
1514+ }
1515+
1516+ struct FunctionParts {
1517+ let signature : FunctionSignatureSyntax
1518+ let name : TokenSyntax
1519+ let attributes : AttributeListSyntax
1520+ }
1521+
1522+ func deconstructFunction( _ declaration: some DeclSyntaxProtocol ) throws -> FunctionParts {
1523+ if let origFuncDecl = declaration. as ( FunctionDeclSyntax . self) {
1524+ return FunctionParts ( signature: origFuncDecl. signature, name: origFuncDecl. name,
1525+ attributes: origFuncDecl. attributes)
1526+ }
1527+ if let origInitDecl = declaration. as ( InitializerDeclSyntax . self) {
1528+ return FunctionParts ( signature: origInitDecl. signature, name: origInitDecl. initKeyword,
1529+ attributes: origInitDecl. attributes)
15071530 }
1508- return ( newDecl , CountExprRewriter ( renamedParams ) )
1531+ throw DiagnosticError ( " @_SwiftifyImport only works on functions and initializers " , node : declaration )
15091532}
15101533
15111534/// A macro that adds safe(r) wrappers for functions with unsafe pointer types.
@@ -1521,10 +1544,8 @@ public struct SwiftifyImportMacro: PeerMacro {
15211544 in context: some MacroExpansionContext
15221545 ) throws -> [ DeclSyntax ] {
15231546 do {
1524- guard let origFuncDecl = declaration. as ( FunctionDeclSyntax . self) else {
1525- throw DiagnosticError ( " @_SwiftifyImport only works on functions " , node: declaration)
1526- }
1527- let ( funcDecl, rewriter) = renameParameterNamesIfNeeded ( origFuncDecl)
1547+ let origFuncComponents = try deconstructFunction ( declaration)
1548+ let ( funcComponents, rewriter) = renameParameterNamesIfNeeded ( origFuncComponents)
15281549
15291550 let argumentList = node. arguments!. as ( LabeledExprListSyntax . self) !
15301551 var arguments = [ LabeledExprSyntax] ( argumentList)
@@ -1540,10 +1561,10 @@ public struct SwiftifyImportMacro: PeerMacro {
15401561 var lifetimeDependencies : [ SwiftifyExpr : [ LifetimeDependence ] ] = [ : ]
15411562 var parsedArgs = try arguments. compactMap {
15421563 try parseMacroParam (
1543- $0, funcDecl . signature, rewriter, nonescapingPointers: & nonescapingPointers,
1564+ $0, funcComponents . signature, rewriter, nonescapingPointers: & nonescapingPointers,
15441565 lifetimeDependencies: & lifetimeDependencies)
15451566 }
1546- parsedArgs. append ( contentsOf: try parseCxxSpansInSignature ( funcDecl . signature, typeMappings) )
1567+ parsedArgs. append ( contentsOf: try parseCxxSpansInSignature ( funcComponents . signature, typeMappings) )
15471568 setNonescapingPointers ( & parsedArgs, nonescapingPointers)
15481569 setLifetimeDependencies ( & parsedArgs, lifetimeDependencies)
15491570 // We only transform non-escaping spans.
@@ -1554,7 +1575,7 @@ public struct SwiftifyImportMacro: PeerMacro {
15541575 return true
15551576 }
15561577 }
1557- try checkArgs ( parsedArgs, funcDecl )
1578+ try checkArgs ( parsedArgs, funcComponents )
15581579 parsedArgs. sort { a, b in
15591580 // make sure return value cast to Span happens last so that withUnsafeBufferPointer
15601581 // doesn't return a ~Escapable type
@@ -1566,12 +1587,12 @@ public struct SwiftifyImportMacro: PeerMacro {
15661587 }
15671588 return paramOrReturnIndex ( a. pointerIndex) < paramOrReturnIndex ( b. pointerIndex)
15681589 }
1569- let baseBuilder = FunctionCallBuilder ( funcDecl )
1590+ let baseBuilder = FunctionCallBuilder ( funcComponents )
15701591
15711592 let builder : BoundsCheckedThunkBuilder = parsedArgs. reduce (
15721593 baseBuilder,
15731594 { ( prev, parsedArg) in
1574- parsedArg. getBoundsCheckedThunkBuilder ( prev, funcDecl )
1595+ parsedArg. getBoundsCheckedThunkBuilder ( prev, funcComponents )
15751596 } )
15761597 let newSignature = try builder. buildFunctionSignature ( [ : ] , nil )
15771598 var eliminatedArgs = Set < Int > ( )
@@ -1580,15 +1601,22 @@ public struct SwiftifyImportMacro: PeerMacro {
15801601 let checks = ( basicChecks + compoundChecks) . map { e in
15811602 CodeBlockItemSyntax ( leadingTrivia: " \n " , item: e)
15821603 }
1583- let call = CodeBlockItemSyntax (
1584- item: CodeBlockItemSyntax . Item (
1585- ReturnStmtSyntax (
1586- returnKeyword: . keyword( . return, trailingTrivia: " " ) ,
1587- expression: try builder. buildFunctionCall ( [ : ] ) ) ) )
1604+ var call : CodeBlockItemSyntax
1605+ if declaration. is ( InitializerDeclSyntax . self) {
1606+ call = CodeBlockItemSyntax (
1607+ item: CodeBlockItemSyntax . Item (
1608+ try builder. buildFunctionCall ( [ : ] ) ) )
1609+ } else {
1610+ call = CodeBlockItemSyntax (
1611+ item: CodeBlockItemSyntax . Item (
1612+ ReturnStmtSyntax (
1613+ returnKeyword: . keyword( . return, trailingTrivia: " " ) ,
1614+ expression: try builder. buildFunctionCall ( [ : ] ) ) ) )
1615+ }
15881616 let body = CodeBlockSyntax ( statements: CodeBlockItemListSyntax ( checks + [ call] ) )
1589- let returnLifetimeAttribute = getReturnLifetimeAttribute ( funcDecl , lifetimeDependencies)
1617+ let returnLifetimeAttribute = getReturnLifetimeAttribute ( funcComponents , lifetimeDependencies)
15901618 let lifetimeAttrs =
1591- returnLifetimeAttribute + paramLifetimeAttributes( newSignature, funcDecl . attributes)
1619+ returnLifetimeAttribute + paramLifetimeAttributes( newSignature, funcComponents . attributes)
15921620 let availabilityAttr = try getAvailability ( newSignature, spanAvailability)
15931621 let disfavoredOverload : [ AttributeListSyntax . Element ] =
15941622 [
@@ -1597,13 +1625,7 @@ public struct SwiftifyImportMacro: PeerMacro {
15971625 atSign: . atSignToken( ) ,
15981626 attributeName: IdentifierTypeSyntax ( name: " _disfavoredOverload " ) ) )
15991627 ]
1600- let newFunc =
1601- funcDecl
1602- . with ( \. signature, newSignature)
1603- . with ( \. body, body)
1604- . with (
1605- \. attributes,
1606- funcDecl. attributes. filter { e in
1628+ let attributes = funcComponents. attributes. filter { e in
16071629 switch e {
16081630 case . attribute( let attr) :
16091631 // don't apply this macro recursively, and avoid dupe _alwaysEmitIntoClient
@@ -1619,9 +1641,23 @@ public struct SwiftifyImportMacro: PeerMacro {
16191641 ]
16201642 + availabilityAttr
16211643 + lifetimeAttrs
1622- + disfavoredOverload)
1623- . with ( \. leadingTrivia, node. leadingTrivia + . docLineComment( " /// This is an auto-generated wrapper for safer interop \n " ) )
1624- return [ DeclSyntax ( newFunc) ]
1644+ + disfavoredOverload
1645+ let trivia = node. leadingTrivia + . docLineComment( " /// This is an auto-generated wrapper for safer interop \n " )
1646+ if let origFuncDecl = declaration. as ( FunctionDeclSyntax . self) {
1647+ return [ DeclSyntax ( origFuncDecl
1648+ . with ( \. signature, newSignature)
1649+ . with ( \. body, body)
1650+ . with ( \. attributes, AttributeListSyntax ( attributes) )
1651+ . with ( \. leadingTrivia, trivia) ) ]
1652+ }
1653+ if let origInitDecl = declaration. as ( InitializerDeclSyntax . self) {
1654+ return [ DeclSyntax ( origInitDecl
1655+ . with ( \. signature, newSignature)
1656+ . with ( \. body, body)
1657+ . with ( \. attributes, AttributeListSyntax ( attributes) )
1658+ . with ( \. leadingTrivia, trivia) ) ]
1659+ }
1660+ return [ ]
16251661 } catch let error as DiagnosticError {
16261662 context. diagnose (
16271663 Diagnostic (
@@ -1686,6 +1722,9 @@ extension FunctionParameterSyntax {
16861722
16871723extension TokenSyntax {
16881724 public var withoutBackticks : TokenSyntax {
1725+ if self . identifier == nil {
1726+ return self
1727+ }
16891728 return . identifier( self . identifier!. name)
16901729 }
16911730
0 commit comments