Skip to content

Commit e24ac86

Browse files
committed
[ConstraintSystem] Cache applied disjunction constraints in the
constraint system to use later in DisjunctionStep.
1 parent 0c01b62 commit e24ac86

File tree

5 files changed

+23
-27
lines changed

5 files changed

+23
-27
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2177,6 +2177,11 @@ class ConstraintSystem {
21772177
std::vector<std::pair<ConstraintLocator*, unsigned>>
21782178
DisjunctionChoices;
21792179

2180+
/// A map from applied disjunction constraints to the corresponding
2181+
/// argument function type.
2182+
llvm::SmallMapVector<ConstraintLocator *, const FunctionType *, 4>
2183+
AppliedDisjunctions;
2184+
21802185
/// For locators associated with call expressions, the trailing closure
21812186
/// matching rule that was applied.
21822187
std::vector<std::pair<ConstraintLocator*, TrailingClosureMatching>>
@@ -2669,6 +2674,9 @@ class ConstraintSystem {
26692674
/// The length of \c DisjunctionChoices.
26702675
unsigned numDisjunctionChoices;
26712676

2677+
/// The length of \c AppliedDisjunctions.
2678+
unsigned numAppliedDisjunctions;
2679+
26722680
/// The length of \c trailingClosureMatchingChoices;
26732681
unsigned numTrailingClosureMatchingChoices;
26742682

@@ -3923,7 +3931,6 @@ class ConstraintSystem {
39233931
llvm::function_ref<void(unsigned int, Type, ConstraintLocator *)>
39243932
verifyThatArgumentIsHashable);
39253933

3926-
public:
39273934
/// Describes a direction of optional wrapping, either increasing optionality
39283935
/// or decreasing optionality.
39293936
enum class OptionalWrappingDirection {
@@ -3951,7 +3958,6 @@ class ConstraintSystem {
39513958
TypeVariableType *typeVar, OptionalWrappingDirection optionalDirection,
39523959
llvm::function_ref<bool(Constraint *, TypeVariableType *)> predicate);
39533960

3954-
private:
39553961
/// Attempt to simplify the set of overloads corresponding to a given
39563962
/// function application constraint.
39573963
///
@@ -5460,6 +5466,13 @@ class ConstraintSystem {
54605466
SmallVectorImpl<unsigned> &Ordering,
54615467
SmallVectorImpl<unsigned> &PartitionBeginning);
54625468

5469+
// If the given constraint is an applied disjunction, get the argument function
5470+
// that the disjunction is applied to.
5471+
const FunctionType *getAppliedDisjunctionArgumentFunction(Constraint *disjunction) {
5472+
assert(disjunction->getKind() == ConstraintKind::Disjunction);
5473+
return AppliedDisjunctions[disjunction->getLocator()];
5474+
}
5475+
54635476
/// The overload sets that have already been resolved along the current path.
54645477
const llvm::MapVector<ConstraintLocator *, SelectedOverload> &
54655478
getResolvedOverloads() const {

lib/Sema/CSSimplify.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8893,6 +8893,7 @@ bool ConstraintSystem::simplifyAppliedOverloads(
88938893
auto *applicableFn = result->first;
88948894
auto *fnTypeVar = applicableFn->getSecondType()->castTo<TypeVariableType>();
88958895
auto argFnType = applicableFn->getFirstType()->castTo<FunctionType>();
8896+
AppliedDisjunctions[disjunction->getLocator()] = argFnType;
88968897
return simplifyAppliedOverloadsImpl(disjunction, fnTypeVar, argFnType,
88978898
/*numOptionalUnwraps*/ result->second,
88988899
locator);
@@ -8912,6 +8913,8 @@ bool ConstraintSystem::simplifyAppliedOverloads(
89128913
getUnboundBindOverloadDisjunction(fnTypeVar, &numOptionalUnwraps);
89138914
if (!disjunction)
89148915
return false;
8916+
8917+
AppliedDisjunctions[disjunction->getLocator()] = argFnType;
89158918
return simplifyAppliedOverloadsImpl(disjunction, fnTypeVar, argFnType,
89168919
numOptionalUnwraps, locator);
89178920
}

lib/Sema/CSSolver.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs)
471471
numFixes = cs.Fixes.size();
472472
numFixedRequirements = cs.FixedRequirements.size();
473473
numDisjunctionChoices = cs.DisjunctionChoices.size();
474+
numAppliedDisjunctions = cs.AppliedDisjunctions.size();
474475
numTrailingClosureMatchingChoices = cs.trailingClosureMatchingChoices.size();
475476
numOpenedTypes = cs.OpenedTypes.size();
476477
numOpenedExistentialTypes = cs.OpenedExistentialTypes.size();
@@ -526,6 +527,9 @@ ConstraintSystem::SolverScope::~SolverScope() {
526527
// Remove any disjunction choices.
527528
truncate(cs.DisjunctionChoices, numDisjunctionChoices);
528529

530+
// Remove any applied disjunctions.
531+
truncate(cs.AppliedDisjunctions, numAppliedDisjunctions);
532+
529533
// Remove any trailing closure matching choices;
530534
truncate(
531535
cs.trailingClosureMatchingChoices, numTrailingClosureMatchingChoices);

lib/Sema/CSStep.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,7 @@ bool DisjunctionStep::shouldSkip(const DisjunctionChoice &choice) const {
632632
// introduce any conversions (i.e., the score is not worse than the
633633
// current score), we can skip any generic operators with conformance
634634
// requirements that are not satisfied by any known argument types.
635+
auto argFnType = CS.getAppliedDisjunctionArgumentFunction(Disjunction);
635636
auto bestScore = getBestScore(Solutions);
636637
auto bestChoiceNeedsConversions = bestScore && (bestScore > getCurrentScore());
637638
if (bestScore && !bestChoiceNeedsConversions && choice.isGenericOperator() && argFnType) {

lib/Sema/CSStep.h

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -639,8 +639,6 @@ class DisjunctionStep final : public BindingStep<DisjunctionChoiceProducer> {
639639
Optional<Score> BestNonGenericScore;
640640
Optional<std::pair<Constraint *, Score>> LastSolvedChoice;
641641

642-
FunctionType *argFnType = nullptr;
643-
644642
public:
645643
DisjunctionStep(ConstraintSystem &cs, Constraint *disjunction,
646644
SmallVectorImpl<Solution> &solutions)
@@ -649,29 +647,6 @@ class DisjunctionStep final : public BindingStep<DisjunctionChoiceProducer> {
649647
assert(Disjunction->getKind() == ConstraintKind::Disjunction);
650648
pruneOverloadSet(Disjunction);
651649
++cs.solverState->NumDisjunctions;
652-
653-
// FIXME: This is duplicate (and expensive) work from simplifyAppliedOverloads
654-
auto choices = disjunction->getNestedConstraints();
655-
auto *typeVar = choices.front()->getFirstType()->getAs<TypeVariableType>();
656-
if (!typeVar)
657-
return;
658-
659-
auto result = cs.findConstraintThroughOptionals(
660-
typeVar, ConstraintSystem::OptionalWrappingDirection::Unwrap,
661-
[&](Constraint *match, TypeVariableType *currentRep) {
662-
// Check to see if we have an applicable fn with a type var RHS that
663-
// matches the disjunction.
664-
if (match->getKind() != ConstraintKind::ApplicableFunction)
665-
return false;
666-
667-
auto *rhsTyVar = match->getSecondType()->getAs<TypeVariableType>();
668-
return rhsTyVar && currentRep == cs.getRepresentative(rhsTyVar);
669-
});
670-
671-
if (result) {
672-
auto *applicableFn = result->first;
673-
argFnType = applicableFn->getFirstType()->castTo<FunctionType>();
674-
}
675650
}
676651

677652
~DisjunctionStep() override {

0 commit comments

Comments
 (0)