|
22 | 22 | using namespace swift;
|
23 | 23 | using namespace constraints;
|
24 | 24 |
|
| 25 | +void ConstraintSystem::PotentialBindings::inferTransitiveProtocolRequirements( |
| 26 | + const ConstraintSystem &cs, |
| 27 | + llvm::SmallDenseMap<TypeVariableType *, ConstraintSystem::PotentialBindings> |
| 28 | + &inferredBindings) { |
| 29 | + if (TransitiveProtocols) |
| 30 | + return; |
| 31 | + |
| 32 | + llvm::SmallVector<std::pair<TypeVariableType *, TypeVariableType *>, 4> |
| 33 | + workList; |
| 34 | + llvm::SmallPtrSet<TypeVariableType *, 4> visitedRelations; |
| 35 | + |
| 36 | + llvm::SmallDenseMap<TypeVariableType *, SmallPtrSet<Constraint *, 4>, 4> |
| 37 | + protocols; |
| 38 | + |
| 39 | + auto addToWorkList = [&](TypeVariableType *parent, |
| 40 | + TypeVariableType *typeVar) { |
| 41 | + if (visitedRelations.insert(typeVar).second) |
| 42 | + workList.push_back({parent, typeVar}); |
| 43 | + }; |
| 44 | + |
| 45 | + auto propagateProtocolsTo = |
| 46 | + [&protocols](TypeVariableType *dstVar, |
| 47 | + const SmallVectorImpl<Constraint *> &direct, |
| 48 | + const SmallPtrSetImpl<Constraint *> &transitive) { |
| 49 | + auto &destination = protocols[dstVar]; |
| 50 | + |
| 51 | + for (auto *protocol : direct) |
| 52 | + destination.insert(protocol); |
| 53 | + |
| 54 | + for (auto *protocol : transitive) |
| 55 | + destination.insert(protocol); |
| 56 | + }; |
| 57 | + |
| 58 | + addToWorkList(nullptr, TypeVar); |
| 59 | + |
| 60 | + do { |
| 61 | + auto *currentVar = workList.back().second; |
| 62 | + |
| 63 | + auto cachedBindings = inferredBindings.find(currentVar); |
| 64 | + if (cachedBindings == inferredBindings.end()) { |
| 65 | + workList.pop_back(); |
| 66 | + continue; |
| 67 | + } |
| 68 | + |
| 69 | + auto &bindings = cachedBindings->getSecond(); |
| 70 | + |
| 71 | + // If current variable already has transitive protocol |
| 72 | + // conformances inferred, there is no need to look deeper |
| 73 | + // into subtype/equivalence chain. |
| 74 | + if (bindings.TransitiveProtocols) { |
| 75 | + TypeVariableType *parent = nullptr; |
| 76 | + std::tie(parent, currentVar) = workList.pop_back_val(); |
| 77 | + assert(parent); |
| 78 | + propagateProtocolsTo(parent, bindings.Protocols, |
| 79 | + *bindings.TransitiveProtocols); |
| 80 | + continue; |
| 81 | + } |
| 82 | + |
| 83 | + for (const auto &entry : bindings.SubtypeOf) |
| 84 | + addToWorkList(currentVar, entry.first); |
| 85 | + |
| 86 | + // If current type variable is part of an equivalence |
| 87 | + // class, make it a "representative" and let's it infer |
| 88 | + // supertypes and direct protocol requirements from |
| 89 | + // other members. |
| 90 | + for (const auto &entry : bindings.EquivalentTo) { |
| 91 | + auto eqBindings = inferredBindings.find(entry.first); |
| 92 | + if (eqBindings != inferredBindings.end()) { |
| 93 | + const auto &bindings = eqBindings->getSecond(); |
| 94 | + |
| 95 | + llvm::SmallPtrSet<Constraint *, 2> placeholder; |
| 96 | + // Add any direct protocols from members of the |
| 97 | + // equivalence class, so they could be propagated |
| 98 | + // to all of the members. |
| 99 | + propagateProtocolsTo(currentVar, bindings.Protocols, placeholder); |
| 100 | + |
| 101 | + // Since type variables are equal, current type variable |
| 102 | + // becomes a subtype to any supertype found in the current |
| 103 | + // equivalence class. |
| 104 | + for (const auto &eqEntry : bindings.SubtypeOf) |
| 105 | + addToWorkList(currentVar, eqEntry.first); |
| 106 | + } |
| 107 | + } |
| 108 | + |
| 109 | + // More subtype/equivalences relations have been added. |
| 110 | + if (workList.back().second != currentVar) |
| 111 | + continue; |
| 112 | + |
| 113 | + TypeVariableType *parent = nullptr; |
| 114 | + std::tie(parent, currentVar) = workList.pop_back_val(); |
| 115 | + |
| 116 | + // At all of the protocols associated with current type variable |
| 117 | + // are transitive to its parent, propogate them down the subtype/equivalence |
| 118 | + // chain. |
| 119 | + if (parent) { |
| 120 | + propagateProtocolsTo(parent, bindings.Protocols, protocols[currentVar]); |
| 121 | + } |
| 122 | + |
| 123 | + auto inferredProtocols = std::move(protocols[currentVar]); |
| 124 | + |
| 125 | + llvm::SmallPtrSet<Constraint *, 4> protocolsForEquivalence; |
| 126 | + |
| 127 | + // Equivalence class should contain both: |
| 128 | + // - direct protocol requirements of the current type |
| 129 | + // variable; |
| 130 | + // - all of the transitive protocols inferred through |
| 131 | + // the members of the equivalence class. |
| 132 | + { |
| 133 | + protocolsForEquivalence.insert(bindings.Protocols.begin(), |
| 134 | + bindings.Protocols.end()); |
| 135 | + |
| 136 | + protocolsForEquivalence.insert(inferredProtocols.begin(), |
| 137 | + inferredProtocols.end()); |
| 138 | + } |
| 139 | + |
| 140 | + // Propogate inferred protocols to all of the members of the |
| 141 | + // equivalence class. |
| 142 | + for (const auto &equivalence : bindings.EquivalentTo) { |
| 143 | + auto eqBindings = inferredBindings.find(equivalence.first); |
| 144 | + if (eqBindings != inferredBindings.end()) { |
| 145 | + auto &bindings = eqBindings->getSecond(); |
| 146 | + bindings.TransitiveProtocols.emplace(protocolsForEquivalence); |
| 147 | + } |
| 148 | + } |
| 149 | + |
| 150 | + // Update the bindings associated with current type variable, |
| 151 | + // to avoid repeating this inference process. |
| 152 | + bindings.TransitiveProtocols.emplace(std::move(inferredProtocols)); |
| 153 | + } while (!workList.empty()); |
| 154 | +} |
| 155 | + |
25 | 156 | void ConstraintSystem::PotentialBindings::inferTransitiveBindings(
|
26 | 157 | ConstraintSystem &cs, llvm::SmallPtrSetImpl<CanType> &existingTypes,
|
27 | 158 | const llvm::SmallDenseMap<TypeVariableType *,
|
28 | 159 | ConstraintSystem::PotentialBindings>
|
29 | 160 | &inferredBindings) {
|
30 | 161 | using BindingKind = ConstraintSystem::AllowedBindingKind;
|
31 | 162 |
|
32 |
| - llvm::SmallVector<Constraint *, 4> conversions; |
33 |
| - // First, let's collect all of the conversions associated |
34 |
| - // with this type variable. |
35 |
| - llvm::copy_if( |
36 |
| - Sources, std::back_inserter(conversions), |
37 |
| - [&](const Constraint *constraint) -> bool { |
38 |
| - if (constraint->getKind() != ConstraintKind::Subtype && |
39 |
| - constraint->getKind() != ConstraintKind::Conversion && |
40 |
| - constraint->getKind() != ConstraintKind::ArgumentConversion && |
41 |
| - constraint->getKind() != ConstraintKind::OperatorArgumentConversion) |
42 |
| - return false; |
43 |
| - |
44 |
| - auto rhs = cs.simplifyType(constraint->getSecondType()); |
45 |
| - return rhs->getAs<TypeVariableType>() == TypeVar; |
46 |
| - }); |
47 |
| - |
48 |
| - for (auto *constraint : conversions) { |
49 |
| - auto *tv = |
50 |
| - cs.simplifyType(constraint->getFirstType())->getAs<TypeVariableType>(); |
51 |
| - if (!tv || tv == TypeVar) |
52 |
| - continue; |
53 |
| - |
54 |
| - auto relatedBindings = inferredBindings.find(tv); |
| 163 | + for (const auto &entry : SupertypeOf) { |
| 164 | + auto relatedBindings = inferredBindings.find(entry.first); |
55 | 165 | if (relatedBindings == inferredBindings.end())
|
56 | 166 | continue;
|
57 | 167 |
|
@@ -89,7 +199,7 @@ void ConstraintSystem::PotentialBindings::inferTransitiveBindings(
|
89 | 199 | llvm::copy(bindings.Defaults, std::back_inserter(Defaults));
|
90 | 200 |
|
91 | 201 | // TODO: We shouldn't need this in the future.
|
92 |
| - if (constraint->getKind() != ConstraintKind::Subtype) |
| 202 | + if (entry.second->getKind() != ConstraintKind::Subtype) |
93 | 203 | continue;
|
94 | 204 |
|
95 | 205 | for (auto &binding : bindings.Bindings) {
|
@@ -302,17 +412,16 @@ void ConstraintSystem::PotentialBindings::inferDefaultTypes(
|
302 | 412 |
|
303 | 413 | void ConstraintSystem::PotentialBindings::finalize(
|
304 | 414 | ConstraintSystem &cs,
|
305 |
| - const llvm::SmallDenseMap<TypeVariableType *, |
306 |
| - ConstraintSystem::PotentialBindings> |
| 415 | + llvm::SmallDenseMap<TypeVariableType *, ConstraintSystem::PotentialBindings> |
307 | 416 | &inferredBindings) {
|
308 | 417 | // We need to make sure that there are no duplicate bindings in the
|
309 | 418 | // set, otherwise solver would produce multiple identical solutions.
|
310 | 419 | llvm::SmallPtrSet<CanType, 4> existingTypes;
|
311 | 420 | for (const auto &binding : Bindings)
|
312 | 421 | existingTypes.insert(binding.BindingType->getCanonicalType());
|
313 | 422 |
|
| 423 | + inferTransitiveProtocolRequirements(cs, inferredBindings); |
314 | 424 | inferTransitiveBindings(cs, existingTypes, inferredBindings);
|
315 |
| - |
316 | 425 | inferDefaultTypes(cs, existingTypes);
|
317 | 426 |
|
318 | 427 | // Adjust optionality of existing bindings based on presence of
|
@@ -670,10 +779,6 @@ ConstraintSystem::getPotentialBindingForRelationalConstraint(
|
670 | 779 |
|
671 | 780 | auto *typeVar = result.TypeVar;
|
672 | 781 |
|
673 |
| - // Record constraint which contributes to the |
674 |
| - // finding of potential bindings. |
675 |
| - result.Sources.insert(constraint); |
676 |
| - |
677 | 782 | auto first = simplifyType(constraint->getFirstType());
|
678 | 783 | auto second = simplifyType(constraint->getSecondType());
|
679 | 784 |
|
@@ -790,9 +895,29 @@ ConstraintSystem::getPotentialBindingForRelationalConstraint(
|
790 | 895 | }
|
791 | 896 | }
|
792 | 897 |
|
793 |
| - if (constraint->getKind() == ConstraintKind::Subtype && |
794 |
| - kind == AllowedBindingKind::Subtypes) { |
795 |
| - result.SubtypeOf.insert(bindingTypeVar); |
| 898 | + switch (constraint->getKind()) { |
| 899 | + case ConstraintKind::Subtype: |
| 900 | + case ConstraintKind::Conversion: |
| 901 | + case ConstraintKind::ArgumentConversion: |
| 902 | + case ConstraintKind::OperatorArgumentConversion: { |
| 903 | + if (kind == AllowedBindingKind::Subtypes) { |
| 904 | + result.SubtypeOf.insert({bindingTypeVar, constraint}); |
| 905 | + } else { |
| 906 | + assert(kind == AllowedBindingKind::Supertypes); |
| 907 | + result.SupertypeOf.insert({bindingTypeVar, constraint}); |
| 908 | + } |
| 909 | + break; |
| 910 | + } |
| 911 | + |
| 912 | + case ConstraintKind::Bind: |
| 913 | + case ConstraintKind::BindParam: |
| 914 | + case ConstraintKind::Equal: { |
| 915 | + result.EquivalentTo.insert({bindingTypeVar, constraint}); |
| 916 | + break; |
| 917 | + } |
| 918 | + |
| 919 | + default: |
| 920 | + break; |
796 | 921 | }
|
797 | 922 |
|
798 | 923 | return None;
|
@@ -986,8 +1111,13 @@ bool ConstraintSystem::PotentialBindings::infer(
|
986 | 1111 | break;
|
987 | 1112 |
|
988 | 1113 | case ConstraintKind::ConformsTo:
|
989 |
| - case ConstraintKind::SelfObjectOfProtocol: |
990 |
| - return false; |
| 1114 | + case ConstraintKind::SelfObjectOfProtocol: { |
| 1115 | + auto protocolTy = constraint->getSecondType(); |
| 1116 | + if (!protocolTy->is<ProtocolType>()) |
| 1117 | + return false; |
| 1118 | + |
| 1119 | + LLVM_FALLTHROUGH; |
| 1120 | + } |
991 | 1121 |
|
992 | 1122 | case ConstraintKind::LiteralConformsTo: {
|
993 | 1123 | // Record constraint where protocol requirement originated
|
|
0 commit comments