Skip to content

Commit 1e2d4fb

Browse files
committed
Sema: Record score increases in the trail
1 parent 877c60e commit 1e2d4fb

File tree

8 files changed

+86
-22
lines changed

8 files changed

+86
-22
lines changed

include/swift/Sema/CSTrail.def

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,10 @@ CHANGE(RecordedCaseLabelItemInfo)
7878
CHANGE(RecordedPotentialThrowSite)
7979
CHANGE(RecordedIsolatedParam)
8080
CHANGE(RecordedKeyPath)
81+
CHANGE(IncreasedScore)
82+
CHANGE(DecreasedScore)
8183

82-
LAST_CHANGE(RecordedKeyPath)
84+
LAST_CHANGE(DecreasedScore)
8385

8486
#undef LOCATOR_CHANGE
8587
#undef EXPR_CHANGE

include/swift/Sema/CSTrail.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,12 @@ class SolverTrail {
229229
/// Create a change that recorded a key path expression.
230230
static Change RecordedKeyPath(KeyPathExpr *expr);
231231

232+
/// Create a change that increased the score.
233+
static Change IncreasedScore(ScoreKind kind, unsigned value);
234+
235+
/// Create a change that decreased the score.
236+
static Change DecreasedScore(ScoreKind kind, unsigned value);
237+
232238
/// Undo this change, reverting the constraint graph to the state it
233239
/// had prior to this change.
234240
///

include/swift/Sema/ConstraintSystem.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2857,9 +2857,6 @@ class ConstraintSystem {
28572857
/// FIXME: Remove this.
28582858
unsigned numFixes;
28592859

2860-
/// The previous score.
2861-
Score PreviousScore;
2862-
28632860
/// The scope number of this scope. Set when the scope is registered.
28642861
unsigned scopeNumber = 0;
28652862

@@ -5534,10 +5531,13 @@ class ConstraintSystem {
55345531

55355532
public:
55365533
/// Increase the score of the given kind for the current (partial) solution
5537-
/// along the.
5534+
/// along the current solver path.
55385535
void increaseScore(ScoreKind kind, ConstraintLocatorBuilder Locator,
55395536
unsigned value = 1);
55405537

5538+
/// Primitive form of the above. Records a change in the trail.
5539+
void increaseScore(ScoreKind kind, unsigned value);
5540+
55415541
/// Determine whether this solution is guaranteed to be worse than the best
55425542
/// solution found so far.
55435543
bool worseThanBestSolution() const;

lib/Sema/CSRanking.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,14 @@ static bool shouldIgnoreScoreIncreaseForCodeCompletion(
110110
return false;
111111
}
112112

113+
void ConstraintSystem::increaseScore(ScoreKind kind, unsigned value) {
114+
unsigned index = static_cast<unsigned>(kind);
115+
CurrentScore.Data[index] += value;
116+
117+
if (solverState && value > 0)
118+
recordChange(SolverTrail::Change::IncreasedScore(kind, value));
119+
}
120+
113121
void ConstraintSystem::increaseScore(ScoreKind kind,
114122
ConstraintLocatorBuilder Locator,
115123
unsigned value) {
@@ -135,8 +143,7 @@ void ConstraintSystem::increaseScore(ScoreKind kind,
135143
llvm::errs() << ")\n";
136144
}
137145

138-
unsigned index = static_cast<unsigned>(kind);
139-
CurrentScore.Data[index] += value;
146+
increaseScore(kind, value);
140147
}
141148

142149
bool ConstraintSystem::worseThanBestSolution() const {

lib/Sema/CSSolver.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,15 @@ Solution ConstraintSystem::finalize() {
274274

275275
void ConstraintSystem::replaySolution(const Solution &solution,
276276
bool shouldIncreaseScore) {
277-
if (shouldIncreaseScore)
278-
CurrentScore += solution.getFixedScore();
277+
if (shouldIncreaseScore) {
278+
// Update the score. We do this instead of operator+= because we
279+
// want to record the increments in the trail.
280+
auto solutionScore = solution.getFixedScore();
281+
for (unsigned i = 0; i < NumScoreKinds; ++i) {
282+
if (unsigned value = solutionScore.Data[i])
283+
increaseScore(ScoreKind(i), value);
284+
}
285+
}
279286

280287
// Assign fixed types to the type variables solved by this solution.
281288
for (auto binding : solution.typeBindings) {
@@ -711,8 +718,6 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs)
711718
numTypeVariables = cs.TypeVariables.size();
712719
numFixes = cs.Fixes.size();
713720

714-
PreviousScore = cs.CurrentScore;
715-
716721
cs.solverState->registerScope(this);
717722
assert(!cs.failedConstraint && "Unexpected failed constraint!");
718723
}
@@ -742,9 +747,6 @@ ConstraintSystem::SolverScope::~SolverScope() {
742747
// constraints introduced by the current scope.
743748
cs.solverState->rollback(this);
744749

745-
// Reset the previous score.
746-
cs.CurrentScore = PreviousScore;
747-
748750
// Clear out other "failed" state.
749751
cs.failedConstraint = nullptr;
750752
}

lib/Sema/CSStep.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,12 @@ bool ConjunctionStep::attempt(const ConjunctionElement &element) {
885885

886886
// Make sure that element is solved in isolation
887887
// by dropping all scoring information.
888+
for (unsigned i = 0; i < NumScoreKinds; ++i) {
889+
if (unsigned value = CS.CurrentScore.Data[i]) {
890+
CS.recordChange(
891+
SolverTrail::Change::DecreasedScore(ScoreKind(i), value));
892+
}
893+
}
888894
CS.CurrentScore = Score();
889895

890896
// Reset the scope counter to avoid "too complex" failures

lib/Sema/CSStep.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -906,8 +906,6 @@ class ConjunctionStep : public BindingStep<ConjunctionElementProducer> {
906906

907907
/// Best solution solver reached so far.
908908
std::optional<Score> BestScore;
909-
/// The score established before conjunction is attempted.
910-
Score CurrentScore;
911909

912910
/// The number of constraint solver scopes already explored
913911
/// before accepting this conjunction.
@@ -949,7 +947,7 @@ class ConjunctionStep : public BindingStep<ConjunctionElementProducer> {
949947
SmallVectorImpl<Solution> &solutions)
950948
: BindingStep(cs, {cs, conjunction},
951949
conjunction->isIsolated() ? IsolatedSolutions : solutions),
952-
BestScore(getBestScore()), CurrentScore(getCurrentScore()),
950+
BestScore(getBestScore()),
953951
OuterScopeCount(cs.CountScopes, 0), Conjunction(conjunction),
954952
AfterConjunction(erase(conjunction)), OuterSolutions(solutions) {
955953
assert(conjunction->getKind() == ConstraintKind::Conjunction);
@@ -975,11 +973,8 @@ class ConjunctionStep : public BindingStep<ConjunctionElementProducer> {
975973

976974
// Restore best score only if conjunction fails because
977975
// successful outcome should keep a score set by `restoreOuterState`.
978-
if (HadFailure) {
979-
auto solutionScore = Score();
976+
if (HadFailure)
980977
restoreBestScore();
981-
restoreCurrentScore(solutionScore);
982-
}
983978

984979
if (OuterTimeRemaining) {
985980
auto anchor = OuterTimeRemaining->first;
@@ -1033,7 +1028,6 @@ class ConjunctionStep : public BindingStep<ConjunctionElementProducer> {
10331028
private:
10341029
/// Restore best and current scores as they were before conjunction.
10351030
void restoreCurrentScore(const Score &solutionScore) const {
1036-
CS.CurrentScore = CurrentScore;
10371031
CS.increaseScore(SK_Fix, Conjunction->getLocator(),
10381032
solutionScore.Data[SK_Fix]);
10391033
CS.increaseScore(SK_Hole, Conjunction->getLocator(),

lib/Sema/CSTrail.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,26 @@ SolverTrail::Change::RecordedKeyPath(KeyPathExpr *expr) {
325325
return result;
326326
}
327327

328+
SolverTrail::Change
329+
SolverTrail::Change::IncreasedScore(ScoreKind kind, unsigned value) {
330+
ASSERT(value <= 0xffffff && "value must fit in 24 bits");
331+
332+
Change result;
333+
result.Kind = ChangeKind::IncreasedScore;
334+
result.Options = unsigned(kind) | (value << 8);
335+
return result;
336+
}
337+
338+
SolverTrail::Change
339+
SolverTrail::Change::DecreasedScore(ScoreKind kind, unsigned value) {
340+
ASSERT(value <= 0xffffff && "value must fit in 24 bits");
341+
342+
Change result;
343+
result.Kind = ChangeKind::DecreasedScore;
344+
result.Options = unsigned(kind) | (value << 8);
345+
return result;
346+
}
347+
328348
SyntacticElementTargetKey
329349
SolverTrail::Change::getSyntacticElementTargetKey() const {
330350
ASSERT(Kind == ChangeKind::RecordedTarget);
@@ -487,6 +507,21 @@ void SolverTrail::Change::undo(ConstraintSystem &cs) const {
487507
case ChangeKind::RecordedKeyPath:
488508
cs.removeKeyPath(KeyPath.Expr);
489509
break;
510+
511+
case ChangeKind::IncreasedScore: {
512+
auto kind = Options & 0xff;
513+
unsigned value = Options >> 8;
514+
ASSERT(cs.CurrentScore.Data[kind] >= value);
515+
cs.CurrentScore.Data[kind] -= value;
516+
break;
517+
}
518+
519+
case ChangeKind::DecreasedScore: {
520+
auto kind = Options & 0xff;
521+
unsigned value = Options >> 8;
522+
cs.CurrentScore.Data[kind] += value;
523+
break;
524+
}
490525
}
491526
}
492527

@@ -706,6 +741,18 @@ void SolverTrail::Change::dump(llvm::raw_ostream &out,
706741
simple_display(out, KeyPath.Expr);
707742
out << ")\n";
708743
break;
744+
745+
case ChangeKind::IncreasedScore:
746+
out << "(IncreasedScore ";
747+
out << Score::getNameFor(ScoreKind(Options & 0xff));
748+
out << " by " << (Options >> 8) << ")\n";
749+
break;
750+
751+
case ChangeKind::DecreasedScore:
752+
out << "(DecreasedScore ";
753+
out << Score::getNameFor(ScoreKind(Options & 0xff));
754+
out << " by " << (Options >> 8) << ")\n";
755+
break;
709756
}
710757
}
711758

0 commit comments

Comments
 (0)