Skip to content

Commit 7bab318

Browse files
authored
Merge pull request swiftlang#24159 from jckarter/protocol-extension-context-lookup
MetadataLookup: Use extension's generic context for non-nominal extensions.
2 parents 007fbb6 + 9a411be commit 7bab318

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

stdlib/public/runtime/MetadataLookup.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,9 @@ _findExtendedTypeContextDescriptor(const ExtensionContextDescriptor *extension,
263263
return nullptr;
264264
node = node->getChild(0);
265265
}
266-
node = Demangle::getUnspecialized(node, demangler);
266+
if (Demangle::isSpecialized(node)) {
267+
node = Demangle::getUnspecialized(node, demangler);
268+
}
267269

268270
return _findNominalTypeDescriptor(node, demangler);
269271
}
@@ -839,15 +841,17 @@ bool swift::_gatherGenericParameterCounts(
839841
const ContextDescriptor *descriptor,
840842
SmallVectorImpl<unsigned> &genericParamCounts,
841843
Demangler &BorrowFrom) {
842-
// If we have an extension descriptor, extract the extended type and use
843-
// that.
844844
DemanglerForRuntimeTypeResolution<> demangler;
845845
demangler.providePreallocatedMemory(BorrowFrom);
846846

847847
if (auto extension = dyn_cast<ExtensionContextDescriptor>(descriptor)) {
848+
// If we have an nominal type extension descriptor, extract the extended type
849+
// and use that. If the extension is not nominal, then we can use the
850+
// extension's own signature.
848851
if (auto extendedType =
849-
_findExtendedTypeContextDescriptor(extension, demangler))
852+
_findExtendedTypeContextDescriptor(extension, demangler)) {
850853
descriptor = extendedType;
854+
}
851855
}
852856

853857
// Once we hit a non-generic descriptor, we're done.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %target-run-simple-swift | %FileCheck %s
2+
// REQUIRES: executable_test
3+
4+
protocol P {
5+
associatedtype AT
6+
func foo() -> AT
7+
}
8+
9+
extension P {
10+
func foo() -> some P {
11+
return self
12+
}
13+
}
14+
15+
func getPAT<T: P>(_: T.Type) -> Any.Type {
16+
return T.AT.self
17+
}
18+
19+
extension Int: P { }
20+
21+
// CHECK: Int
22+
print(getPAT(Int.self))

0 commit comments

Comments
 (0)