Skip to content

Commit 61e418c

Browse files
authored
Merge pull request #85513 from slavapestov/refactor-csbindings
Refactor BindingSet a little bit
2 parents a3f9e0a + ce3b5eb commit 61e418c

File tree

6 files changed

+119
-108
lines changed

6 files changed

+119
-108
lines changed

include/swift/Sema/CSBindings.h

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,6 @@ class BindingSet {
469469
/// checking.
470470
bool isViable(PotentialBinding &binding, bool isTransitive);
471471

472-
explicit operator bool() const {
473-
return hasViableBindings() || isDirectHole();
474-
}
475-
476472
/// Determine whether this set has any "viable" (or non-hole) bindings.
477473
///
478474
/// A viable binding could be - a direct or transitive binding
@@ -486,6 +482,12 @@ class BindingSet {
486482
!Defaults.empty();
487483
}
488484

485+
/// Determine whether this set can be chosen as the next binding set
486+
/// to attempt.
487+
bool isViable() const {
488+
return hasViableBindings() || isDirectHole();
489+
}
490+
489491
ArrayRef<Constraint *> getConformanceRequirements() const {
490492
return Protocols;
491493
}
@@ -544,6 +546,8 @@ class BindingSet {
544546
/// Check if this binding is favored over a conjunction.
545547
bool favoredOverConjunction(Constraint *conjunction) const;
546548

549+
void inferTransitiveKeyPathBindings();
550+
547551
/// Detect `subtype` relationship between two type variables and
548552
/// attempt to infer supertype bindings transitively e.g.
549553
///
@@ -553,19 +557,27 @@ class BindingSet {
553557
///
554558
/// \param inferredBindings The set of all bindings inferred for type
555559
/// variables in the workset.
556-
void inferTransitiveBindings();
560+
void inferTransitiveSupertypeBindings();
561+
562+
void inferTransitiveUnresolvedMemberRefBindings();
557563

558564
/// Detect subtype, conversion or equivalence relationship
559565
/// between two type variables and attempt to propagate protocol
560566
/// requirements down the subtype or equivalence chain.
561567
void inferTransitiveProtocolRequirements();
562568

563-
/// Finalize binding computation for this type variable by
564-
/// inferring bindings from context e.g. transitive bindings.
569+
/// Check whether the given binding set covers any of the
570+
/// literal protocols associated with this type variable.
571+
void determineLiteralCoverage();
572+
573+
/// Finalize binding computation for key path type variables.
565574
///
566575
/// \returns true if finalization successful (which makes binding set viable),
567576
/// and false otherwise.
568-
bool finalize(bool transitive);
577+
bool finalizeKeyPathBindings();
578+
579+
/// Handle diagnostics of unresolved member chains.
580+
void finalizeUnresolvedMemberChainResult();
569581

570582
static BindingScore formBindingScore(const BindingSet &b);
571583

@@ -590,10 +602,6 @@ class BindingSet {
590602

591603
void addDefault(Constraint *constraint);
592604

593-
/// Check whether the given binding set covers any of the
594-
/// literal protocols associated with this type variable.
595-
void determineLiteralCoverage();
596-
597605
StringRef getLiteralBindingKind(LiteralBindingKind K) const {
598606
#define ENTRY(Kind, String) \
599607
case LiteralBindingKind::Kind: \

lib/Sema/CSBindings.cpp

Lines changed: 72 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,7 @@ bool BindingSet::isDirectHole() const {
103103
if (!CS.shouldAttemptFixes())
104104
return false;
105105

106-
return Bindings.empty() && getNumViableLiteralBindings() == 0 &&
107-
Defaults.empty() && TypeVar->getImpl().canBindToHole();
106+
return !hasViableBindings() && TypeVar->getImpl().canBindToHole();
108107
}
109108

110109
static bool isGenericParameter(TypeVariableType *TypeVar) {
@@ -494,9 +493,7 @@ void BindingSet::inferTransitiveProtocolRequirements() {
494493
} while (!workList.empty());
495494
}
496495

497-
void BindingSet::inferTransitiveBindings() {
498-
using BindingKind = AllowedBindingKind;
499-
496+
void BindingSet::inferTransitiveKeyPathBindings() {
500497
// If the current type variable represents a key path root type
501498
// let's try to transitively infer its type through bindings of
502499
// a key path type.
@@ -551,15 +548,17 @@ void BindingSet::inferTransitiveBindings() {
551548
}
552549
} else {
553550
addBinding(
554-
binding.withSameSource(inferredRootTy, BindingKind::Exact),
551+
binding.withSameSource(inferredRootTy, AllowedBindingKind::Exact),
555552
/*isTransitive=*/true);
556553
}
557554
}
558555
}
559556
}
560557
}
561558
}
559+
}
562560

561+
void BindingSet::inferTransitiveSupertypeBindings() {
563562
for (const auto &entry : Info.SupertypeOf) {
564563
auto &node = CS.getConstraintGraph()[entry.first];
565564
if (!node.hasBindingSet())
@@ -609,8 +608,8 @@ void BindingSet::inferTransitiveBindings() {
609608
// either be Exact or Supertypes in order for it to make sense
610609
// to add Supertype bindings based on the relationship between
611610
// our type variables.
612-
if (binding.Kind != BindingKind::Exact &&
613-
binding.Kind != BindingKind::Supertypes)
611+
if (binding.Kind != AllowedBindingKind::Exact &&
612+
binding.Kind != AllowedBindingKind::Supertypes)
614613
continue;
615614

616615
auto type = binding.BindingType;
@@ -621,12 +620,49 @@ void BindingSet::inferTransitiveBindings() {
621620
if (ConstraintSystem::typeVarOccursInType(TypeVar, type))
622621
continue;
623622

624-
addBinding(binding.withSameSource(type, BindingKind::Supertypes),
623+
addBinding(binding.withSameSource(type, AllowedBindingKind::Supertypes),
625624
/*isTransitive=*/true);
626625
}
627626
}
628627
}
629628

629+
void BindingSet::inferTransitiveUnresolvedMemberRefBindings() {
630+
if (!hasViableBindings()) {
631+
if (auto *locator = TypeVar->getImpl().getLocator()) {
632+
if (locator->isLastElement<LocatorPathElt::MemberRefBase>()) {
633+
// If this is a base of an unresolved member chain, as a last
634+
// resort effort let's infer base to be a protocol type based
635+
// on contextual conformance requirements.
636+
//
637+
// This allows us to find solutions in cases like this:
638+
//
639+
// \code
640+
// func foo<T: P>(_: T) {}
641+
// foo(.bar) <- `.bar` should be a static member of `P`.
642+
// \endcode
643+
inferTransitiveProtocolRequirements();
644+
645+
if (TransitiveProtocols.has_value()) {
646+
for (auto *constraint : *TransitiveProtocols) {
647+
Type protocolTy = constraint->getSecondType();
648+
649+
// Compiler-known marker protocols cannot be extended with members,
650+
// so do not consider them.
651+
if (auto p = protocolTy->getAs<ProtocolType>()) {
652+
if (ProtocolDecl *decl = p->getDecl())
653+
if (decl->getKnownProtocolKind() && decl->isMarkerProtocol())
654+
continue;
655+
}
656+
657+
addBinding({protocolTy, AllowedBindingKind::Exact, constraint},
658+
/*isTransitive=*/false);
659+
}
660+
}
661+
}
662+
}
663+
}
664+
}
665+
630666
static Type getKeyPathType(ASTContext &ctx, KeyPathCapability capability,
631667
Type rootType, Type valueType) {
632668
KeyPathMutability mutability;
@@ -664,51 +700,11 @@ static Type getKeyPathType(ASTContext &ctx, KeyPathCapability capability,
664700
return keyPathTy;
665701
}
666702

667-
bool BindingSet::finalize(bool transitive) {
668-
if (transitive)
669-
inferTransitiveBindings();
670-
671-
determineLiteralCoverage();
672-
703+
bool BindingSet::finalizeKeyPathBindings() {
673704
if (auto *locator = TypeVar->getImpl().getLocator()) {
674-
if (locator->isLastElement<LocatorPathElt::MemberRefBase>()) {
675-
// If this is a base of an unresolved member chain, as a last
676-
// resort effort let's infer base to be a protocol type based
677-
// on contextual conformance requirements.
678-
//
679-
// This allows us to find solutions in cases like this:
680-
//
681-
// \code
682-
// func foo<T: P>(_: T) {}
683-
// foo(.bar) <- `.bar` should be a static member of `P`.
684-
// \endcode
685-
if (transitive && !hasViableBindings()) {
686-
inferTransitiveProtocolRequirements();
687-
688-
if (TransitiveProtocols.has_value()) {
689-
for (auto *constraint : *TransitiveProtocols) {
690-
Type protocolTy = constraint->getSecondType();
691-
692-
// Compiler-known marker protocols cannot be extended with members,
693-
// so do not consider them.
694-
if (auto p = protocolTy->getAs<ProtocolType>()) {
695-
if (ProtocolDecl *decl = p->getDecl())
696-
if (decl->getKnownProtocolKind() && decl->isMarkerProtocol())
697-
continue;
698-
}
699-
700-
addBinding({protocolTy, AllowedBindingKind::Exact, constraint},
701-
/*isTransitive=*/false);
702-
}
703-
}
704-
}
705-
}
706-
707705
if (TypeVar->getImpl().isKeyPathType()) {
708706
auto &ctx = CS.getASTContext();
709-
710-
auto *keyPathLoc = TypeVar->getImpl().getLocator();
711-
auto *keyPath = castToExpr<KeyPathExpr>(keyPathLoc->getAnchor());
707+
auto *keyPath = castToExpr<KeyPathExpr>(locator->getAnchor());
712708

713709
bool isValid;
714710
std::optional<KeyPathCapability> capability;
@@ -775,7 +771,7 @@ bool BindingSet::finalize(bool transitive) {
775771
auto keyPathTy = getKeyPathType(ctx, *capability, rootTy,
776772
CS.getKeyPathValueType(keyPath));
777773
updatedBindings.insert(
778-
{keyPathTy, AllowedBindingKind::Exact, keyPathLoc});
774+
{keyPathTy, AllowedBindingKind::Exact, locator});
779775
} else if (CS.shouldAttemptFixes()) {
780776
auto fixedRootTy = CS.getFixedType(rootTy);
781777
// If key path is structurally correct and has a resolved root
@@ -802,10 +798,14 @@ bool BindingSet::finalize(bool transitive) {
802798

803799
Bindings = std::move(updatedBindings);
804800
Defaults.clear();
805-
806-
return true;
807801
}
802+
}
808803

804+
return true;
805+
}
806+
807+
void BindingSet::finalizeUnresolvedMemberChainResult() {
808+
if (auto *locator = TypeVar->getImpl().getLocator()) {
809809
if (CS.shouldAttemptFixes() &&
810810
locator->isLastElement<LocatorPathElt::UnresolvedMemberChainResult>()) {
811811
// Let's see whether this chain is valid, if it isn't then to avoid
@@ -828,8 +828,6 @@ bool BindingSet::finalize(bool transitive) {
828828
}
829829
}
830830
}
831-
832-
return true;
833831
}
834832

835833
void BindingSet::addBinding(PotentialBinding binding, bool isTransitive) {
@@ -1143,37 +1141,6 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
11431141
node.initBindingSet();
11441142
}
11451143

1146-
// Determine whether given type variable with its set of bindings is
1147-
// viable to be attempted on the next step of the solver. If type variable
1148-
// has no "direct" bindings of any kind e.g. direct bindings to concrete
1149-
// types, default types from "defaultable" constraints or literal
1150-
// conformances, such type variable is not viable to be evaluated to be
1151-
// attempted next.
1152-
auto isViableForRanking = [this](const BindingSet &bindings) -> bool {
1153-
auto *typeVar = bindings.getTypeVariable();
1154-
1155-
// Key path root type variable is always viable because it can be
1156-
// transitively inferred from key path type during binding set
1157-
// finalization.
1158-
if (typeVar->getImpl().isKeyPathRoot())
1159-
return true;
1160-
1161-
// Type variable representing a base of unresolved member chain should
1162-
// always be considered viable for ranking since it's allow to infer
1163-
// types from transitive protocol requirements.
1164-
if (auto *locator = typeVar->getImpl().getLocator()) {
1165-
if (locator->isLastElement<LocatorPathElt::MemberRefBase>())
1166-
return true;
1167-
}
1168-
1169-
// If type variable is marked as a potential hole there is always going
1170-
// to be at least one binding available for it.
1171-
if (shouldAttemptFixes() && typeVar->getImpl().canBindToHole())
1172-
return true;
1173-
1174-
return bool(bindings);
1175-
};
1176-
11771144
// Now let's see if we could infer something for related type
11781145
// variables based on other bindings.
11791146
for (auto *typeVar : getTypeVariables()) {
@@ -1183,6 +1150,16 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
11831150

11841151
auto &bindings = node.getBindingSet();
11851152

1153+
// Special handling for key paths.
1154+
bindings.inferTransitiveKeyPathBindings();
1155+
if (!bindings.finalizeKeyPathBindings())
1156+
continue;
1157+
1158+
// Special handling for "leading-dot" unresolved member references,
1159+
// like .foo.
1160+
bindings.inferTransitiveUnresolvedMemberRefBindings();
1161+
bindings.finalizeUnresolvedMemberChainResult();
1162+
11861163
// Before attempting to infer transitive bindings let's check
11871164
// whether there are any viable "direct" bindings associated with
11881165
// current type variable, if there are none - it means that this type
@@ -1193,12 +1170,12 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
11931170
// associated with given type variable, any default constraints,
11941171
// or any conformance requirements to literal protocols with can
11951172
// produce a default type.
1196-
bool isViable = isViableForRanking(bindings);
1173+
bool isViable = bindings.isViable();
11971174

1198-
if (!bindings.finalize(true))
1199-
continue;
1175+
bindings.inferTransitiveSupertypeBindings();
1176+
bindings.determineLiteralCoverage();
12001177

1201-
if (!bindings || !isViable)
1178+
if (!isViable)
12021179
continue;
12031180

12041181
onCandidate(bindings);
@@ -1591,7 +1568,10 @@ BindingSet ConstraintSystem::getBindingsFor(TypeVariableType *typeVar) {
15911568
assert(!typeVar->getImpl().getFixedType(nullptr) && "has a fixed type");
15921569

15931570
BindingSet bindings(*this, typeVar, CG[typeVar].getPotentialBindings());
1594-
bindings.finalize(false);
1571+
1572+
(void) bindings.finalizeKeyPathBindings();
1573+
bindings.finalizeUnresolvedMemberChainResult();
1574+
bindings.determineLiteralCoverage();
15951575

15961576
return bindings;
15971577
}

lib/Sema/CSOptimizer.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1105,7 +1105,9 @@ static void determineBestChoicesInContext(
11051105
// Simply adding it as a binding won't work because if the second argument
11061106
// is non-optional the overload that returns `T?` would still have a lower
11071107
// score.
1108-
if (!bindingSet && isNilCoalescingOperator(disjunction)) {
1108+
if (!bindingSet.hasViableBindings() &&
1109+
!bindingSet.isDirectHole() &&
1110+
isNilCoalescingOperator(disjunction)) {
11091111
auto &cg = cs.getConstraintGraph();
11101112
if (llvm::any_of(cg[typeVar].getConstraints(),
11111113
[&typeVar](Constraint *constraint) {

lib/Sema/ConstraintGraph.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -921,7 +921,8 @@ bool ConstraintGraph::contractEdges() {
921921
// us enough information to decided on l-valueness.
922922
if (tyvar1->getImpl().canBindToInOut()) {
923923
bool isNotContractable = true;
924-
if (auto bindings = CS.getBindingsFor(tyvar1)) {
924+
auto bindings = CS.getBindingsFor(tyvar1);
925+
if (bindings.isViable()) {
925926
// Holes can't be contracted.
926927
if (bindings.isHole())
927928
continue;

unittests/Sema/BindingInferenceTests.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,15 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) {
125125

126126
cs.getConstraintGraph()[floatLiteralTy].initBindingSet();
127127

128-
bindings.finalize(/*transitive=*/true);
128+
bindings.inferTransitiveKeyPathBindings();
129+
(void) bindings.finalizeKeyPathBindings();
130+
131+
bindings.inferTransitiveUnresolvedMemberRefBindings();
132+
bindings.finalizeUnresolvedMemberChainResult();
133+
134+
bindings.inferTransitiveSupertypeBindings();
135+
136+
bindings.determineLiteralCoverage();
129137

130138
// Inferred a single transitive binding through `$T_float`.
131139
ASSERT_EQ(bindings.Bindings.size(), (unsigned)1);

0 commit comments

Comments
 (0)