Skip to content

Commit 59ebd61

Browse files
author
Gabor Horvath
committed
[cxx-interop] Better support SIMD types
rdar://153218744
1 parent 016e55b commit 59ebd61

File tree

7 files changed

+144
-25
lines changed

7 files changed

+144
-25
lines changed

lib/IRGen/IRABIDetailsProvider.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "clang/CodeGen/ModuleBuilder.h"
3939
#include "clang/CodeGen/SwiftCallingConv.h"
4040
#include "llvm/IR/DerivedTypes.h"
41+
#include <optional>
4142

4243
using namespace swift;
4344
using namespace irgen;
@@ -59,13 +60,25 @@ getPrimitiveTypeFromLLVMType(ASTContext &ctx, const llvm::Type *type) {
5960
default:
6061
return std::nullopt;
6162
}
62-
} else if (type->isFloatTy()) {
63+
}
64+
if (type->isFloatTy()) {
6365
return ctx.getFloatType();
64-
} else if (type->isDoubleTy()) {
66+
}
67+
if (type->isDoubleTy()) {
6568
return ctx.getDoubleType();
66-
} else if (type->isPointerTy()) {
69+
}
70+
if (type->isPointerTy()) {
6771
return ctx.getOpaquePointerType();
6872
}
73+
if (const auto *vecTy = dyn_cast<llvm::VectorType>(type)) {
74+
auto elemTy = getPrimitiveTypeFromLLVMType(ctx, vecTy->getElementType());
75+
if (!elemTy)
76+
return std::nullopt;
77+
auto elemCount = vecTy->getElementCount();
78+
if (!elemCount.isFixed())
79+
return std::nullopt;
80+
return BuiltinVectorType::get(ctx, *elemTy, elemCount.getFixedValue());
81+
}
6982
// FIXME: Handle vector type.
7083
return std::nullopt;
7184
}

lib/PrintAsClang/PrimitiveTypeMapping.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@
1414
#include "swift/AST/ASTContext.h"
1515
#include "swift/AST/Decl.h"
1616
#include "swift/AST/Module.h"
17+
#include "swift/AST/Type.h"
18+
#include "swift/AST/Types.h"
1719
#include "swift/Basic/Assertions.h"
1820
#include "swift/ClangImporter/ClangImporter.h"
21+
#include <cctype>
22+
#include <optional>
23+
#include <string>
1924

2025
using namespace swift;
2126

@@ -102,11 +107,17 @@ void PrimitiveTypeMapping::initialize(ASTContext &ctx) {
102107
StringRef simd2##BASENAME = "swift_" #BASENAME "2"; \
103108
mappedTypeNames[{ctx.Id_simd, ctx.getIdentifier(#BASENAME "2")}] = { \
104109
simd2##BASENAME, simd2##BASENAME, simd2##BASENAME, false}; \
110+
mappedTypeNames[{ctx.Id_simd, ctx.getIdentifier("simd_" #BASENAME "2")}] = { \
111+
simd2##BASENAME, simd2##BASENAME, simd2##BASENAME, false}; \
105112
StringRef simd3##BASENAME = "swift_" #BASENAME "3"; \
106113
mappedTypeNames[{ctx.Id_simd, ctx.getIdentifier(#BASENAME "3")}] = { \
107114
simd3##BASENAME, simd3##BASENAME, simd3##BASENAME, false}; \
115+
mappedTypeNames[{ctx.Id_simd, ctx.getIdentifier("simd_" #BASENAME "3")}] = { \
116+
simd3##BASENAME, simd3##BASENAME, simd3##BASENAME, false}; \
108117
StringRef simd4##BASENAME = "swift_" #BASENAME "4"; \
109118
mappedTypeNames[{ctx.Id_simd, ctx.getIdentifier(#BASENAME "4")}] = { \
119+
simd4##BASENAME, simd4##BASENAME, simd4##BASENAME, false}; \
120+
mappedTypeNames[{ctx.Id_simd, ctx.getIdentifier("simd_" #BASENAME "4")}] = { \
110121
simd4##BASENAME, simd4##BASENAME, simd4##BASENAME, false};
111122
#include "swift/ClangImporter/SIMDMappedTypes.def"
112123
static_assert(SWIFT_MAX_IMPORTED_SIMD_ELEMENTS == 4,
@@ -151,3 +162,27 @@ PrimitiveTypeMapping::getKnownCxxTypeInfo(const TypeDecl *typeDecl) {
151162
}
152163
return std::nullopt;
153164
}
165+
166+
std::optional<PrimitiveTypeMapping::ClangTypeInfo>
167+
PrimitiveTypeMapping::getKnownSIMDTypeInfo(Type t, ASTContext &ctx) {
168+
auto vecTy = t->getAs<BuiltinVectorType>();
169+
if (!vecTy)
170+
return std::nullopt;
171+
172+
auto elemTy = vecTy->getElementType();
173+
auto numElems = vecTy->getNumElements();
174+
175+
if (mappedTypeNames.empty())
176+
initialize(ctx);
177+
178+
Identifier moduleName = ctx.Id_simd;
179+
std::string elemTyName = elemTy.getString();
180+
// While the element type starts with an upper case, vector types start with
181+
// lower case.
182+
elemTyName[0] = std::tolower(elemTyName[0]);
183+
Identifier name = ctx.getIdentifier(elemTyName + std::to_string(numElems));
184+
auto iter = mappedTypeNames.find({moduleName, name});
185+
if (iter == mappedTypeNames.end())
186+
return std::nullopt;
187+
return ClangTypeInfo{*iter->second.cName, false};
188+
}

lib/PrintAsClang/PrimitiveTypeMapping.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@ namespace swift {
2121

2222
class ASTContext;
2323
class TypeDecl;
24+
class Type;
2425

2526
/// Provides a mapping from Swift's primitive types to C / Objective-C / C++
2627
/// primitive types.
2728
///
2829
/// Certain types have mappings that differ in different language modes.
29-
/// For example, Swift's `Int` maps to `NSInteger` for Objective-C declarations,
30-
/// but to something like `intptr_t` or `swift::Int` for C and C++ declarations.
30+
/// For example, Swift's `Int` maps to `NSInteger` for Objective-C
31+
/// declarations, but to something like `intptr_t` or `swift::Int` for C and
32+
/// C++ declarations.
3133
class PrimitiveTypeMapping {
3234
public:
3335
struct ClangTypeInfo {
@@ -47,6 +49,8 @@ class PrimitiveTypeMapping {
4749
/// primitive type declaration, or \c None if no such type name exists.
4850
std::optional<ClangTypeInfo> getKnownCxxTypeInfo(const TypeDecl *typeDecl);
4951

52+
std::optional<ClangTypeInfo> getKnownSIMDTypeInfo(Type t, ASTContext &ctx);
53+
5054
private:
5155
void initialize(ASTContext &ctx);
5256

lib/PrintAsClang/PrintClangFunction.cpp

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ getKnownTypeInfo(const TypeDecl *typeDecl, PrimitiveTypeMapping &typeMapping,
5353
}
5454

5555
bool isKnownType(Type t, PrimitiveTypeMapping &typeMapping,
56-
OutputLanguageMode languageMode) {
56+
OutputLanguageMode languageMode, ASTContext &ctx) {
5757
if (auto *typeAliasType = dyn_cast<TypeAliasType>(t.getPointer())) {
5858
auto aliasInfo =
5959
getKnownTypeInfo(typeAliasType->getDecl(), typeMapping, languageMode);
6060
if (aliasInfo != std::nullopt)
6161
return true;
6262
return isKnownType(typeAliasType->getSinglyDesugaredType(), typeMapping,
63-
languageMode);
63+
languageMode, ctx);
6464
}
6565

6666
const TypeDecl *typeDecl;
@@ -80,6 +80,8 @@ bool isKnownType(Type t, PrimitiveTypeMapping &typeMapping,
8080
isa<clang::ObjCInterfaceDecl>(
8181
classType->getClassOrBoundGenericClass()->getClangDecl());
8282
}
83+
if (t->getAs<BuiltinVectorType>())
84+
return (bool)typeMapping.getKnownSIMDTypeInfo(t, ctx);
8385

8486
if (auto *structDecl = t->getStructOrBoundGenericStruct())
8587
typeDecl = structDecl;
@@ -88,12 +90,13 @@ bool isKnownType(Type t, PrimitiveTypeMapping &typeMapping,
8890
return getKnownTypeInfo(typeDecl, typeMapping, languageMode) != std::nullopt;
8991
}
9092

91-
bool isKnownCxxType(Type t, PrimitiveTypeMapping &typeMapping) {
92-
return isKnownType(t, typeMapping, OutputLanguageMode::Cxx);
93+
bool isKnownCxxType(Type t, PrimitiveTypeMapping &typeMapping,
94+
ASTContext &ctx) {
95+
return isKnownType(t, typeMapping, OutputLanguageMode::Cxx, ctx);
9396
}
9497

95-
bool isKnownCType(Type t, PrimitiveTypeMapping &typeMapping) {
96-
return isKnownType(t, typeMapping, OutputLanguageMode::ObjC);
98+
bool isKnownCType(Type t, PrimitiveTypeMapping &typeMapping, ASTContext &ctx) {
99+
return isKnownType(t, typeMapping, OutputLanguageMode::ObjC, ctx);
97100
}
98101

99102
struct CFunctionSignatureTypePrinterModifierDelegate {
@@ -463,7 +466,7 @@ class CFunctionSignatureTypePrinter
463466
llvm::SaveAndRestore<FunctionSignatureTypeUse> typeUseNormal(
464467
typeUseKind, FunctionSignatureTypeUse::TypeReference);
465468
// FIXME: We can definitely support pointers to known Clang types.
466-
if (!isKnownCType(args.front(), typeMapping))
469+
if (!isKnownCType(args.front(), typeMapping, BGT->getASTContext()))
467470
return ClangRepresentation(ClangRepresentation::unsupported);
468471
auto partRepr = visitPart(args.front(), OTK_None, /*isInOutParam=*/false);
469472
if (partRepr.isUnsupported())
@@ -572,9 +575,11 @@ DeclAndTypeClangFunctionPrinter::printClangFunctionReturnType(
572575
static void addABIRecordToTypeEncoding(llvm::raw_ostream &typeEncodingOS,
573576
clang::CharUnits offset,
574577
clang::CharUnits end, Type t,
575-
PrimitiveTypeMapping &typeMapping) {
576-
auto info =
577-
typeMapping.getKnownCTypeInfo(t->getNominalOrBoundGenericNominal());
578+
PrimitiveTypeMapping &typeMapping,
579+
ASTContext &ctx) {
580+
auto info = typeMapping.getKnownSIMDTypeInfo(t, ctx);
581+
if (!info)
582+
info = typeMapping.getKnownCTypeInfo(t->getNominalOrBoundGenericNominal());
578583
assert(info);
579584
typeEncodingOS << '_';
580585
for (char c : info->name) {
@@ -605,7 +610,8 @@ static std::string encodeTypeInfo(const T &abiTypeInfo,
605610
ClangSyntaxPrinter(moduleContext->getASTContext(), typeEncodingOS).printBaseName(moduleContext);
606611
abiTypeInfo.enumerateRecordMembers(
607612
[&](clang::CharUnits offset, clang::CharUnits end, Type t) {
608-
addABIRecordToTypeEncoding(typeEncodingOS, offset, end, t, typeMapping);
613+
addABIRecordToTypeEncoding(typeEncodingOS, offset, end, t, typeMapping,
614+
moduleContext->getASTContext());
609615
});
610616
return std::move(typeEncodingOS.str());
611617
}
@@ -652,7 +658,8 @@ static bool printDirectReturnOrParamCType(
652658
clang::CharUnits end, Type t) {
653659
lastOffset = offset;
654660
++Count;
655-
addABIRecordToTypeEncoding(typeEncodingOS, offset, end, t, typeMapping);
661+
addABIRecordToTypeEncoding(typeEncodingOS, offset, end, t, typeMapping,
662+
emittedModule->getASTContext());
656663
}))
657664
return false;
658665
if (isResultType && Count == 0) {
@@ -664,7 +671,7 @@ static bool printDirectReturnOrParamCType(
664671
assert(Count > 0 && "missing return values");
665672

666673
// FIXME: is this "prettyfying" logic sound for multiple return values?
667-
if (isKnownCType(valueType, typeMapping) ||
674+
if (isKnownCType(valueType, typeMapping, emittedModule->getASTContext()) ||
668675
(Count == 1 && lastOffset.isZero() && !valueType->hasTypeParameter() &&
669676
(valueType->isAnyClassReferenceType() ||
670677
isOptionalObjCExistential(valueType) ||
@@ -683,7 +690,10 @@ static bool printDirectReturnOrParamCType(
683690
abiTypeInfo.enumerateRecordMembers([&](clang::CharUnits offset,
684691
clang::CharUnits end, Type t) {
685692
auto info =
686-
typeMapping.getKnownCTypeInfo(t->getNominalOrBoundGenericNominal());
693+
typeMapping.getKnownSIMDTypeInfo(t, emittedModule->getASTContext());
694+
if (!info)
695+
info =
696+
typeMapping.getKnownCTypeInfo(t->getNominalOrBoundGenericNominal());
687697
os << " " << info->name;
688698
if (info->canBeNullable)
689699
os << " _Nullable";
@@ -962,8 +972,8 @@ ClangRepresentation DeclAndTypeClangFunctionPrinter::printFunctionSignature(
962972
->isAnyClassReferenceType());
963973
if (isConst)
964974
functionSignatureOS << "const ";
965-
if (isKnownCType(param.getParamDecl().getInterfaceType(),
966-
typeMapping) ||
975+
if (isKnownCType(param.getParamDecl().getInterfaceType(), typeMapping,
976+
FD->getASTContext()) ||
967977
(!param.getParamDecl().getInterfaceType()->hasTypeParameter() &&
968978
param.getParamDecl()
969979
.getInterfaceType()
@@ -1079,7 +1089,7 @@ void DeclAndTypeClangFunctionPrinter::printCxxToCFunctionParameterUse(
10791089
Type type, StringRef name, const ModuleDecl *moduleContext, bool isInOut,
10801090
bool isIndirect, std::string directTypeEncoding, bool forceSelf) {
10811091
auto namePrinter = [&]() { ClangSyntaxPrinter(moduleContext->getASTContext(), os).printIdentifier(name); };
1082-
if (!isKnownCxxType(type, typeMapping) &&
1092+
if (!isKnownCxxType(type, typeMapping, moduleContext->getASTContext()) &&
10831093
!hasKnownOptionalNullableCxxMapping(type)) {
10841094
if (type->is<GenericTypeParamType>()) {
10851095
os << "swift::" << cxx_synthesis::getCxxImplNamespaceName()
@@ -1463,7 +1473,7 @@ void DeclAndTypeClangFunctionPrinter::printCxxThunkBody(
14631473

14641474
// Values types are returned either direcly in their C representation, or
14651475
// indirectly by a pointer.
1466-
if (!isKnownCxxType(resultTy, typeMapping) &&
1476+
if (!isKnownCxxType(resultTy, typeMapping, moduleContext->getASTContext()) &&
14671477
!hasKnownOptionalNullableCxxMapping(resultTy)) {
14681478
if (const auto *gtpt = resultTy->getAs<GenericTypeParamType>()) {
14691479
printGenericReturnSequence(os, gtpt, printCallToCFunc);

test/Interop/CxxToSwiftToCxx/simd-bridge-cxx-struct-back-to-cxx.swift

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,14 @@ public func passStruct(_ x : Struct) {
4343
}
4444

4545
// CHECK: class SWIFT_SYMBOL("s:8UseCxxTy6StructV") Struct final {
46-
// CHECK-NOT: init(
47-
// CHECK: // Unavailable in C++: Swift global function 'passStruct(_:)'
46+
47+
// CHECK: SWIFT_INLINE_THUNK void passStruct(const Struct& x) noexcept SWIFT_SYMBOL("s:8UseCxxTy10passStructyyAA0E0VF") {
48+
// CHECK-NEXT: UseCxxTy::_impl::$s8UseCxxTy10passStructyyAA0E0VF(UseCxxTy::_impl::swift_interop_passDirect_UseCxxTy_swift_float4_0_16_swift_float4_16_32_swift_float4_32_48_swift_float4_48_64(UseCxxTy::_impl::_impl_Struct::getOpaquePointer(x)));
49+
// CHECK-NEXT:}
50+
51+
// CHECK: SWIFT_INLINE_THUNK Struct Struct::init() {
52+
// CHECK-NEXT: return UseCxxTy::_impl::_impl_Struct::returnNewValue([&](char * _Nonnull result) SWIFT_INLINE_THUNK_ATTRIBUTES {
53+
// CHECK-NEXT: UseCxxTy::_impl::swift_interop_returnDirect_UseCxxTy_swift_float4_0_16_swift_float4_16_32_swift_float4_32_48_swift_float4_48_64(result, UseCxxTy::_impl::$s8UseCxxTy6StructVACycfC());
54+
// CHECK-NEXT: });
55+
// CHECK-NEXT: }
56+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %empty-directory(%t)
2+
3+
// RUN: %target-swift-frontend %S/swift-simd-in-cxx.swift -module-name SIMD -clang-header-expose-decls=all-public -typecheck -verify -emit-clang-header-path %t/simd.h
4+
5+
// RUN: %target-interop-build-clangxx -c %s -I %t -o %t/swift-simd-execution.o
6+
// RUN: %target-interop-build-swift %S/swift-simd-in-cxx.swift -o %t/swift-simd-execution -Xlinker %t/swift-simd-execution.o -module-name SIMD -Xfrontend -entry-point-function-name -Xfrontend swiftMain
7+
8+
// RUN: %target-codesign %t/swift-simd-execution
9+
// RUN: %target-run %t/swift-simd-execution | %FileCheck %s
10+
11+
// REQUIRES: executable_test
12+
13+
#include <assert.h>
14+
#include "simd.h"
15+
16+
int main() {
17+
swift_float3 vec{1.0, 2.0, 3.0};
18+
SIMD::swiftThingSIMD(vec);
19+
// CHECK: SIMD3<Float>(1.0, 2.0, 0.0)
20+
return 0;
21+
}
22+
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: %empty-directory(%t)
2+
3+
// RUN: %target-swift-frontend %s -module-name UseSIMD -cxx-interoperability-mode=default -clang-header-expose-decls=all-public -typecheck -verify -emit-clang-header-path %t/UseSIMD.h
4+
// RUN: %FileCheck %s < %t/UseSIMD.h
5+
6+
// RUN: %target-interop-build-clangxx -std=gnu++20 -fobjc-arc -c -x objective-c++-header %t/UseSIMD.h -o %t/o.o
7+
8+
// REQUIRES: objc_interop
9+
10+
import simd
11+
12+
public func swiftThingScalar(a: Float) {}
13+
public func swiftThingSIMD(a: simd_float3) {
14+
print(a)
15+
}
16+
17+
// CHECK: SWIFT_EXTERN void $s7UseSIMD010swiftThingB01ays5SIMD3VySfG_tF(swift_float3 a) SWIFT_NOEXCEPT SWIFT_CALL; // swiftThingSIMD(a:)
18+
// CHECK: SWIFT_EXTERN void $s7UseSIMD16swiftThingScalar1aySf_tF(float a) SWIFT_NOEXCEPT SWIFT_CALL; // swiftThingScalar(a:)
19+
20+
// CHECK: WIFT_INLINE_THUNK void swiftThingSIMD(swift_float3 a) noexcept SWIFT_SYMBOL("s:7UseSIMD010swiftThingB01ays5SIMD3VySfG_tF") {
21+
// CHECK-NEXT: UseSIMD::_impl::$s7UseSIMD010swiftThingB01ays5SIMD3VySfG_tF(a);
22+
// CHECK-NEXT: }
23+
24+
// CHECK: WIFT_INLINE_THUNK void swiftThingScalar(float a) noexcept SWIFT_SYMBOL("s:7UseSIMD16swiftThingScalar1aySf_tF") {
25+
// CHECK-NEXT: UseSIMD::_impl::$s7UseSIMD16swiftThingScalar1aySf_tF(a);
26+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)