Skip to content

Commit 9e159de

Browse files
authored
Merge pull request swiftlang#36911 from xedin/refactor-common-operator-identification-logic
[ConstraintSystem] NFC: Factor operator overload identification and partitioning.
2 parents b15f77d + d57c112 commit 9e159de

File tree

3 files changed

+76
-80
lines changed

3 files changed

+76
-80
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 24 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4983,27 +4983,6 @@ class ConstraintSystem {
49834983
return getExpressionTooComplex(solutionMemory);
49844984
}
49854985

4986-
typedef std::function<bool(unsigned index, Constraint *)> ConstraintMatcher;
4987-
typedef std::function<void(ArrayRef<Constraint *>, ConstraintMatcher)>
4988-
ConstraintMatchLoop;
4989-
typedef std::function<void(SmallVectorImpl<unsigned> &options)>
4990-
PartitionAppendCallback;
4991-
4992-
// Partition the choices in the disjunction into groups that we will
4993-
// iterate over in an order appropriate to attempt to stop before we
4994-
// have to visit all of the options.
4995-
void partitionDisjunction(ArrayRef<Constraint *> Choices,
4996-
SmallVectorImpl<unsigned> &Ordering,
4997-
SmallVectorImpl<unsigned> &PartitionBeginning);
4998-
4999-
/// Partition the choices in the range \c first to \c last into groups and
5000-
/// order the groups in the best order to attempt based on the argument
5001-
/// function type that the operator is applied to.
5002-
void partitionGenericOperators(ArrayRef<Constraint *> Choices,
5003-
SmallVectorImpl<unsigned>::iterator first,
5004-
SmallVectorImpl<unsigned>::iterator last,
5005-
ConstraintLocator *locator);
5006-
50074986
// If the given constraint is an applied disjunction, get the argument function
50084987
// that the disjunction is applied to.
50094988
const FunctionType *getAppliedDisjunctionArgumentFunction(const Constraint *disjunction) {
@@ -5328,6 +5307,8 @@ Type isRawRepresentable(ConstraintSystem &cs, Type type,
53285307
Type getDynamicSelfReplacementType(Type baseObjTy, const ValueDecl *member,
53295308
ConstraintLocator *memberLocator);
53305309

5310+
ValueDecl *getOverloadChoiceDecl(Constraint *choice);
5311+
53315312
class DisjunctionChoice {
53325313
ConstraintSystem &CS;
53335314
unsigned Index;
@@ -5353,7 +5334,7 @@ class DisjunctionChoice {
53535334
}
53545335

53555336
bool isUnavailable() const {
5356-
if (auto *decl = getDecl(Choice))
5337+
if (auto *decl = getOverloadChoiceDecl(Choice))
53575338
return CS.isDeclUnavailable(decl, Choice->getLocator());
53585339
return false;
53595340
}
@@ -5381,23 +5362,12 @@ class DisjunctionChoice {
53815362
void propagateConversionInfo(ConstraintSystem &cs) const;
53825363

53835364
static ValueDecl *getOperatorDecl(Constraint *choice) {
5384-
auto *decl = getDecl(choice);
5365+
auto *decl = getOverloadChoiceDecl(choice);
53855366
if (!decl)
53865367
return nullptr;
53875368

53885369
return decl->isOperator() ? decl : nullptr;
53895370
}
5390-
5391-
static ValueDecl *getDecl(Constraint *constraint) {
5392-
if (constraint->getKind() != ConstraintKind::BindOverload)
5393-
return nullptr;
5394-
5395-
auto choice = constraint->getOverloadChoice();
5396-
if (choice.getKind() != OverloadChoiceKind::Decl)
5397-
return nullptr;
5398-
5399-
return choice.getDecl();
5400-
}
54015371
};
54025372

54035373
class TypeVariableBinding {
@@ -5560,7 +5530,7 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
55605530
assert(!disjunction->shouldRememberChoice() || disjunction->getLocator());
55615531

55625532
// Order and partition the disjunction choices.
5563-
CS.partitionDisjunction(Choices, Ordering, PartitionBeginning);
5533+
partitionDisjunction(Ordering, PartitionBeginning);
55645534
}
55655535

55665536
void setNeedsGenericOperatorOrdering(bool flag) {
@@ -5586,16 +5556,27 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
55865556
if (needsGenericOperatorOrdering && choice.isGenericOperator()) {
55875557
unsigned nextPartitionIndex = (PartitionIndex < PartitionBeginning.size() ?
55885558
PartitionBeginning[PartitionIndex] : Ordering.size());
5589-
CS.partitionGenericOperators(Choices,
5590-
Ordering.begin() + currIndex,
5591-
Ordering.begin() + nextPartitionIndex,
5592-
Disjunction->getLocator());
5559+
partitionGenericOperators(Ordering.begin() + currIndex,
5560+
Ordering.begin() + nextPartitionIndex);
55935561
needsGenericOperatorOrdering = false;
55945562
}
55955563

55965564
return DisjunctionChoice(CS, currIndex, Choices[Ordering[currIndex]],
55975565
IsExplicitConversion, isBeginningOfPartition);
55985566
}
5567+
5568+
private:
5569+
// Partition the choices in the disjunction into groups that we will
5570+
// iterate over in an order appropriate to attempt to stop before we
5571+
// have to visit all of the options.
5572+
void partitionDisjunction(SmallVectorImpl<unsigned> &Ordering,
5573+
SmallVectorImpl<unsigned> &PartitionBeginning);
5574+
5575+
/// Partition the choices in the range \c first to \c last into groups and
5576+
/// order the groups in the best order to attempt based on the argument
5577+
/// function type that the operator is applied to.
5578+
void partitionGenericOperators(SmallVectorImpl<unsigned>::iterator first,
5579+
SmallVectorImpl<unsigned>::iterator last);
55995580
};
56005581

56015582
/// Determine whether given type is a known one
@@ -5620,6 +5601,10 @@ void performSyntacticDiagnosticsForTarget(
56205601
/// generic requirement and if so return that type or null type otherwise.
56215602
Type getConcreteReplacementForProtocolSelfType(ValueDecl *member);
56225603

5604+
/// Determine whether given disjunction constraint represents a set
5605+
/// of operator overload choices.
5606+
bool isOperatorDisjunction(Constraint *disjunction);
5607+
56235608
} // end namespace constraints
56245609

56255610
template<typename ...Args>

lib/Sema/CSSolver.cpp

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,18 +1734,6 @@ Constraint *ConstraintSystem::getUnboundBindOverloadDisjunction(
17341734
return result->first;
17351735
}
17361736

1737-
static bool isOperatorBindOverload(Constraint *bindOverload) {
1738-
if (bindOverload->getKind() != ConstraintKind::BindOverload)
1739-
return false;
1740-
1741-
auto choice = bindOverload->getOverloadChoice();
1742-
if (!choice.isDecl())
1743-
return false;
1744-
1745-
auto *funcDecl = dyn_cast<FuncDecl>(choice.getDecl());
1746-
return funcDecl && funcDecl->getOperatorDecl();
1747-
}
1748-
17491737
// Performance hack: if there are two generic overloads, and one is
17501738
// more specialized than the other, prefer the more-specialized one.
17511739
static Constraint *tryOptimizeGenericDisjunction(
@@ -1878,15 +1866,14 @@ static void existingOperatorBindingsForDisjunction(ConstraintSystem &CS,
18781866
}
18791867
}
18801868

1881-
void ConstraintSystem::partitionGenericOperators(ArrayRef<Constraint *> constraints,
1882-
SmallVectorImpl<unsigned>::iterator first,
1883-
SmallVectorImpl<unsigned>::iterator last,
1884-
ConstraintLocator *locator) {
1885-
auto *argFnType = AppliedDisjunctions[locator];
1886-
if (!isOperatorBindOverload(constraints[0]) || !argFnType)
1869+
void DisjunctionChoiceProducer::partitionGenericOperators(
1870+
SmallVectorImpl<unsigned>::iterator first,
1871+
SmallVectorImpl<unsigned>::iterator last) {
1872+
auto *argFnType = CS.getAppliedDisjunctionArgumentFunction(Disjunction);
1873+
if (!isOperatorDisjunction(Disjunction) || !argFnType)
18871874
return;
18881875

1889-
auto operatorName = constraints[0]->getOverloadChoice().getName();
1876+
auto operatorName = Choices[0]->getOverloadChoice().getName();
18901877
if (!operatorName.getBaseIdentifier().isArithmeticOperator())
18911878
return;
18921879

@@ -1899,7 +1886,8 @@ void ConstraintSystem::partitionGenericOperators(ArrayRef<Constraint *> constrai
18991886
if (!nominal)
19001887
return false;
19011888

1902-
auto *protocol = TypeChecker::getProtocol(getASTContext(), SourceLoc(), kind);
1889+
auto *protocol =
1890+
TypeChecker::getProtocol(CS.getASTContext(), SourceLoc(), kind);
19031891

19041892
if (auto *refined = dyn_cast<ProtocolDecl>(nominal))
19051893
return refined->inheritsFrom(protocol);
@@ -1911,7 +1899,7 @@ void ConstraintSystem::partitionGenericOperators(ArrayRef<Constraint *> constrai
19111899
// Gather Numeric and Sequence overloads into separate buckets.
19121900
for (auto iter = first; iter != last; ++iter) {
19131901
unsigned index = *iter;
1914-
auto *decl = constraints[index]->getOverloadChoice().getDecl();
1902+
auto *decl = Choices[index]->getOverloadChoice().getDecl();
19151903
auto *nominal = decl->getDeclContext()->getSelfNominalTypeDecl();
19161904
if (!decl->getInterfaceType()->is<GenericFunctionType>()) {
19171905
concreteOverloads.push_back(index);
@@ -1926,8 +1914,10 @@ void ConstraintSystem::partitionGenericOperators(ArrayRef<Constraint *> constrai
19261914

19271915
auto sortPartition = [&](SmallVectorImpl<unsigned> &partition) {
19281916
llvm::sort(partition, [&](unsigned lhs, unsigned rhs) -> bool {
1929-
auto *declA = dyn_cast<ValueDecl>(constraints[lhs]->getOverloadChoice().getDecl());
1930-
auto *declB = dyn_cast<ValueDecl>(constraints[rhs]->getOverloadChoice().getDecl());
1917+
auto *declA =
1918+
dyn_cast<ValueDecl>(Choices[lhs]->getOverloadChoice().getDecl());
1919+
auto *declB =
1920+
dyn_cast<ValueDecl>(Choices[rhs]->getOverloadChoice().getDecl());
19311921

19321922
return TypeChecker::isDeclRefinementOf(declA, declB);
19331923
});
@@ -1946,19 +1936,22 @@ void ConstraintSystem::partitionGenericOperators(ArrayRef<Constraint *> constrai
19461936
// overload choices first.
19471937
for (auto arg : argFnType->getParams()) {
19481938
auto argType = arg.getPlainType();
1949-
argType = getFixedTypeRecursive(argType, /*wantRValue=*/true);
1939+
argType = CS.getFixedTypeRecursive(argType, /*wantRValue=*/true);
19501940

19511941
if (argType->isTypeVariableOrMember())
19521942
continue;
19531943

1954-
if (conformsToKnownProtocol(DC, argType, KnownProtocolKind::AdditiveArithmetic)) {
1955-
first = std::copy(numericOverloads.begin(), numericOverloads.end(), first);
1944+
if (conformsToKnownProtocol(CS.DC, argType,
1945+
KnownProtocolKind::AdditiveArithmetic)) {
1946+
first =
1947+
std::copy(numericOverloads.begin(), numericOverloads.end(), first);
19561948
numericOverloads.clear();
19571949
break;
19581950
}
19591951

1960-
if (conformsToKnownProtocol(DC, argType, KnownProtocolKind::Sequence)) {
1961-
first = std::copy(sequenceOverloads.begin(), sequenceOverloads.end(), first);
1952+
if (conformsToKnownProtocol(CS.DC, argType, KnownProtocolKind::Sequence)) {
1953+
first =
1954+
std::copy(sequenceOverloads.begin(), sequenceOverloads.end(), first);
19621955
sequenceOverloads.clear();
19631956
break;
19641957
}
@@ -1969,17 +1962,23 @@ void ConstraintSystem::partitionGenericOperators(ArrayRef<Constraint *> constrai
19691962
first = std::copy(sequenceOverloads.begin(), sequenceOverloads.end(), first);
19701963
}
19711964

1972-
void ConstraintSystem::partitionDisjunction(
1973-
ArrayRef<Constraint *> Choices, SmallVectorImpl<unsigned> &Ordering,
1965+
void DisjunctionChoiceProducer::partitionDisjunction(
1966+
SmallVectorImpl<unsigned> &Ordering,
19741967
SmallVectorImpl<unsigned> &PartitionBeginning) {
19751968
// Apply a special-case rule for favoring one generic function over
19761969
// another.
1977-
if (auto favored = tryOptimizeGenericDisjunction(DC, Choices)) {
1978-
favorConstraint(favored);
1970+
if (auto favored = tryOptimizeGenericDisjunction(CS.DC, Choices)) {
1971+
CS.favorConstraint(favored);
19791972
}
19801973

19811974
SmallSet<Constraint *, 16> taken;
19821975

1976+
using ConstraintMatcher = std::function<bool(unsigned index, Constraint *)>;
1977+
using ConstraintMatchLoop =
1978+
std::function<void(ArrayRef<Constraint *>, ConstraintMatcher)>;
1979+
using PartitionAppendCallback =
1980+
std::function<void(SmallVectorImpl<unsigned> & options)>;
1981+
19831982
// Local function used to iterate over the untaken choices from the
19841983
// disjunction and use a higher-order function to determine if they
19851984
// should be part of a partition.
@@ -2006,7 +2005,7 @@ void ConstraintSystem::partitionDisjunction(
20062005

20072006
// Add existing operator bindings to the main partition first. This often
20082007
// helps the solver find a solution fast.
2009-
existingOperatorBindingsForDisjunction(*this, Choices, everythingElse);
2008+
existingOperatorBindingsForDisjunction(CS, Choices, everythingElse);
20102009
for (auto index : everythingElse)
20112010
taken.insert(Choices[index]);
20122011

@@ -2026,7 +2025,7 @@ void ConstraintSystem::partitionDisjunction(
20262025
});
20272026

20282027
// Then unavailable constraints if we're skipping them.
2029-
if (!shouldAttemptFixes()) {
2028+
if (!CS.shouldAttemptFixes()) {
20302029
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
20312030
if (constraint->getKind() != ConstraintKind::BindOverload)
20322031
return false;
@@ -2036,7 +2035,7 @@ void ConstraintSystem::partitionDisjunction(
20362035
if (!funcDecl)
20372036
return false;
20382037

2039-
if (!isDeclUnavailable(funcDecl, constraint->getLocator()))
2038+
if (!CS.isDeclUnavailable(funcDecl, constraint->getLocator()))
20402039
return false;
20412040

20422041
unavailable.push_back(index);
@@ -2045,11 +2044,8 @@ void ConstraintSystem::partitionDisjunction(
20452044
}
20462045

20472046
// Partition SIMD operators.
2048-
if (isOperatorBindOverload(Choices[0])) {
2047+
if (isOperatorDisjunction(Disjunction)) {
20492048
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
2050-
if (!isOperatorBindOverload(constraint))
2051-
return false;
2052-
20532049
if (isSIMDOperator(constraint->getOverloadChoice().getDecl())) {
20542050
simdOperators.push_back(index);
20552051
return true;
@@ -2103,8 +2099,7 @@ Constraint *ConstraintSystem::selectDisjunction() {
21032099
unsigned firstFavored = first->countFavoredNestedConstraints();
21042100
unsigned secondFavored = second->countFavoredNestedConstraints();
21052101

2106-
if (!isOperatorBindOverload(first->getNestedConstraints().front()) ||
2107-
!isOperatorBindOverload(second->getNestedConstraints().front()))
2102+
if (!isOperatorDisjunction(first) || !isOperatorDisjunction(second))
21082103
return firstActive < secondActive;
21092104

21102105
if (firstFavored == secondFavored) {

lib/Sema/ConstraintSystem.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5689,3 +5689,19 @@ TypeVarBindingProducer::getDefaultBinding(Constraint *constraint) const {
56895689
? binding.withType(OptionalType::get(type))
56905690
: binding;
56915691
}
5692+
5693+
ValueDecl *constraints::getOverloadChoiceDecl(Constraint *choice) {
5694+
if (choice->getKind() != ConstraintKind::BindOverload)
5695+
return nullptr;
5696+
return choice->getOverloadChoice().getDeclOrNull();
5697+
}
5698+
5699+
bool constraints::isOperatorDisjunction(Constraint *disjunction) {
5700+
assert(disjunction->getKind() == ConstraintKind::Disjunction);
5701+
5702+
auto choices = disjunction->getNestedConstraints();
5703+
assert(!choices.empty());
5704+
5705+
auto *decl = getOverloadChoiceDecl(choices.front());
5706+
return decl ? decl->isOperator() : false;
5707+
}

0 commit comments

Comments
 (0)