Skip to content

Commit a3957b9

Browse files
committed
[ConstraintSystem] Only do the work of partitioning the generic operator
overloads if generic operators are not going to be skipped.
1 parent daec9c9 commit a3957b9

File tree

6 files changed

+84
-67
lines changed

6 files changed

+84
-67
lines changed

include/swift/AST/KnownProtocols.def

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ PROTOCOL(Error)
6969
PROTOCOL_(ErrorCodeProtocol)
7070
PROTOCOL(OptionSet)
7171
PROTOCOL(CaseIterable)
72-
PROTOCOL(SIMD)
7372
PROTOCOL(SIMDScalar)
7473
PROTOCOL(BinaryInteger)
7574

include/swift/Sema/ConstraintSystem.h

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5464,8 +5464,15 @@ class ConstraintSystem {
54645464
// have to visit all of the options.
54655465
void partitionDisjunction(ArrayRef<Constraint *> Choices,
54665466
SmallVectorImpl<unsigned> &Ordering,
5467-
SmallVectorImpl<unsigned> &PartitionBeginning,
5468-
ConstraintLocator *locator);
5467+
SmallVectorImpl<unsigned> &PartitionBeginning);
5468+
5469+
/// Partition the choices in the range \c first to \c last into groups and
5470+
/// order the groups in the best order to attempt based on the argument
5471+
/// function type that the operator is applied to.
5472+
void partitionGenericOperators(ArrayRef<Constraint *> Choices,
5473+
SmallVectorImpl<unsigned>::iterator first,
5474+
SmallVectorImpl<unsigned>::iterator last,
5475+
ConstraintLocator *locator);
54695476

54705477
// If the given constraint is an applied disjunction, get the argument function
54715478
// that the disjunction is applied to.
@@ -5975,8 +5982,12 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
59755982

59765983
bool IsExplicitConversion;
59775984

5985+
Constraint *Disjunction;
5986+
59785987
unsigned Index = 0;
59795988

5989+
bool needsGenericOperatorOrdering = true;
5990+
59805991
public:
59815992
using Element = DisjunctionChoice;
59825993

@@ -5985,22 +5996,17 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
59855996
? disjunction->getLocator()
59865997
: nullptr),
59875998
Choices(disjunction->getNestedConstraints()),
5988-
IsExplicitConversion(disjunction->isExplicitConversion()) {
5999+
IsExplicitConversion(disjunction->isExplicitConversion()),
6000+
Disjunction(disjunction) {
59896001
assert(disjunction->getKind() == ConstraintKind::Disjunction);
59906002
assert(!disjunction->shouldRememberChoice() || disjunction->getLocator());
59916003

59926004
// Order and partition the disjunction choices.
5993-
CS.partitionDisjunction(Choices, Ordering, PartitionBeginning, disjunction->getLocator());
6005+
CS.partitionDisjunction(Choices, Ordering, PartitionBeginning);
59946006
}
59956007

5996-
DisjunctionChoiceProducer(ConstraintSystem &cs,
5997-
ArrayRef<Constraint *> choices,
5998-
ConstraintLocator *locator, bool explicitConversion)
5999-
: BindingProducer(cs, locator), Choices(choices),
6000-
IsExplicitConversion(explicitConversion) {
6001-
6002-
// Order and partition the disjunction choices.
6003-
CS.partitionDisjunction(Choices, Ordering, PartitionBeginning, locator);
6008+
void setNeedsGenericOperatorOrdering(bool flag) {
6009+
needsGenericOperatorOrdering = flag;
60046010
}
60056011

60066012
Optional<Element> operator()() override {
@@ -6015,6 +6021,20 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
60156021

60166022
++Index;
60176023

6024+
auto choice = DisjunctionChoice(CS, currIndex, Choices[Ordering[currIndex]],
6025+
IsExplicitConversion, isBeginningOfPartition);
6026+
// Partition the generic operators before producing the first generic
6027+
// operator disjunction choice.
6028+
if (needsGenericOperatorOrdering && choice.isGenericOperator()) {
6029+
unsigned nextPartitionIndex = (PartitionIndex < PartitionBeginning.size() ?
6030+
PartitionBeginning[PartitionIndex] : Ordering.size());
6031+
CS.partitionGenericOperators(Choices,
6032+
Ordering.begin() + currIndex,
6033+
Ordering.begin() + nextPartitionIndex,
6034+
Disjunction->getLocator());
6035+
needsGenericOperatorOrdering = false;
6036+
}
6037+
60186038
return DisjunctionChoice(CS, currIndex, Choices[Ordering[currIndex]],
60196039
IsExplicitConversion, isBeginningOfPartition);
60206040
}

lib/IRGen/GenMeta.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5123,7 +5123,6 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
51235123
case KnownProtocolKind::Differentiable:
51245124
case KnownProtocolKind::FloatingPoint:
51255125
case KnownProtocolKind::Actor:
5126-
case KnownProtocolKind::SIMD:
51275126
return SpecialProtocol::None;
51285127
}
51295128

lib/Sema/CSSolver.cpp

Lines changed: 28 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2077,27 +2077,28 @@ static void existingOperatorBindingsForDisjunction(ConstraintSystem &CS,
20772077
}
20782078
}
20792079

2080-
/// Populates the \c found vector with the indices of the given constraints
2081-
/// that are arithmetic operator choices in the best order to attempt based on
2082-
/// the argument function type that the operator is applied to.
2083-
static void takeArithmeticOperators(ConstraintSystem &CS, const FunctionType *argFnType,
2084-
ArrayRef<Constraint *> constraints,
2085-
SmallVectorImpl<unsigned> &found,
2086-
ConstraintSystem::ConstraintMatchLoop forEachChoice) {
2080+
void ConstraintSystem::partitionGenericOperators(ArrayRef<Constraint *> constraints,
2081+
SmallVectorImpl<unsigned>::iterator first,
2082+
SmallVectorImpl<unsigned>::iterator last,
2083+
ConstraintLocator *locator) {
2084+
auto *argFnType = AppliedDisjunctions[locator];
2085+
if (!isOperatorBindOverload(constraints[0]) || !argFnType)
2086+
return;
2087+
20872088
auto operatorName = constraints[0]->getOverloadChoice().getName();
20882089
if (!operatorName.getBaseIdentifier().isArithmeticOperator())
20892090
return;
20902091

2092+
SmallVector<unsigned, 4> concreteOverloads;
20912093
SmallVector<unsigned, 4> numericOverloads;
20922094
SmallVector<unsigned, 4> sequenceOverloads;
2093-
SmallVector<unsigned, 4> simdOverloads;
20942095
SmallVector<unsigned, 4> otherGenericOverloads;
20952096

20962097
auto refinesOrConformsTo = [&](NominalTypeDecl *nominal, KnownProtocolKind kind) -> bool {
20972098
if (!nominal)
20982099
return false;
20992100

2100-
auto *protocol = TypeChecker::getProtocol(CS.getASTContext(), SourceLoc(), kind);
2101+
auto *protocol = TypeChecker::getProtocol(getASTContext(), SourceLoc(), kind);
21012102

21022103
if (auto *refined = dyn_cast<ProtocolDecl>(nominal))
21032104
return refined->inheritsFrom(protocol);
@@ -2106,29 +2107,21 @@ static void takeArithmeticOperators(ConstraintSystem &CS, const FunctionType *ar
21062107
nominal->getDeclContext());
21072108
};
21082109

2109-
// Gather concrete, Numeric, Sequence, and SIMD overloads into separate buckets.
2110-
forEachChoice(constraints, [&](unsigned index, Constraint *constraint) -> bool {
2111-
auto *decl = constraint->getOverloadChoice().getDecl();
2112-
2113-
if (isSIMDOperator(decl)) {
2114-
simdOverloads.push_back(index);
2115-
return true;
2116-
}
2117-
2110+
// Gather Numeric and Sequence overloads into separate buckets.
2111+
for (auto iter = first; iter != last; ++iter) {
2112+
unsigned index = *iter;
2113+
auto *decl = constraints[index]->getOverloadChoice().getDecl();
21182114
auto *nominal = decl->getDeclContext()->getSelfNominalTypeDecl();
21192115
if (!decl->getInterfaceType()->is<GenericFunctionType>()) {
2120-
// Concrete overloads should always be attempted first.
2121-
found.push_back(index);
2116+
concreteOverloads.push_back(index);
21222117
} else if (refinesOrConformsTo(nominal, KnownProtocolKind::AdditiveArithmetic)) {
21232118
numericOverloads.push_back(index);
21242119
} else if (refinesOrConformsTo(nominal, KnownProtocolKind::Sequence)) {
21252120
sequenceOverloads.push_back(index);
21262121
} else {
21272122
otherGenericOverloads.push_back(index);
21282123
}
2129-
2130-
return true;
2131-
});
2124+
}
21322125

21332126
auto sortPartition = [&](SmallVectorImpl<unsigned> &partition) {
21342127
llvm::sort(partition, [&](unsigned lhs, unsigned rhs) -> bool {
@@ -2144,6 +2137,9 @@ static void takeArithmeticOperators(ConstraintSystem &CS, const FunctionType *ar
21442137
// subsequent choices that the successful choice is a refinement of.
21452138
sortPartition(sequenceOverloads);
21462139

2140+
// Attempt concrete overloads first.
2141+
first = std::copy(concreteOverloads.begin(), concreteOverloads.end(), first);
2142+
21472143
// Check if any of the known argument types conform to one of the standard
21482144
// arithmetic protocols. If so, the sovler should attempt the corresponding
21492145
// overload choices first.
@@ -2152,34 +2148,27 @@ static void takeArithmeticOperators(ConstraintSystem &CS, const FunctionType *ar
21522148
if (!argType || argType->hasTypeVariable())
21532149
continue;
21542150

2155-
if (conformsToKnownProtocol(CS.DC, argType, KnownProtocolKind::AdditiveArithmetic)) {
2156-
found.append(numericOverloads.begin(), numericOverloads.end());
2151+
if (conformsToKnownProtocol(DC, argType, KnownProtocolKind::AdditiveArithmetic)) {
2152+
first = std::copy(numericOverloads.begin(), numericOverloads.end(), first);
21572153
numericOverloads.clear();
21582154
break;
21592155
}
21602156

2161-
if (conformsToKnownProtocol(CS.DC, argType, KnownProtocolKind::Sequence)) {
2162-
found.append(sequenceOverloads.begin(), sequenceOverloads.end());
2157+
if (conformsToKnownProtocol(DC, argType, KnownProtocolKind::Sequence)) {
2158+
first = std::copy(sequenceOverloads.begin(), sequenceOverloads.end(), first);
21632159
sequenceOverloads.clear();
21642160
break;
21652161
}
2166-
2167-
if (conformsToKnownProtocol(CS.DC, argType, KnownProtocolKind::SIMD)) {
2168-
found.append(simdOverloads.begin(), simdOverloads.end());
2169-
simdOverloads.clear();
2170-
break;
2171-
}
21722162
}
21732163

2174-
found.append(otherGenericOverloads.begin(), otherGenericOverloads.end());
2175-
found.append(numericOverloads.begin(), numericOverloads.end());
2176-
found.append(sequenceOverloads.begin(), sequenceOverloads.end());
2177-
found.append(simdOverloads.begin(), simdOverloads.end());
2164+
first = std::copy(otherGenericOverloads.begin(), otherGenericOverloads.end(), first);
2165+
first = std::copy(numericOverloads.begin(), numericOverloads.end(), first);
2166+
first = std::copy(sequenceOverloads.begin(), sequenceOverloads.end(), first);
21782167
}
21792168

21802169
void ConstraintSystem::partitionDisjunction(
21812170
ArrayRef<Constraint *> Choices, SmallVectorImpl<unsigned> &Ordering,
2182-
SmallVectorImpl<unsigned> &PartitionBeginning, ConstraintLocator *locator) {
2171+
SmallVectorImpl<unsigned> &PartitionBeginning) {
21832172
// Apply a special-case rule for favoring one generic function over
21842173
// another.
21852174
if (auto favored = tryOptimizeGenericDisjunction(DC, Choices)) {
@@ -2252,11 +2241,8 @@ void ConstraintSystem::partitionDisjunction(
22522241
});
22532242
}
22542243

2255-
// Gather arithmetic and SIMD operators.
2244+
// Partition SIMD operators.
22562245
if (isOperatorBindOverload(Choices[0])) {
2257-
if (auto *argFnType = AppliedDisjunctions[locator])
2258-
takeArithmeticOperators(*this, argFnType, Choices, everythingElse, forEachChoice);
2259-
22602246
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
22612247
if (!isOperatorBindOverload(constraint))
22622248
return false;

lib/Sema/CSStep.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,15 @@ StepResult DisjunctionStep::resume(bool prevFailed) {
510510
auto score = getBestScore(Solutions);
511511

512512
if (!choice.isGenericOperator() && choice.isSymmetricOperator()) {
513-
if (!BestNonGenericScore || score < BestNonGenericScore)
513+
if (!BestNonGenericScore || score < BestNonGenericScore) {
514514
BestNonGenericScore = score;
515+
if (shouldSkipGenericOperators()) {
516+
// The disjunction choice producer shouldn't do the work
517+
// to partition the generic operator choices if generic
518+
// operators are going to be skipped.
519+
Producer.setNeedsGenericOperatorOrdering(false);
520+
}
521+
}
515522
}
516523

517524
AnySolved = true;
@@ -673,16 +680,8 @@ bool DisjunctionStep::shouldSkip(const DisjunctionChoice &choice) const {
673680
// already have a solution involving non-generic operators,
674681
// but continue looking for a better non-generic operator
675682
// solution.
676-
if (BestNonGenericScore && choice.isGenericOperator()) {
677-
auto &score = BestNonGenericScore->Data;
678-
// Let's skip generic overload choices only in case if
679-
// non-generic score indicates that there were no forced
680-
// unwrappings of optional(s), no unavailable overload
681-
// choices present in the solution, no fixes required,
682-
// and there are no non-trivial function conversions.
683-
if (score[SK_ForceUnchecked] == 0 && score[SK_Unavailable] == 0 &&
684-
score[SK_Fix] == 0 && score[SK_FunctionConversion] == 0)
685-
return skip("generic");
683+
if (shouldSkipGenericOperators() && choice.isGenericOperator()) {
684+
return skip("generic");
686685
}
687686

688687
return false;

lib/Sema/CSStep.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,20 @@ class DisjunctionStep final : public BindingStep<DisjunctionChoiceProducer> {
683683
bool shortCircuitDisjunctionAt(Constraint *currentChoice,
684684
Constraint *lastSuccessfulChoice) const;
685685

686+
bool shouldSkipGenericOperators() const {
687+
if (!BestNonGenericScore)
688+
return false;
689+
690+
// Let's skip generic overload choices only in case if
691+
// non-generic score indicates that there were no forced
692+
// unwrappings of optional(s), no unavailable overload
693+
// choices present in the solution, no fixes required,
694+
// and there are no non-trivial function conversions.
695+
auto &score = BestNonGenericScore->Data;
696+
return (score[SK_ForceUnchecked] == 0 && score[SK_Unavailable] == 0 &&
697+
score[SK_Fix] == 0 && score[SK_FunctionConversion] == 0);
698+
}
699+
686700
/// Attempt to apply given disjunction choice to constraint system.
687701
/// This action is going to establish "active choice" of this disjunction
688702
/// to point to a given choice.

0 commit comments

Comments
 (0)