Skip to content

Commit 52b0fa9

Browse files
authored
Merge pull request #70696 from slavapestov/fix-rdar119499800-5.10
Sema: Associated type inference fixes [5.10]
2 parents 8fe51f3 + 1cc41a9 commit 52b0fa9

File tree

6 files changed

+116
-28
lines changed

6 files changed

+116
-28
lines changed

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7172,10 +7172,11 @@ void TypeChecker::inferDefaultWitnesses(ProtocolDecl *proto) {
71727172
DefaultWitnessChecker checker(proto);
71737173

71747174
// Find the default for the given associated type.
7175-
auto findAssociatedTypeDefault = [](AssociatedTypeDecl *assocType)
7175+
auto findAssociatedTypeDefault = [proto](AssociatedTypeDecl *assocType)
71767176
-> std::pair<Type, AssociatedTypeDecl *> {
71777177
auto defaultedAssocType =
7178-
AssociatedTypeInference::findDefaultedAssociatedType(assocType);
7178+
AssociatedTypeInference::findDefaultedAssociatedType(
7179+
proto, proto, assocType);
71797180
if (!defaultedAssocType)
71807181
return {Type(), nullptr};
71817182

lib/Sema/TypeCheckProtocol.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1271,7 +1271,8 @@ class AssociatedTypeInference {
12711271

12721272
/// Find an associated type declaration that provides a default definition.
12731273
static AssociatedTypeDecl *findDefaultedAssociatedType(
1274-
AssociatedTypeDecl *assocType);
1274+
DeclContext *dc, NominalTypeDecl *adoptee,
1275+
AssociatedTypeDecl *assocType);
12751276
};
12761277

12771278
/// Match the given witness to the given requirement.

lib/Sema/TypeCheckProtocolInference.cpp

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -165,27 +165,25 @@ namespace {
165165
/// Try to avoid situations where resolving the type of a witness calls back
166166
/// into associated type inference.
167167
struct TypeReprCycleCheckWalker : ASTWalker {
168+
ASTContext &ctx;
168169
llvm::SmallDenseSet<Identifier, 2> circularNames;
169170
ValueDecl *witness;
170171
bool found;
171172

172173
TypeReprCycleCheckWalker(
174+
ASTContext &ctx,
173175
const llvm::SetVector<AssociatedTypeDecl *> &allUnresolved)
174-
: witness(nullptr), found(false) {
176+
: ctx(ctx), witness(nullptr), found(false) {
175177
for (auto *assocType : allUnresolved) {
176178
circularNames.insert(assocType->getName());
177179
}
178180
}
179181

180182
PreWalkAction walkToTypeReprPre(TypeRepr *T) override {
181-
// FIXME: We should still visit any generic arguments of this member type.
182-
// However, we want to skip 'Foo.Element' because the 'Element' reference is
183-
// not unqualified.
184-
if (auto *memberTyR = dyn_cast<MemberTypeRepr>(T)) {
185-
return Action::SkipChildren();
186-
}
183+
// FIXME: Visit generic arguments.
187184

188185
if (auto *identTyR = dyn_cast<SimpleIdentTypeRepr>(T)) {
186+
// If we're inferring `Foo`, don't look at a witness mentioning `Foo`.
189187
if (circularNames.count(identTyR->getNameRef().getBaseIdentifier()) > 0) {
190188
// If unqualified lookup can find a type with this name without looking
191189
// into protocol members, don't skip the witness, since this type might
@@ -194,7 +192,6 @@ struct TypeReprCycleCheckWalker : ASTWalker {
194192
identTyR->getNameRef(), witness->getDeclContext(),
195193
identTyR->getLoc(), UnqualifiedLookupOptions());
196194

197-
auto &ctx = witness->getASTContext();
198195
auto results =
199196
evaluateOrDefault(ctx.evaluator, UnqualifiedLookupRequest{desc}, {});
200197

@@ -207,6 +204,34 @@ struct TypeReprCycleCheckWalker : ASTWalker {
207204
}
208205
}
209206

207+
if (auto *memberTyR = dyn_cast<MemberTypeRepr>(T)) {
208+
// If we're looking at a member type`Foo.Bar`, check `Foo` recursively.
209+
auto *baseTyR = memberTyR->getBaseComponent();
210+
baseTyR->walk(*this);
211+
212+
// If we're inferring `Foo`, don't look at a witness mentioning `Self.Foo`.
213+
if (auto *identTyR = dyn_cast<SimpleIdentTypeRepr>(baseTyR)) {
214+
if (identTyR->getNameRef().getBaseIdentifier() == ctx.Id_Self) {
215+
// But if qualified lookup can find a type with this name without
216+
// looking into protocol members, don't skip the witness, since this
217+
// type might be a candidate witness.
218+
SmallVector<ValueDecl *, 2> results;
219+
witness->getInnermostDeclContext()->lookupQualified(
220+
witness->getDeclContext()->getSelfTypeInContext(),
221+
identTyR->getNameRef(), SourceLoc(), NLOptions(), results);
222+
223+
// Ok, resolving this member type would trigger associated type
224+
// inference recursively. We're going to skip this witness.
225+
if (results.empty()) {
226+
found = true;
227+
return Action::Stop();
228+
}
229+
}
230+
}
231+
232+
return Action::SkipChildren();
233+
}
234+
210235
return Action::Continue();
211236
}
212237

@@ -296,7 +321,7 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
296321
abort();
297322
}
298323

299-
TypeReprCycleCheckWalker cycleCheck(allUnresolved);
324+
TypeReprCycleCheckWalker cycleCheck(dc->getASTContext(), allUnresolved);
300325

301326
InferredAssociatedTypesByWitnesses result;
302327

@@ -917,26 +942,31 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitness(ValueDecl *req,
917942
}
918943

919944
AssociatedTypeDecl *AssociatedTypeInference::findDefaultedAssociatedType(
945+
DeclContext *dc,
946+
NominalTypeDecl *adoptee,
920947
AssociatedTypeDecl *assocType) {
921948
// If this associated type has a default, we're done.
922949
if (assocType->hasDefaultDefinitionType())
923950
return assocType;
924951

925-
// Look at overridden associated types.
952+
// Otherwise, look for all associated types with the same name along all the
953+
// protocols that the adoptee conforms to.
954+
SmallVector<ValueDecl *, 4> decls;
955+
auto options = NL_ProtocolMembers | NL_OnlyTypes;
956+
dc->lookupQualified(adoptee, DeclNameRef(assocType->getName()),
957+
SourceLoc(), options, decls);
958+
926959
SmallPtrSet<CanType, 4> canonicalTypes;
927960
SmallVector<AssociatedTypeDecl *, 2> results;
928-
for (auto overridden : assocType->getOverriddenDecls()) {
929-
auto overriddenDefault = findDefaultedAssociatedType(overridden);
930-
if (!overriddenDefault) continue;
931-
932-
Type overriddenType =
933-
overriddenDefault->getDefaultDefinitionType();
934-
assert(overriddenType);
935-
if (!overriddenType) continue;
961+
for (auto *decl : decls) {
962+
if (auto *assocDecl = dyn_cast<AssociatedTypeDecl>(decl)) {
963+
auto defaultType = assocDecl->getDefaultDefinitionType();
964+
if (!defaultType) continue;
936965

937-
CanType key = overriddenType->getCanonicalType();
966+
CanType key = defaultType->getCanonicalType();
938967
if (canonicalTypes.insert(key).second)
939-
results.push_back(overriddenDefault);
968+
results.push_back(assocDecl);
969+
}
940970
}
941971

942972
// If there was a single result, return it.
@@ -997,7 +1027,8 @@ llvm::Optional<AbstractTypeWitness>
9971027
AssociatedTypeInference::computeDefaultTypeWitness(
9981028
AssociatedTypeDecl *assocType) const {
9991029
// Go find a default definition.
1000-
auto *const defaultedAssocType = findDefaultedAssociatedType(assocType);
1030+
auto *const defaultedAssocType = findDefaultedAssociatedType(
1031+
dc, dc->getSelfNominalTypeDecl(), assocType);
10011032
if (!defaultedAssocType)
10021033
return llvm::None;
10031034

test/decl/protocol/req/assoc_type_inference_cycle.swift

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,39 @@ public enum CaseWitness: CaseProtocol {
9595
case b(_: A)
9696
case c(_: A)
9797
}
98+
99+
// rdar://119499800 #1
100+
public typealias A8 = Batch.Iterator
101+
102+
public struct Batch: Collection {
103+
public typealias Element = Int
104+
public typealias Index = Array<Element>.Index
105+
106+
var elements: [Element]
107+
108+
init(_ elements: some Collection<Element>) {
109+
self.elements = Array(elements)
110+
}
111+
112+
public var startIndex: Index { return elements.startIndex }
113+
public var endIndex: Index { return elements.endIndex }
114+
115+
public subscript(index: Index) -> Iterator.Element {
116+
return elements[index]
117+
}
118+
119+
public func index(after i: Index) -> Index {
120+
return elements.index(after: i)
121+
}
122+
}
123+
124+
// rdar://119499800 #2
125+
public typealias A9 = LogTypes.RawValue
126+
127+
public struct LogTypes: OptionSet {
128+
public init(rawValue: Self.RawValue) {
129+
self.rawValue = rawValue
130+
}
131+
132+
public let rawValue: Int
133+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: %target-typecheck-verify-swift
2+
3+
protocol P1 {
4+
associatedtype A
5+
6+
func f(_: A)
7+
}
8+
9+
protocol P2: P1 {
10+
associatedtype A = Int
11+
}
12+
13+
func foo<T: P1>(_: T.Type) -> T.A.Type {}
14+
15+
_ = foo(S.self)
16+
17+
struct S: P2 {
18+
func f(_: A) {}
19+
}

test/decl/protocol/req/associated_type_tuple.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@ protocol P1 {
99
extension Tuple: P1 where repeat each T: P1 {} // expected-error {{type '(repeat each T)' does not conform to protocol 'P1'}}
1010

1111
protocol P2 {
12-
associatedtype A = Int // expected-note {{default type 'Int' for associated type 'A' (from protocol 'P2') is unsuitable for tuple conformance; the associated type requirement must be fulfilled by a type alias with underlying type '(repeat (each T).A)'}}
12+
associatedtype B = Int // expected-note {{default type 'Int' for associated type 'B' (from protocol 'P2') is unsuitable for tuple conformance; the associated type requirement must be fulfilled by a type alias with underlying type '(repeat (each T).B)'}}
1313
}
1414

1515
extension Tuple: P2 where repeat each T: P2 {} // expected-error {{type '(repeat each T)' does not conform to protocol 'P2'}}
1616

1717
protocol P3 {
18-
associatedtype A // expected-note {{unable to infer associated type 'A' for protocol 'P3'}}
19-
func f() -> A
18+
associatedtype C // expected-note {{unable to infer associated type 'C' for protocol 'P3'}}
19+
func f() -> C
2020
}
2121

2222
extension Tuple: P3 where repeat each T: P3 { // expected-error {{type '(repeat each T)' does not conform to protocol 'P3'}}
23-
func f() -> Int {} // expected-note {{cannot infer 'A' = 'Int' in tuple conformance because the associated type requirement must be fulfilled by a type alias with underlying type '(repeat (each T).A)'}}
23+
func f() -> Int {} // expected-note {{cannot infer 'C' = 'Int' in tuple conformance because the associated type requirement must be fulfilled by a type alias with underlying type '(repeat (each T).C)'}}
2424
}

0 commit comments

Comments
 (0)