diff --git a/deepnote_toolkit/kernel_checkpoint/__init__.py b/deepnote_toolkit/kernel_checkpoint/__init__.py new file mode 100644 index 0000000..fccdf9f --- /dev/null +++ b/deepnote_toolkit/kernel_checkpoint/__init__.py @@ -0,0 +1,72 @@ +"""Kernel-state snapshot / restore for cold-start reduction. + +The dominant cold-start cost in Deepnote today is re-running the init notebook +on every fresh container boot. For projects with heavy init (data loading, +client setup), this can be tens of seconds to minutes per cold start. + +This module captures the kernel's globals after init completes and restores +them on the next cold boot, skipping the init re-run entirely. + +# Design surface + +Three layers, each with its own file: + + - `store`: a `SnapshotStore` protocol with a `LocalDiskSnapshotStore` impl + for the POC. The S3 production impl will be a sibling later. + - `checkpoint`: `save_checkpoint` and `try_restore_checkpoint` — the actual + serialise/deserialise logic. Built on `dill` (already a deepnote-toolkit + dep) to handle closures and most user-defined types. + - `key`: `compute_checkpoint_key` — the stable cache-key composition. Any + input change here (init source, environment) must invalidate the cache. + +A `__main__` CLI exists for end-to-end testing without touching production +code paths. + +# Correctness invariant + +A restore is correct iff the restored namespace produces the same downstream +execution behaviour as freshly re-running init in the current environment. + +The checkpoint key includes the init source hash and environment id so that +source/env changes always invalidate the cache. What the key does NOT +capture (limitations the production PR must address): + + - filesystem state in `/work` — if init read a file that has since + changed, restore returns stale data + - external connections (db, http clients) — they are unpicklable and get + skipped on save; restored kernel will need to re-establish them + - cross-Python-version restore — dill bytes from 3.11 may not load on + 3.12; the env id is expected to discriminate Python versions + +# What is NOT in this PR + + - S3 backend (LocalDiskSnapshotStore only) + - Wiring into `runtime/executor.py` init lifecycle + - Webapp signaling + - Per-workspace feature flag + +See the [POC plan](~/.claude/plans/snapshot-restore-poc.md) for the +productionisation path. +""" + +from deepnote_toolkit.kernel_checkpoint.checkpoint import ( + RestoreReport, + SaveReport, + save_checkpoint, + try_restore_checkpoint, +) +from deepnote_toolkit.kernel_checkpoint.key import compute_checkpoint_key +from deepnote_toolkit.kernel_checkpoint.store import ( + LocalDiskSnapshotStore, + SnapshotStore, +) + +__all__ = [ + "LocalDiskSnapshotStore", + "RestoreReport", + "SaveReport", + "SnapshotStore", + "compute_checkpoint_key", + "save_checkpoint", + "try_restore_checkpoint", +] diff --git a/deepnote_toolkit/kernel_checkpoint/__main__.py b/deepnote_toolkit/kernel_checkpoint/__main__.py new file mode 100644 index 0000000..2434840 --- /dev/null +++ b/deepnote_toolkit/kernel_checkpoint/__main__.py @@ -0,0 +1,86 @@ +"""End-to-end CLI for the kernel-checkpoint POC. + +Two commands so the round-trip is testable without touching production code: + + $ python -m deepnote_toolkit.kernel_checkpoint save + $ python -m deepnote_toolkit.kernel_checkpoint restore + +The `save` command execs the python file in a fresh namespace, then snapshots +the resulting globals. `restore` reads the snapshot back into a fresh +namespace and prints the restored names. + +This is the proof-of-life that productionisation in `runtime/executor.py` +will be a wiring change, not a design change. +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +from deepnote_toolkit.kernel_checkpoint.checkpoint import ( + save_checkpoint, + try_restore_checkpoint, +) +from deepnote_toolkit.kernel_checkpoint.store import LocalDiskSnapshotStore + + +def _cmd_save(args: argparse.Namespace) -> int: + source = Path(args.python_file).read_text() + namespace: dict[str, object] = {} + exec(compile(source, args.python_file, "exec"), namespace) + store = LocalDiskSnapshotStore(args.root) + report = save_checkpoint(namespace, store, args.key) + print(f"[checkpoint] saved {len(report.saved_names)} globals to {store._path_for(args.key)}") + print(f" bytes_written: {report.bytes_written}") + if report.skipped_unpicklable: + print(" skipped (unpicklable):") + for name, kind in report.skipped_unpicklable: + print(f" - {name}: {kind}") + if report.skipped_by_rule: + print(" skipped (by rule):") + for name, reason in report.skipped_by_rule: + print(f" - {name}: {reason}") + print(" saved:") + for name in report.saved_names: + print(f" - {name}") + return 0 + + +def _cmd_restore(args: argparse.Namespace) -> int: + store = LocalDiskSnapshotStore(args.root) + result = try_restore_checkpoint(store, args.key) + if result is None: + print(f"[checkpoint] no snapshot found for key: {args.key}") + return 1 + globals_dict, report = result + print(f"[checkpoint] restored {len(report.restored_names)} globals ({report.bytes_read} bytes):") + for name in report.restored_names: + value = globals_dict[name] + print(f" {name}: {type(value).__name__}") + return 0 + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(prog="python -m deepnote_toolkit.kernel_checkpoint") + parser.add_argument( + "--root", default="/tmp/deepnote-checkpoint", help="Local-disk snapshot store root." + ) + sub = parser.add_subparsers(dest="cmd", required=True) + + save = sub.add_parser("save", help="Exec a python file and snapshot the resulting globals.") + save.add_argument("key") + save.add_argument("python_file") + save.set_defaults(func=_cmd_save) + + restore = sub.add_parser("restore", help="Restore a snapshot into a fresh namespace and print names.") + restore.add_argument("key") + restore.set_defaults(func=_cmd_restore) + + args = parser.parse_args(argv) + return args.func(args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/deepnote_toolkit/kernel_checkpoint/checkpoint.py b/deepnote_toolkit/kernel_checkpoint/checkpoint.py new file mode 100644 index 0000000..d6abe0e --- /dev/null +++ b/deepnote_toolkit/kernel_checkpoint/checkpoint.py @@ -0,0 +1,127 @@ +"""Serialise / deserialise a Python globals namespace via dill. + +# Save semantics + +`save_checkpoint(globals_dict, store, key)` iterates the globals one-by-one +and tries to serialise each. On any individual failure (unpicklable file +handle, db connection, thread, etc.) the offending name is logged-and- +skipped; the rest of the snapshot is preserved. This is intentional — a +single unpicklable variable should not torch the entire init state. + +Items skipped unconditionally: + - dunders (`__name__`, `__builtins__`, ...) — kernel-bootstrap state we + don't want to overwrite on restore + - module objects — restoring them risks shadowing the freshly-booted + kernel's own imports; the user code re-imports as needed + - IPython / ipykernel artifacts — same risk as modules + +# Restore semantics + +`try_restore_checkpoint(store, key)` returns a dict (the restored globals) +on a successful read, or None if the key is absent. The caller is +responsible for *merging* the restored dict into the target namespace — +this lets the kernel's own bootstrap state survive. + +Errors during dill load are NOT caught: a corrupt snapshot is a real +problem and should surface, not silently fail-open to running init. +""" + +from __future__ import annotations + +import logging +import types +from dataclasses import dataclass, field +from typing import Any + +import dill + +from deepnote_toolkit.kernel_checkpoint.store import SnapshotStore + +logger = logging.getLogger(__name__) + + +@dataclass +class SaveReport: + """Per-checkpoint save outcome — useful for tests and prod metrics.""" + + saved_names: list[str] = field(default_factory=list) + skipped_unpicklable: list[tuple[str, str]] = field(default_factory=list) + skipped_by_rule: list[tuple[str, str]] = field(default_factory=list) + bytes_written: int = 0 + + +@dataclass +class RestoreReport: + """Per-checkpoint restore outcome — useful for tests and prod metrics.""" + + restored_names: list[str] + bytes_read: int + + +_ALWAYS_SKIP_PREFIXES = ("_",) +_ALWAYS_SKIP_NAMES = { + "In", + "Out", + "exit", + "quit", + "get_ipython", +} + + +def _should_skip_name(name: str) -> str | None: + """Return a reason string when the name must be skipped, else None.""" + if name in _ALWAYS_SKIP_NAMES: + return "ipython_artifact" + if any(name.startswith(prefix) for prefix in _ALWAYS_SKIP_PREFIXES): + return "dunder_or_private" + return None + + +def _should_skip_value(value: Any) -> str | None: + """Return a reason string when the value type must be skipped, else None.""" + if isinstance(value, types.ModuleType): + return "module_object" + return None + + +def save_checkpoint(globals_dict: dict[str, Any], store: SnapshotStore, key: str) -> SaveReport: + """Snapshot the named globals to `store` under `key`. + + Returns a SaveReport so the caller can log/metric the result. + """ + report = SaveReport() + saved: dict[str, Any] = {} + + for name, value in globals_dict.items(): + rule_reason = _should_skip_name(name) or _should_skip_value(value) + if rule_reason: + report.skipped_by_rule.append((name, rule_reason)) + continue + try: + # Round-trip-test individually so an unpicklable value doesn't + # corrupt the entire dict's dump later. + dill.dumps(value) + except Exception as exc: # noqa: BLE001 — broad on purpose; dill raises many types + report.skipped_unpicklable.append((name, type(exc).__name__)) + logger.info("[checkpoint] skipping unpicklable %s: %s", name, exc) + continue + saved[name] = value + report.saved_names.append(name) + + payload = dill.dumps(saved) + store.write(key, payload) + report.bytes_written = len(payload) + return report + + +def try_restore_checkpoint(store: SnapshotStore, key: str) -> tuple[dict[str, Any], RestoreReport] | None: + """Attempt to restore from `store`. Returns (globals, report) or None if absent. + + Callers MERGE the returned dict into their target namespace — this module + intentionally doesn't mutate any caller state. + """ + payload = store.read(key) + if payload is None: + return None + restored: dict[str, Any] = dill.loads(payload) + return restored, RestoreReport(restored_names=list(restored.keys()), bytes_read=len(payload)) diff --git a/deepnote_toolkit/kernel_checkpoint/key.py b/deepnote_toolkit/kernel_checkpoint/key.py new file mode 100644 index 0000000..444a2dc --- /dev/null +++ b/deepnote_toolkit/kernel_checkpoint/key.py @@ -0,0 +1,32 @@ +"""Stable cache-key composition for kernel checkpoints. + +The key MUST change whenever any input that influences post-init kernel state +changes. Today that's: + + - project_id — checkpoints are project-scoped + - init_source_hash — re-running an updated init produces different state + - environment_id — pandas/numpy version bumps can change in-memory shapes + +Production-shaped: a deterministic composition with all inputs visible in the +key string so debugging stale-restore incidents is one `cat` away. +""" + +from __future__ import annotations + + +def compute_checkpoint_key(*, project_id: str, init_source_hash: str, environment_id: str) -> str: + """Compose a stable checkpoint key from its inputs. + + Each component is required and must be a non-empty string. The format is + `proj::init::env:` so the + key is human-readable when listing the snapshot store. + """ + for name, value in ( + ("project_id", project_id), + ("init_source_hash", init_source_hash), + ("environment_id", environment_id), + ): + if not isinstance(value, str) or not value: + raise ValueError(f"{name} must be a non-empty string, got: {value!r}") + + return f"proj:{project_id}:init:{init_source_hash}:env:{environment_id}" diff --git a/deepnote_toolkit/kernel_checkpoint/store.py b/deepnote_toolkit/kernel_checkpoint/store.py new file mode 100644 index 0000000..a85ba71 --- /dev/null +++ b/deepnote_toolkit/kernel_checkpoint/store.py @@ -0,0 +1,68 @@ +"""Snapshot storage backend interface and local-disk implementation. + +The S3 production backend will be a sibling implementation. Callers must use +the `SnapshotStore` protocol so swapping is a one-line change. +""" + +from __future__ import annotations + +import os +import tempfile +from pathlib import Path +from typing import Protocol + + +class SnapshotStore(Protocol): + """Read/write opaque byte blobs by string key.""" + + def read(self, key: str) -> bytes | None: + """Return the bytes for `key`, or None if absent.""" + ... + + def write(self, key: str, data: bytes) -> None: + """Write `data` under `key`, overwriting any prior value.""" + ... + + +class LocalDiskSnapshotStore: + """File-system implementation. Keys are mapped to files under `root`. + + Atomic write (temp + rename) so a crash mid-write doesn't leave a corrupt + snapshot that would later deserialise to garbage. Forward-slashes in keys + are encoded so the key composition is free to use any separator. + + The default root is `/tmp/deepnote-checkpoint/`. Production-shaped use + would point this at a per-project persistent volume, or replace the whole + class with an S3-backed implementation. + """ + + def __init__(self, root: str | Path = "/tmp/deepnote-checkpoint") -> None: + self._root = Path(root) + self._root.mkdir(parents=True, exist_ok=True) + + def _path_for(self, key: str) -> Path: + # Forward slashes and colons would create unwanted directories; flatten. + safe = key.replace("/", "__").replace(":", "_") + return self._root / safe + + def read(self, key: str) -> bytes | None: + path = self._path_for(key) + if not path.exists(): + return None + return path.read_bytes() + + def write(self, key: str, data: bytes) -> None: + path = self._path_for(key) + # Atomic write: temp file in the same dir, then rename. Avoids a partial + # file being readable as a "successful" snapshot. + fd, tmp_path = tempfile.mkstemp(dir=path.parent, suffix=".tmp") + try: + with os.fdopen(fd, "wb") as f: + f.write(data) + os.replace(tmp_path, path) + except BaseException: + try: + os.unlink(tmp_path) + except OSError: + pass + raise diff --git a/tests/unit/test_kernel_checkpoint.py b/tests/unit/test_kernel_checkpoint.py new file mode 100644 index 0000000..477de88 --- /dev/null +++ b/tests/unit/test_kernel_checkpoint.py @@ -0,0 +1,225 @@ +"""Round-trip tests for the kernel-checkpoint POC. + +Each test exercises one shape of the snapshot/restore contract: + - basic types survive + - user-defined functions and classes survive (dill handles closures) + - unpicklable items are skipped-not-fatal + - dunders / modules / ipython artifacts are skipped by rule + - key composition is stable, and any input change invalidates the key +""" + +from __future__ import annotations + +import pytest + +from deepnote_toolkit.kernel_checkpoint import ( + LocalDiskSnapshotStore, + compute_checkpoint_key, + save_checkpoint, + try_restore_checkpoint, +) + + +def _has_pandas() -> bool: + try: + import pandas # noqa: F401 + return True + except ImportError: + return False + + +PANDAS_AVAILABLE = _has_pandas() + + +@pytest.fixture +def store(tmp_path) -> LocalDiskSnapshotStore: + return LocalDiskSnapshotStore(root=tmp_path / "snap") + + +def test_roundtrip_primitives(store): + globals_dict = {"x": 1, "y": "hello", "z": True, "w": None, "n": 3.14} + save_checkpoint(globals_dict, store, "k1") + + result = try_restore_checkpoint(store, "k1") + assert result is not None + restored, _ = result + assert restored == {"x": 1, "y": "hello", "z": True, "w": None, "n": 3.14} + + +def test_roundtrip_containers(store): + globals_dict = {"lst": [1, 2, 3], "d": {"a": 1, "b": [True, False]}, "t": (1, 2)} + save_checkpoint(globals_dict, store, "k1") + + restored, _ = try_restore_checkpoint(store, "k1") + assert restored == globals_dict + + +def test_roundtrip_user_function(store): + def adder(x, y): + return x + y + + save_checkpoint({"adder": adder}, store, "k1") + + restored, _ = try_restore_checkpoint(store, "k1") + assert restored["adder"](2, 3) == 5 + + +def test_roundtrip_user_class_instance(store): + class Holder: + def __init__(self, n): + self.n = n + + def doubled(self): + return self.n * 2 + + save_checkpoint({"obj": Holder(21)}, store, "k1") + + restored, _ = try_restore_checkpoint(store, "k1") + assert restored["obj"].doubled() == 42 + + +def test_unpicklable_is_skipped_rest_survives(store): + # Generators are one of the few types dill genuinely refuses to serialise + # (their internal frame state isn't recoverable). A real init script might + # leak similar via a third-party library; this proves a single unpicklable + # value doesn't take the whole snapshot down with it. + bad = (x for x in range(3)) + next(bad) + globals_dict = {"good": 42, "bad_gen": bad, "other_good": "ok"} + report = save_checkpoint(globals_dict, store, "k1") + + assert "good" in report.saved_names + assert "other_good" in report.saved_names + assert any(name == "bad_gen" for name, _ in report.skipped_unpicklable) + + restored, _ = try_restore_checkpoint(store, "k1") + assert restored == {"good": 42, "other_good": "ok"} + + +def test_modules_skipped_by_rule(store): + import os as os_module # noqa: PLC0415 — intentional, testing skip rule + + globals_dict = {"good": 1, "os_module": os_module} + report = save_checkpoint(globals_dict, store, "k1") + + assert "good" in report.saved_names + assert ("os_module", "module_object") in report.skipped_by_rule + + restored, _ = try_restore_checkpoint(store, "k1") + assert restored == {"good": 1} + + +def test_dunder_names_skipped_by_rule(store): + globals_dict = {"__name__": "__main__", "__builtins__": {}, "_private": 1, "public": 2} + report = save_checkpoint(globals_dict, store, "k1") + + saved = set(report.saved_names) + assert saved == {"public"} + assert ("__name__", "dunder_or_private") in report.skipped_by_rule + assert ("__builtins__", "dunder_or_private") in report.skipped_by_rule + assert ("_private", "dunder_or_private") in report.skipped_by_rule + + +def test_ipython_artifacts_skipped(store): + globals_dict = {"In": ["cell1"], "Out": {}, "exit": "shouldnt-survive", "real": 1} + report = save_checkpoint(globals_dict, store, "k1") + + assert report.saved_names == ["real"] + for name in ("In", "Out", "exit"): + assert (name, "ipython_artifact") in report.skipped_by_rule + + +def test_restore_returns_none_when_key_missing(store): + assert try_restore_checkpoint(store, "never-saved") is None + + +def test_restore_returns_none_for_different_key(store): + save_checkpoint({"x": 1}, store, "key-a") + assert try_restore_checkpoint(store, "key-b") is None + + +def test_report_bytes_written_matches_round_trip(store): + save_report = save_checkpoint({"x": [1, 2, 3, 4, 5]}, store, "k1") + restored, restore_report = try_restore_checkpoint(store, "k1") + assert restored == {"x": [1, 2, 3, 4, 5]} + assert save_report.bytes_written == restore_report.bytes_read + assert save_report.bytes_written > 0 + + +# ----- key composition ----- + + +def test_key_is_stable_for_same_inputs(): + a = compute_checkpoint_key(project_id="p1", init_source_hash="h1", environment_id="e1") + b = compute_checkpoint_key(project_id="p1", init_source_hash="h1", environment_id="e1") + assert a == b + + +def test_key_changes_when_project_changes(): + a = compute_checkpoint_key(project_id="p1", init_source_hash="h1", environment_id="e1") + b = compute_checkpoint_key(project_id="p2", init_source_hash="h1", environment_id="e1") + assert a != b + + +def test_key_changes_when_init_source_hash_changes(): + a = compute_checkpoint_key(project_id="p1", init_source_hash="h1", environment_id="e1") + b = compute_checkpoint_key(project_id="p1", init_source_hash="h2", environment_id="e1") + assert a != b + + +def test_key_changes_when_environment_changes(): + a = compute_checkpoint_key(project_id="p1", init_source_hash="h1", environment_id="e1") + b = compute_checkpoint_key(project_id="p1", init_source_hash="h1", environment_id="e2") + assert a != b + + +def test_key_is_human_readable(): + key = compute_checkpoint_key(project_id="proj-123", init_source_hash="abc", environment_id="py311") + assert "proj-123" in key + assert "abc" in key + assert "py311" in key + + +def test_key_rejects_empty_inputs(): + with pytest.raises(ValueError): + compute_checkpoint_key(project_id="", init_source_hash="h", environment_id="e") + with pytest.raises(ValueError): + compute_checkpoint_key(project_id="p", init_source_hash="", environment_id="e") + with pytest.raises(ValueError): + compute_checkpoint_key(project_id="p", init_source_hash="h", environment_id="") + + +# ----- store implementation ----- + + +def test_local_disk_store_atomic_overwrite(tmp_path): + store = LocalDiskSnapshotStore(root=tmp_path / "snap") + store.write("k", b"first") + assert store.read("k") == b"first" + store.write("k", b"second") + assert store.read("k") == b"second" + + +def test_local_disk_store_returns_none_for_missing(tmp_path): + store = LocalDiskSnapshotStore(root=tmp_path / "snap") + assert store.read("never-written") is None + + +def test_local_disk_store_handles_keys_with_colons_and_slashes(tmp_path): + store = LocalDiskSnapshotStore(root=tmp_path / "snap") + key = "proj:abc/def:env:py311" + store.write(key, b"payload") + assert store.read(key) == b"payload" + + +# ----- pandas DataFrame round-trip, only if pandas is available ----- + + +@pytest.mark.skipif(not PANDAS_AVAILABLE, reason="pandas not installed") +def test_roundtrip_pandas_dataframe(store): + import pandas as pd + + df = pd.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]}) + save_checkpoint({"df": df}, store, "k1") + restored, _ = try_restore_checkpoint(store, "k1") + pd.testing.assert_frame_equal(restored["df"], df)