@@ -465,6 +465,7 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs)
465
465
numFixes = cs.Fixes .size ();
466
466
numFixedRequirements = cs.FixedRequirements .size ();
467
467
numDisjunctionChoices = cs.DisjunctionChoices .size ();
468
+ numAppliedDisjunctions = cs.AppliedDisjunctions .size ();
468
469
numTrailingClosureMatchingChoices = cs.trailingClosureMatchingChoices .size ();
469
470
numOpenedTypes = cs.OpenedTypes .size ();
470
471
numOpenedExistentialTypes = cs.OpenedExistentialTypes .size ();
@@ -519,6 +520,9 @@ ConstraintSystem::SolverScope::~SolverScope() {
519
520
// Remove any disjunction choices.
520
521
truncate (cs.DisjunctionChoices , numDisjunctionChoices);
521
522
523
+ // Remove any applied disjunctions.
524
+ truncate (cs.AppliedDisjunctions , numAppliedDisjunctions);
525
+
522
526
// Remove any trailing closure matching choices;
523
527
truncate (
524
528
cs.trailingClosureMatchingChoices , numTrailingClosureMatchingChoices);
@@ -2063,6 +2067,95 @@ static void existingOperatorBindingsForDisjunction(ConstraintSystem &CS,
2063
2067
}
2064
2068
}
2065
2069
2070
+ void ConstraintSystem::partitionGenericOperators (ArrayRef<Constraint *> constraints,
2071
+ SmallVectorImpl<unsigned >::iterator first,
2072
+ SmallVectorImpl<unsigned >::iterator last,
2073
+ ConstraintLocator *locator) {
2074
+ auto *argFnType = AppliedDisjunctions[locator];
2075
+ if (!isOperatorBindOverload (constraints[0 ]) || !argFnType)
2076
+ return ;
2077
+
2078
+ auto operatorName = constraints[0 ]->getOverloadChoice ().getName ();
2079
+ if (!operatorName.getBaseIdentifier ().isArithmeticOperator ())
2080
+ return ;
2081
+
2082
+ SmallVector<unsigned , 4 > concreteOverloads;
2083
+ SmallVector<unsigned , 4 > numericOverloads;
2084
+ SmallVector<unsigned , 4 > sequenceOverloads;
2085
+ SmallVector<unsigned , 4 > otherGenericOverloads;
2086
+
2087
+ auto refinesOrConformsTo = [&](NominalTypeDecl *nominal, KnownProtocolKind kind) -> bool {
2088
+ if (!nominal)
2089
+ return false ;
2090
+
2091
+ auto *protocol = TypeChecker::getProtocol (getASTContext (), SourceLoc (), kind);
2092
+
2093
+ if (auto *refined = dyn_cast<ProtocolDecl>(nominal))
2094
+ return refined->inheritsFrom (protocol);
2095
+
2096
+ return (bool )TypeChecker::conformsToProtocol (nominal->getDeclaredType (), protocol,
2097
+ nominal->getDeclContext ());
2098
+ };
2099
+
2100
+ // Gather Numeric and Sequence overloads into separate buckets.
2101
+ for (auto iter = first; iter != last; ++iter) {
2102
+ unsigned index = *iter;
2103
+ auto *decl = constraints[index]->getOverloadChoice ().getDecl ();
2104
+ auto *nominal = decl->getDeclContext ()->getSelfNominalTypeDecl ();
2105
+ if (!decl->getInterfaceType ()->is <GenericFunctionType>()) {
2106
+ concreteOverloads.push_back (index);
2107
+ } else if (refinesOrConformsTo (nominal, KnownProtocolKind::AdditiveArithmetic)) {
2108
+ numericOverloads.push_back (index);
2109
+ } else if (refinesOrConformsTo (nominal, KnownProtocolKind::Sequence)) {
2110
+ sequenceOverloads.push_back (index);
2111
+ } else {
2112
+ otherGenericOverloads.push_back (index);
2113
+ }
2114
+ }
2115
+
2116
+ auto sortPartition = [&](SmallVectorImpl<unsigned > &partition) {
2117
+ llvm::sort (partition, [&](unsigned lhs, unsigned rhs) -> bool {
2118
+ auto *declA = dyn_cast<ValueDecl>(constraints[lhs]->getOverloadChoice ().getDecl ());
2119
+ auto *declB = dyn_cast<ValueDecl>(constraints[rhs]->getOverloadChoice ().getDecl ());
2120
+
2121
+ return TypeChecker::isDeclRefinementOf (declA, declB);
2122
+ });
2123
+ };
2124
+
2125
+ // Sort sequence overloads so that refinements are attempted first.
2126
+ // If the solver finds a solution with an overload, it can then skip
2127
+ // subsequent choices that the successful choice is a refinement of.
2128
+ sortPartition (sequenceOverloads);
2129
+
2130
+ // Attempt concrete overloads first.
2131
+ first = std::copy (concreteOverloads.begin (), concreteOverloads.end (), first);
2132
+
2133
+ // Check if any of the known argument types conform to one of the standard
2134
+ // arithmetic protocols. If so, the sovler should attempt the corresponding
2135
+ // overload choices first.
2136
+ for (auto arg : argFnType->getParams ()) {
2137
+ auto argType = arg.getPlainType ();
2138
+ if (!argType || argType->hasTypeVariable ())
2139
+ continue ;
2140
+
2141
+ if (conformsToKnownProtocol (DC, argType, KnownProtocolKind::AdditiveArithmetic)) {
2142
+ first = std::copy (numericOverloads.begin (), numericOverloads.end (), first);
2143
+ numericOverloads.clear ();
2144
+ break ;
2145
+ }
2146
+
2147
+ if (conformsToKnownProtocol (DC, argType, KnownProtocolKind::Sequence)) {
2148
+ first = std::copy (sequenceOverloads.begin (), sequenceOverloads.end (), first);
2149
+ sequenceOverloads.clear ();
2150
+ break ;
2151
+ }
2152
+ }
2153
+
2154
+ first = std::copy (otherGenericOverloads.begin (), otherGenericOverloads.end (), first);
2155
+ first = std::copy (numericOverloads.begin (), numericOverloads.end (), first);
2156
+ first = std::copy (sequenceOverloads.begin (), sequenceOverloads.end (), first);
2157
+ }
2158
+
2066
2159
void ConstraintSystem::partitionDisjunction (
2067
2160
ArrayRef<Constraint *> Choices, SmallVectorImpl<unsigned > &Ordering,
2068
2161
SmallVectorImpl<unsigned > &PartitionBeginning) {
@@ -2153,6 +2246,12 @@ void ConstraintSystem::partitionDisjunction(
2153
2246
});
2154
2247
}
2155
2248
2249
+ // Gather the remaining options.
2250
+ forEachChoice (Choices, [&](unsigned index, Constraint *constraint) -> bool {
2251
+ everythingElse.push_back (index);
2252
+ return true ;
2253
+ });
2254
+
2156
2255
// Local function to create the next partition based on the options
2157
2256
// passed in.
2158
2257
PartitionAppendCallback appendPartition =
@@ -2163,16 +2262,9 @@ void ConstraintSystem::partitionDisjunction(
2163
2262
}
2164
2263
};
2165
2264
2166
- // Gather the remaining options.
2167
- forEachChoice (Choices, [&](unsigned index, Constraint *constraint) -> bool {
2168
- everythingElse.push_back (index);
2169
- return true ;
2170
- });
2171
2265
appendPartition (favored);
2172
2266
appendPartition (everythingElse);
2173
2267
appendPartition (simdOperators);
2174
-
2175
- // Now create the remaining partitions from what we previously collected.
2176
2268
appendPartition (unavailable);
2177
2269
appendPartition (disabled);
2178
2270
0 commit comments