Skip to content

Commit e1e00ed

Browse files
authored
Merge pull request #60001 from apple/egorzhdan/synthesize-iterator-conformance
[cxx-interop] Synthesize conformances to `UnsafeCxxInputIterator`
2 parents d1220fd + d85d2e9 commit e1e00ed

16 files changed

+514
-40
lines changed

include/swift/AST/KnownIdentifiers.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ IDENTIFIER(CGFloat)
5959
IDENTIFIER(CoreFoundation)
6060
IDENTIFIER(count)
6161
IDENTIFIER(CVarArg)
62+
IDENTIFIER(Cxx)
6263
IDENTIFIER(Darwin)
6364
IDENTIFIER(Distributed)
6465
IDENTIFIER(dealloc)

include/swift/AST/KnownProtocols.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ PROTOCOL(DistributedTargetInvocationEncoder)
104104
PROTOCOL(DistributedTargetInvocationDecoder)
105105
PROTOCOL(DistributedTargetInvocationResultHandler)
106106

107+
// C++ Standard Library Overlay:
108+
PROTOCOL(UnsafeCxxInputIterator)
109+
107110
PROTOCOL(AsyncSequence)
108111
PROTOCOL(AsyncIteratorProtocol)
109112

lib/AST/ASTContext.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,9 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
10441044
case KnownProtocolKind::DistributedTargetInvocationResultHandler:
10451045
M = getLoadedModule(Id_Distributed);
10461046
break;
1047+
case KnownProtocolKind::UnsafeCxxInputIterator:
1048+
M = getLoadedModule(Id_Cxx);
1049+
break;
10471050
default:
10481051
M = getStdlibModule();
10491052
break;

lib/ClangImporter/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_gyb_target(generated_sorted_cf_database
88
add_swift_host_library(swiftClangImporter STATIC
99
CFTypeInfo.cpp
1010
ClangAdapter.cpp
11+
ClangDerivedConformances.cpp
1112
ClangDiagnosticConsumer.cpp
1213
ClangImporter.cpp
1314
ClangImporterRequests.cpp
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
//===--- ClangDerivedConformances.cpp -------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2022 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "ClangDerivedConformances.h"
14+
#include "swift/AST/NameLookup.h"
15+
#include "swift/AST/ParameterList.h"
16+
17+
using namespace swift;
18+
19+
static clang::TypeDecl *
20+
getIteratorCategoryDecl(const clang::CXXRecordDecl *clangDecl) {
21+
clang::IdentifierInfo *iteratorCategoryDeclName =
22+
&clangDecl->getASTContext().Idents.get("iterator_category");
23+
auto iteratorCategories = clangDecl->lookup(iteratorCategoryDeclName);
24+
if (!iteratorCategories.isSingleResult())
25+
return nullptr;
26+
auto iteratorCategory = iteratorCategories.front();
27+
28+
return dyn_cast_or_null<clang::TypeDecl>(iteratorCategory);
29+
}
30+
31+
bool swift::isIterator(const clang::CXXRecordDecl *clangDecl) {
32+
return getIteratorCategoryDecl(clangDecl);
33+
}
34+
35+
void swift::conformToCxxIteratorIfNeeded(
36+
ClangImporter::Implementation &impl, NominalTypeDecl *decl,
37+
const clang::CXXRecordDecl *clangDecl) {
38+
assert(decl);
39+
assert(clangDecl);
40+
ASTContext &ctx = decl->getASTContext();
41+
42+
// We consider a type to be an input iterator if it defines an
43+
// `iterator_category` that inherits from `std::input_iterator_tag`, e.g.
44+
// `using iterator_category = std::input_iterator_tag`.
45+
auto iteratorCategory = getIteratorCategoryDecl(clangDecl);
46+
if (!iteratorCategory)
47+
return;
48+
49+
// If `iterator_category` is a typedef or a using-decl, retrieve the
50+
// underlying struct decl.
51+
clang::CXXRecordDecl *underlyingCategoryDecl = nullptr;
52+
if (auto typedefDecl = dyn_cast<clang::TypedefNameDecl>(iteratorCategory)) {
53+
auto type = typedefDecl->getUnderlyingType();
54+
underlyingCategoryDecl = type->getAsCXXRecordDecl();
55+
} else {
56+
underlyingCategoryDecl = dyn_cast<clang::CXXRecordDecl>(iteratorCategory);
57+
}
58+
if (underlyingCategoryDecl) {
59+
underlyingCategoryDecl = underlyingCategoryDecl->getDefinition();
60+
}
61+
62+
if (!underlyingCategoryDecl)
63+
return;
64+
65+
auto isInputIteratorDecl = [&](const clang::CXXRecordDecl *base) {
66+
return base->isInStdNamespace() && base->getIdentifier() &&
67+
base->getName() == "input_iterator_tag";
68+
};
69+
70+
// Traverse all transitive bases of `underlyingDecl` to check if
71+
// it inherits from `std::input_iterator_tag`.
72+
bool isInputIterator = isInputIteratorDecl(underlyingCategoryDecl);
73+
underlyingCategoryDecl->forallBases([&](const clang::CXXRecordDecl *base) {
74+
if (isInputIteratorDecl(base)) {
75+
isInputIterator = true;
76+
return false;
77+
}
78+
return true;
79+
});
80+
81+
if (!isInputIterator)
82+
return;
83+
84+
// Check if present: `var pointee: Pointee { get }`
85+
auto pointeeId = ctx.getIdentifier("pointee");
86+
auto pointees = decl->lookupDirect(pointeeId);
87+
if (pointees.size() != 1)
88+
return;
89+
auto pointee = dyn_cast<VarDecl>(pointees.front());
90+
if (!pointee || pointee->isGetterMutating())
91+
return;
92+
93+
// Check if present: `func successor() -> Self`
94+
auto successorId = ctx.getIdentifier("successor");
95+
auto successors = decl->lookupDirect(successorId);
96+
if (successors.size() != 1)
97+
return;
98+
auto successor = dyn_cast<FuncDecl>(successors.front());
99+
if (!successor || successor->isMutating())
100+
return;
101+
auto successorTy = successor->getResultInterfaceType();
102+
if (!successorTy || successorTy->getAnyNominal() != decl)
103+
return;
104+
105+
// Check if present: `func ==`
106+
// FIXME: this only detects `operator==` declared as a member.
107+
auto equalEquals = decl->lookupDirect(ctx.Id_EqualsOperator);
108+
if (equalEquals.empty())
109+
return;
110+
auto equalEqual = dyn_cast<FuncDecl>(equalEquals.front());
111+
if (!equalEqual || !equalEqual->hasParameterList())
112+
return;
113+
auto equalEqualParams = equalEqual->getParameters();
114+
if (equalEqualParams->size() != 2)
115+
return;
116+
auto equalEqualLHS = equalEqualParams->get(0);
117+
auto equalEqualRHS = equalEqualParams->get(1);
118+
if (equalEqualLHS->isInOut() || equalEqualRHS->isInOut())
119+
return;
120+
auto equalEqualLHSTy = equalEqualLHS->getType();
121+
auto equalEqualRHSTy = equalEqualRHS->getType();
122+
if (!equalEqualLHSTy || !equalEqualRHSTy ||
123+
equalEqualLHSTy->getAnyNominal() != equalEqualRHSTy->getAnyNominal())
124+
return;
125+
126+
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Pointee"),
127+
pointee->getType());
128+
impl.addSynthesizedProtocolAttrs(decl,
129+
{KnownProtocolKind::UnsafeCxxInputIterator});
130+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===--- ClangDerivedConformances.h -----------------------------*- C++ -*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2022 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef SWIFT_CLANG_DERIVED_CONFORMANCES_H
14+
#define SWIFT_CLANG_DERIVED_CONFORMANCES_H
15+
16+
#include "ImporterImpl.h"
17+
#include "swift/AST/ASTContext.h"
18+
19+
namespace swift {
20+
21+
bool isIterator(const clang::CXXRecordDecl *clangDecl);
22+
23+
/// If the decl is a C++ input iterator, synthesize a conformance to the
24+
/// UnsafeCxxInputIterator protocol, which is defined in the std overlay.
25+
void conformToCxxIteratorIfNeeded(ClangImporter::Implementation &impl,
26+
NominalTypeDecl *decl,
27+
const clang::CXXRecordDecl *clangDecl);
28+
29+
} // namespace swift
30+
31+
#endif // SWIFT_CLANG_DERIVED_CONFORMANCES_H

lib/ClangImporter/ClangImporter.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
//
1515
//===----------------------------------------------------------------------===//
1616
#include "swift/ClangImporter/ClangImporter.h"
17+
#include "ClangDerivedConformances.h"
1718
#include "ClangDiagnosticConsumer.h"
1819
#include "ClangIncludePaths.h"
1920
#include "ImporterImpl.h"
@@ -3232,6 +3233,19 @@ static void getImportDecls(ClangModuleUnit *ClangUnit, const clang::Module *M,
32323233
auto *ID = createImportDecl(Ctx, ClangUnit, ImportedMod, Exported);
32333234
Results.push_back(ID);
32343235
}
3236+
3237+
if (Ctx.LangOpts.EnableCXXInterop && requiresCPlusPlus(M)) {
3238+
// Try to load the Cxx module. We don't use it directly here, but we need to
3239+
// make sure that ASTContext has loaded this module.
3240+
auto *cxxModule = Ctx.getModuleByIdentifier(Ctx.Id_Cxx);
3241+
if (cxxModule) {
3242+
ImportPath::Builder builder(Ctx.Id_Cxx);
3243+
auto *importCxx =
3244+
ImportDecl::create(Ctx, ClangUnit, SourceLoc(), ImportKind::Module,
3245+
SourceLoc(), builder.get());
3246+
Results.push_back(importCxx);
3247+
}
3248+
}
32353249
}
32363250

32373251
void ClangModuleUnit::getDisplayDecls(SmallVectorImpl<Decl*> &results, bool recursive) const {
@@ -5962,7 +5976,7 @@ CxxRecordSemantics::evaluate(Evaluator &evaluator,
59625976
return CxxRecordSemanticsKind::Owned;
59635977
}
59645978

5965-
if (hasIteratorAPIAttr(decl)) {
5979+
if (hasIteratorAPIAttr(decl) || isIterator(decl)) {
59665980
return CxxRecordSemanticsKind::Iterator;
59675981
}
59685982

lib/ClangImporter/ImportDecl.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "CFTypeInfo.h"
1818
#include "ImporterImpl.h"
19+
#include "ClangDerivedConformances.h"
1920
#include "SwiftDeclSynthesizer.h"
2021
#include "swift/AST/ASTContext.h"
2122
#include "swift/AST/Attr.h"
@@ -2499,7 +2500,18 @@ namespace {
24992500
return nullptr;
25002501
}
25012502

2502-
return VisitRecordDecl(decl);
2503+
auto result = VisitRecordDecl(decl);
2504+
2505+
// If this module is declared as a C++ module, try to synthesize
2506+
// conformances to Swift protocols from the Cxx module.
2507+
auto clangModule = decl->getOwningModule();
2508+
if (clangModule && requiresCPlusPlus(clangModule)) {
2509+
if (auto structDecl = dyn_cast_or_null<NominalTypeDecl>(result)) {
2510+
conformToCxxIteratorIfNeeded(Impl, structDecl, decl);
2511+
}
2512+
}
2513+
2514+
return result;
25032515
}
25042516

25052517
bool isSpecializationDepthGreaterThan(

lib/ClangImporter/ImporterImpl.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1861,6 +1861,22 @@ inline Optional<const clang::EnumDecl *> findAnonymousEnumForTypedef(
18611861
return None;
18621862
}
18631863

1864+
inline bool requiresCPlusPlus(const clang::Module *module) {
1865+
// The libc++ modulemap doesn't currently declare the requirement.
1866+
if (module->getTopLevelModuleName() == "std")
1867+
return true;
1868+
1869+
// Modulemaps often declare the requirement for the top-level module only.
1870+
if (auto parent = module->Parent) {
1871+
if (requiresCPlusPlus(parent))
1872+
return true;
1873+
}
1874+
1875+
return llvm::any_of(module->Requirements, [](clang::Module::Requirement req) {
1876+
return req.first == "cplusplus";
1877+
});
1878+
}
1879+
18641880
}
18651881
}
18661882

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5778,6 +5778,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
57785778
case KnownProtocolKind::DistributedTargetInvocationEncoder:
57795779
case KnownProtocolKind::DistributedTargetInvocationDecoder:
57805780
case KnownProtocolKind::DistributedTargetInvocationResultHandler:
5781+
case KnownProtocolKind::UnsafeCxxInputIterator:
57815782
case KnownProtocolKind::SerialExecutor:
57825783
case KnownProtocolKind::Sendable:
57835784
case KnownProtocolKind::UnsafeSendable:

0 commit comments

Comments
 (0)