Skip to content

Commit ba90861

Browse files
committed
Sema: Try a little harder to infer associated types to generic parameters if all else fails
If we have an abstract witness, we don't attempt a generic parameter binding at all. But if simplifying the abstract witness failed, we should still attempt it. This would be cleaner as a disjunction in the solver but I want to change behavior as little as possible, so this adds a new fallback that we run when all else fails. Fixes rdar://problem/123345520.
1 parent 29b08a9 commit ba90861

File tree

3 files changed

+72
-38
lines changed

3 files changed

+72
-38
lines changed

lib/Sema/AssociatedTypeInference.cpp

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,10 @@ class AssociatedTypeInference {
10171017
std::pair<Type, TypeDecl *>
10181018
computeDerivedTypeWitness(AssociatedTypeDecl *assocType);
10191019

1020+
/// See if we have a generic parameter named the same as this associated
1021+
/// type.
1022+
Type computeGenericParamWitness(AssociatedTypeDecl *assocType) const;
1023+
10201024
/// Compute a type witness without using a specific potential witness.
10211025
llvm::Optional<AbstractTypeWitness>
10221026
computeAbstractTypeWitness(AssociatedTypeDecl *assocType);
@@ -2657,6 +2661,28 @@ AssociatedTypeInference::computeAbstractTypeWitness(
26572661
return llvm::None;
26582662
}
26592663

2664+
/// Look for a generic parameter that matches the name of the
2665+
/// associated type.
2666+
Type AssociatedTypeInference::computeGenericParamWitness(
2667+
AssociatedTypeDecl *assocType) const {
2668+
if (auto genericSig = dc->getGenericSignatureOfContext()) {
2669+
// Ignore the generic parameters for AsyncIteratorProtocol.Failure and
2670+
// AsyncSequence.Failure.
2671+
if (!isAsyncIteratorProtocolFailure(assocType)) {
2672+
for (auto *gp : genericSig.getInnermostGenericParams()) {
2673+
// Packs cannot witness associated type requirements.
2674+
if (gp->isParameterPack())
2675+
continue;
2676+
2677+
if (gp->getName() == assocType->getName())
2678+
return dc->mapTypeIntoContext(gp);
2679+
}
2680+
}
2681+
}
2682+
2683+
return Type();
2684+
}
2685+
26602686
void AssociatedTypeInference::collectAbstractTypeWitnesses(
26612687
TypeWitnessSystem &system,
26622688
ArrayRef<AssociatedTypeDecl *> unresolvedAssocTypes) const {
@@ -2705,39 +2731,14 @@ void AssociatedTypeInference::collectAbstractTypeWitnesses(
27052731
if (system.hasResolvedTypeWitness(assocType->getName()))
27062732
continue;
27072733

2708-
bool found = false;
2709-
2710-
// Look for a generic parameter that matches the name of the
2711-
// associated type.
2712-
if (auto genericSig = dc->getGenericSignatureOfContext()) {
2713-
// Ignore the generic parameters for AsyncIteratorProtocol.Failure and
2714-
// AsyncSequence.Failure.
2715-
if (!isAsyncIteratorProtocolFailure(assocType)) {
2716-
for (auto *gp : genericSig.getInnermostGenericParams()) {
2717-
// Packs cannot witness associated type requirements.
2718-
if (gp->isParameterPack())
2719-
continue;
2720-
2721-
if (gp->getName() == assocType->getName()) {
2722-
system.addTypeWitness(assocType->getName(),
2723-
dc->mapTypeIntoContext(gp),
2724-
/*preferred=*/true);
2725-
found = true;
2726-
break;
2727-
}
2728-
}
2729-
}
2730-
}
2731-
2732-
if (!found) {
2733-
// If we find a default type definition, feed it to the system.
2734-
if (const auto &typeWitness = computeDefaultTypeWitness(assocType)) {
2735-
bool preferred = (typeWitness->getDefaultedAssocType()->getDeclContext()
2736-
== conformance->getProtocol());
2737-
system.addDefaultTypeWitness(typeWitness->getType(),
2738-
typeWitness->getDefaultedAssocType(),
2739-
preferred);
2740-
}
2734+
if (auto gpType = computeGenericParamWitness(assocType)) {
2735+
system.addTypeWitness(assocType->getName(), gpType, /*preferred=*/true);
2736+
} else if (const auto &typeWitness = computeDefaultTypeWitness(assocType)) {
2737+
bool preferred = (typeWitness->getDefaultedAssocType()->getDeclContext()
2738+
== conformance->getProtocol());
2739+
system.addDefaultTypeWitness(typeWitness->getType(),
2740+
typeWitness->getDefaultedAssocType(),
2741+
preferred);
27412742
}
27422743
}
27432744
}
@@ -3156,8 +3157,15 @@ AssociatedTypeDecl *AssociatedTypeInference::inferAbstractTypeWitnesses(
31563157

31573158
// If simplification failed, give up.
31583159
if (type->hasTypeParameter()) {
3159-
LLVM_DEBUG(llvm::dbgs() << "-- Simplification failed: " << type << "\n");
3160-
return assocType;
3160+
if (auto gpType = computeGenericParamWitness(assocType)) {
3161+
LLVM_DEBUG(llvm::dbgs() << "-- Found generic parameter as last resort: "
3162+
<< gpType << "\n");
3163+
type = gpType;
3164+
typeWitnesses.insert(assocType, {type, reqDepth});
3165+
} else {
3166+
LLVM_DEBUG(llvm::dbgs() << "-- Simplification failed: " << type << "\n");
3167+
return assocType;
3168+
}
31613169
}
31623170

31633171
if (const auto failed =

test/decl/protocol/req/associated_type_inference_stdlib_3.swift

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
// RUN: not %target-typecheck-verify-swift -enable-experimental-associated-type-inference
1+
// RUN: %target-typecheck-verify-swift -enable-experimental-associated-type-inference
22
// RUN: not %target-typecheck-verify-swift -disable-experimental-associated-type-inference
33

4-
// FIXME: Get this passing with -enable-experimental-associated-type-inference again.
5-
64
struct FooIterator<T: Sequence>: IteratorProtocol {
75
typealias Element = T.Element
86

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %target-typecheck-verify-swift -enable-experimental-associated-type-inference
2+
// RUN: %target-typecheck-verify-swift -disable-experimental-associated-type-inference
3+
4+
struct G1<T> {}
5+
6+
struct G2<A, B>: AP {
7+
func f1(_: G1<(B) -> A>, _: G1<B>) -> G1<A> { fatalError() }
8+
func f2<C>(_: (A) -> C) -> G1<C> { fatalError() }
9+
}
10+
11+
protocol OP: EP {
12+
associatedtype L
13+
associatedtype R
14+
15+
func f1(_: G1<L>, _: G1<R>) -> G1<A>
16+
}
17+
18+
extension OP {
19+
func f1(_: G1<L>?, _: G1<R>?) -> G1<A> { fatalError() }
20+
}
21+
22+
protocol AP: OP where L == (B) -> A, R == B {}
23+
24+
protocol EP {
25+
associatedtype A
26+
associatedtype B
27+
func f2<C>(_: (A) -> C) -> G1<C>
28+
}

0 commit comments

Comments
 (0)