@@ -2077,9 +2077,109 @@ static void existingOperatorBindingsForDisjunction(ConstraintSystem &CS,
2077
2077
}
2078
2078
}
2079
2079
2080
+ // / Populates the \c found vector with the indices of the given constraints
2081
+ // / that are arithmetic operator choices in the best order to attempt based on
2082
+ // / the argument function type that the operator is applied to.
2083
+ static void takeArithmeticOperators (ConstraintSystem &CS, const FunctionType *argFnType,
2084
+ ArrayRef<Constraint *> constraints,
2085
+ SmallVectorImpl<unsigned > &found,
2086
+ ConstraintSystem::ConstraintMatchLoop forEachChoice) {
2087
+ auto operatorName = constraints[0 ]->getOverloadChoice ().getName ();
2088
+ if (!operatorName.getBaseIdentifier ().isArithmeticOperator ())
2089
+ return ;
2090
+
2091
+ SmallVector<unsigned , 4 > numericOverloads;
2092
+ SmallVector<unsigned , 4 > sequenceOverloads;
2093
+ SmallVector<unsigned , 4 > simdOverloads;
2094
+ SmallVector<unsigned , 4 > otherGenericOverloads;
2095
+
2096
+ auto refinesOrConformsTo = [&](NominalTypeDecl *nominal, KnownProtocolKind kind) -> bool {
2097
+ if (!nominal)
2098
+ return false ;
2099
+
2100
+ auto *protocol = TypeChecker::getProtocol (CS.getASTContext (), SourceLoc (), kind);
2101
+
2102
+ if (auto *refined = dyn_cast<ProtocolDecl>(nominal))
2103
+ return refined->inheritsFrom (protocol);
2104
+
2105
+ return (bool )TypeChecker::conformsToProtocol (nominal->getDeclaredType (), protocol,
2106
+ nominal->getDeclContext ());
2107
+ };
2108
+
2109
+ // Gather concrete, Numeric, Sequence, and SIMD overloads into separate buckets.
2110
+ forEachChoice (constraints, [&](unsigned index, Constraint *constraint) -> bool {
2111
+ auto *decl = constraint->getOverloadChoice ().getDecl ();
2112
+
2113
+ if (isSIMDOperator (decl)) {
2114
+ simdOverloads.push_back (index);
2115
+ return true ;
2116
+ }
2117
+
2118
+ auto *nominal = decl->getDeclContext ()->getSelfNominalTypeDecl ();
2119
+ if (!decl->getInterfaceType ()->is <GenericFunctionType>()) {
2120
+ // Concrete overloads should always be attempted first.
2121
+ found.push_back (index);
2122
+ } else if (refinesOrConformsTo (nominal, KnownProtocolKind::AdditiveArithmetic)) {
2123
+ numericOverloads.push_back (index);
2124
+ } else if (refinesOrConformsTo (nominal, KnownProtocolKind::Sequence)) {
2125
+ sequenceOverloads.push_back (index);
2126
+ } else {
2127
+ otherGenericOverloads.push_back (index);
2128
+ }
2129
+
2130
+ return true ;
2131
+ });
2132
+
2133
+ auto sortPartition = [&](SmallVectorImpl<unsigned > &partition) {
2134
+ llvm::sort (partition, [&](unsigned lhs, unsigned rhs) -> bool {
2135
+ auto *declA = dyn_cast<ValueDecl>(constraints[lhs]->getOverloadChoice ().getDecl ());
2136
+ auto *declB = dyn_cast<ValueDecl>(constraints[rhs]->getOverloadChoice ().getDecl ());
2137
+
2138
+ return TypeChecker::isDeclRefinementOf (declA, declB);
2139
+ });
2140
+ };
2141
+
2142
+ // Sort sequence overloads so that refinements are attempted first.
2143
+ // If the solver finds a solution with an overload, it can then skip
2144
+ // subsequent choices that the successful choice is a refinement of.
2145
+ sortPartition (sequenceOverloads);
2146
+
2147
+ // Check if any of the known argument types conform to one of the standard
2148
+ // arithmetic protocols. If so, the sovler should attempt the corresponding
2149
+ // overload choices first.
2150
+ for (auto arg : argFnType->getParams ()) {
2151
+ auto argType = arg.getPlainType ();
2152
+ if (!argType || argType->hasTypeVariable ())
2153
+ continue ;
2154
+
2155
+ if (conformsToKnownProtocol (CS.DC , argType, KnownProtocolKind::AdditiveArithmetic)) {
2156
+ found.append (numericOverloads.begin (), numericOverloads.end ());
2157
+ numericOverloads.clear ();
2158
+ break ;
2159
+ }
2160
+
2161
+ if (conformsToKnownProtocol (CS.DC , argType, KnownProtocolKind::Sequence)) {
2162
+ found.append (sequenceOverloads.begin (), sequenceOverloads.end ());
2163
+ sequenceOverloads.clear ();
2164
+ break ;
2165
+ }
2166
+
2167
+ if (conformsToKnownProtocol (CS.DC , argType, KnownProtocolKind::SIMD)) {
2168
+ found.append (simdOverloads.begin (), simdOverloads.end ());
2169
+ simdOverloads.clear ();
2170
+ break ;
2171
+ }
2172
+ }
2173
+
2174
+ found.append (otherGenericOverloads.begin (), otherGenericOverloads.end ());
2175
+ found.append (numericOverloads.begin (), numericOverloads.end ());
2176
+ found.append (sequenceOverloads.begin (), sequenceOverloads.end ());
2177
+ found.append (simdOverloads.begin (), simdOverloads.end ());
2178
+ }
2179
+
2080
2180
void ConstraintSystem::partitionDisjunction (
2081
2181
ArrayRef<Constraint *> Choices, SmallVectorImpl<unsigned > &Ordering,
2082
- SmallVectorImpl<unsigned > &PartitionBeginning) {
2182
+ SmallVectorImpl<unsigned > &PartitionBeginning, ConstraintLocator *locator ) {
2083
2183
// Apply a special-case rule for favoring one generic function over
2084
2184
// another.
2085
2185
if (auto favored = tryOptimizeGenericDisjunction (DC, Choices)) {
@@ -2152,8 +2252,11 @@ void ConstraintSystem::partitionDisjunction(
2152
2252
});
2153
2253
}
2154
2254
2155
- // Partition SIMD operators.
2255
+ // Gather arithmetic and SIMD operators.
2156
2256
if (isOperatorBindOverload (Choices[0 ])) {
2257
+ if (auto *argFnType = AppliedDisjunctions[locator])
2258
+ takeArithmeticOperators (*this , argFnType, Choices, everythingElse, forEachChoice);
2259
+
2157
2260
forEachChoice (Choices, [&](unsigned index, Constraint *constraint) -> bool {
2158
2261
if (!isOperatorBindOverload (constraint))
2159
2262
return false ;
@@ -2167,6 +2270,12 @@ void ConstraintSystem::partitionDisjunction(
2167
2270
});
2168
2271
}
2169
2272
2273
+ // Gather the remaining options.
2274
+ forEachChoice (Choices, [&](unsigned index, Constraint *constraint) -> bool {
2275
+ everythingElse.push_back (index);
2276
+ return true ;
2277
+ });
2278
+
2170
2279
// Local function to create the next partition based on the options
2171
2280
// passed in.
2172
2281
PartitionAppendCallback appendPartition =
@@ -2177,35 +2286,6 @@ void ConstraintSystem::partitionDisjunction(
2177
2286
}
2178
2287
};
2179
2288
2180
- // Gather the remaining options.
2181
-
2182
- SmallVector<unsigned , 4 > genericOverloads;
2183
-
2184
- forEachChoice (Choices, [&](unsigned index, Constraint *constraint) -> bool {
2185
- if (!isForCodeCompletion () && isOperatorBindOverload (constraint)) {
2186
- // Collect generic overload choices separately, and sort these choices
2187
- // by specificity in order to try the most specific choice first.
2188
- auto *decl = constraint->getOverloadChoice ().getDecl ();
2189
- auto *fnDecl = dyn_cast<AbstractFunctionDecl>(decl);
2190
- if (fnDecl && fnDecl->isGeneric ()) {
2191
- genericOverloads.push_back (index);
2192
- return true ;
2193
- }
2194
- }
2195
-
2196
- everythingElse.push_back (index);
2197
- return true ;
2198
- });
2199
-
2200
- llvm::sort (genericOverloads, [&](unsigned lhs, unsigned rhs) -> bool {
2201
- auto *declA = dyn_cast<ValueDecl>(Choices[lhs]->getOverloadChoice ().getDecl ());
2202
- auto *declB = dyn_cast<ValueDecl>(Choices[rhs]->getOverloadChoice ().getDecl ());
2203
-
2204
- return TypeChecker::isDeclRefinementOf (declA, declB);
2205
- });
2206
-
2207
- everythingElse.append (genericOverloads.begin (), genericOverloads.end ());
2208
-
2209
2289
appendPartition (favored);
2210
2290
appendPartition (everythingElse);
2211
2291
appendPartition (simdOperators);
0 commit comments