Skip to content

Commit daec9c9

Browse files
committed
[ConstraintSystem] Only attempt the refinement overload heuristic
for arithmetic operators. Only sort overloads that are related, e.g. Sequence overloads. Further, choose which generic overloads to attempt first based on whether any known argument types conform to one of the standard arithmetic protocols.
1 parent e24ac86 commit daec9c9

File tree

7 files changed

+125
-41
lines changed

7 files changed

+125
-41
lines changed

include/swift/AST/Identifier.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ class Identifier {
110110
return isOperatorSlow();
111111
}
112112

113+
bool isArithmeticOperator() const {
114+
return is("+") || is("-") || is("*") || is("/") || is("%");
115+
}
116+
113117
// Returns whether this is a standard comparison operator,
114118
// such as '==', '>=' or '!=='.
115119
bool isStandardComparisonOperator() const {

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ PROTOCOL(Error)
6969
PROTOCOL_(ErrorCodeProtocol)
7070
PROTOCOL(OptionSet)
7171
PROTOCOL(CaseIterable)
72+
PROTOCOL(SIMD)
7273
PROTOCOL(SIMDScalar)
7374
PROTOCOL(BinaryInteger)
7475

include/swift/Sema/ConstraintSystem.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5464,7 +5464,8 @@ class ConstraintSystem {
54645464
// have to visit all of the options.
54655465
void partitionDisjunction(ArrayRef<Constraint *> Choices,
54665466
SmallVectorImpl<unsigned> &Ordering,
5467-
SmallVectorImpl<unsigned> &PartitionBeginning);
5467+
SmallVectorImpl<unsigned> &PartitionBeginning,
5468+
ConstraintLocator *locator);
54685469

54695470
// If the given constraint is an applied disjunction, get the argument function
54705471
// that the disjunction is applied to.
@@ -5989,7 +5990,7 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
59895990
assert(!disjunction->shouldRememberChoice() || disjunction->getLocator());
59905991

59915992
// Order and partition the disjunction choices.
5992-
CS.partitionDisjunction(Choices, Ordering, PartitionBeginning);
5993+
CS.partitionDisjunction(Choices, Ordering, PartitionBeginning, disjunction->getLocator());
59935994
}
59945995

59955996
DisjunctionChoiceProducer(ConstraintSystem &cs,
@@ -5999,7 +6000,7 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
59996000
IsExplicitConversion(explicitConversion) {
60006001

60016002
// Order and partition the disjunction choices.
6002-
CS.partitionDisjunction(Choices, Ordering, PartitionBeginning);
6003+
CS.partitionDisjunction(Choices, Ordering, PartitionBeginning, locator);
60036004
}
60046005

60056006
Optional<Element> operator()() override {

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5123,6 +5123,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
51235123
case KnownProtocolKind::Differentiable:
51245124
case KnownProtocolKind::FloatingPoint:
51255125
case KnownProtocolKind::Actor:
5126+
case KnownProtocolKind::SIMD:
51265127
return SpecialProtocol::None;
51275128
}
51285129

lib/Sema/CSGen.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,7 @@ using namespace swift;
3838
using namespace swift::constraints;
3939

4040
static bool isArithmeticOperatorDecl(ValueDecl *vd) {
41-
return vd &&
42-
(vd->getBaseName() == "+" ||
43-
vd->getBaseName() == "-" ||
44-
vd->getBaseName() == "*" ||
45-
vd->getBaseName() == "/" ||
46-
vd->getBaseName() == "%");
41+
return vd && vd->getBaseIdentifier().isArithmeticOperator();
4742
}
4843

4944
static bool mergeRepresentativeEquivalenceClasses(ConstraintSystem &CS,

lib/Sema/CSSolver.cpp

Lines changed: 111 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2077,9 +2077,109 @@ static void existingOperatorBindingsForDisjunction(ConstraintSystem &CS,
20772077
}
20782078
}
20792079

2080+
/// Populates the \c found vector with the indices of the given constraints
2081+
/// that are arithmetic operator choices in the best order to attempt based on
2082+
/// the argument function type that the operator is applied to.
2083+
static void takeArithmeticOperators(ConstraintSystem &CS, const FunctionType *argFnType,
2084+
ArrayRef<Constraint *> constraints,
2085+
SmallVectorImpl<unsigned> &found,
2086+
ConstraintSystem::ConstraintMatchLoop forEachChoice) {
2087+
auto operatorName = constraints[0]->getOverloadChoice().getName();
2088+
if (!operatorName.getBaseIdentifier().isArithmeticOperator())
2089+
return;
2090+
2091+
SmallVector<unsigned, 4> numericOverloads;
2092+
SmallVector<unsigned, 4> sequenceOverloads;
2093+
SmallVector<unsigned, 4> simdOverloads;
2094+
SmallVector<unsigned, 4> otherGenericOverloads;
2095+
2096+
auto refinesOrConformsTo = [&](NominalTypeDecl *nominal, KnownProtocolKind kind) -> bool {
2097+
if (!nominal)
2098+
return false;
2099+
2100+
auto *protocol = TypeChecker::getProtocol(CS.getASTContext(), SourceLoc(), kind);
2101+
2102+
if (auto *refined = dyn_cast<ProtocolDecl>(nominal))
2103+
return refined->inheritsFrom(protocol);
2104+
2105+
return (bool)TypeChecker::conformsToProtocol(nominal->getDeclaredType(), protocol,
2106+
nominal->getDeclContext());
2107+
};
2108+
2109+
// Gather concrete, Numeric, Sequence, and SIMD overloads into separate buckets.
2110+
forEachChoice(constraints, [&](unsigned index, Constraint *constraint) -> bool {
2111+
auto *decl = constraint->getOverloadChoice().getDecl();
2112+
2113+
if (isSIMDOperator(decl)) {
2114+
simdOverloads.push_back(index);
2115+
return true;
2116+
}
2117+
2118+
auto *nominal = decl->getDeclContext()->getSelfNominalTypeDecl();
2119+
if (!decl->getInterfaceType()->is<GenericFunctionType>()) {
2120+
// Concrete overloads should always be attempted first.
2121+
found.push_back(index);
2122+
} else if (refinesOrConformsTo(nominal, KnownProtocolKind::AdditiveArithmetic)) {
2123+
numericOverloads.push_back(index);
2124+
} else if (refinesOrConformsTo(nominal, KnownProtocolKind::Sequence)) {
2125+
sequenceOverloads.push_back(index);
2126+
} else {
2127+
otherGenericOverloads.push_back(index);
2128+
}
2129+
2130+
return true;
2131+
});
2132+
2133+
auto sortPartition = [&](SmallVectorImpl<unsigned> &partition) {
2134+
llvm::sort(partition, [&](unsigned lhs, unsigned rhs) -> bool {
2135+
auto *declA = dyn_cast<ValueDecl>(constraints[lhs]->getOverloadChoice().getDecl());
2136+
auto *declB = dyn_cast<ValueDecl>(constraints[rhs]->getOverloadChoice().getDecl());
2137+
2138+
return TypeChecker::isDeclRefinementOf(declA, declB);
2139+
});
2140+
};
2141+
2142+
// Sort sequence overloads so that refinements are attempted first.
2143+
// If the solver finds a solution with an overload, it can then skip
2144+
// subsequent choices that the successful choice is a refinement of.
2145+
sortPartition(sequenceOverloads);
2146+
2147+
// Check if any of the known argument types conform to one of the standard
2148+
// arithmetic protocols. If so, the sovler should attempt the corresponding
2149+
// overload choices first.
2150+
for (auto arg : argFnType->getParams()) {
2151+
auto argType = arg.getPlainType();
2152+
if (!argType || argType->hasTypeVariable())
2153+
continue;
2154+
2155+
if (conformsToKnownProtocol(CS.DC, argType, KnownProtocolKind::AdditiveArithmetic)) {
2156+
found.append(numericOverloads.begin(), numericOverloads.end());
2157+
numericOverloads.clear();
2158+
break;
2159+
}
2160+
2161+
if (conformsToKnownProtocol(CS.DC, argType, KnownProtocolKind::Sequence)) {
2162+
found.append(sequenceOverloads.begin(), sequenceOverloads.end());
2163+
sequenceOverloads.clear();
2164+
break;
2165+
}
2166+
2167+
if (conformsToKnownProtocol(CS.DC, argType, KnownProtocolKind::SIMD)) {
2168+
found.append(simdOverloads.begin(), simdOverloads.end());
2169+
simdOverloads.clear();
2170+
break;
2171+
}
2172+
}
2173+
2174+
found.append(otherGenericOverloads.begin(), otherGenericOverloads.end());
2175+
found.append(numericOverloads.begin(), numericOverloads.end());
2176+
found.append(sequenceOverloads.begin(), sequenceOverloads.end());
2177+
found.append(simdOverloads.begin(), simdOverloads.end());
2178+
}
2179+
20802180
void ConstraintSystem::partitionDisjunction(
20812181
ArrayRef<Constraint *> Choices, SmallVectorImpl<unsigned> &Ordering,
2082-
SmallVectorImpl<unsigned> &PartitionBeginning) {
2182+
SmallVectorImpl<unsigned> &PartitionBeginning, ConstraintLocator *locator) {
20832183
// Apply a special-case rule for favoring one generic function over
20842184
// another.
20852185
if (auto favored = tryOptimizeGenericDisjunction(DC, Choices)) {
@@ -2152,8 +2252,11 @@ void ConstraintSystem::partitionDisjunction(
21522252
});
21532253
}
21542254

2155-
// Partition SIMD operators.
2255+
// Gather arithmetic and SIMD operators.
21562256
if (isOperatorBindOverload(Choices[0])) {
2257+
if (auto *argFnType = AppliedDisjunctions[locator])
2258+
takeArithmeticOperators(*this, argFnType, Choices, everythingElse, forEachChoice);
2259+
21572260
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
21582261
if (!isOperatorBindOverload(constraint))
21592262
return false;
@@ -2167,6 +2270,12 @@ void ConstraintSystem::partitionDisjunction(
21672270
});
21682271
}
21692272

2273+
// Gather the remaining options.
2274+
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
2275+
everythingElse.push_back(index);
2276+
return true;
2277+
});
2278+
21702279
// Local function to create the next partition based on the options
21712280
// passed in.
21722281
PartitionAppendCallback appendPartition =
@@ -2177,35 +2286,6 @@ void ConstraintSystem::partitionDisjunction(
21772286
}
21782287
};
21792288

2180-
// Gather the remaining options.
2181-
2182-
SmallVector<unsigned, 4> genericOverloads;
2183-
2184-
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
2185-
if (!isForCodeCompletion() && isOperatorBindOverload(constraint)) {
2186-
// Collect generic overload choices separately, and sort these choices
2187-
// by specificity in order to try the most specific choice first.
2188-
auto *decl = constraint->getOverloadChoice().getDecl();
2189-
auto *fnDecl = dyn_cast<AbstractFunctionDecl>(decl);
2190-
if (fnDecl && fnDecl->isGeneric()) {
2191-
genericOverloads.push_back(index);
2192-
return true;
2193-
}
2194-
}
2195-
2196-
everythingElse.push_back(index);
2197-
return true;
2198-
});
2199-
2200-
llvm::sort(genericOverloads, [&](unsigned lhs, unsigned rhs) -> bool {
2201-
auto *declA = dyn_cast<ValueDecl>(Choices[lhs]->getOverloadChoice().getDecl());
2202-
auto *declB = dyn_cast<ValueDecl>(Choices[rhs]->getOverloadChoice().getDecl());
2203-
2204-
return TypeChecker::isDeclRefinementOf(declA, declB);
2205-
});
2206-
2207-
everythingElse.append(genericOverloads.begin(), genericOverloads.end());
2208-
22092289
appendPartition(favored);
22102290
appendPartition(everythingElse);
22112291
appendPartition(simdOperators);

lib/Sema/CSStep.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,8 +624,10 @@ bool DisjunctionStep::shouldSkip(const DisjunctionChoice &choice) const {
624624
auto *declA = LastSolvedChoice->first->getOverloadChoice().getDecl();
625625
auto *declB = static_cast<Constraint *>(choice)->getOverloadChoice().getDecl();
626626

627-
if (TypeChecker::isDeclRefinementOf(declA, declB))
627+
if (declA->getBaseIdentifier().isArithmeticOperator() &&
628+
TypeChecker::isDeclRefinementOf(declA, declB)) {
628629
return skip("subtype");
630+
}
629631
}
630632

631633
// If the solver already found a solution with a choice that did not

0 commit comments

Comments
 (0)