@@ -151,11 +151,22 @@ func getPointerMutability(text: String) -> Mutability? {
151
151
case " UnsafeMutablePointer " : return . Mutable
152
152
case " UnsafeRawPointer " : return . Immutable
153
153
case " UnsafeMutableRawPointer " : return . Mutable
154
+ case " OpaquePointer " : return . Immutable
154
155
default :
155
156
return nil
156
157
}
157
158
}
158
159
160
+ func isRawPointerType( text: String ) -> Bool {
161
+ switch text {
162
+ case " UnsafeRawPointer " : return true
163
+ case " UnsafeMutableRawPointer " : return true
164
+ case " OpaquePointer " : return true
165
+ default :
166
+ return false
167
+ }
168
+ }
169
+
159
170
func getSafePointerName( mut: Mutability , generateSpan: Bool , isRaw: Bool ) -> TokenSyntax {
160
171
switch ( mut, generateSpan, isRaw) {
161
172
case ( . Immutable, true , true ) : return " RawSpan "
@@ -180,9 +191,13 @@ func transformType(_ prev: TypeSyntax, _ variant: Variant, _ isSizedBy: Bool) th
180
191
}
181
192
let name = try getTypeName ( prev)
182
193
let text = name. text
183
- if !isSizedBy && ( text == " UnsafeRawPointer " || text == " UnsafeMutableRawPointer " ) {
194
+ let isRaw = isRawPointerType ( text: text)
195
+ if isRaw && !isSizedBy {
184
196
throw DiagnosticError ( " raw pointers only supported for SizedBy " , node: name)
185
197
}
198
+ if !isRaw && isSizedBy {
199
+ throw DiagnosticError ( " SizedBy only supported for raw pointers " , node: name)
200
+ }
186
201
187
202
guard let kind: Mutability = getPointerMutability ( text: text) else {
188
203
throw DiagnosticError (
@@ -390,7 +405,7 @@ struct CountedOrSizedPointerThunkBuilder: PointerBoundsThunkBuilder {
390
405
var args = argOverrides
391
406
let argExpr = ExprSyntax ( " \( unwrappedName) .baseAddress " )
392
407
assert ( args [ index] == nil )
393
- args [ index] = unwrapIfNonnullable ( argExpr)
408
+ args [ index] = try castPointerToOpaquePointer ( unwrapIfNonnullable ( argExpr) )
394
409
let call = try base. buildFunctionCall ( args, variant)
395
410
let ptrRef = unwrapIfNullable ( ExprSyntax ( DeclReferenceExprSyntax ( baseName: name) ) )
396
411
@@ -412,7 +427,26 @@ struct CountedOrSizedPointerThunkBuilder: PointerBoundsThunkBuilder {
412
427
return ExprSyntax ( " \( name) . \( raw: countName) " )
413
428
}
414
429
415
- func getPointerArg( ) -> ExprSyntax {
430
+ func peelOptionalType( _ type: TypeSyntax ) -> TypeSyntax {
431
+ if let optType = type. as ( OptionalTypeSyntax . self) {
432
+ return optType. wrappedType
433
+ }
434
+ if let impOptType = type. as ( ImplicitlyUnwrappedOptionalTypeSyntax . self) {
435
+ return impOptType. wrappedType
436
+ }
437
+ return type
438
+ }
439
+
440
+ func castPointerToOpaquePointer( _ baseAddress: ExprSyntax ) throws -> ExprSyntax {
441
+ let i = try getParameterIndexForParamName ( signature. parameterClause. parameters, name)
442
+ let type = peelOptionalType ( getParam ( signature, i) . type)
443
+ if type. canRepresentBasicType ( type: OpaquePointer . self) {
444
+ return ExprSyntax ( " OpaquePointer( \( baseAddress) ) " )
445
+ }
446
+ return baseAddress
447
+ }
448
+
449
+ func getPointerArg( ) throws -> ExprSyntax {
416
450
if nullable {
417
451
return ExprSyntax ( " \( name) ?.baseAddress " )
418
452
}
@@ -450,7 +484,7 @@ struct CountedOrSizedPointerThunkBuilder: PointerBoundsThunkBuilder {
450
484
return unwrappedCall
451
485
}
452
486
453
- args [ index] = getPointerArg ( )
487
+ args [ index] = try castPointerToOpaquePointer ( getPointerArg ( ) )
454
488
return try base. buildFunctionCall ( args, variant)
455
489
}
456
490
}
@@ -499,22 +533,28 @@ func getOptionalArgumentByName(_ argumentList: LabeledExprListSyntax, _ name: St
499
533
} ) ? . expression
500
534
}
501
535
502
- func getParameterIndexForDeclRef (
503
- _ parameterList: FunctionParameterListSyntax , _ ref : DeclReferenceExprSyntax
536
+ func getParameterIndexForParamName (
537
+ _ parameterList: FunctionParameterListSyntax , _ tok : TokenSyntax
504
538
) throws -> Int {
505
- let name = ref . baseName . text
539
+ let name = tok . text
506
540
guard
507
541
let index = parameterList. enumerated ( ) . first ( where: {
508
542
( _: Int , param: FunctionParameterSyntax ) in
509
543
let paramenterName = param. secondName ?? param. firstName
510
544
return paramenterName. trimmed. text == name
511
545
} ) ? . offset
512
546
else {
513
- throw DiagnosticError ( " no parameter with name ' \( name) ' in ' \( parameterList) ' " , node: ref )
547
+ throw DiagnosticError ( " no parameter with name ' \( name) ' in ' \( parameterList) ' " , node: tok )
514
548
}
515
549
return index
516
550
}
517
551
552
+ func getParameterIndexForDeclRef(
553
+ _ parameterList: FunctionParameterListSyntax , _ ref: DeclReferenceExprSyntax
554
+ ) throws -> Int {
555
+ return try getParameterIndexForParamName ( ( parameterList) , ref. baseName)
556
+ }
557
+
518
558
/// A macro that adds safe(r) wrappers for functions with unsafe pointer types.
519
559
/// Depends on bounds, escapability and lifetime information for each pointer.
520
560
/// Intended to map to C attributes like __counted_by, __ended_by and __no_escape,
0 commit comments