Skip to content

Commit f1bdb2b

Browse files
authored
Optimistic backjumping with cutoff (#188)
1 parent 207da9a commit f1bdb2b

File tree

7 files changed

+276
-12
lines changed

7 files changed

+276
-12
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ exclude = [
9898
]
9999

100100
[tool.ruff.lint.mccabe]
101-
max-complexity = 12
101+
max-complexity = 20
102102

103103
[tool.mypy]
104104
warn_unused_configs = true

src/resolvelib/resolvers/resolution.py

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import collections
44
import itertools
55
import operator
6-
from typing import TYPE_CHECKING, Collection, Generic, Iterable, Mapping
6+
from typing import TYPE_CHECKING, Generic
77

88
from ..structs import (
99
CT,
@@ -27,9 +27,13 @@
2727
)
2828

2929
if TYPE_CHECKING:
30+
from collections.abc import Collection, Iterable, Mapping
31+
3032
from ..providers import AbstractProvider, Preference
3133
from ..reporters import BaseReporter
3234

35+
_OPTIMISTIC_BACKJUMPING_RATIO: float = 0.1
36+
3337

3438
def _build_result(state: State[RT, CT, KT]) -> Result[RT, CT, KT]:
3539
mapping = state.mapping
@@ -77,6 +81,11 @@ def __init__(
7781
self._r = reporter
7882
self._states: list[State[RT, CT, KT]] = []
7983

84+
# Optimistic backjumping variables
85+
self._optimistic_backjumping_ratio = _OPTIMISTIC_BACKJUMPING_RATIO
86+
self._save_states: list[State[RT, CT, KT]] | None = None
87+
self._optimistic_start_round: int | None = None
88+
8089
@property
8190
def state(self) -> State[RT, CT, KT]:
8291
try:
@@ -274,6 +283,25 @@ def _patch_criteria(
274283
)
275284
return True
276285

286+
def _save_state(self) -> None:
287+
"""Save states for potential rollback if optimistic backjumping fails."""
288+
if self._save_states is None:
289+
self._save_states = [
290+
State(
291+
mapping=s.mapping.copy(),
292+
criteria=s.criteria.copy(),
293+
backtrack_causes=s.backtrack_causes[:],
294+
)
295+
for s in self._states
296+
]
297+
298+
def _rollback_states(self) -> None:
299+
"""Rollback states and disable optimistic backjumping."""
300+
self._optimistic_backjumping_ratio = 0.0
301+
if self._save_states:
302+
self._states = self._save_states
303+
self._save_states = None
304+
277305
def _backjump(self, causes: list[RequirementInformation[RT, CT]]) -> bool:
278306
"""Perform backjumping.
279307
@@ -324,13 +352,26 @@ def _backjump(self, causes: list[RequirementInformation[RT, CT]]) -> bool:
324352
except (IndexError, KeyError):
325353
raise ResolutionImpossible(causes) from None
326354

327-
# Only backjump if the current broken state is
328-
# an incompatible dependency
329-
if name not in incompatible_deps:
355+
if (
356+
not self._optimistic_backjumping_ratio
357+
and name not in incompatible_deps
358+
):
359+
# For safe backjumping only backjump if the current dependency
360+
# is not the same as the incompatible dependency
330361
break
331362

363+
# On the first time a non-safe backjump is done the state
364+
# is saved so we can restore it later if the resolution fails
365+
if (
366+
self._optimistic_backjumping_ratio
367+
and self._save_states is None
368+
and name not in incompatible_deps
369+
):
370+
self._save_state()
371+
332372
# If the current dependencies and the incompatible dependencies
333-
# are overlapping then we have found a cause of the incompatibility
373+
# are overlapping then we have likely found a cause of the
374+
# incompatibility
334375
current_dependencies = {
335376
self._p.identify(d) for d in self._p.get_dependencies(candidate)
336377
}
@@ -394,9 +435,32 @@ def resolve(self, requirements: Iterable[RT], max_rounds: int) -> State[RT, CT,
394435
# pinning the virtual "root" package in the graph.
395436
self._push_new_state()
396437

438+
# Variables for optimistic backjumping
439+
optimistic_rounds_cutoff: int | None = None
440+
optimistic_backjumping_start_round: int | None = None
441+
397442
for round_index in range(max_rounds):
398443
self._r.starting_round(index=round_index)
399444

445+
# Handle if optimistic backjumping has been running for too long
446+
if self._optimistic_backjumping_ratio and self._save_states is not None:
447+
if optimistic_backjumping_start_round is None:
448+
optimistic_backjumping_start_round = round_index
449+
optimistic_rounds_cutoff = int(
450+
(max_rounds - round_index) * self._optimistic_backjumping_ratio
451+
)
452+
453+
if optimistic_rounds_cutoff <= 0:
454+
self._rollback_states()
455+
continue
456+
elif optimistic_rounds_cutoff is not None:
457+
if (
458+
round_index - optimistic_backjumping_start_round
459+
>= optimistic_rounds_cutoff
460+
):
461+
self._rollback_states()
462+
continue
463+
400464
unsatisfied_names = [
401465
key
402466
for key, criterion in self.state.criteria.items()
@@ -448,12 +512,29 @@ def resolve(self, requirements: Iterable[RT], max_rounds: int) -> State[RT, CT,
448512
# Backjump if pinning fails. The backjump process puts us in
449513
# an unpinned state, so we can work on it in the next round.
450514
self._r.resolving_conflicts(causes=causes)
451-
success = self._backjump(causes)
452-
self.state.backtrack_causes[:] = causes
453515

454-
# Dead ends everywhere. Give up.
455-
if not success:
456-
raise ResolutionImpossible(self.state.backtrack_causes)
516+
try:
517+
success = self._backjump(causes)
518+
except ResolutionImpossible:
519+
if self._optimistic_backjumping_ratio and self._save_states:
520+
failed_optimistic_backjumping = True
521+
else:
522+
raise
523+
else:
524+
failed_optimistic_backjumping = bool(
525+
not success
526+
and self._optimistic_backjumping_ratio
527+
and self._save_states
528+
)
529+
530+
if failed_optimistic_backjumping and self._save_states:
531+
self._rollback_states()
532+
else:
533+
self.state.backtrack_causes[:] = causes
534+
535+
# Dead ends everywhere. Give up.
536+
if not success:
537+
raise ResolutionImpossible(self.state.backtrack_causes)
457538
else:
458539
# discard as information sources any invalidated names
459540
# (unsatisfied names that were previously satisfied)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"index": "backjump-test-3",
3+
"requested": [
4+
"a==1",
5+
"b",
6+
"c"
7+
],
8+
"resolved": {
9+
"a": "1",
10+
"b": "1",
11+
"c": "2",
12+
"d": "1"
13+
},
14+
"unvisited": {
15+
"c": ["1"]
16+
},
17+
"needs_optimistic": true
18+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
{
2+
"index": "backjump-test-4",
3+
"requested": [
4+
"a==1",
5+
"b",
6+
"c",
7+
"e",
8+
"f"
9+
],
10+
"resolved": {
11+
"a": "1",
12+
"b": "1",
13+
"c": "2",
14+
"d": "1",
15+
"e": "2",
16+
"f": "2"
17+
},
18+
"unvisited": {
19+
"c": ["1"],
20+
"e": ["1"],
21+
"f": ["1"]
22+
},
23+
"needs_optimistic": true
24+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
{
2+
"a": {
3+
"1": {
4+
"dependencies": []
5+
},
6+
"2": {
7+
"dependencies": []
8+
}
9+
},
10+
"b": {
11+
"2": {
12+
"dependencies": [
13+
"d==2"
14+
]
15+
},
16+
"1": {
17+
"dependencies": [
18+
"d==1"
19+
]
20+
}
21+
},
22+
"c": {
23+
"1": {
24+
"dependencies": []
25+
},
26+
"2": {
27+
"dependencies": []
28+
}
29+
},
30+
"d": {
31+
"2": {
32+
"dependencies": [
33+
"a==2"
34+
]
35+
},
36+
"1": {
37+
"dependencies": [
38+
"a==1"
39+
]
40+
}
41+
},
42+
"needs_optimistic": true
43+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
{
2+
"a": {
3+
"1": {
4+
"dependencies": []
5+
},
6+
"2": {
7+
"dependencies": []
8+
}
9+
},
10+
"b": {
11+
"2": {
12+
"dependencies": [
13+
"d==2"
14+
]
15+
},
16+
"1": {
17+
"dependencies": [
18+
"d==1"
19+
]
20+
}
21+
},
22+
"c": {
23+
"1": {
24+
"dependencies": []
25+
},
26+
"2": {
27+
"dependencies": []
28+
}
29+
},
30+
"d": {
31+
"2": {
32+
"dependencies": [
33+
"a==2"
34+
]
35+
},
36+
"1": {
37+
"dependencies": [
38+
"a==1"
39+
]
40+
}
41+
},
42+
"e": {
43+
"1": {
44+
"dependencies": []
45+
},
46+
"2": {
47+
"dependencies": []
48+
}
49+
},
50+
"f": {
51+
"1": {
52+
"dependencies": []
53+
},
54+
"2": {
55+
"dependencies": []
56+
}
57+
}
58+
}

tests/functional/python/test_resolvers_python.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
import packaging.markers
88
import packaging.requirements
9-
import packaging.specifiers
109
import packaging.utils
1110
import packaging.version
1211
import pytest
1312

13+
import resolvelib.resolvers.resolution
1414
from resolvelib import AbstractProvider, ResolutionImpossible, Resolver
1515

1616
Candidate = collections.namedtuple("Candidate", "name version extras")
@@ -65,6 +65,11 @@ def __init__(self, filename):
6565
else:
6666
self.expected_unvisited = None
6767

68+
if "needs_optimistic" in case_data:
69+
self.needs_optimistic = case_data["needs_optimistic"]
70+
else:
71+
self.needs_optimistic = False
72+
6873
def identify(self, requirement_or_candidate):
6974
name = packaging.utils.canonicalize_name(requirement_or_candidate.name)
7075
if requirement_or_candidate.extras:
@@ -219,3 +224,38 @@ def test_resolver(provider, reporter):
219224
assert not unexpected_versions, (
220225
f"Unexpcted versions visited {name}: {', '.join(unexpected_versions)}"
221226
)
227+
228+
229+
def test_no_optimistic_backtracking_resolver(provider, reporter, monkeypatch):
230+
"""
231+
Tests the resolver works with optimistic backtracking disabled for all
232+
cases, except for skipping candidates that is known to require optimistic
233+
backtracking.
234+
"""
235+
monkeypatch.setattr(
236+
resolvelib.resolvers.resolution, "_OPTIMISTIC_BACKJUMPING_RATIO", 0.0
237+
)
238+
resolver = Resolver(provider, reporter)
239+
240+
if provider.expected_confliction:
241+
with pytest.raises(ResolutionImpossible) as ctx:
242+
result = resolver.resolve(provider.root_requirements)
243+
print(_format_resolution(result)) # Provide some debugging hints.
244+
assert _format_confliction(ctx.value) == provider.expected_confliction
245+
else:
246+
resolution = resolver.resolve(provider.root_requirements)
247+
assert _format_resolution(resolution) == provider.expected_resolution
248+
249+
if provider.expected_unvisited and not provider.needs_optimistic:
250+
visited_versions = defaultdict(set)
251+
for visited_candidate in reporter.visited:
252+
visited_versions[visited_candidate.name].add(str(visited_candidate.version))
253+
254+
for name, versions in provider.expected_unvisited.items():
255+
if name not in visited_versions:
256+
continue
257+
258+
unexpected_versions = set(versions).intersection(visited_versions[name])
259+
assert not unexpected_versions, (
260+
f"Unexpcted versions visited {name}: {', '.join(unexpected_versions)}"
261+
)

0 commit comments

Comments
 (0)