Skip to content

Commit 918a5d7

Browse files
pfackeldeyikrommydlgray
authored
feat: add checkpointing functionality to coffea.processor.Runner (#1420)
* add checkpointing functionality to coffea.processor.Runner * fix annotation * add test * add check to 'save' step * improve the LocalCheckpointer logic a bit * improve test to not rerun preprocessing everytime * use fsspec for load/save * 'LocalCheckpointer' -> 'SimpleCheckpointer' * switch to fsspec path checking * switch to rich print * test with fsspec to be able to to offline tests on remote storages * do not re-open open paths * we don't need the token --------- Co-authored-by: Iason Krommydas <iason.krom@gmail.com> Co-authored-by: Lindsey Gray <lindsey.gray@gmail.com>
1 parent 97dfa70 commit 918a5d7

File tree

5 files changed

+251
-18
lines changed

5 files changed

+251
-18
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ dependencies = [
6262
"cachetools",
6363
"requests",
6464
"aiohttp",
65+
"fsspec",
6566
]
6667
dynamic = ["version"]
6768

src/coffea/processor/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
set_accumulator,
1212
value_accumulator,
1313
)
14+
from .checkpointer import CheckpointerABC, SimpleCheckpointer
1415
from .executor import (
1516
DaskExecutor,
1617
FuturesExecutor,
@@ -29,6 +30,8 @@
2930
"ParslExecutor",
3031
"TaskVineExecutor",
3132
"Runner",
33+
"CheckpointerABC",
34+
"SimpleCheckpointer",
3235
"accumulate",
3336
"Accumulatable",
3437
"AccumulatorABC",
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from __future__ import annotations
2+
3+
from abc import ABCMeta, abstractmethod
4+
from pathlib import Path
5+
from typing import TYPE_CHECKING, Any
6+
7+
import cloudpickle
8+
import fsspec
9+
from rich import print
10+
11+
if TYPE_CHECKING:
12+
from coffea.processor import Accumulatable, ProcessorABC
13+
14+
15+
class CheckpointerABC(metaclass=ABCMeta):
16+
"""ABC for a generalized checkpointer
17+
18+
Checkpointers are used to save chunk outputs to disk, and reload them if the same chunk is processed again.
19+
This is useful for long-running jobs that may be interrupted (resumable processing).
20+
21+
Examples
22+
--------
23+
24+
>>> from datetime import datetime
25+
>>> from coffea import processor
26+
>>> from coffea.processor import SimpleCheckpointer
27+
28+
# create a checkpointer that stores checkpoints in a directory with the current date/time
29+
# (you may want to use a more specific directory in practice)
30+
>>> datestring = datetime.now().strftime("%Y%m%d%H")
31+
>>> checkpointer = SimpleCheckpointer(checkpoint_dir=f"checkpoints/{datestring}", verbose=True)
32+
33+
# pass the checkpointer to a Runner
34+
>>> run = processor.Runner(..., checkpointer=checkpointer)
35+
>>> output = run(...)
36+
37+
After the run, the checkpoints will be stored in the directory ``checkpoints/{datestring}``. On a subsequent run,
38+
if the same chunks are processed (and the same checkpointer, or rather ``checkpoint_dir`` is used),
39+
the results will be loaded from disk instead of being recomputed.
40+
"""
41+
42+
@abstractmethod
43+
def load(
44+
self, metadata: Any, processor_instance: ProcessorABC
45+
) -> Accumulatable | None: ...
46+
47+
@abstractmethod
48+
def save(
49+
self, output: Accumulatable, metadata: Any, processor_instance: ProcessorABC
50+
) -> None: ...
51+
52+
53+
class SimpleCheckpointer(CheckpointerABC):
54+
def __init__(
55+
self,
56+
checkpoint_dir: str,
57+
verbose: bool = False,
58+
overwrite: bool = True,
59+
) -> None:
60+
fs, path = fsspec.url_to_fs(checkpoint_dir)
61+
self.fs = fs
62+
self.checkpoint_dir = path
63+
self.verbose = verbose
64+
self.overwrite = overwrite
65+
66+
def filepath(self, metadata: Any, processor_instance: ProcessorABC) -> str:
67+
del processor_instance # not used here, but could be in subclasses
68+
69+
# build a path from metadata, how to include 'metadata["filename"]'? Is it needed?
70+
path = Path(self.checkpoint_dir)
71+
path /= metadata["dataset"]
72+
path /= metadata["fileuuid"]
73+
path /= metadata["treename"]
74+
path /= f"{metadata['entrystart']}-{metadata['entrystop']}.coffea"
75+
return str(path)
76+
77+
def load(
78+
self, metadata: Any, processor_instance: ProcessorABC
79+
) -> Accumulatable | None:
80+
fs = self.fs
81+
fpath = self.filepath(metadata, processor_instance)
82+
if not fs.exists(fpath):
83+
if self.verbose:
84+
print(
85+
f"Checkpoint file {fpath} does not exist. May be the first run..."
86+
)
87+
return None
88+
# else:
89+
try:
90+
with fs.open(fpath, "rb", compression="lz4") as fin:
91+
output = cloudpickle.load(fin)
92+
return output
93+
94+
except Exception as e:
95+
if self.verbose:
96+
print(f"Could not load checkpoint: {e}.")
97+
return None
98+
99+
def save(
100+
self, output: Accumulatable, metadata: Any, processor_instance: ProcessorABC
101+
) -> None:
102+
fs = self.fs
103+
fpath = self.filepath(metadata, processor_instance)
104+
# ensure directory exists
105+
fs.mkdirs(str(Path(fpath).parent), exist_ok=True)
106+
if fs.exists(fpath) and not self.overwrite:
107+
if self.verbose:
108+
print(f"Checkpoint file {fpath} already exists. Not overwriting...")
109+
return None
110+
# else:
111+
try:
112+
with fs.open(fpath, "wb", compression="lz4") as fout:
113+
output = cloudpickle.dump(output, fout)
114+
except Exception as e:
115+
if self.verbose:
116+
print(
117+
f"Could not save checkpoint: {e}. Continuing without checkpointing..."
118+
)
119+
return None

src/coffea/processor/executor.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from ..nanoevents import NanoEventsFactory, schemas
3030
from ..util import _exception_chain, _hash, deprecate, rich_bar
3131
from .accumulator import Accumulatable, accumulate, set_accumulator
32+
from .checkpointer import CheckpointerABC
3233
from .processor import ProcessorABC
3334

3435
_PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
@@ -1038,6 +1039,8 @@ class Runner:
10381039
determine chunking. Defaults to a in-memory LRU cache that holds 100k entries
10391040
(about 1MB depending on the length of filenames, etc.) If you edit an input file
10401041
(please don't) during a session, the session can be restarted to clear the cache.
1042+
checkpointer : CheckpointerABC, optional
1043+
A CheckpointerABC instance to manage checkpointing of each chunk output
10411044
"""
10421045

10431046
executor: ExecutorBase
@@ -1054,6 +1057,7 @@ class Runner:
10541057
use_skyhook: Optional[bool] = False
10551058
skyhook_options: Optional[dict] = field(default_factory=dict)
10561059
format: str = "root"
1060+
checkpointer: Optional[CheckpointerABC] = None
10571061

10581062
@staticmethod
10591063
def read_coffea_config():
@@ -1399,6 +1403,7 @@ def _work_function(
13991403
processor_instance: ProcessorABC,
14001404
uproot_options: dict,
14011405
iteritems_options: dict,
1406+
checkpointer: CheckpointerABC,
14021407
) -> dict:
14031408
if "timeout" in uproot_options:
14041409
xrootdtimeout = uproot_options["timeout"]
@@ -1407,6 +1412,28 @@ def _work_function(
14071412
if not isinstance(processor_instance, ProcessorABC):
14081413
processor_instance = cloudpickle.loads(lz4f.decompress(processor_instance))
14091414

1415+
metadata = {
1416+
"dataset": item.dataset,
1417+
"filename": item.filename,
1418+
"treename": item.treename,
1419+
"entrystart": item.entrystart,
1420+
"entrystop": item.entrystop,
1421+
"fileuuid": (
1422+
str(uuid.UUID(bytes=item.fileuuid)) if len(item.fileuuid) > 0 else ""
1423+
),
1424+
}
1425+
if item.usermeta is not None:
1426+
metadata.update(item.usermeta)
1427+
1428+
if checkpointer is not None:
1429+
if not isinstance(checkpointer, CheckpointerABC):
1430+
raise TypeError("Expected checkpointer to derive from CheckpointerABC")
1431+
# try to load from checkpoint
1432+
out = checkpointer.load(metadata, processor_instance)
1433+
# if we got something, return it
1434+
if out is not None:
1435+
return out
1436+
14101437
try:
14111438
if format == "root":
14121439
filecontext = uproot.open(
@@ -1421,19 +1448,6 @@ def _work_function(
14211448
f"Failed to open file: {item!r}. The error was: {e!r}."
14221449
) from e
14231450

1424-
metadata = {
1425-
"dataset": item.dataset,
1426-
"filename": item.filename,
1427-
"treename": item.treename,
1428-
"entrystart": item.entrystart,
1429-
"entrystop": item.entrystop,
1430-
"fileuuid": (
1431-
str(uuid.UUID(bytes=item.fileuuid)) if len(item.fileuuid) > 0 else ""
1432-
),
1433-
}
1434-
if item.usermeta is not None:
1435-
metadata.update(item.usermeta)
1436-
14371451
with filecontext as file:
14381452
if schema is None:
14391453
raise ValueError("Schema must be set")
@@ -1479,9 +1493,7 @@ def _work_function(
14791493
"Output of process() should not be None. Make sure your processor's process() function returns an accumulator."
14801494
)
14811495
toc = time.time()
1482-
if use_dataframes:
1483-
return out
1484-
else:
1496+
if not use_dataframes:
14851497
if savemetrics:
14861498
metrics = {}
14871499
if isinstance(file, uproot.ReadOnlyDirectory):
@@ -1490,8 +1502,17 @@ def _work_function(
14901502
metrics["columns"] = set(materialized)
14911503
metrics["entries"] = len(events)
14921504
metrics["processtime"] = toc - tic
1493-
return {"out": out, "metrics": metrics, "processed": {item}}
1494-
return {"out": out, "processed": {item}}
1505+
out = {"out": out, "metrics": metrics, "processed": {item}}
1506+
out = {"out": out, "processed": {item}}
1507+
1508+
if checkpointer is not None:
1509+
if not isinstance(checkpointer, CheckpointerABC):
1510+
raise TypeError(
1511+
"Expected checkpointer to derive from CheckpointerABC"
1512+
)
1513+
# save the output
1514+
checkpointer.save(out, metadata, processor_instance)
1515+
return out
14951516

14961517
def __call__(
14971518
self,
@@ -1661,6 +1682,7 @@ def run(
16611682
processor_instance="heavy",
16621683
uproot_options=uproot_options,
16631684
iteritems_options=iteritems_options,
1685+
checkpointer=self.checkpointer,
16641686
)
16651687
else:
16661688
closure = partial(
@@ -1673,6 +1695,7 @@ def run(
16731695
processor_instance=pi_to_send,
16741696
uproot_options=uproot_options,
16751697
iteritems_options=iteritems_options,
1698+
checkpointer=self.checkpointer,
16761699
)
16771700

16781701
chunks = list(chunks)

tests/test_checkpointing.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import os.path as osp
2+
import random
3+
from pathlib import Path
4+
5+
import awkward as ak
6+
import fsspec
7+
import numpy as np
8+
9+
from coffea import processor
10+
from coffea.nanoevents import schemas
11+
12+
# we want repeatable failures, and know that we never run indefinitely
13+
random.seed(1234)
14+
15+
16+
class UnstableNanoEventsProcessor(processor.ProcessorABC):
17+
@property
18+
def accumulator(self):
19+
return {"cutflow": {}}
20+
21+
def process(self, events):
22+
if random.random() < 0.5:
23+
raise RuntimeError("Random failure for testing checkpointing")
24+
25+
output = self.accumulator
26+
dataset = events.metadata["dataset"]
27+
output["cutflow"]["%s_pt" % dataset] = ak.sum(ak.num(events.Muon, axis=1))
28+
return output
29+
30+
def postprocess(self, accumulator):
31+
return accumulator
32+
33+
34+
def test_checkpointing():
35+
filelist = {
36+
"ZJets": {
37+
"treename": "Events",
38+
"files": [osp.abspath("tests/samples/nano_dy.root")],
39+
},
40+
"Data": {
41+
"treename": "Events",
42+
"files": [osp.abspath("tests/samples/nano_dimuon.root")],
43+
},
44+
}
45+
46+
executor = processor.IterativeExecutor()
47+
48+
checkpoint_dir = str(Path(__file__).parent / "test_checkpointing")
49+
# checkpoint_dir = "root://cmseos.fnal.gov//store/user/ikrommyd/test"
50+
checkpointer = processor.SimpleCheckpointer(checkpoint_dir)
51+
run = processor.Runner(
52+
executor=executor,
53+
schema=schemas.NanoAODSchema,
54+
chunksize=10,
55+
format="root",
56+
checkpointer=checkpointer,
57+
)
58+
# use the chunk generator to not re-run the preprocessing step
59+
chunks = list(run.preprocess(filelist, "Events"))
60+
61+
def chunk_gen():
62+
yield from chunks
63+
64+
# number of WorkItems
65+
n_expected_checkpoints = len(chunks)
66+
ntries = 0
67+
fs, path = fsspec.url_to_fs(checkpoint_dir)
68+
69+
# keep trying until we have as many checkpoints as WorkItems
70+
while len(list(filter(fs.isfile, fs.glob(f"{path}/**")))) != n_expected_checkpoints:
71+
fs.invalidate_cache()
72+
ntries += 1
73+
try:
74+
out = run(chunk_gen(), UnstableNanoEventsProcessor(), "Events")
75+
except Exception:
76+
print(f"Run failed, trying again, try number {ntries}...")
77+
continue
78+
79+
# make sure we have as many checkpoints as WorkItems
80+
fs.invalidate_cache()
81+
assert len(list(filter(fs.isfile, fs.glob(f"{path}/**")))) == n_expected_checkpoints
82+
83+
# make sure we got the right answer
84+
assert out == {"cutflow": {"Data_pt": np.int64(84), "ZJets_pt": np.int64(18)}}
85+
86+
# cleanup
87+
fs.rm(path, recursive=True)

0 commit comments

Comments
 (0)