@@ -329,10 +329,7 @@ func transformType(
329329 let text = name. text
330330 let isRaw = isRawPointerType ( text: text)
331331 if isRaw && !isSizedBy {
332- throw DiagnosticError ( " raw pointers only supported for SizedBy " , node: name)
333- }
334- if !isRaw && isSizedBy {
335- throw DiagnosticError ( " SizedBy only supported for raw pointers " , node: name)
332+ throw DiagnosticError ( " void pointers not supported for countedBy " , node: name)
336333 }
337334
338335 guard let kind: Mutability = getPointerMutability ( text: text) else {
@@ -375,6 +372,33 @@ func isMutablePointerType(_ type: TypeSyntax) -> Bool {
375372 }
376373}
377374
375+ func getPointeeType( _ type: TypeSyntax ) -> TypeSyntax ? {
376+ if let optType = type. as ( OptionalTypeSyntax . self) {
377+ return getPointeeType ( optType. wrappedType)
378+ }
379+ if let impOptType = type. as ( ImplicitlyUnwrappedOptionalTypeSyntax . self) {
380+ return getPointeeType ( impOptType. wrappedType)
381+ }
382+ if let attrType = type. as ( AttributedTypeSyntax . self) {
383+ return getPointeeType ( attrType. baseType)
384+ }
385+
386+ guard let idType = type. as ( IdentifierTypeSyntax . self) else {
387+ return nil
388+ }
389+ let text = idType. name. text
390+ if text != " UnsafePointer " && text != " UnsafeMutablePointer " {
391+ return nil
392+ }
393+ guard let x = idType. genericArgumentClause else {
394+ return nil
395+ }
396+ guard let y = x. arguments. first else {
397+ return nil
398+ }
399+ return y. argument. as ( TypeSyntax . self)
400+ }
401+
378402protocol BoundsCheckedThunkBuilder {
379403 func buildFunctionCall( _ pointerArgs: [ Int : ExprSyntax ] ) throws -> ExprSyntax
380404 // buildBasicBoundsChecks creates a variable with the same name as the parameter it replaced,
@@ -648,6 +672,7 @@ extension PointerBoundsThunkBuilder {
648672 return try transformType ( oldType, generateSpan, isSizedBy, isParameter)
649673 }
650674 }
675+
651676 var countLabel : String {
652677 return isSizedBy && generateSpan ? " byteCount " : " count "
653678 }
@@ -826,7 +851,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
826851 var args = argOverrides
827852 let argExpr = ExprSyntax ( " \( unwrappedName) .baseAddress " )
828853 assert ( args [ index] == nil )
829- args [ index] = try castPointerToOpaquePointer ( unwrapIfNonnullable ( argExpr) )
854+ args [ index] = try castPointerToTargetType ( unwrapIfNonnullable ( argExpr) )
830855 let call = try base. buildFunctionCall ( args)
831856 let ptrRef = unwrapIfNullable ( " \( name) " )
832857
@@ -871,11 +896,16 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
871896 return type
872897 }
873898
874- func castPointerToOpaquePointer ( _ baseAddress: ExprSyntax ) throws -> ExprSyntax {
899+ func castPointerToTargetType ( _ baseAddress: ExprSyntax ) throws -> ExprSyntax {
875900 let type = peelOptionalType ( getParam ( signature, index) . type)
876901 if type. canRepresentBasicType ( type: OpaquePointer . self) {
877902 return ExprSyntax ( " OpaquePointer( \( baseAddress) ) " )
878903 }
904+ if isSizedBy {
905+ if let pointeeType = getPointeeType ( type) {
906+ return " \( baseAddress) .assumingMemoryBound(to: \( pointeeType) .self) "
907+ }
908+ }
879909 return baseAddress
880910 }
881911
@@ -907,7 +937,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
907937 return unwrappedCall
908938 }
909939
910- args [ index] = try castPointerToOpaquePointer ( getPointerArg ( ) )
940+ args [ index] = try castPointerToTargetType ( getPointerArg ( ) )
911941 return try base. buildFunctionCall ( args)
912942 }
913943}
0 commit comments