Skip to content

Commit b9cf9ad

Browse files
authored
Merge pull request swiftlang#63497 from apple/egorzhdan/cxx-dictionary
[cxx-interop] Add `CxxDictionary` protocol for `std::map` ergonomics
2 parents 3b4d03c + 919eea7 commit b9cf9ad

File tree

11 files changed

+195
-43
lines changed

11 files changed

+195
-43
lines changed

include/swift/AST/KnownProtocols.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ PROTOCOL(DistributedTargetInvocationResultHandler)
106106

107107
// C++ Standard Library Overlay:
108108
PROTOCOL(CxxConvertibleToCollection)
109+
PROTOCOL(CxxDictionary)
110+
PROTOCOL(CxxPair)
109111
PROTOCOL(CxxSet)
110112
PROTOCOL(CxxRandomAccessCollection)
111113
PROTOCOL(CxxSequence)

lib/AST/ASTContext.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,8 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
11251125
M = getLoadedModule(Id_Distributed);
11261126
break;
11271127
case KnownProtocolKind::CxxConvertibleToCollection:
1128+
case KnownProtocolKind::CxxDictionary:
1129+
case KnownProtocolKind::CxxPair:
11281130
case KnownProtocolKind::CxxRandomAccessCollection:
11291131
case KnownProtocolKind::CxxSet:
11301132
case KnownProtocolKind::CxxSequence:

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 99 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ lookupDirectWithoutExtensions(NominalTypeDecl *decl, Identifier id) {
6161
return result;
6262
}
6363

64+
template <typename Decl>
65+
static Decl *lookupDirectSingleWithoutExtensions(NominalTypeDecl *decl,
66+
Identifier id) {
67+
auto results = lookupDirectWithoutExtensions(decl, id);
68+
if (results.size() != 1)
69+
return nullptr;
70+
return dyn_cast<Decl>(results.front());
71+
}
72+
6473
/// Similar to ModuleDecl::conformsToProtocol, but doesn't introduce a
6574
/// dependency on Sema.
6675
static bool isConcreteAndValid(ProtocolConformanceRef conformanceRef,
@@ -315,19 +324,14 @@ void swift::conformToCxxIteratorIfNeeded(
315324

316325
// Check if present: `var pointee: Pointee { get }`
317326
auto pointeeId = ctx.getIdentifier("pointee");
318-
auto pointees = lookupDirectWithoutExtensions(decl, pointeeId);
319-
if (pointees.size() != 1)
320-
return;
321-
auto pointee = dyn_cast<VarDecl>(pointees.front());
327+
auto pointee = lookupDirectSingleWithoutExtensions<VarDecl>(decl, pointeeId);
322328
if (!pointee || pointee->isGetterMutating() || pointee->getType()->hasError())
323329
return;
324330

325331
// Check if present: `func successor() -> Self`
326332
auto successorId = ctx.getIdentifier("successor");
327-
auto successors = lookupDirectWithoutExtensions(decl, successorId);
328-
if (successors.size() != 1)
329-
return;
330-
auto successor = dyn_cast<FuncDecl>(successors.front());
333+
auto successor =
334+
lookupDirectSingleWithoutExtensions<FuncDecl>(decl, successorId);
331335
if (!successor || successor->isMutating())
332336
return;
333337
auto successorTy = successor->getResultInterfaceType();
@@ -398,20 +402,14 @@ void swift::conformToCxxSequenceIfNeeded(
398402

399403
// Check if present: `func __beginUnsafe() -> RawIterator`
400404
auto beginId = ctx.getIdentifier("__beginUnsafe");
401-
auto begins = lookupDirectWithoutExtensions(decl, beginId);
402-
if (begins.size() != 1)
403-
return;
404-
auto begin = dyn_cast<FuncDecl>(begins.front());
405+
auto begin = lookupDirectSingleWithoutExtensions<FuncDecl>(decl, beginId);
405406
if (!begin)
406407
return;
407408
auto rawIteratorTy = begin->getResultInterfaceType();
408409

409410
// Check if present: `func __endUnsafe() -> RawIterator`
410411
auto endId = ctx.getIdentifier("__endUnsafe");
411-
auto ends = lookupDirectWithoutExtensions(decl, endId);
412-
if (ends.size() != 1)
413-
return;
414-
auto end = dyn_cast<FuncDecl>(ends.front());
412+
auto end = lookupDirectSingleWithoutExtensions<FuncDecl>(decl, endId);
415413
if (!end)
416414
return;
417415

@@ -524,6 +522,16 @@ void swift::conformToCxxSequenceIfNeeded(
524522
}
525523
}
526524

525+
static bool isStdDecl(const clang::CXXRecordDecl *clangDecl,
526+
llvm::ArrayRef<StringRef> names) {
527+
if (!clangDecl->isInStdNamespace())
528+
return false;
529+
if (!clangDecl->getIdentifier())
530+
return false;
531+
StringRef name = clangDecl->getName();
532+
return llvm::is_contained(names, name);
533+
}
534+
527535
void swift::conformToCxxSetIfNeeded(ClangImporter::Implementation &impl,
528536
NominalTypeDecl *decl,
529537
const clang::CXXRecordDecl *clangDecl) {
@@ -535,29 +543,88 @@ void swift::conformToCxxSetIfNeeded(ClangImporter::Implementation &impl,
535543

536544
// Only auto-conform types from the C++ standard library. Custom user types
537545
// might have a similar interface but different semantics.
538-
if (!clangDecl->isInStdNamespace())
539-
return;
540-
if (!clangDecl->getIdentifier())
541-
return;
542-
StringRef name = clangDecl->getName();
543-
if (name != "set" && name != "unordered_set" && name != "multiset")
544-
return;
545-
546-
auto valueTypeId = ctx.getIdentifier("value_type");
547-
auto valueTypes = lookupDirectWithoutExtensions(decl, valueTypeId);
548-
if (valueTypes.size() != 1)
546+
if (!isStdDecl(clangDecl, {"set", "unordered_set", "multiset"}))
549547
return;
550-
auto valueType = dyn_cast<TypeAliasDecl>(valueTypes.front());
551548

552-
auto sizeTypeId = ctx.getIdentifier("size_type");
553-
auto sizeTypes = lookupDirectWithoutExtensions(decl, sizeTypeId);
554-
if (sizeTypes.size() != 1)
549+
auto valueType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
550+
decl, ctx.getIdentifier("value_type"));
551+
auto sizeType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
552+
decl, ctx.getIdentifier("size_type"));
553+
if (!valueType || !sizeType)
555554
return;
556-
auto sizeType = dyn_cast<TypeAliasDecl>(sizeTypes.front());
557555

558556
impl.addSynthesizedTypealias(decl, ctx.Id_Element,
559557
valueType->getUnderlyingType());
560558
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Size"),
561559
sizeType->getUnderlyingType());
562560
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxSet});
563561
}
562+
563+
void swift::conformToCxxPairIfNeeded(ClangImporter::Implementation &impl,
564+
NominalTypeDecl *decl,
565+
const clang::CXXRecordDecl *clangDecl) {
566+
PrettyStackTraceDecl trace("conforming to CxxPair", decl);
567+
568+
assert(decl);
569+
assert(clangDecl);
570+
ASTContext &ctx = decl->getASTContext();
571+
572+
// Only auto-conform types from the C++ standard library. Custom user types
573+
// might have a similar interface but different semantics.
574+
if (!isStdDecl(clangDecl, {"pair"}))
575+
return;
576+
577+
auto firstType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
578+
decl, ctx.getIdentifier("first_type"));
579+
auto secondType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
580+
decl, ctx.getIdentifier("second_type"));
581+
if (!firstType || !secondType)
582+
return;
583+
584+
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("First"),
585+
firstType->getUnderlyingType());
586+
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Second"),
587+
secondType->getUnderlyingType());
588+
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxPair});
589+
}
590+
591+
void swift::conformToCxxDictionaryIfNeeded(
592+
ClangImporter::Implementation &impl, NominalTypeDecl *decl,
593+
const clang::CXXRecordDecl *clangDecl) {
594+
PrettyStackTraceDecl trace("conforming to CxxDictionary", decl);
595+
596+
assert(decl);
597+
assert(clangDecl);
598+
ASTContext &ctx = decl->getASTContext();
599+
600+
// Only auto-conform types from the C++ standard library. Custom user types
601+
// might have a similar interface but different semantics.
602+
if (!isStdDecl(clangDecl, {"map", "unordered_map"}))
603+
return;
604+
605+
auto keyType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
606+
decl, ctx.getIdentifier("key_type"));
607+
auto valueType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
608+
decl, ctx.getIdentifier("mapped_type"));
609+
auto iterType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
610+
decl, ctx.getIdentifier("const_iterator"));
611+
if (!keyType || !valueType || !iterType)
612+
return;
613+
614+
// Make the original subscript that returns a non-optional value unavailable.
615+
// CxxDictionary adds another subscript that returns an optional value,
616+
// similarly to Swift.Dictionary.
617+
for (auto member : decl->getCurrentMembersWithoutLoading()) {
618+
if (auto subscript = dyn_cast<SubscriptDecl>(member)) {
619+
impl.markUnavailable(subscript,
620+
"use subscript with optional return value");
621+
}
622+
}
623+
624+
impl.addSynthesizedTypealias(decl, ctx.Id_Key, keyType->getUnderlyingType());
625+
impl.addSynthesizedTypealias(decl, ctx.Id_Value,
626+
valueType->getUnderlyingType());
627+
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("RawIterator"),
628+
iterType->getUnderlyingType());
629+
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxDictionary});
630+
}

lib/ClangImporter/ClangDerivedConformances.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@ void conformToCxxSetIfNeeded(ClangImporter::Implementation &impl,
3939
NominalTypeDecl *decl,
4040
const clang::CXXRecordDecl *clangDecl);
4141

42+
/// If the decl is an instantiation of C++ `std::pair`, synthesize a conformance
43+
/// to CxxPair, which is defined in the Cxx module.
44+
void conformToCxxPairIfNeeded(ClangImporter::Implementation &impl,
45+
NominalTypeDecl *decl,
46+
const clang::CXXRecordDecl *clangDecl);
47+
48+
/// If the decl is an instantiation of C++ `std::map` or `std::unordered_map`,
49+
/// synthesize a conformance to CxxDictionary, which is defined in the Cxx module.
50+
void conformToCxxDictionaryIfNeeded(ClangImporter::Implementation &impl,
51+
NominalTypeDecl *decl,
52+
const clang::CXXRecordDecl *clangDecl);
53+
4254
} // namespace swift
4355

4456
#endif // SWIFT_CLANG_DERIVED_CONFORMANCES_H

lib/ClangImporter/ImportDecl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2613,6 +2613,8 @@ namespace {
26132613
conformToCxxIteratorIfNeeded(Impl, nominalDecl, decl);
26142614
conformToCxxSequenceIfNeeded(Impl, nominalDecl, decl);
26152615
conformToCxxSetIfNeeded(Impl, nominalDecl, decl);
2616+
conformToCxxDictionaryIfNeeded(Impl, nominalDecl, decl);
2617+
conformToCxxPairIfNeeded(Impl, nominalDecl, decl);
26162618
}
26172619

26182620
if (auto *ntd = dyn_cast<NominalTypeDecl>(result))

lib/IRGen/GenMeta.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5875,6 +5875,8 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
58755875
case KnownProtocolKind::DistributedTargetInvocationDecoder:
58765876
case KnownProtocolKind::DistributedTargetInvocationResultHandler:
58775877
case KnownProtocolKind::CxxConvertibleToCollection:
5878+
case KnownProtocolKind::CxxDictionary:
5879+
case KnownProtocolKind::CxxPair:
58785880
case KnownProtocolKind::CxxRandomAccessCollection:
58795881
case KnownProtocolKind::CxxSet:
58805882
case KnownProtocolKind::CxxSequence:

stdlib/public/Cxx/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ endif()
55

66
add_swift_target_library(swiftCxx ${SWIFT_CXX_LIBRARY_KIND} NO_LINK_NAME IS_STDLIB
77
CxxConvertibleToCollection.swift
8+
CxxDictionary.swift
9+
CxxPair.swift
810
CxxSet.swift
911
CxxRandomAccessCollection.swift
1012
CxxSequence.swift

stdlib/public/Cxx/CxxDictionary.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
public protocol CxxDictionary<Key, Value> {
14+
associatedtype Key
15+
associatedtype Value
16+
associatedtype RawIterator: UnsafeCxxInputIterator
17+
where RawIterator.Pointee: CxxPair<Key, Value>
18+
19+
/// Do not implement this function manually in Swift.
20+
func __findUnsafe(_ key: Key) -> RawIterator
21+
22+
/// Do not implement this function manually in Swift.
23+
func __endUnsafe() -> RawIterator
24+
}
25+
26+
extension CxxDictionary {
27+
@inlinable
28+
public subscript(key: Key) -> Value? {
29+
get {
30+
let iter = __findUnsafe(key)
31+
guard iter != __endUnsafe() else {
32+
return nil
33+
}
34+
return iter.pointee.second
35+
}
36+
}
37+
}

stdlib/public/Cxx/CxxPair.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
public protocol CxxPair<First, Second> {
14+
associatedtype First
15+
associatedtype Second
16+
17+
var first: First { get }
18+
var second: Second { get }
19+
}

test/Interop/Cxx/stdlib/Inputs/std-map.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
#define TEST_INTEROP_CXX_STDLIB_INPUTS_STD_MAP_H
33

44
#include <map>
5+
#include <unordered_map>
56

67
using Map = std::map<int, int>;
8+
using UnorderedMap = std::unordered_map<int, int>;
79

810
inline Map initMap() { return {{1, 3}, {2, 2}, {3, 3}}; }
11+
inline UnorderedMap initUnorderedMap() { return {{1, 3}, {3, 3}, {2, 2}}; }
912

1013
#endif // TEST_INTEROP_CXX_STDLIB_INPUTS_STD_MAP_H

0 commit comments

Comments
 (0)