Skip to content

Commit 00da70f

Browse files
authored
Merge pull request #84461 from susmonteiro/susmonteiro/copyable-if-annotation
[cxx-interop] Add SWIFT_COPYABLE_IF macro
2 parents e93ffb7 + f4cf914 commit 00da70f

File tree

4 files changed

+185
-86
lines changed

4 files changed

+185
-86
lines changed

include/swift/ClangImporter/ClangImporterRequests.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ SourceLoc extractNearestSourceLoc(EscapabilityLookupDescriptor desc);
576576
// the object’s storage. This means reference types can be imported as
577577
// copyable to Swift, even when they are non-copyable in C++.
578578
enum class CxxValueSemanticsKind {
579+
Unknown,
579580
Copyable,
580581
MoveOnly,
581582
// A record that is either not copyable/movable or not destructible.

lib/ClangImporter/ClangImporter.cpp

Lines changed: 128 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -5298,16 +5298,66 @@ static const llvm::StringMap<std::vector<int>> STLConditionalParams{
52985298
{"unordered_multimap", {0, 1}},
52995299
};
53005300

5301+
template <typename Kind>
5302+
static std::optional<Kind> checkConditionalParams(
5303+
clang::RecordDecl *recordDecl, const std::vector<int> &STLParams,
5304+
std::set<StringRef> &conditionalParams,
5305+
std::function<std::optional<Kind>(clang::TemplateArgument &, StringRef)>
5306+
&checkArg) {
5307+
auto specDecl = cast<clang::ClassTemplateSpecializationDecl>(recordDecl);
5308+
SmallVector<std::pair<unsigned, StringRef>, 4> argumentsToCheck;
5309+
bool hasInjectedSTLAnnotation = !STLParams.empty();
5310+
while (specDecl) {
5311+
auto templateDecl = specDecl->getSpecializedTemplate();
5312+
if (hasInjectedSTLAnnotation) {
5313+
auto params = templateDecl->getTemplateParameters();
5314+
for (auto idx : STLParams)
5315+
argumentsToCheck.push_back(
5316+
std::make_pair(idx, params->getParam(idx)->getName()));
5317+
} else {
5318+
for (auto [idx, param] :
5319+
llvm::enumerate(*templateDecl->getTemplateParameters())) {
5320+
if (conditionalParams.erase(param->getName()))
5321+
argumentsToCheck.push_back(std::make_pair(idx, param->getName()));
5322+
}
5323+
}
5324+
auto &argList = specDecl->getTemplateArgs();
5325+
for (auto argToCheck : argumentsToCheck) {
5326+
auto arg = argList[argToCheck.first];
5327+
llvm::SmallVector<clang::TemplateArgument, 1> nonPackArgs;
5328+
if (arg.getKind() == clang::TemplateArgument::Pack) {
5329+
auto pack = arg.getPackAsArray();
5330+
nonPackArgs.assign(pack.begin(), pack.end());
5331+
} else
5332+
nonPackArgs.push_back(arg);
5333+
for (auto nonPackArg : nonPackArgs) {
5334+
auto result = checkArg(nonPackArg, argToCheck.second);
5335+
if (result.has_value())
5336+
return result.value();
5337+
}
5338+
}
5339+
if (hasInjectedSTLAnnotation)
5340+
break;
5341+
clang::DeclContext *dc = specDecl;
5342+
specDecl = nullptr;
5343+
while ((dc = dc->getParent())) {
5344+
specDecl = dyn_cast<clang::ClassTemplateSpecializationDecl>(dc);
5345+
if (specDecl)
5346+
break;
5347+
}
5348+
}
5349+
return std::nullopt;
5350+
}
5351+
53015352
static std::set<StringRef>
5302-
getConditionalEscapableAttrParams(const clang::RecordDecl *decl) {
5353+
getConditionalAttrParams(const clang::RecordDecl *decl, StringRef attrName) {
53035354
std::set<StringRef> result;
53045355
if (!decl->hasAttrs())
53055356
return result;
53065357
for (auto attr : decl->getAttrs()) {
5307-
if (auto swiftAttr = dyn_cast<clang::SwiftAttrAttr>(attr))
5308-
if (swiftAttr->getAttribute().starts_with("escapable_if:")) {
5309-
StringRef params = swiftAttr->getAttribute().drop_front(
5310-
StringRef("escapable_if:").size());
5358+
if (auto swiftAttr = dyn_cast<clang::SwiftAttrAttr>(attr)) {
5359+
StringRef params = swiftAttr->getAttribute();
5360+
if (params.consume_front(attrName)) {
53115361
auto commaPos = params.find(',');
53125362
StringRef nextParam = params.take_front(commaPos);
53135363
while (!nextParam.empty() && commaPos != StringRef::npos) {
@@ -5317,10 +5367,21 @@ getConditionalEscapableAttrParams(const clang::RecordDecl *decl) {
53175367
nextParam = params.take_front(commaPos);
53185368
}
53195369
}
5370+
}
53205371
}
53215372
return result;
53225373
}
53235374

5375+
static std::set<StringRef>
5376+
getConditionalEscapableAttrParams(const clang::RecordDecl *decl) {
5377+
return getConditionalAttrParams(decl, "escapable_if:");
5378+
}
5379+
5380+
static std::set<StringRef>
5381+
getConditionalCopyableAttrParams(const clang::RecordDecl *decl) {
5382+
return getConditionalAttrParams(decl, "copyable_if:");
5383+
}
5384+
53245385
CxxEscapability
53255386
ClangTypeEscapability::evaluate(Evaluator &evaluator,
53265387
EscapabilityLookupDescriptor desc) const {
@@ -5346,60 +5407,33 @@ ClangTypeEscapability::evaluate(Evaluator &evaluator,
53465407
recordDecl->isInStdNamespace()
53475408
? STLConditionalParams.find(recordDecl->getName())
53485409
: STLConditionalParams.end();
5349-
bool hasInjectedSTLAnnotation =
5350-
injectedStlAnnotation != STLConditionalParams.end();
5410+
auto STLParams = injectedStlAnnotation != STLConditionalParams.end()
5411+
? injectedStlAnnotation->second
5412+
: std::vector<int>();
53515413
auto conditionalParams = getConditionalEscapableAttrParams(recordDecl);
5352-
if (!conditionalParams.empty() || hasInjectedSTLAnnotation) {
5353-
auto specDecl = cast<clang::ClassTemplateSpecializationDecl>(recordDecl);
5354-
SmallVector<std::pair<unsigned, StringRef>, 4> argumentsToCheck;
5414+
5415+
if (!STLParams.empty() || !conditionalParams.empty()) {
53555416
HeaderLoc loc{recordDecl->getLocation()};
5356-
while (specDecl) {
5357-
auto templateDecl = specDecl->getSpecializedTemplate();
5358-
if (hasInjectedSTLAnnotation) {
5359-
auto params = templateDecl->getTemplateParameters();
5360-
for (auto idx : injectedStlAnnotation->second)
5361-
argumentsToCheck.push_back(
5362-
std::make_pair(idx, params->getParam(idx)->getName()));
5363-
} else {
5364-
for (auto [idx, param] :
5365-
llvm::enumerate(*templateDecl->getTemplateParameters())) {
5366-
if (conditionalParams.erase(param->getName()))
5367-
argumentsToCheck.push_back(std::make_pair(idx, param->getName()));
5368-
}
5417+
std::function checkArgEscapability =
5418+
[&](clang::TemplateArgument &arg,
5419+
StringRef argToCheck) -> std::optional<CxxEscapability> {
5420+
if (arg.getKind() != clang::TemplateArgument::Type && desc.impl) {
5421+
desc.impl->diagnose(loc, diag::type_template_parameter_expected,
5422+
argToCheck);
5423+
return CxxEscapability::Unknown;
53695424
}
5370-
auto &argList = specDecl->getTemplateArgs();
5371-
for (auto argToCheck : argumentsToCheck) {
5372-
auto arg = argList[argToCheck.first];
5373-
llvm::SmallVector<clang::TemplateArgument, 1> nonPackArgs;
5374-
if (arg.getKind() == clang::TemplateArgument::Pack) {
5375-
auto pack = arg.getPackAsArray();
5376-
nonPackArgs.assign(pack.begin(), pack.end());
5377-
} else
5378-
nonPackArgs.push_back(arg);
5379-
for (auto nonPackArg : nonPackArgs) {
5380-
if (nonPackArg.getKind() != clang::TemplateArgument::Type &&
5381-
desc.impl) {
5382-
desc.impl->diagnose(loc, diag::type_template_parameter_expected,
5383-
argToCheck.second);
5384-
return CxxEscapability::Unknown;
5385-
}
53865425

5387-
auto argEscapability = evaluateEscapability(
5388-
nonPackArg.getAsType()->getUnqualifiedDesugaredType());
5389-
if (argEscapability == CxxEscapability::NonEscapable)
5390-
return CxxEscapability::NonEscapable;
5391-
}
5392-
}
5393-
if (hasInjectedSTLAnnotation)
5394-
break;
5395-
clang::DeclContext *dc = specDecl;
5396-
specDecl = nullptr;
5397-
while ((dc = dc->getParent())) {
5398-
specDecl = dyn_cast<clang::ClassTemplateSpecializationDecl>(dc);
5399-
if (specDecl)
5400-
break;
5401-
}
5402-
}
5426+
auto argEscapability = evaluateEscapability(
5427+
arg.getAsType()->getUnqualifiedDesugaredType());
5428+
if (argEscapability == CxxEscapability::NonEscapable)
5429+
return CxxEscapability::NonEscapable;
5430+
return std::nullopt;
5431+
};
5432+
5433+
auto result = checkConditionalParams<CxxEscapability>(
5434+
recordDecl, STLParams, conditionalParams, checkArgEscapability);
5435+
if (result.has_value())
5436+
return result.value();
54035437

54045438
if (desc.impl)
54055439
for (auto name : conditionalParams)
@@ -8338,36 +8372,48 @@ CxxValueSemantics::evaluate(Evaluator &evaluator,
83388372
if (recordDecl->getIdentifier() &&
83398373
recordDecl->getName() == "_Optional_construct_base")
83408374
return CxxValueSemanticsKind::Copyable;
8375+
}
83418376

8342-
auto injectedStlAnnotation =
8343-
STLConditionalParams.find(recordDecl->getName());
8344-
8345-
if (injectedStlAnnotation != STLConditionalParams.end()) {
8346-
auto specDecl = cast<clang::ClassTemplateSpecializationDecl>(recordDecl);
8347-
auto &argList = specDecl->getTemplateArgs();
8348-
for (auto argToCheck : injectedStlAnnotation->second) {
8349-
auto arg = argList[argToCheck];
8350-
llvm::SmallVector<clang::TemplateArgument, 1> nonPackArgs;
8351-
if (arg.getKind() == clang::TemplateArgument::Pack) {
8352-
auto pack = arg.getPackAsArray();
8353-
nonPackArgs.assign(pack.begin(), pack.end());
8354-
} else
8355-
nonPackArgs.push_back(arg);
8356-
for (auto nonPackArg : nonPackArgs) {
8357-
8358-
auto argValueSemantics = evaluateOrDefault(
8359-
evaluator,
8360-
CxxValueSemantics(
8361-
{nonPackArg.getAsType()->getUnqualifiedDesugaredType(),
8362-
desc.importerImpl}),
8363-
{});
8364-
if (argValueSemantics != CxxValueSemanticsKind::Copyable)
8365-
return argValueSemantics;
8366-
}
8377+
auto injectedStlAnnotation =
8378+
recordDecl->isInStdNamespace()
8379+
? STLConditionalParams.find(recordDecl->getName())
8380+
: STLConditionalParams.end();
8381+
auto STLParams = injectedStlAnnotation != STLConditionalParams.end()
8382+
? injectedStlAnnotation->second
8383+
: std::vector<int>();
8384+
auto conditionalParams = getConditionalCopyableAttrParams(recordDecl);
8385+
8386+
if (!STLParams.empty() || !conditionalParams.empty()) {
8387+
HeaderLoc loc{recordDecl->getLocation()};
8388+
std::function checkArgValueSemantics =
8389+
[&](clang::TemplateArgument &arg,
8390+
StringRef argToCheck) -> std::optional<CxxValueSemanticsKind> {
8391+
if (arg.getKind() != clang::TemplateArgument::Type && importerImpl) {
8392+
importerImpl->diagnose(loc, diag::type_template_parameter_expected,
8393+
argToCheck);
8394+
return CxxValueSemanticsKind::Unknown;
83678395
}
83688396

8369-
return CxxValueSemanticsKind::Copyable;
8370-
}
8397+
auto argValueSemantics = evaluateOrDefault(
8398+
evaluator,
8399+
CxxValueSemantics(
8400+
{arg.getAsType()->getUnqualifiedDesugaredType(), importerImpl}),
8401+
{});
8402+
if (argValueSemantics != CxxValueSemanticsKind::Copyable)
8403+
return argValueSemantics;
8404+
return std::nullopt;
8405+
};
8406+
8407+
auto result = checkConditionalParams<CxxValueSemanticsKind>(
8408+
recordDecl, STLParams, conditionalParams, checkArgValueSemantics);
8409+
if (result.has_value())
8410+
return result.value();
8411+
8412+
if (importerImpl)
8413+
for (auto name : conditionalParams)
8414+
importerImpl->diagnose(loc, diag::unknown_template_parameter, name);
8415+
8416+
return CxxValueSemanticsKind::Copyable;
83718417
}
83728418

83738419
const auto cxxRecordDecl = dyn_cast<clang::CXXRecordDecl>(recordDecl);

lib/ClangImporter/SwiftBridging/swift/bridging

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@
190190
__attribute__((swift_attr("~Copyable"))) \
191191
__attribute__((swift_attr(_CXX_INTEROP_STRINGIFY(destroy:_destroy))))
192192

193+
/// Specifies that a C++ `class` or `struct` should be imported as a copyable
194+
/// Swift value if all of the specified template arguments are copyable.
195+
#define SWIFT_COPYABLE_IF(...) \
196+
__attribute__((swift_attr("copyable_if:" _CXX_INTEROP_CONCAT(__VA_ARGS__))))
197+
193198
/// Specifies that a specific class or struct should be imported
194199
/// as a non-escapable Swift value type.
195200
#define SWIFT_NONESCAPABLE \
@@ -283,6 +288,7 @@
283288
#define SWIFT_UNCHECKED_SENDABLE
284289
#define SWIFT_NONCOPYABLE
285290
#define SWIDT_NONCOPYABLE_WITH_DESTROY(_destroy)
291+
#define SWIFT_COPYABLE_IF(...)
286292
#define SWIFT_NONESCAPABLE
287293
#define SWIFT_ESCAPABLE
288294
#define SWIFT_ESCAPABLE_IF(...)

test/Interop/Cxx/class/noncopyable-typechecker.swift

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: %empty-directory(%t)
22
// RUN: split-file %s %t
3-
// RUN: %target-swift-frontend -cxx-interoperability-mode=default -typecheck -verify -I %t/Inputs %t/test.swift
4-
// RUN: %target-swift-frontend -cxx-interoperability-mode=default -Xcc -std=c++20 -verify-additional-prefix cpp20- -D CPP20 -typecheck -verify -I %t/Inputs %t/test.swift
3+
// RUN: %target-swift-frontend -cxx-interoperability-mode=default -typecheck -verify -I %swift_src_root/lib/ClangImporter/SwiftBridging -I %t/Inputs %t/test.swift
4+
// RUN: %target-swift-frontend -cxx-interoperability-mode=default -Xcc -std=c++20 -verify-additional-prefix cpp20- -D CPP20 -typecheck -verify -I %swift_src_root/lib/ClangImporter/SwiftBridging -I %t/Inputs %t/test.swift
55

66
//--- Inputs/module.modulemap
77
module Test {
@@ -10,6 +10,7 @@ module Test {
1010
}
1111

1212
//--- Inputs/noncopyable.h
13+
#include "swift/bridging"
1314
#include <string>
1415

1516
struct NonCopyable {
@@ -28,6 +29,29 @@ struct OwnsT {
2829

2930
using OwnsNonCopyable = OwnsT<NonCopyable>;
3031

32+
template <typename T>
33+
struct SWIFT_COPYABLE_IF(T) AnnotatedOwnsT {
34+
T element;
35+
AnnotatedOwnsT() {}
36+
AnnotatedOwnsT(const AnnotatedOwnsT &other) : element(other.element) {}
37+
AnnotatedOwnsT(AnnotatedOwnsT&& other) {}
38+
};
39+
40+
using AnnotatedOwnsNonCopyable = AnnotatedOwnsT<NonCopyable>;
41+
42+
template <typename F, typename S>
43+
struct SWIFT_COPYABLE_IF(F, S) MyPair {
44+
F first;
45+
S second;
46+
};
47+
48+
MyPair<int, NonCopyable> p1();
49+
MyPair<int, NonCopyable*> p2();
50+
MyPair<int, OwnsNonCopyable> p3();
51+
MyPair<int, AnnotatedOwnsNonCopyable> p4();
52+
MyPair<int, MyPair<int, NonCopyable>> p5();
53+
MyPair<NonCopyable, int> p6();
54+
3155
#if __cplusplus >= 202002L
3256
template <typename T>
3357
struct RequiresCopyableT {
@@ -38,6 +62,9 @@ struct RequiresCopyableT {
3862
};
3963

4064
using NonCopyableRequires = RequiresCopyableT<NonCopyable>;
65+
using CopyableIfRequires = RequiresCopyableT<MyPair<int, NonCopyable>>;
66+
67+
MyPair<int, NonCopyableRequires> p7();
4168

4269
#endif
4370

@@ -55,9 +82,28 @@ func userDefinedTypes() {
5582
takeCopyable(ownsT) // no error, OwnsNonCopyable imported as Copyable
5683
}
5784

85+
func useCopyableIf() {
86+
takeCopyable(p1()) // expected-error {{global function 'takeCopyable' requires that 'MyPair<CInt, NonCopyable>' conform to 'Copyable'}}
87+
takeCopyable(p2())
88+
89+
// p3() -> MyPair<int, OwnsNonCopyable> is imported as Copyable and will cause an error during IRGen.
90+
// During typecheck we don't produce an error because we're missing an annotation in OwnsT.
91+
takeCopyable(p3())
92+
// p4() -> (MyPair<int, AnnotatedOwnsNonCopyable>) is imported as NonCopyable because AnnotatedOwnsT is correctly annotated.
93+
takeCopyable(p4()) // expected-error {{global function 'takeCopyable' requires that 'MyPair<CInt, AnnotatedOwnsT<NonCopyable>>' conform to 'Copyable'}}
94+
95+
takeCopyable(p5()) // expected-error {{global function 'takeCopyable' requires that 'MyPair<CInt, MyPair<CInt, NonCopyable>>' conform to 'Copyable'}}
96+
takeCopyable(p6()) // expected-error {{global function 'takeCopyable' requires that 'MyPair<NonCopyable, CInt>' conform to 'Copyable'}}
97+
}
98+
5899
#if CPP20
59100
func useOfRequires() {
60-
let nCop = NonCopyableRequires()
61-
takeCopyable(nCop) // expected-cpp20-error {{global function 'takeCopyable' requires that 'NonCopyableRequires' (aka 'RequiresCopyableT<NonCopyable>') conform to 'Copyable'}}
101+
let a = NonCopyableRequires()
102+
takeCopyable(a) // expected-cpp20-error {{global function 'takeCopyable' requires that 'NonCopyableRequires' (aka 'RequiresCopyableT<NonCopyable>') conform to 'Copyable'}}
103+
104+
let b = CopyableIfRequires()
105+
takeCopyable(b) // expected-cpp20-error {{global function 'takeCopyable' requires that 'CopyableIfRequires' (aka 'RequiresCopyableT<MyPair<CInt, NonCopyable>>') conform to 'Copyable'}}
106+
107+
takeCopyable(p7()) // expected-cpp20-error {{global function 'takeCopyable' requires that 'MyPair<CInt, RequiresCopyableT<NonCopyable>>' conform to 'Copyable'}}
62108
}
63109
#endif

0 commit comments

Comments
 (0)