Skip to content

Commit 50212a0

Browse files
committed
[ConstraintSystem] If the solver has already found a solution with a
disjunction choice that does not introduce conversions, check to see if known argument types satisfy generic operator conformance requirements early, and skip the overload choice if any requirements fail. This helps the solver avoid exploring way too much search space when the right solution involves a generic operator, but the argument types are known up front, such as `collection + collection + collection`.
1 parent bda36cf commit 50212a0

File tree

3 files changed

+65
-1
lines changed

3 files changed

+65
-1
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3923,6 +3923,7 @@ class ConstraintSystem {
39233923
llvm::function_ref<void(unsigned int, Type, ConstraintLocator *)>
39243924
verifyThatArgumentIsHashable);
39253925

3926+
public:
39263927
/// Describes a direction of optional wrapping, either increasing optionality
39273928
/// or decreasing optionality.
39283929
enum class OptionalWrappingDirection {
@@ -3950,6 +3951,7 @@ class ConstraintSystem {
39503951
TypeVariableType *typeVar, OptionalWrappingDirection optionalDirection,
39513952
llvm::function_ref<bool(Constraint *, TypeVariableType *)> predicate);
39523953

3954+
private:
39533955
/// Attempt to simplify the set of overloads corresponding to a given
39543956
/// function application constraint.
39553957
///

lib/Sema/CSStep.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
//===----------------------------------------------------------------------===//
1717

1818
#include "CSStep.h"
19+
#include "TypeChecker.h"
1920
#include "swift/AST/Types.h"
2021
#include "swift/Sema/ConstraintSystem.h"
2122
#include "llvm/ADT/ArrayRef.h"
@@ -555,6 +556,40 @@ bool DisjunctionStep::shouldSkip(const DisjunctionChoice &choice) const {
555556
if (ctx.TypeCheckerOpts.DisableConstraintSolverPerformanceHacks)
556557
return false;
557558

559+
// If the solver already found a solution with a choice that did not
560+
// introduce any conversions (i.e., the score is not worse than the
561+
// current score), we can skip any generic operators with conformance
562+
// requirements that are not satisfied by any known argument types.
563+
auto bestScore = getBestScore(Solutions);
564+
auto bestChoiceNeedsConversions = bestScore && (bestScore > getCurrentScore());
565+
if (!bestChoiceNeedsConversions && choice.isGenericOperator() && argFnType) {
566+
Constraint *constraint = choice;
567+
auto *decl = constraint->getOverloadChoice().getDecl();
568+
auto *useDC = constraint->getOverloadUseDC();
569+
auto choiceType = CS.getEffectiveOverloadType(constraint->getOverloadChoice(),
570+
/*allowMembers=*/true, useDC);
571+
auto choiceFnType = choiceType->getAs<FunctionType>();
572+
auto genericFnType = decl->getInterfaceType()->getAs<GenericFunctionType>();
573+
auto signature = genericFnType->getGenericSignature();
574+
575+
for (auto argParamPair : llvm::zip(argFnType->getParams(),
576+
choiceFnType->getParams())) {
577+
auto argType = std::get<0>(argParamPair).getPlainType();
578+
auto paramType = std::get<1>(argParamPair).getPlainType();
579+
580+
// Only check argument types with no type variables that will be matched
581+
// against a plain type parameter.
582+
argType = argType->getCanonicalType()->getWithoutSpecifierType();
583+
if (argType->hasTypeVariable() || !paramType->isTypeParameter())
584+
continue;
585+
586+
for (auto *protocol : signature->getRequiredProtocols(paramType)) {
587+
if (!TypeChecker::conformsToProtocol(argType, protocol, useDC))
588+
return skip("unsatisfied");
589+
}
590+
}
591+
}
592+
558593
// Don't attempt to solve for generic operators if we already have
559594
// a non-generic solution.
560595

lib/Sema/CSStep.h

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,8 @@ class DisjunctionStep final : public BindingStep<DisjunctionChoiceProducer> {
639639
Optional<Score> BestNonGenericScore;
640640
Optional<std::pair<Constraint *, Score>> LastSolvedChoice;
641641

642+
FunctionType *argFnType = nullptr;
643+
642644
public:
643645
DisjunctionStep(ConstraintSystem &cs, Constraint *disjunction,
644646
SmallVectorImpl<Solution> &solutions)
@@ -647,6 +649,29 @@ class DisjunctionStep final : public BindingStep<DisjunctionChoiceProducer> {
647649
assert(Disjunction->getKind() == ConstraintKind::Disjunction);
648650
pruneOverloadSet(Disjunction);
649651
++cs.solverState->NumDisjunctions;
652+
653+
// FIXME: This is duplicate (and expensive) work from simplifyAppliedOverloads
654+
auto choices = disjunction->getNestedConstraints();
655+
auto *typeVar = choices.front()->getFirstType()->getAs<TypeVariableType>();
656+
if (!typeVar)
657+
return;
658+
659+
auto result = cs.findConstraintThroughOptionals(
660+
typeVar, ConstraintSystem::OptionalWrappingDirection::Unwrap,
661+
[&](Constraint *match, TypeVariableType *currentRep) {
662+
// Check to see if we have an applicable fn with a type var RHS that
663+
// matches the disjunction.
664+
if (match->getKind() != ConstraintKind::ApplicableFunction)
665+
return false;
666+
667+
auto *rhsTyVar = match->getSecondType()->getAs<TypeVariableType>();
668+
return rhsTyVar && currentRep == cs.getRepresentative(rhsTyVar);
669+
});
670+
671+
if (result) {
672+
auto *applicableFn = result->first;
673+
argFnType = applicableFn->getFirstType()->castTo<FunctionType>();
674+
}
650675
}
651676

652677
~DisjunctionStep() override {
@@ -732,7 +757,9 @@ class DisjunctionStep final : public BindingStep<DisjunctionChoiceProducer> {
732757

733758
// Figure out which of the solutions has the smallest score.
734759
static Optional<Score> getBestScore(SmallVectorImpl<Solution> &solutions) {
735-
assert(!solutions.empty());
760+
if (solutions.empty())
761+
return None;
762+
736763
Score bestScore = solutions.front().getFixedScore();
737764
if (solutions.size() == 1)
738765
return bestScore;

0 commit comments

Comments
 (0)