Skip to content

Commit 2edba9d

Browse files
gregomnihborla
authored andcommitted
Instead of chaining binops, favor disjunctions with op overloads whose types match existing binding choices
1 parent e0199f2 commit 2edba9d

File tree

4 files changed

+68
-130
lines changed

4 files changed

+68
-130
lines changed

include/swift/Sema/Constraint.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,15 @@ class Constraint final : public llvm::ilist_node<Constraint>,
672672
return Nested;
673673
}
674674

675+
unsigned countFavoredNestedConstraints() const {
676+
unsigned count = 0;
677+
for (auto *constraint : Nested)
678+
if (constraint->isFavored() && !constraint->isDisabled())
679+
count++;
680+
681+
return count;
682+
}
683+
675684
unsigned countActiveNestedConstraints() const {
676685
unsigned count = 0;
677686
for (auto *constraint : Nested)

lib/Sema/CSGen.cpp

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,6 @@ static bool isArithmeticOperatorDecl(ValueDecl *vd) {
4646
vd->getBaseName() == "%");
4747
}
4848

49-
static bool hasBinOpOverloadWithSameArgTypesDifferingResult(
50-
OverloadedDeclRefExpr *overloads) {
51-
for (auto decl : overloads->getDecls()) {
52-
auto metaFuncType = decl->getInterfaceType()->getAs<AnyFunctionType>();
53-
auto funcType = metaFuncType->getResult()->getAs<FunctionType>();
54-
if (!funcType)
55-
continue;
56-
57-
auto params = funcType->getParams();
58-
if (params.size() != 2)
59-
continue;
60-
61-
if (params[0].getPlainType().getPointer() != params[1].getPlainType().getPointer())
62-
continue;
63-
if (funcType->getResult().getPointer() != params[0].getPlainType().getPointer())
64-
return true;
65-
}
66-
return false;
67-
}
68-
6949
static bool mergeRepresentativeEquivalenceClasses(ConstraintSystem &CS,
7050
TypeVariableType* tyvar1,
7151
TypeVariableType* tyvar2) {
@@ -323,75 +303,6 @@ namespace {
323303
// solver can still make progress.
324304
auto favoredTy = (*lti.collectedTypes.begin())->getWithoutSpecifierType();
325305
CS.setFavoredType(expr, favoredTy.getPointer());
326-
327-
// If we have a chain of identical binop expressions with homogeneous
328-
// argument types, we can directly simplify the associated constraint
329-
// graph.
330-
auto simplifyBinOpExprTyVars = [&]() {
331-
// Don't attempt to do linking if there are
332-
// literals intermingled with other inferred types.
333-
if (lti.hasLiteral)
334-
return;
335-
336-
for (auto binExp1 : lti.binaryExprs) {
337-
for (auto binExp2 : lti.binaryExprs) {
338-
if (binExp1 == binExp2)
339-
continue;
340-
341-
auto fnTy1 = CS.getType(binExp1)->getAs<TypeVariableType>();
342-
auto fnTy2 = CS.getType(binExp2)->getAs<TypeVariableType>();
343-
344-
if (!(fnTy1 && fnTy2))
345-
return;
346-
347-
auto ODR1 = dyn_cast<OverloadedDeclRefExpr>(binExp1->getFn());
348-
auto ODR2 = dyn_cast<OverloadedDeclRefExpr>(binExp2->getFn());
349-
350-
if (!(ODR1 && ODR2))
351-
return;
352-
353-
// TODO: We currently limit this optimization to known arithmetic
354-
// operators, but we should be able to broaden this out to
355-
// logical operators as well.
356-
if (!isArithmeticOperatorDecl(ODR1->getDecls()[0]))
357-
return;
358-
359-
if (ODR1->getDecls()[0]->getBaseName() !=
360-
ODR2->getDecls()[0]->getBaseName())
361-
return;
362-
363-
if (hasBinOpOverloadWithSameArgTypesDifferingResult(ODR1))
364-
return;
365-
366-
// All things equal, we can merge the tyvars for the function
367-
// types.
368-
auto rep1 = CS.getRepresentative(fnTy1);
369-
auto rep2 = CS.getRepresentative(fnTy2);
370-
371-
if (rep1 != rep2) {
372-
CS.mergeEquivalenceClasses(rep1, rep2,
373-
/*updateWorkList*/ false);
374-
}
375-
376-
auto odTy1 = CS.getType(ODR1)->getAs<TypeVariableType>();
377-
auto odTy2 = CS.getType(ODR2)->getAs<TypeVariableType>();
378-
379-
if (odTy1 && odTy2) {
380-
auto odRep1 = CS.getRepresentative(odTy1);
381-
auto odRep2 = CS.getRepresentative(odTy2);
382-
383-
// Since we'll be choosing the same overload, we can merge
384-
// the overload tyvar as well.
385-
if (odRep1 != odRep2)
386-
CS.mergeEquivalenceClasses(odRep1, odRep2,
387-
/*updateWorkList*/ false);
388-
}
389-
}
390-
}
391-
};
392-
393-
simplifyBinOpExprTyVars();
394-
395306
return true;
396307
}
397308

lib/Sema/CSSolver.cpp

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2013,9 +2013,47 @@ static Constraint *tryOptimizeGenericDisjunction(
20132013
llvm_unreachable("covered switch");
20142014
}
20152015

2016+
// Performance hack: favor operator overloads with types we're already binding elsewhere in this expression.
2017+
static void existingOperatorBindingsForDisjunction(ConstraintSystem &CS, ArrayRef<Constraint *> constraints, SmallVectorImpl<Constraint *> &found) {
2018+
auto *choice = constraints.front();
2019+
if (choice->getKind() != ConstraintKind::BindOverload)
2020+
return;
2021+
2022+
auto overload = choice->getOverloadChoice();
2023+
if (!overload.isDecl())
2024+
return;
2025+
auto decl = overload.getDecl();
2026+
if (!decl->isOperator())
2027+
return;
2028+
2029+
for (auto *resolved = CS.getResolvedOverloadSets(); resolved;
2030+
resolved = resolved->Previous) {
2031+
if (!resolved->Choice.isDecl())
2032+
continue;
2033+
auto representativeDecl = resolved->Choice.getDecl();
2034+
2035+
if (!representativeDecl->isOperator())
2036+
continue;
2037+
2038+
for (auto *constraint : constraints) {
2039+
if (constraint->isFavored())
2040+
continue;
2041+
auto choice = constraint->getOverloadChoice();
2042+
if (choice.getDecl()->getInterfaceType()->isEqual(representativeDecl->getInterfaceType()))
2043+
found.push_back(constraint);
2044+
}
2045+
}
2046+
}
2047+
20162048
void ConstraintSystem::partitionDisjunction(
20172049
ArrayRef<Constraint *> Choices, SmallVectorImpl<unsigned> &Ordering,
20182050
SmallVectorImpl<unsigned> &PartitionBeginning) {
2051+
2052+
SmallVector<Constraint *, 4> existing;
2053+
existingOperatorBindingsForDisjunction(*this, Choices, existing);
2054+
for (auto constraint : existing)
2055+
favorConstraint(constraint);
2056+
20192057
// Apply a special-case rule for favoring one generic function over
20202058
// another.
20212059
if (auto favored = tryOptimizeGenericDisjunction(DC, Choices)) {
@@ -2134,12 +2172,30 @@ Constraint *ConstraintSystem::selectDisjunction() {
21342172
if (auto *disjunction = selectBestBindingDisjunction(*this, disjunctions))
21352173
return disjunction;
21362174

2137-
// Pick the disjunction with the smallest number of active choices.
2175+
// Pick the disjunction with the smallest number of favored, then active choices.
2176+
auto cs = this;
21382177
auto minDisjunction =
21392178
std::min_element(disjunctions.begin(), disjunctions.end(),
21402179
[&](Constraint *first, Constraint *second) -> bool {
2141-
return first->countActiveNestedConstraints() <
2142-
second->countActiveNestedConstraints();
2180+
unsigned firstFavored = first->countFavoredNestedConstraints();
2181+
unsigned secondFavored = second->countFavoredNestedConstraints();
2182+
2183+
if (firstFavored == secondFavored) {
2184+
// Look for additional choices to favor
2185+
SmallVector<Constraint *, 4> firstExisting;
2186+
SmallVector<Constraint *, 4> secondExisting;
2187+
2188+
existingOperatorBindingsForDisjunction(*cs, first->getNestedConstraints(), firstExisting);
2189+
firstFavored = firstExisting.size() ?: first->countActiveNestedConstraints();
2190+
existingOperatorBindingsForDisjunction(*cs, second->getNestedConstraints(), secondExisting);
2191+
secondFavored = secondExisting.size() ?: second->countActiveNestedConstraints();
2192+
2193+
return firstFavored < secondFavored;
2194+
} else {
2195+
firstFavored = firstFavored ?: first->countActiveNestedConstraints();
2196+
secondFavored = secondFavored ?: second->countActiveNestedConstraints();
2197+
return firstFavored < secondFavored;
2198+
}
21432199
});
21442200

21452201
if (minDisjunction != disjunctions.end())

lib/Sema/CSStep.h

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,6 @@ class DisjunctionStep final : public BindingStep<DisjunctionChoiceProducer> {
654654
: BindingStep(cs, {cs, disjunction}, solutions), Disjunction(disjunction),
655655
AfterDisjunction(erase(disjunction)) {
656656
assert(Disjunction->getKind() == ConstraintKind::Disjunction);
657-
pruneOverloadSet(Disjunction);
658657
++cs.solverState->NumDisjunctions;
659658
}
660659

@@ -702,43 +701,6 @@ class DisjunctionStep final : public BindingStep<DisjunctionChoiceProducer> {
702701
/// simplified further, false otherwise.
703702
bool attempt(const DisjunctionChoice &choice) override;
704703

705-
// Check if selected disjunction has a representative
706-
// this might happen when there are multiple binary operators
707-
// chained together. If so, disable choices which differ
708-
// from currently selected representative.
709-
void pruneOverloadSet(Constraint *disjunction) {
710-
auto *choice = disjunction->getNestedConstraints().front();
711-
auto *typeVar = choice->getFirstType()->getAs<TypeVariableType>();
712-
if (!typeVar)
713-
return;
714-
715-
auto *repr = typeVar->getImpl().getRepresentative(nullptr);
716-
if (!repr || repr == typeVar)
717-
return;
718-
719-
for (auto elt : getResolvedOverloads()) {
720-
auto resolved = elt.second;
721-
if (!resolved.boundType->isEqual(repr))
722-
continue;
723-
724-
auto &representative = resolved.choice;
725-
if (!representative.isDecl())
726-
return;
727-
728-
// Disable all of the overload choices which are different from
729-
// the one which is currently picked for representative.
730-
for (auto *constraint : disjunction->getNestedConstraints()) {
731-
auto choice = constraint->getOverloadChoice();
732-
if (!choice.isDecl() || choice.getDecl() == representative.getDecl())
733-
continue;
734-
735-
constraint->setDisabled();
736-
DisabledChoices.push_back(constraint);
737-
}
738-
break;
739-
}
740-
};
741-
742704
// Figure out which of the solutions has the smallest score.
743705
static Optional<Score> getBestScore(SmallVectorImpl<Solution> &solutions) {
744706
assert(!solutions.empty());

0 commit comments

Comments
 (0)