Skip to content

Commit fef108f

Browse files
committed
[ConstraintSystem] NFC: Move partitioning methods to DisjunctionChoiceProducer
1 parent c9a289e commit fef108f

File tree

2 files changed

+50
-49
lines changed

2 files changed

+50
-49
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4983,27 +4983,6 @@ class ConstraintSystem {
49834983
return getExpressionTooComplex(solutionMemory);
49844984
}
49854985

4986-
typedef std::function<bool(unsigned index, Constraint *)> ConstraintMatcher;
4987-
typedef std::function<void(ArrayRef<Constraint *>, ConstraintMatcher)>
4988-
ConstraintMatchLoop;
4989-
typedef std::function<void(SmallVectorImpl<unsigned> &options)>
4990-
PartitionAppendCallback;
4991-
4992-
// Partition the choices in the disjunction into groups that we will
4993-
// iterate over in an order appropriate to attempt to stop before we
4994-
// have to visit all of the options.
4995-
void partitionDisjunction(ArrayRef<Constraint *> Choices,
4996-
SmallVectorImpl<unsigned> &Ordering,
4997-
SmallVectorImpl<unsigned> &PartitionBeginning);
4998-
4999-
/// Partition the choices in the range \c first to \c last into groups and
5000-
/// order the groups in the best order to attempt based on the argument
5001-
/// function type that the operator is applied to.
5002-
void partitionGenericOperators(ArrayRef<Constraint *> Choices,
5003-
SmallVectorImpl<unsigned>::iterator first,
5004-
SmallVectorImpl<unsigned>::iterator last,
5005-
ConstraintLocator *locator);
5006-
50074986
// If the given constraint is an applied disjunction, get the argument function
50084987
// that the disjunction is applied to.
50094988
const FunctionType *getAppliedDisjunctionArgumentFunction(const Constraint *disjunction) {
@@ -5560,7 +5539,7 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
55605539
assert(!disjunction->shouldRememberChoice() || disjunction->getLocator());
55615540

55625541
// Order and partition the disjunction choices.
5563-
CS.partitionDisjunction(Choices, Ordering, PartitionBeginning);
5542+
partitionDisjunction(Ordering, PartitionBeginning);
55645543
}
55655544

55665545
void setNeedsGenericOperatorOrdering(bool flag) {
@@ -5586,16 +5565,27 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
55865565
if (needsGenericOperatorOrdering && choice.isGenericOperator()) {
55875566
unsigned nextPartitionIndex = (PartitionIndex < PartitionBeginning.size() ?
55885567
PartitionBeginning[PartitionIndex] : Ordering.size());
5589-
CS.partitionGenericOperators(Choices,
5590-
Ordering.begin() + currIndex,
5591-
Ordering.begin() + nextPartitionIndex,
5592-
Disjunction->getLocator());
5568+
partitionGenericOperators(Ordering.begin() + currIndex,
5569+
Ordering.begin() + nextPartitionIndex);
55935570
needsGenericOperatorOrdering = false;
55945571
}
55955572

55965573
return DisjunctionChoice(CS, currIndex, Choices[Ordering[currIndex]],
55975574
IsExplicitConversion, isBeginningOfPartition);
55985575
}
5576+
5577+
private:
5578+
// Partition the choices in the disjunction into groups that we will
5579+
// iterate over in an order appropriate to attempt to stop before we
5580+
// have to visit all of the options.
5581+
void partitionDisjunction(SmallVectorImpl<unsigned> &Ordering,
5582+
SmallVectorImpl<unsigned> &PartitionBeginning);
5583+
5584+
/// Partition the choices in the range \c first to \c last into groups and
5585+
/// order the groups in the best order to attempt based on the argument
5586+
/// function type that the operator is applied to.
5587+
void partitionGenericOperators(SmallVectorImpl<unsigned>::iterator first,
5588+
SmallVectorImpl<unsigned>::iterator last);
55995589
};
56005590

56015591
/// Determine whether given type is a known one

lib/Sema/CSSolver.cpp

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1878,15 +1878,14 @@ static void existingOperatorBindingsForDisjunction(ConstraintSystem &CS,
18781878
}
18791879
}
18801880

1881-
void ConstraintSystem::partitionGenericOperators(ArrayRef<Constraint *> constraints,
1882-
SmallVectorImpl<unsigned>::iterator first,
1883-
SmallVectorImpl<unsigned>::iterator last,
1884-
ConstraintLocator *locator) {
1885-
auto *argFnType = AppliedDisjunctions[locator];
1886-
if (!isOperatorBindOverload(constraints[0]) || !argFnType)
1881+
void DisjunctionChoiceProducer::partitionGenericOperators(
1882+
SmallVectorImpl<unsigned>::iterator first,
1883+
SmallVectorImpl<unsigned>::iterator last) {
1884+
auto *argFnType = CS.getAppliedDisjunctionArgumentFunction(Disjunction);
1885+
if (!isOperatorBindOverload(Choices.front()) || !argFnType)
18871886
return;
18881887

1889-
auto operatorName = constraints[0]->getOverloadChoice().getName();
1888+
auto operatorName = Choices[0]->getOverloadChoice().getName();
18901889
if (!operatorName.getBaseIdentifier().isArithmeticOperator())
18911890
return;
18921891

@@ -1899,7 +1898,8 @@ void ConstraintSystem::partitionGenericOperators(ArrayRef<Constraint *> constrai
18991898
if (!nominal)
19001899
return false;
19011900

1902-
auto *protocol = TypeChecker::getProtocol(getASTContext(), SourceLoc(), kind);
1901+
auto *protocol =
1902+
TypeChecker::getProtocol(CS.getASTContext(), SourceLoc(), kind);
19031903

19041904
if (auto *refined = dyn_cast<ProtocolDecl>(nominal))
19051905
return refined->inheritsFrom(protocol);
@@ -1911,7 +1911,7 @@ void ConstraintSystem::partitionGenericOperators(ArrayRef<Constraint *> constrai
19111911
// Gather Numeric and Sequence overloads into separate buckets.
19121912
for (auto iter = first; iter != last; ++iter) {
19131913
unsigned index = *iter;
1914-
auto *decl = constraints[index]->getOverloadChoice().getDecl();
1914+
auto *decl = Choices[index]->getOverloadChoice().getDecl();
19151915
auto *nominal = decl->getDeclContext()->getSelfNominalTypeDecl();
19161916
if (!decl->getInterfaceType()->is<GenericFunctionType>()) {
19171917
concreteOverloads.push_back(index);
@@ -1926,8 +1926,10 @@ void ConstraintSystem::partitionGenericOperators(ArrayRef<Constraint *> constrai
19261926

19271927
auto sortPartition = [&](SmallVectorImpl<unsigned> &partition) {
19281928
llvm::sort(partition, [&](unsigned lhs, unsigned rhs) -> bool {
1929-
auto *declA = dyn_cast<ValueDecl>(constraints[lhs]->getOverloadChoice().getDecl());
1930-
auto *declB = dyn_cast<ValueDecl>(constraints[rhs]->getOverloadChoice().getDecl());
1929+
auto *declA =
1930+
dyn_cast<ValueDecl>(Choices[lhs]->getOverloadChoice().getDecl());
1931+
auto *declB =
1932+
dyn_cast<ValueDecl>(Choices[rhs]->getOverloadChoice().getDecl());
19311933

19321934
return TypeChecker::isDeclRefinementOf(declA, declB);
19331935
});
@@ -1946,19 +1948,22 @@ void ConstraintSystem::partitionGenericOperators(ArrayRef<Constraint *> constrai
19461948
// overload choices first.
19471949
for (auto arg : argFnType->getParams()) {
19481950
auto argType = arg.getPlainType();
1949-
argType = getFixedTypeRecursive(argType, /*wantRValue=*/true);
1951+
argType = CS.getFixedTypeRecursive(argType, /*wantRValue=*/true);
19501952

19511953
if (argType->isTypeVariableOrMember())
19521954
continue;
19531955

1954-
if (conformsToKnownProtocol(DC, argType, KnownProtocolKind::AdditiveArithmetic)) {
1955-
first = std::copy(numericOverloads.begin(), numericOverloads.end(), first);
1956+
if (conformsToKnownProtocol(CS.DC, argType,
1957+
KnownProtocolKind::AdditiveArithmetic)) {
1958+
first =
1959+
std::copy(numericOverloads.begin(), numericOverloads.end(), first);
19561960
numericOverloads.clear();
19571961
break;
19581962
}
19591963

1960-
if (conformsToKnownProtocol(DC, argType, KnownProtocolKind::Sequence)) {
1961-
first = std::copy(sequenceOverloads.begin(), sequenceOverloads.end(), first);
1964+
if (conformsToKnownProtocol(CS.DC, argType, KnownProtocolKind::Sequence)) {
1965+
first =
1966+
std::copy(sequenceOverloads.begin(), sequenceOverloads.end(), first);
19621967
sequenceOverloads.clear();
19631968
break;
19641969
}
@@ -1969,17 +1974,23 @@ void ConstraintSystem::partitionGenericOperators(ArrayRef<Constraint *> constrai
19691974
first = std::copy(sequenceOverloads.begin(), sequenceOverloads.end(), first);
19701975
}
19711976

1972-
void ConstraintSystem::partitionDisjunction(
1973-
ArrayRef<Constraint *> Choices, SmallVectorImpl<unsigned> &Ordering,
1977+
void DisjunctionChoiceProducer::partitionDisjunction(
1978+
SmallVectorImpl<unsigned> &Ordering,
19741979
SmallVectorImpl<unsigned> &PartitionBeginning) {
19751980
// Apply a special-case rule for favoring one generic function over
19761981
// another.
1977-
if (auto favored = tryOptimizeGenericDisjunction(DC, Choices)) {
1978-
favorConstraint(favored);
1982+
if (auto favored = tryOptimizeGenericDisjunction(CS.DC, Choices)) {
1983+
CS.favorConstraint(favored);
19791984
}
19801985

19811986
SmallSet<Constraint *, 16> taken;
19821987

1988+
using ConstraintMatcher = std::function<bool(unsigned index, Constraint *)>;
1989+
using ConstraintMatchLoop =
1990+
std::function<void(ArrayRef<Constraint *>, ConstraintMatcher)>;
1991+
using PartitionAppendCallback =
1992+
std::function<void(SmallVectorImpl<unsigned> & options)>;
1993+
19831994
// Local function used to iterate over the untaken choices from the
19841995
// disjunction and use a higher-order function to determine if they
19851996
// should be part of a partition.
@@ -2006,7 +2017,7 @@ void ConstraintSystem::partitionDisjunction(
20062017

20072018
// Add existing operator bindings to the main partition first. This often
20082019
// helps the solver find a solution fast.
2009-
existingOperatorBindingsForDisjunction(*this, Choices, everythingElse);
2020+
existingOperatorBindingsForDisjunction(CS, Choices, everythingElse);
20102021
for (auto index : everythingElse)
20112022
taken.insert(Choices[index]);
20122023

@@ -2026,7 +2037,7 @@ void ConstraintSystem::partitionDisjunction(
20262037
});
20272038

20282039
// Then unavailable constraints if we're skipping them.
2029-
if (!shouldAttemptFixes()) {
2040+
if (!CS.shouldAttemptFixes()) {
20302041
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
20312042
if (constraint->getKind() != ConstraintKind::BindOverload)
20322043
return false;
@@ -2036,7 +2047,7 @@ void ConstraintSystem::partitionDisjunction(
20362047
if (!funcDecl)
20372048
return false;
20382049

2039-
if (!isDeclUnavailable(funcDecl, constraint->getLocator()))
2050+
if (!CS.isDeclUnavailable(funcDecl, constraint->getLocator()))
20402051
return false;
20412052

20422053
unavailable.push_back(index);

0 commit comments

Comments
 (0)