Skip to content

Commit 64293ec

Browse files
committed
Sema: Push reconciliation down into applySolution() to strengthen invariants
Now, we assert if you try to record the same change twice in any other code path.
1 parent 61575d9 commit 64293ec

File tree

4 files changed

+76
-67
lines changed

4 files changed

+76
-67
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4422,7 +4422,8 @@ class ConstraintSystem {
44224422
/// Record the set of opened types for the given locator.
44234423
void recordOpenedTypes(
44244424
ConstraintLocatorBuilder locator,
4425-
const OpenedTypeMap &replacements);
4425+
const OpenedTypeMap &replacements,
4426+
bool fixmeAllowDuplicates=false);
44264427

44274428
/// Check whether the given type conforms to the given protocol and if
44284429
/// so return a valid conformance reference.

lib/Sema/CSSimplify.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14927,10 +14927,10 @@ void ConstraintSystem::recordMatchCallArgumentResult(
1492714927
ConstraintLocator *locator, MatchCallArgumentResult result) {
1492814928
assert(locator->isLastElement<LocatorPathElt::ApplyArgument>());
1492914929
bool inserted = argumentMatchingChoices.insert({locator, result}).second;
14930-
if (inserted) {
14931-
if (isRecordingChanges())
14932-
recordChange(SolverTrail::Change::recordedMatchCallArgumentResult(locator));
14933-
}
14930+
ASSERT(inserted);
14931+
14932+
if (solverState)
14933+
recordChange(SolverTrail::Change::recordedMatchCallArgumentResult(locator));
1493414934
}
1493514935

1493614936
void ConstraintSystem::recordCallAsFunction(UnresolvedDotExpr *root,

lib/Sema/CSSolver.cpp

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -304,17 +304,20 @@ void ConstraintSystem::applySolution(const Solution &solution) {
304304

305305
// Register the solution's disjunction choices.
306306
for (auto &choice : solution.DisjunctionChoices) {
307-
recordDisjunctionChoice(choice.first, choice.second);
307+
if (DisjunctionChoices.count(choice.first) == 0)
308+
recordDisjunctionChoice(choice.first, choice.second);
308309
}
309310

310311
// Register the solution's applied disjunctions.
311312
for (auto &choice : solution.AppliedDisjunctions) {
312-
recordAppliedDisjunction(choice.first, choice.second);
313+
if (AppliedDisjunctions.count(choice.first) == 0)
314+
recordAppliedDisjunction(choice.first, choice.second);
313315
}
314316

315317
// Remember all of the argument/parameter matching choices we made.
316318
for (auto &argumentMatch : solution.argumentMatchingChoices) {
317-
recordMatchCallArgumentResult(argumentMatch.first, argumentMatch.second);
319+
if (argumentMatchingChoices.count(argumentMatch.first) == 0)
320+
recordMatchCallArgumentResult(argumentMatch.first, argumentMatch.second);
318321
}
319322

320323
// Remember implied results.
@@ -323,32 +326,40 @@ void ConstraintSystem::applySolution(const Solution &solution) {
323326

324327
// Register the solution's opened types.
325328
for (const auto &opened : solution.OpenedTypes) {
326-
recordOpenedType(opened.first, opened.second);
329+
if (OpenedTypes.count(opened.first) == 0)
330+
recordOpenedType(opened.first, opened.second);
327331
}
328332

329333
// Register the solution's opened existential types.
330334
for (const auto &openedExistential : solution.OpenedExistentialTypes) {
331-
recordOpenedExistentialType(openedExistential.first, openedExistential.second);
335+
if (OpenedExistentialTypes.count(openedExistential.first) == 0) {
336+
recordOpenedExistentialType(openedExistential.first,
337+
openedExistential.second);
338+
}
332339
}
333340

334341
// Register the solution's opened pack expansion types.
335342
for (const auto &expansion : solution.OpenedPackExpansionTypes) {
336-
recordOpenedPackExpansionType(expansion.first, expansion.second);
343+
if (OpenedPackExpansionTypes.count(expansion.first) == 0)
344+
recordOpenedPackExpansionType(expansion.first, expansion.second);
337345
}
338346

339347
// Register the solutions's pack expansion environments.
340348
for (const auto &expansion : solution.PackExpansionEnvironments) {
341-
recordPackExpansionEnvironment(expansion.first, expansion.second);
349+
if (PackExpansionEnvironments.count(expansion.first) == 0)
350+
recordPackExpansionEnvironment(expansion.first, expansion.second);
342351
}
343352

344353
// Register the solutions's pack environments.
345354
for (auto &packEnvironment : solution.PackEnvironments) {
346-
addPackEnvironment(packEnvironment.first, packEnvironment.second);
355+
if (PackEnvironments.count(packEnvironment.first) == 0)
356+
addPackEnvironment(packEnvironment.first, packEnvironment.second);
347357
}
348358

349359
// Register the defaulted type variables.
350-
for (auto *locator : solution.DefaultedConstraints)
360+
for (auto *locator : solution.DefaultedConstraints) {
351361
recordDefaultedConstraint(locator);
362+
}
352363

353364
// Add the node types back.
354365
for (auto &nodeType : solution.nodeTypes) {
@@ -424,12 +435,17 @@ void ConstraintSystem::applySolution(const Solution &solution) {
424435
}
425436

426437
// Register any fixes produced along this path.
427-
for (auto *fix : solution.Fixes)
428-
addFix(fix);
438+
for (auto *fix : solution.Fixes) {
439+
if (Fixes.count(fix) == 0)
440+
addFix(fix);
441+
}
429442

430443
// Register fixed requirements.
431-
for (auto fix : solution.FixedRequirements)
432-
recordFixedRequirement(std::get<0>(fix), std::get<1>(fix), std::get<2>(fix));
444+
for (auto fix : solution.FixedRequirements) {
445+
recordFixedRequirement(std::get<0>(fix),
446+
std::get<1>(fix),
447+
std::get<2>(fix));
448+
}
433449
}
434450
bool ConstraintSystem::simplify() {
435451
// While we have a constraint in the worklist, process it.

lib/Sema/ConstraintSystem.cpp

Lines changed: 41 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,7 @@ void ConstraintSystem::removeConversionRestriction(
281281

282282
void ConstraintSystem::addFix(ConstraintFix *fix) {
283283
bool inserted = Fixes.insert(fix);
284-
if (!inserted)
285-
return;
284+
ASSERT(inserted);
286285

287286
if (solverState)
288287
recordChange(SolverTrail::Change::addedFix(fix));
@@ -294,32 +293,23 @@ void ConstraintSystem::removeFix(ConstraintFix *fix) {
294293
}
295294

296295
void ConstraintSystem::recordDisjunctionChoice(
297-
ConstraintLocator *locator,
298-
unsigned index) {
299-
// We shouldn't ever register disjunction choices multiple times.
300-
auto inserted = DisjunctionChoices.insert(
301-
std::make_pair(locator, index));
302-
if (!inserted.second) {
303-
ASSERT(inserted.first->second == index);
304-
return;
305-
}
296+
ConstraintLocator *locator, unsigned index) {
297+
bool inserted = DisjunctionChoices.insert({locator, index}).second;
298+
ASSERT(inserted);
306299

307-
if (solverState) {
308-
recordChange(SolverTrail::Change::recordedDisjunctionChoice(
309-
locator, index));
310-
}
300+
if (solverState)
301+
recordChange(SolverTrail::Change::recordedDisjunctionChoice(locator, index));
311302
}
312303

313304
void ConstraintSystem::recordAppliedDisjunction(
314305
ConstraintLocator *locator, FunctionType *fnType) {
315306
// We shouldn't ever register disjunction choices multiple times.
316-
auto inserted = AppliedDisjunctions.insert(
317-
std::make_pair(locator, fnType));
318-
if (inserted.second) {
319-
if (solverState) {
320-
recordChange(SolverTrail::Change::recordedAppliedDisjunction(locator));
321-
}
322-
}
307+
bool inserted = AppliedDisjunctions.insert(
308+
std::make_pair(locator, fnType)).second;
309+
ASSERT(inserted);
310+
311+
if (solverState)
312+
recordChange(SolverTrail::Change::recordedAppliedDisjunction(locator));
323313
}
324314

325315
/// Retrieve a dynamic result signature for the given declaration.
@@ -853,10 +843,10 @@ std::pair<Type, OpenedArchetypeType *> ConstraintSystem::openExistentialType(
853843
void ConstraintSystem::recordOpenedExistentialType(
854844
ConstraintLocator *locator, OpenedArchetypeType *opened) {
855845
bool inserted = OpenedExistentialTypes.insert({locator, opened}).second;
856-
if (inserted) {
857-
if (solverState)
858-
recordChange(SolverTrail::Change::recordedOpenedExistentialType(locator));
859-
}
846+
ASSERT(inserted);
847+
848+
if (solverState)
849+
recordChange(SolverTrail::Change::recordedOpenedExistentialType(locator));
860850
}
861851

862852
GenericEnvironment *
@@ -894,12 +884,10 @@ ConstraintSystem::getPackElementEnvironment(ConstraintLocator *locator,
894884
void ConstraintSystem::recordPackExpansionEnvironment(
895885
ConstraintLocator *locator, std::pair<UUID, Type> uuidAndShape) {
896886
bool inserted = PackExpansionEnvironments.insert({locator, uuidAndShape}).second;
897-
if (inserted) {
898-
if (solverState) {
899-
recordChange(
900-
SolverTrail::Change::recordedPackExpansionEnvironment(locator));
901-
}
902-
}
887+
ASSERT(inserted);
888+
889+
if (solverState)
890+
recordChange(SolverTrail::Change::recordedPackExpansionEnvironment(locator));
903891
}
904892

905893
PackExpansionExpr *
@@ -910,12 +898,11 @@ ConstraintSystem::getPackEnvironment(PackElementExpr *packElement) const {
910898

911899
void ConstraintSystem::addPackEnvironment(PackElementExpr *packElement,
912900
PackExpansionExpr *packExpansion) {
913-
bool inserted =
914-
PackEnvironments.insert({packElement, packExpansion}).second;
915-
if (inserted) {
916-
if (solverState)
917-
recordChange(SolverTrail::Change::recordedPackEnvironment(packElement));
918-
}
901+
bool inserted = PackEnvironments.insert({packElement, packExpansion}).second;
902+
ASSERT(inserted);
903+
904+
if (solverState)
905+
recordChange(SolverTrail::Change::recordedPackEnvironment(packElement));
919906
}
920907

921908
/// Extend the given depth map by adding depths for all of the subexpressions
@@ -1028,7 +1015,8 @@ Type ConstraintSystem::openUnboundGenericType(GenericTypeDecl *decl,
10281015
openGeneric(decl->getDeclContext(), decl->getGenericSignature(), locator,
10291016
replacements);
10301017

1031-
recordOpenedTypes(locator, replacements);
1018+
// FIXME: Get rid of fixmeAllowDuplicates.
1019+
recordOpenedTypes(locator, replacements, /*fixmeAllowDuplicates=*/true);
10321020

10331021
if (parentTy) {
10341022
const auto parentTyInContext =
@@ -1278,10 +1266,10 @@ Type ConstraintSystem::openPackExpansionType(PackExpansionType *expansion,
12781266
void ConstraintSystem::recordOpenedPackExpansionType(PackExpansionType *expansion,
12791267
TypeVariableType *expansionVar) {
12801268
bool inserted = OpenedPackExpansionTypes.insert({expansion, expansionVar}).second;
1281-
if (inserted) {
1282-
if (solverState)
1283-
recordChange(SolverTrail::Change::recordedOpenedPackExpansionType(expansion));
1284-
}
1269+
ASSERT(inserted);
1270+
1271+
if (solverState)
1272+
recordChange(SolverTrail::Change::recordedOpenedPackExpansionType(expansion));
12851273
}
12861274

12871275
Type ConstraintSystem::openOpaqueType(OpaqueTypeArchetypeType *opaque,
@@ -1687,10 +1675,10 @@ Type ConstraintSystem::getUnopenedTypeOfReference(
16871675
void ConstraintSystem::recordOpenedType(
16881676
ConstraintLocator *locator, ArrayRef<OpenedType> openedTypes) {
16891677
bool inserted = OpenedTypes.insert({locator, openedTypes}).second;
1690-
if (inserted) {
1691-
if (solverState)
1692-
recordChange(SolverTrail::Change::recordedOpenedTypes(locator));
1693-
}
1678+
ASSERT(inserted);
1679+
1680+
if (solverState)
1681+
recordChange(SolverTrail::Change::recordedOpenedTypes(locator));
16941682
}
16951683

16961684
void ConstraintSystem::removeOpenedType(ConstraintLocator *locator) {
@@ -1700,7 +1688,8 @@ void ConstraintSystem::removeOpenedType(ConstraintLocator *locator) {
17001688

17011689
void ConstraintSystem::recordOpenedTypes(
17021690
ConstraintLocatorBuilder locator,
1703-
const OpenedTypeMap &replacements) {
1691+
const OpenedTypeMap &replacements,
1692+
bool fixmeAllowDuplicates) {
17041693
if (replacements.empty())
17051694
return;
17061695

@@ -1721,7 +1710,10 @@ void ConstraintSystem::recordOpenedTypes(
17211710
OpenedType* openedTypes
17221711
= Allocator.Allocate<OpenedType>(replacements.size());
17231712
std::copy(replacements.begin(), replacements.end(), openedTypes);
1724-
recordOpenedType(
1713+
1714+
// FIXME: Get rid of fixmeAllowDuplicates.
1715+
if (!fixmeAllowDuplicates || OpenedTypes.count(locatorPtr) == 0)
1716+
recordOpenedType(
17251717
locatorPtr, llvm::ArrayRef(openedTypes, replacements.size()));
17261718
}
17271719

0 commit comments

Comments
 (0)