Skip to content

Commit 8b4a58f

Browse files
committed
Sema: Record applied disjunctions in the trail
1 parent 8799596 commit 8b4a58f

File tree

6 files changed

+66
-12
lines changed

6 files changed

+66
-12
lines changed

include/swift/Sema/CSTrail.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class SolverTrail {
6161
AddedFixedRequirement,
6262
/// Recorded a disjunction choice.
6363
RecordedDisjunctionChoice,
64+
/// Recorded an applied disjunction.
65+
RecordedAppliedDisjunction,
6466
};
6567

6668
/// A change made to the constraint system.
@@ -177,6 +179,9 @@ class SolverTrail {
177179
static Change recordedDisjunctionChoice(ConstraintLocator *locator,
178180
unsigned index);
179181

182+
/// Create a change that recorded an applied disjunction.
183+
static Change recordedAppliedDisjunction(ConstraintLocator *locator);
184+
180185
/// Undo this change, reverting the constraint graph to the state it
181186
/// had prior to this change.
182187
///

include/swift/Sema/ConstraintSystem.h

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,6 +1526,10 @@ class Solution {
15261526
/// which informs constraint application.
15271527
llvm::DenseMap<ConstraintLocator *, unsigned> DisjunctionChoices;
15281528

1529+
/// A map from applied disjunction constraints to the corresponding
1530+
/// argument function type.
1531+
llvm::DenseMap<ConstraintLocator *, FunctionType *> AppliedDisjunctions;
1532+
15291533
/// The set of opened types for a given locator.
15301534
llvm::DenseMap<ConstraintLocator *, ArrayRef<OpenedType>> OpenedTypes;
15311535

@@ -2332,7 +2336,7 @@ class ConstraintSystem {
23322336

23332337
/// A map from applied disjunction constraints to the corresponding
23342338
/// argument function type.
2335-
llvm::SmallMapVector<ConstraintLocator *, const FunctionType *, 4>
2339+
llvm::SmallDenseMap<ConstraintLocator *, FunctionType *, 4>
23362340
AppliedDisjunctions;
23372341

23382342
/// For locators associated with call expressions, the trailing closure
@@ -2879,9 +2883,6 @@ class ConstraintSystem {
28792883
/// FIXME: Remove this.
28802884
unsigned numFixes;
28812885

2882-
/// The length of \c AppliedDisjunctions.
2883-
unsigned numAppliedDisjunctions;
2884-
28852886
/// The length of \c argumentMatchingChoices.
28862887
unsigned numArgumentMatchingChoices;
28872888

@@ -5325,6 +5326,15 @@ class ConstraintSystem {
53255326
ASSERT(erased);
53265327
}
53275328

5329+
/// Record applied disjunction and add a change to the trail.
5330+
void recordAppliedDisjunction(ConstraintLocator *locator,
5331+
FunctionType *type);
5332+
5333+
/// Undo the above change.
5334+
void removeAppliedDisjunction(ConstraintLocator *locator) {
5335+
bool erased = AppliedDisjunctions.erase(locator);
5336+
ASSERT(erased);
5337+
}
53285338

53295339
/// Filter the set of disjunction terms, keeping only those where the
53305340
/// predicate returns \c true.
@@ -5629,9 +5639,12 @@ class ConstraintSystem {
56295639

56305640
// If the given constraint is an applied disjunction, get the argument function
56315641
// that the disjunction is applied to.
5632-
const FunctionType *getAppliedDisjunctionArgumentFunction(const Constraint *disjunction) {
5642+
FunctionType *getAppliedDisjunctionArgumentFunction(const Constraint *disjunction) {
56335643
assert(disjunction->getKind() == ConstraintKind::Disjunction);
5634-
return AppliedDisjunctions[disjunction->getLocator()];
5644+
auto found = AppliedDisjunctions.find(disjunction->getLocator());
5645+
if (found == AppliedDisjunctions.end())
5646+
return nullptr;
5647+
return found->second;
56355648
}
56365649

56375650
/// The overload sets that have already been resolved along the current path.

lib/Sema/CSSimplify.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12823,7 +12823,7 @@ bool ConstraintSystem::simplifyAppliedOverloads(
1282312823
auto *applicableFn = result->first;
1282412824
auto *fnTypeVar = applicableFn->getSecondType()->castTo<TypeVariableType>();
1282512825
auto argFnType = applicableFn->getFirstType()->castTo<FunctionType>();
12826-
AppliedDisjunctions[disjunction->getLocator()] = argFnType;
12826+
recordAppliedDisjunction(disjunction->getLocator(), argFnType);
1282712827
return simplifyAppliedOverloadsImpl(disjunction, fnTypeVar, argFnType,
1282812828
/*numOptionalUnwraps*/ result->second,
1282912829
applicableFn->getLocator());
@@ -12843,7 +12843,7 @@ bool ConstraintSystem::simplifyAppliedOverloads(
1284312843
if (!disjunction)
1284412844
return false;
1284512845

12846-
AppliedDisjunctions[disjunction->getLocator()] = argFnType;
12846+
recordAppliedDisjunction(disjunction->getLocator(), argFnType);
1284712847
return simplifyAppliedOverloadsImpl(disjunction, fnTypeVar, argFnType,
1284812848
numOptionalUnwraps, locator);
1284912849
}

lib/Sema/CSSolver.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ Solution ConstraintSystem::finalize() {
146146
solution.DisjunctionChoices.insert(choice);
147147
}
148148

149+
// Remember all the applied disjunctions.
150+
for (auto &choice : AppliedDisjunctions) {
151+
solution.AppliedDisjunctions.insert(choice);
152+
}
153+
149154
// Remember all of the argument/parameter matching choices we made.
150155
for (auto &argumentMatch : argumentMatchingChoices) {
151156
auto inserted = solution.argumentMatchingChoices.insert(argumentMatch);
@@ -305,6 +310,11 @@ void ConstraintSystem::applySolution(const Solution &solution) {
305310
recordDisjunctionChoice(choice.first, choice.second);
306311
}
307312

313+
// Register the solution's applied disjunctions.
314+
for (auto &choice : solution.AppliedDisjunctions) {
315+
recordAppliedDisjunction(choice.first, choice.second);
316+
}
317+
308318
// Remember all of the argument/parameter matching choices we made.
309319
for (auto &argumentMatch : solution.argumentMatchingChoices) {
310320
argumentMatchingChoices.insert(argumentMatch);
@@ -662,7 +672,6 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs)
662672

663673
numTypeVariables = cs.TypeVariables.size();
664674
numFixes = cs.Fixes.size();
665-
numAppliedDisjunctions = cs.AppliedDisjunctions.size();
666675
numArgumentMatchingChoices = cs.argumentMatchingChoices.size();
667676
numOpenedTypes = cs.OpenedTypes.size();
668677
numOpenedExistentialTypes = cs.OpenedExistentialTypes.size();
@@ -726,9 +735,6 @@ ConstraintSystem::SolverScope::~SolverScope() {
726735
// constraints introduced by the current scope.
727736
cs.solverState->rollback(this);
728737

729-
// Remove any applied disjunctions.
730-
truncate(cs.AppliedDisjunctions, numAppliedDisjunctions);
731-
732738
// Remove any argument matching choices;
733739
truncate(cs.argumentMatchingChoices, numArgumentMatchingChoices);
734740

lib/Sema/CSTrail.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,14 @@ SolverTrail::Change::recordedDisjunctionChoice(ConstraintLocator *locator,
161161
return result;
162162
}
163163

164+
SolverTrail::Change
165+
SolverTrail::Change::recordedAppliedDisjunction(ConstraintLocator *locator) {
166+
Change result;
167+
result.Kind = ChangeKind::RecordedAppliedDisjunction;
168+
result.Locator = locator;
169+
return result;
170+
}
171+
164172
void SolverTrail::Change::undo(ConstraintSystem &cs) const {
165173
auto &cg = cs.getConstraintGraph();
166174

@@ -217,6 +225,10 @@ void SolverTrail::Change::undo(ConstraintSystem &cs) const {
217225
case ChangeKind::RecordedDisjunctionChoice:
218226
cs.removeDisjunctionChoice(Locator);
219227
break;
228+
229+
case ChangeKind::RecordedAppliedDisjunction:
230+
cs.removeAppliedDisjunction(Locator);
231+
break;
220232
}
221233
}
222234

@@ -334,6 +346,12 @@ void SolverTrail::Change::dump(llvm::raw_ostream &out,
334346
out << " index ";
335347
out << Options << ")\n";
336348
break;
349+
350+
case ChangeKind::RecordedAppliedDisjunction:
351+
out << "(recorded applied disjunction at ";
352+
Locator->dump(&cs.getASTContext().SourceMgr, out);
353+
out << ")\n";
354+
break;
337355
}
338356
}
339357

lib/Sema/ConstraintSystem.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,18 @@ void ConstraintSystem::recordDisjunctionChoice(
310310
}
311311
}
312312

313+
void ConstraintSystem::recordAppliedDisjunction(
314+
ConstraintLocator *locator, FunctionType *fnType) {
315+
// We shouldn't ever register disjunction choices multiple times.
316+
auto inserted = AppliedDisjunctions.insert(
317+
std::make_pair(locator, fnType));
318+
if (inserted.second) {
319+
if (isRecordingChanges()) {
320+
recordChange(SolverTrail::Change::recordedAppliedDisjunction(locator));
321+
}
322+
}
323+
}
324+
313325
/// Retrieve a dynamic result signature for the given declaration.
314326
static std::tuple<char, ObjCSelector, CanType>
315327
getDynamicResultSignature(ValueDecl *decl) {

0 commit comments

Comments
 (0)