@@ -329,10 +329,7 @@ func transformType(
329
329
let text = name. text
330
330
let isRaw = isRawPointerType ( text: text)
331
331
if isRaw && !isSizedBy {
332
- throw DiagnosticError ( " raw pointers only supported for SizedBy " , node: name)
333
- }
334
- if !isRaw && isSizedBy {
335
- throw DiagnosticError ( " SizedBy only supported for raw pointers " , node: name)
332
+ throw DiagnosticError ( " void pointers not supported for countedBy " , node: name)
336
333
}
337
334
338
335
guard let kind: Mutability = getPointerMutability ( text: text) else {
@@ -375,6 +372,33 @@ func isMutablePointerType(_ type: TypeSyntax) -> Bool {
375
372
}
376
373
}
377
374
375
+ func getPointeeType( _ type: TypeSyntax ) -> TypeSyntax ? {
376
+ if let optType = type. as ( OptionalTypeSyntax . self) {
377
+ return getPointeeType ( optType. wrappedType)
378
+ }
379
+ if let impOptType = type. as ( ImplicitlyUnwrappedOptionalTypeSyntax . self) {
380
+ return getPointeeType ( impOptType. wrappedType)
381
+ }
382
+ if let attrType = type. as ( AttributedTypeSyntax . self) {
383
+ return getPointeeType ( attrType. baseType)
384
+ }
385
+
386
+ guard let idType = type. as ( IdentifierTypeSyntax . self) else {
387
+ return nil
388
+ }
389
+ let text = idType. name. text
390
+ if text != " UnsafePointer " && text != " UnsafeMutablePointer " {
391
+ return nil
392
+ }
393
+ guard let x = idType. genericArgumentClause else {
394
+ return nil
395
+ }
396
+ guard let y = x. arguments. first else {
397
+ return nil
398
+ }
399
+ return y. argument. as ( TypeSyntax . self)
400
+ }
401
+
378
402
protocol BoundsCheckedThunkBuilder {
379
403
func buildFunctionCall( _ pointerArgs: [ Int : ExprSyntax ] ) throws -> ExprSyntax
380
404
// buildBasicBoundsChecks creates a variable with the same name as the parameter it replaced,
@@ -652,6 +676,7 @@ extension PointerBoundsThunkBuilder {
652
676
return try transformType ( oldType, generateSpan, isSizedBy, isParameter)
653
677
}
654
678
}
679
+
655
680
var countLabel : String {
656
681
return isSizedBy && generateSpan ? " byteCount " : " count "
657
682
}
@@ -830,7 +855,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
830
855
var args = argOverrides
831
856
let argExpr = ExprSyntax ( " \( unwrappedName) .baseAddress " )
832
857
assert ( args [ index] == nil )
833
- args [ index] = try castPointerToOpaquePointer ( unwrapIfNonnullable ( argExpr) )
858
+ args [ index] = try castPointerToTargetType ( unwrapIfNonnullable ( argExpr) )
834
859
let call = try base. buildFunctionCall ( args)
835
860
let ptrRef = unwrapIfNullable ( " \( name) " )
836
861
@@ -875,11 +900,16 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
875
900
return type
876
901
}
877
902
878
- func castPointerToOpaquePointer ( _ baseAddress: ExprSyntax ) throws -> ExprSyntax {
903
+ func castPointerToTargetType ( _ baseAddress: ExprSyntax ) throws -> ExprSyntax {
879
904
let type = peelOptionalType ( getParam ( signature, index) . type)
880
905
if type. canRepresentBasicType ( type: OpaquePointer . self) {
881
906
return ExprSyntax ( " OpaquePointer( \( baseAddress) ) " )
882
907
}
908
+ if isSizedBy {
909
+ if let pointeeType = getPointeeType ( type) {
910
+ return " \( baseAddress) .assumingMemoryBound(to: \( pointeeType) .self) "
911
+ }
912
+ }
883
913
return baseAddress
884
914
}
885
915
@@ -911,7 +941,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
911
941
return unwrappedCall
912
942
}
913
943
914
- args [ index] = try castPointerToOpaquePointer ( getPointerArg ( ) )
944
+ args [ index] = try castPointerToTargetType ( getPointerArg ( ) )
915
945
return try base. buildFunctionCall ( args)
916
946
}
917
947
}
0 commit comments