Skip to content

Commit 84dafca

Browse files
committed
repro: parallel execution of stages
1 parent f342cc7 commit 84dafca

File tree

5 files changed

+188
-70
lines changed

5 files changed

+188
-70
lines changed

dvc/dvcfile.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import contextlib
22
import os
3+
import threading
4+
from collections import defaultdict
35
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, TypeVar, Union
46

57
from dvc.exceptions import DvcException
@@ -77,6 +79,10 @@ def check_dvcfile_path(repo, path):
7779
raise FileIsGitIgnored(relpath(path), True)
7880

7981

82+
_file_locks: dict[str, threading.Lock] = defaultdict(threading.Lock)
83+
_file_locks_lock = threading.Lock()
84+
85+
8086
class FileMixin:
8187
SCHEMA: Callable[[_T], _T]
8288

@@ -85,6 +91,10 @@ def __init__(self, repo, path, verify=True, **kwargs):
8591
self.path = path
8692
self.verify = verify
8793

94+
def _thread_lock(self) -> threading.Lock:
95+
with _file_locks_lock:
96+
return _file_locks[self.path]
97+
8898
def __repr__(self):
8999
return f"{self.__class__.__name__}: {relpath(self.path, self.repo.root_dir)}"
90100

@@ -148,15 +158,19 @@ def validate(cls, d: _T, fname: Optional[str] = None) -> _T:
148158
def _load_yaml(self, **kwargs: Any) -> tuple[Any, str]:
149159
from dvc.utils import strictyaml
150160

151-
return strictyaml.load(
152-
self.path,
153-
self.SCHEMA, # type: ignore[arg-type]
154-
self.repo.fs,
155-
**kwargs,
156-
)
161+
with self._thread_lock():
162+
return strictyaml.load(
163+
self.path,
164+
self.SCHEMA, # type: ignore[arg-type]
165+
self.repo.fs,
166+
**kwargs,
167+
)
157168

158169
def remove(self, force=False): # noqa: ARG002
159-
with contextlib.suppress(FileNotFoundError):
170+
with (
171+
self._thread_lock(),
172+
contextlib.suppress(FileNotFoundError),
173+
):
160174
os.unlink(self.path)
161175

162176
def dump(self, stage, **kwargs):
@@ -407,7 +421,10 @@ def _load(self, **kwargs: Any):
407421
return {}, ""
408422

409423
def dump_dataset(self, dataset: dict):
410-
with modify_yaml(self.path, fs=self.repo.fs) as data:
424+
with (
425+
self._thread_lock(),
426+
modify_yaml(self.path, fs=self.repo.fs) as data,
427+
):
411428
data.update({"schema": "2.0"})
412429
if not data:
413430
logger.info("Generating lock file '%s'", self.relpath)
@@ -430,7 +447,10 @@ def dump_stages(self, stages, **kwargs):
430447

431448
is_modified = False
432449
log_updated = False
433-
with modify_yaml(self.path, fs=self.repo.fs) as data:
450+
with (
451+
self._thread_lock(),
452+
modify_yaml(self.path, fs=self.repo.fs) as data,
453+
):
434454
if not data:
435455
data.update({"schema": "2.0"})
436456
# order is important, meta should always be at the top
@@ -468,7 +488,8 @@ def remove_stage(self, stage):
468488
del data[stage.name]
469489

470490
if data:
471-
dump_yaml(self.path, d)
491+
with self._thread_lock():
492+
dump_yaml(self.path, d)
472493
else:
473494
self.remove()
474495

dvc/repo/reproduce.py

Lines changed: 113 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import concurrent.futures
12
from collections.abc import Iterable
3+
from dataclasses import dataclass
4+
from enum import Enum
25
from typing import TYPE_CHECKING, Callable, NoReturn, Optional, TypeVar, Union, cast
36

47
from funcy import ldistinct
@@ -120,14 +123,6 @@ def _reproduce_stage(stage: "Stage", **kwargs) -> Optional["Stage"]:
120123
return ret
121124

122125

123-
def _get_upstream_downstream_nodes(
124-
graph: Optional["DiGraph"], node: T
125-
) -> tuple[list[T], list[T]]:
126-
succ = list(graph.successors(node)) if graph else []
127-
pre = list(graph.predecessors(node)) if graph else []
128-
return succ, pre
129-
130-
131126
def _repr(stages: Iterable["Stage"]) -> str:
132127
return humanize.join(repr(stage.addressing) for stage in stages)
133128

@@ -155,56 +150,129 @@ def _raise_error(exc: Optional[Exception], *stages: "Stage") -> NoReturn:
155150
raise ReproductionError(f"failed to reproduce{segment} {names}") from exc
156151

157152

158-
def _reproduce(
159-
stages: list["Stage"],
160-
graph: Optional["DiGraph"] = None,
161-
force_downstream: bool = False,
162-
on_error: str = "fail",
163-
force: bool = False,
153+
class ReproStatus(Enum):
154+
READY = "ready"
155+
IN_PROGRESS = "in-progress"
156+
COMPLETE = "complete"
157+
SKIPPED = "skipped"
158+
FAILED = "failed"
159+
160+
161+
@dataclass
162+
class StageInfo:
163+
upstream: list["Stage"]
164+
upstream_unfinished: set["Stage"]
165+
downstream: list["Stage"]
166+
force: bool
167+
status: ReproStatus
168+
result: Optional["Stage"]
169+
170+
171+
def _start_ready_stages(
172+
to_repro: dict["Stage", StageInfo],
173+
executor: concurrent.futures.ThreadPoolExecutor,
164174
repro_fn: Callable = _reproduce_stage,
165175
**kwargs,
176+
) -> dict[concurrent.futures.Future["Stage"], "Stage"]:
177+
futures = {
178+
executor.submit(
179+
repro_fn,
180+
stage,
181+
upstream=stage_info.upstream,
182+
force=stage_info.force,
183+
**kwargs,
184+
): stage
185+
for stage, stage_info in to_repro.items()
186+
if stage_info.status == ReproStatus.READY and not stage_info.upstream_unfinished
187+
}
188+
for stage in futures.values():
189+
to_repro[stage].status = ReproStatus.IN_PROGRESS
190+
return futures
191+
192+
193+
def _result_or_raise(
194+
to_repro: dict["Stage", StageInfo], stages: list["Stage"], on_error: str
166195
) -> list["Stage"]:
167-
assert on_error in ("fail", "keep-going", "ignore")
168-
169196
result: list[Stage] = []
170197
failed: list[Stage] = []
171-
to_skip: dict[Stage, Stage] = {}
172-
ret: Optional[Stage] = None
173-
174-
force_state = dict.fromkeys(stages, force)
175-
198+
# Preserve original order
176199
for stage in stages:
177-
if stage in to_skip:
178-
continue
179-
180-
if ret:
181-
logger.info("") # add a newline
182-
183-
upstream, downstream = _get_upstream_downstream_nodes(graph, stage)
184-
force_stage = force_state[stage]
185-
186-
try:
187-
ret = repro_fn(stage, upstream=upstream, force=force_stage, **kwargs)
188-
except Exception as exc: # noqa: BLE001
200+
stage_info = to_repro[stage]
201+
if stage_info.status == ReproStatus.FAILED:
189202
failed.append(stage)
190-
if on_error == "fail":
191-
_raise_error(exc, stage)
192-
193-
dependents = handle_error(graph, on_error, exc, stage)
194-
to_skip.update(dict.fromkeys(dependents, stage))
195-
continue
196-
197-
if force_downstream and (ret or force_stage):
198-
force_state.update(dict.fromkeys(downstream, True))
199-
200-
if ret:
201-
result.append(ret)
203+
elif stage_info.result:
204+
result.append(stage_info.result)
202205

203206
if on_error != "ignore" and failed:
204207
_raise_error(None, *failed)
208+
205209
return result
206210

207211

212+
def _reproduce(
213+
stages: list["Stage"],
214+
graph: Optional["DiGraph"] = None,
215+
force_downstream: bool = False,
216+
on_error: str = "fail",
217+
force: bool = False,
218+
max_workers: int = 10,
219+
**kwargs,
220+
) -> list["Stage"]:
221+
assert on_error in ("fail", "keep-going", "ignore")
222+
223+
ret: Optional[Stage] = None
224+
to_repro = {
225+
stage: StageInfo(
226+
upstream=(upstream := list(graph.successors(stage)) if graph else []),
227+
upstream_unfinished=set(upstream).intersection(stages),
228+
downstream=list(graph.predecessors(stage)) if graph else [],
229+
force=force,
230+
status=ReproStatus.READY,
231+
result=None,
232+
)
233+
for stage in stages
234+
}
235+
236+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
237+
futures = _start_ready_stages(to_repro, executor, **kwargs)
238+
while futures:
239+
done, _ = concurrent.futures.wait(
240+
futures, return_when=concurrent.futures.FIRST_COMPLETED
241+
)
242+
for future in done:
243+
stage = futures.pop(future)
244+
stage_info = to_repro[stage]
245+
246+
try:
247+
ret = future.result()
248+
stage_info.status = ReproStatus.COMPLETE
249+
except Exception as exc: # noqa: BLE001
250+
if on_error == "fail":
251+
return _raise_error(exc, stage)
252+
253+
stage_info.status = ReproStatus.FAILED
254+
dependents = handle_error(graph, on_error, exc, stage)
255+
for dependent in dependents:
256+
to_repro[dependent].status = ReproStatus.SKIPPED
257+
258+
success = stage_info.status == ReproStatus.COMPLETE
259+
for dependent in stage_info.downstream:
260+
if dependent not in to_repro:
261+
continue
262+
dependent_info = to_repro[dependent]
263+
if stage in dependent_info.upstream_unfinished:
264+
dependent_info.upstream_unfinished.remove(stage)
265+
if success and force_downstream and (ret or stage_info.force):
266+
dependent_info.force = True
267+
268+
if success and ret:
269+
stage_info.result = ret
270+
271+
futures.update(_start_ready_stages(to_repro, executor, **kwargs))
272+
273+
return _result_or_raise(to_repro, stages, on_error)
274+
275+
208276
@locked
209277
@scm_context
210278
def reproduce(

dvc/rwlock.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
import threading
34
from collections import defaultdict
45
from contextlib import contextmanager
56

@@ -25,6 +26,8 @@
2526
}
2627
)
2728

29+
RWLOCK_THREAD_LOCK = threading.Lock()
30+
RWLOCK_THREAD_TIMEOUT = 3
2831
RWLOCK_FILE = "rwlock"
2932
RWLOCK_LOCK = "rwlock.lock"
3033

@@ -50,21 +53,26 @@ def _edit_rwlock(lock_dir, fs, hardlink):
5053
tmp_dir=lock_dir,
5154
hardlink_lock=hardlink,
5255
)
53-
with rwlock_guard:
54-
try:
55-
with fs.open(path, encoding="utf-8") as fobj:
56-
lock = SCHEMA(json.load(fobj))
57-
except FileNotFoundError:
58-
lock = SCHEMA({})
59-
except json.JSONDecodeError as exc:
60-
raise RWLockFileCorruptedError(path) from exc
61-
except Invalid as exc:
62-
raise RWLockFileFormatError(path) from exc
63-
lock["read"] = defaultdict(list, lock["read"])
64-
lock["write"] = defaultdict(dict, lock["write"])
65-
yield lock
66-
with fs.open(path, "w", encoding="utf-8") as fobj:
67-
json.dump(lock, fobj)
56+
RWLOCK_THREAD_LOCK.acquire(timeout=RWLOCK_THREAD_TIMEOUT)
57+
try:
58+
with rwlock_guard:
59+
try:
60+
with fs.open(path, encoding="utf-8") as fobj:
61+
lock = SCHEMA(json.load(fobj))
62+
except FileNotFoundError:
63+
lock = SCHEMA({})
64+
except json.JSONDecodeError as exc:
65+
raise RWLockFileCorruptedError(path) from exc
66+
except Invalid as exc:
67+
raise RWLockFileFormatError(path) from exc
68+
lock["read"] = defaultdict(list, lock["read"])
69+
lock["write"] = defaultdict(dict, lock["write"])
70+
yield lock
71+
with fs.open(path, "w", encoding="utf-8") as fobj:
72+
json.dump(lock, fobj)
73+
finally:
74+
if RWLOCK_THREAD_LOCK.locked():
75+
RWLOCK_THREAD_LOCK.release()
6876

6977

7078
def _infos_to_str(infos):

dvc/stage/decorators.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import threading
12
from functools import wraps
23

34
from funcy import decorator
@@ -47,6 +48,8 @@ def _chain(names):
4748
def unlocked_repo(f):
4849
@wraps(f)
4950
def wrapper(stage, *args, **kwargs):
51+
if threading.current_thread() is not threading.main_thread():
52+
return f(stage, *args, **kwargs)
5053
stage.repo.lock.unlock()
5154
stage.repo._reset()
5255
try:

tests/unit/test_rwlock.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import concurrent.futures
12
import json
23
import os
4+
import time
35

46
import pytest
57

@@ -59,6 +61,7 @@ def test_rwlock_reentrant(tmp_path):
5961
def test_rwlock_edit_is_guarded(tmp_path, mocker):
6062
# patching to speedup tests
6163
mocker.patch("dvc.lock.DEFAULT_TIMEOUT", 0.01)
64+
mocker.patch("dvc.rwlock.RWLOCK_THREAD_TIMEOUT", 0.01)
6265

6366
path = os.fspath(tmp_path)
6467

@@ -68,6 +71,21 @@ def test_rwlock_edit_is_guarded(tmp_path, mocker):
6871
pass
6972

7073

74+
def test_rwlock_multiple_threads(tmp_path, mocker):
75+
# patching to speedup tests
76+
mocker.patch("dvc.rwlock.RWLOCK_THREAD_TIMEOUT", 0.01)
77+
path = os.fspath(tmp_path)
78+
foo = "foo"
79+
80+
def work():
81+
with rwlock(path, localfs, "cmd1", [foo], [], False):
82+
time.sleep(1)
83+
84+
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
85+
futures = [executor.submit(work) for _ in range(2)]
86+
concurrent.futures.wait(futures)
87+
88+
7189
def test_rwlock_subdirs(tmp_path):
7290
path = os.fspath(tmp_path)
7391
foo = "foo"

0 commit comments

Comments
 (0)