Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions deepnote_toolkit/kernel_checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Kernel-state snapshot / restore for cold-start reduction.

The dominant cold-start cost in Deepnote today is re-running the init notebook
on every fresh container boot. For projects with heavy init (data loading,
client setup), this can be tens of seconds to minutes per cold start.

This module captures the kernel's globals after init completes and restores
them on the next cold boot, skipping the init re-run entirely.

# Design surface

Three layers, each with its own file:

- `store`: a `SnapshotStore` protocol with a `LocalDiskSnapshotStore` impl
for the POC. The S3 production impl will be a sibling later.
- `checkpoint`: `save_checkpoint` and `try_restore_checkpoint` — the actual
serialise/deserialise logic. Built on `dill` (already a deepnote-toolkit
dep) to handle closures and most user-defined types.
- `key`: `compute_checkpoint_key` — the stable cache-key composition. Any
input change here (init source, environment) must invalidate the cache.

A `__main__` CLI exists for end-to-end testing without touching production
code paths.

# Correctness invariant

A restore is correct iff the restored namespace produces the same downstream
execution behaviour as freshly re-running init in the current environment.

The checkpoint key includes the init source hash and environment id so that
source/env changes always invalidate the cache. What the key does NOT
capture (limitations the production PR must address):

- filesystem state in `/work` — if init read a file that has since
changed, restore returns stale data
- external connections (db, http clients) — they are unpicklable and get
skipped on save; restored kernel will need to re-establish them
- cross-Python-version restore — dill bytes from 3.11 may not load on
3.12; the env id is expected to discriminate Python versions

# What is NOT in this PR

- S3 backend (LocalDiskSnapshotStore only)
- Wiring into `runtime/executor.py` init lifecycle
- Webapp signaling
- Per-workspace feature flag

See the [POC plan](~/.claude/plans/snapshot-restore-poc.md) for the
productionisation path.
"""

from deepnote_toolkit.kernel_checkpoint.checkpoint import (
RestoreReport,
SaveReport,
save_checkpoint,
try_restore_checkpoint,
)
from deepnote_toolkit.kernel_checkpoint.key import compute_checkpoint_key
from deepnote_toolkit.kernel_checkpoint.store import (
LocalDiskSnapshotStore,
SnapshotStore,
)

__all__ = [
"LocalDiskSnapshotStore",
"RestoreReport",
"SaveReport",
"SnapshotStore",
"compute_checkpoint_key",
"save_checkpoint",
"try_restore_checkpoint",
]
86 changes: 86 additions & 0 deletions deepnote_toolkit/kernel_checkpoint/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""End-to-end CLI for the kernel-checkpoint POC.

Two commands so the round-trip is testable without touching production code:

$ python -m deepnote_toolkit.kernel_checkpoint save <key> <python_file>
$ python -m deepnote_toolkit.kernel_checkpoint restore <key>

The `save` command execs the python file in a fresh namespace, then snapshots
the resulting globals. `restore` reads the snapshot back into a fresh
namespace and prints the restored names.

This is the proof-of-life that productionisation in `runtime/executor.py`
will be a wiring change, not a design change.
"""

from __future__ import annotations

import argparse
import sys
from pathlib import Path

from deepnote_toolkit.kernel_checkpoint.checkpoint import (
save_checkpoint,
try_restore_checkpoint,
)
from deepnote_toolkit.kernel_checkpoint.store import LocalDiskSnapshotStore


def _cmd_save(args: argparse.Namespace) -> int:
source = Path(args.python_file).read_text()
namespace: dict[str, object] = {}
exec(compile(source, args.python_file, "exec"), namespace)
store = LocalDiskSnapshotStore(args.root)
report = save_checkpoint(namespace, store, args.key)
print(f"[checkpoint] saved {len(report.saved_names)} globals to {store._path_for(args.key)}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Avoid calling a private store method from CLI.

Line 35 reaches into store._path_for(...). Expose a public helper (or return path info from write) so CLI does not depend on private internals.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepnote_toolkit/kernel_checkpoint/__main__.py` at line 35, The CLI currently
calls the private method store._path_for(...) in
deepnote_toolkit/kernel_checkpoint/__main__.py to build the message; instead
expose a public API (either make store.write return the final path or add a
public method like store.path_for/get_path) and update the CLI to use that
public result instead of _path_for; ensure the change returns the same path used
when writing so the print line (which references report.saved_names and
args.key) can display the path without touching private store internals.

print(f" bytes_written: {report.bytes_written}")
if report.skipped_unpicklable:
print(" skipped (unpicklable):")
for name, kind in report.skipped_unpicklable:
print(f" - {name}: {kind}")
if report.skipped_by_rule:
print(" skipped (by rule):")
for name, reason in report.skipped_by_rule:
print(f" - {name}: {reason}")
print(" saved:")
for name in report.saved_names:
print(f" - {name}")
return 0


def _cmd_restore(args: argparse.Namespace) -> int:
store = LocalDiskSnapshotStore(args.root)
result = try_restore_checkpoint(store, args.key)
if result is None:
print(f"[checkpoint] no snapshot found for key: {args.key}")
return 1
globals_dict, report = result
print(f"[checkpoint] restored {len(report.restored_names)} globals ({report.bytes_read} bytes):")
for name in report.restored_names:
value = globals_dict[name]
print(f" {name}: {type(value).__name__}")
return 0


def main(argv: list[str] | None = None) -> int:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

repo_root="."
file="deepnote_toolkit/kernel_checkpoint/__main__.py"

echo "=== File exists? ==="
ls -l "$file" || true

echo "=== Showing around line 55-75 with numbers ==="
nl -ba "$file" | sed -n '50,85p'

Repository: deepnote/deepnote-toolkit

Length of output: 295


🏁 Script executed:

#!/bin/bash
set -euo pipefail

file="deepnote_toolkit/kernel_checkpoint/__main__.py"
echo "=== File exists? ==="
ls -l "$file" || true

echo "=== Showing around line 55-75 with numbers ==="
nl -ba "$file" | sed -n '50,85p'

echo "=== Search for 'Optional' usage in that file ==="
rg -n "Optional\[" "$file" || true

echo "=== Search for other PEP604 unions ' | None' in that file ==="
rg -n "\|\s*None" "$file" || true

Repository: deepnote/deepnote-toolkit

Length of output: 294


🏁 Script executed:

#!/bin/bash
set -euo pipefail

file="deepnote_toolkit/kernel_checkpoint/__main__.py"
echo "=== File exists? ==="
ls -l "$file" || true

echo "=== Showing around line 55-75 with numbers ==="
nl -ba "$file" | sed -n '50,85p'

echo "=== Imports in that file (top 40 lines) ==="
nl -ba "$file" | sed -n '1,45p'

echo "=== Search for 'Optional[' usage in that file ==="
rg -n "Optional\[" "$file" || true

echo "=== Search for PEP604 unions '| None' in that file ==="
rg -n "\|\s*None" "$file" || true

Repository: deepnote/deepnote-toolkit

Length of output: 294


🏁 Script executed:

#!/bin/bash
set -euo pipefail

file="deepnote_toolkit/kernel_checkpoint/__main__.py"
echo "=== File exists? ==="
ls -l "$file" || true

echo "=== Showing around line 55-75 with numbers ==="
nl -ba "$file" | sed -n '55,75p'

echo "=== Imports in that file (top 60 lines) ==="
nl -ba "$file" | sed -n '1,60p'

echo "=== Search for typing.Optional usage in repo (small scope) ==="
rg -n "from typing import Optional|Optional\[" deepnote_toolkit | head -n 20

echo "=== Search for PEP604 '| None' usage in same module area (deepnote_toolkit/kernel_checkpoint) ==="
rg -n "\|\s*None" deepnote_toolkit/kernel_checkpoint | head -n 20

Repository: deepnote/deepnote-toolkit

Length of output: 294


🏁 Script executed:

#!/bin/bash
set -euo pipefail

file="deepnote_toolkit/kernel_checkpoint/__main__.py"
echo "=== Line 1-120 with numbers (awk) ==="
awk 'NR>=1 && NR<=120 {printf "%d\t%s\n", NR, $0}' "$file"

echo "=== Exact main() signature line containing 'def main' ==="
rg -n "def main\(" "$file" || true

echo "=== Search for 'Optional[' usage in that file ==="
rg -n "Optional\[" "$file" || true

echo "=== Search for PEP604 '| None' usage in that file ==="
rg -n "\|\s*None" "$file" || true

echo "=== Imports in that file ==="
rg -n "^(from typing import|import typing|from __future__ import|from typing_extensions import)" "$file" || true

Repository: deepnote/deepnote-toolkit

Length of output: 3797


🏁 Script executed:

#!/bin/bash
set -euo pipefail

rg -n "Optional\[" deepnote_toolkit | head -n 20
echo "----"
rg -n "\|\s*None" deepnote_toolkit | head -n 20
echo "----"
echo "Counts:"
python3 - <<'PY'
import re, subprocess, sys
def count(pattern):
    p=subprocess.run(["bash","-lc",f"rg -n {pattern!r} deepnote_toolkit | wc -l"],capture_output=True,text=True)
    return int(p.stdout.strip() or "0")
# basic counts
opt=count("Optional\\[")
union=count("\\|\\s*None")
print("Optional[ count:",opt)
print("| None count:",union)
PY

Repository: deepnote/deepnote-toolkit

Length of output: 2441


Use Optional[list[str]] for the nullable argv param (or align the whole module)
deepnote_toolkit/kernel_checkpoint/__main__.py:65 uses list[str] | None; this kernel_checkpoint package already uses PEP 604 (... | None) in store.py/checkpoint.py, so either switch to Optional[...] consistently there (incl. this function) or allow | None as the repo convention.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepnote_toolkit/kernel_checkpoint/__main__.py` at line 65, The function
signature for main uses the union syntax "list[str] | None" which is
inconsistent with the module's preferred use of typing.Optional; update the main
function signature to use Optional[list[str]] instead, add "from typing import
Optional" to the imports if missing, and ensure any other type hints in this
module follow the same Optional[...] convention for consistency with store.py
and checkpoint.py.

parser = argparse.ArgumentParser(prog="python -m deepnote_toolkit.kernel_checkpoint")
parser.add_argument(
"--root", default="/tmp/deepnote-checkpoint", help="Local-disk snapshot store root."
)
sub = parser.add_subparsers(dest="cmd", required=True)

save = sub.add_parser("save", help="Exec a python file and snapshot the resulting globals.")
save.add_argument("key")
save.add_argument("python_file")
save.set_defaults(func=_cmd_save)

restore = sub.add_parser("restore", help="Restore a snapshot into a fresh namespace and print names.")
restore.add_argument("key")
restore.set_defaults(func=_cmd_restore)

args = parser.parse_args(argv)
return args.func(args)


if __name__ == "__main__":
sys.exit(main())
127 changes: 127 additions & 0 deletions deepnote_toolkit/kernel_checkpoint/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Serialise / deserialise a Python globals namespace via dill.

# Save semantics

`save_checkpoint(globals_dict, store, key)` iterates the globals one-by-one
and tries to serialise each. On any individual failure (unpicklable file
handle, db connection, thread, etc.) the offending name is logged-and-
skipped; the rest of the snapshot is preserved. This is intentional — a
single unpicklable variable should not torch the entire init state.

Items skipped unconditionally:
- dunders (`__name__`, `__builtins__`, ...) — kernel-bootstrap state we
don't want to overwrite on restore
- module objects — restoring them risks shadowing the freshly-booted
kernel's own imports; the user code re-imports as needed
- IPython / ipykernel artifacts — same risk as modules

# Restore semantics

`try_restore_checkpoint(store, key)` returns a dict (the restored globals)
on a successful read, or None if the key is absent. The caller is
responsible for *merging* the restored dict into the target namespace —
this lets the kernel's own bootstrap state survive.

Errors during dill load are NOT caught: a corrupt snapshot is a real
problem and should surface, not silently fail-open to running init.
"""

from __future__ import annotations

import logging
import types
from dataclasses import dataclass, field
from typing import Any

import dill

from deepnote_toolkit.kernel_checkpoint.store import SnapshotStore

logger = logging.getLogger(__name__)


@dataclass
class SaveReport:
"""Per-checkpoint save outcome — useful for tests and prod metrics."""

saved_names: list[str] = field(default_factory=list)
skipped_unpicklable: list[tuple[str, str]] = field(default_factory=list)
skipped_by_rule: list[tuple[str, str]] = field(default_factory=list)
bytes_written: int = 0


@dataclass
class RestoreReport:
"""Per-checkpoint restore outcome — useful for tests and prod metrics."""

restored_names: list[str]
bytes_read: int


_ALWAYS_SKIP_PREFIXES = ("_",)
_ALWAYS_SKIP_NAMES = {
"In",
"Out",
"exit",
"quit",
"get_ipython",
}


def _should_skip_name(name: str) -> str | None:
"""Return a reason string when the name must be skipped, else None."""
if name in _ALWAYS_SKIP_NAMES:
return "ipython_artifact"
if any(name.startswith(prefix) for prefix in _ALWAYS_SKIP_PREFIXES):
return "dunder_or_private"
return None


def _should_skip_value(value: Any) -> str | None:
"""Return a reason string when the value type must be skipped, else None."""
if isinstance(value, types.ModuleType):
return "module_object"
return None


def save_checkpoint(globals_dict: dict[str, Any], store: SnapshotStore, key: str) -> SaveReport:
"""Snapshot the named globals to `store` under `key`.

Returns a SaveReport so the caller can log/metric the result.
"""
report = SaveReport()
saved: dict[str, Any] = {}

for name, value in globals_dict.items():
rule_reason = _should_skip_name(name) or _should_skip_value(value)
if rule_reason:
report.skipped_by_rule.append((name, rule_reason))
continue
try:
# Round-trip-test individually so an unpicklable value doesn't
# corrupt the entire dict's dump later.
dill.dumps(value)
except Exception as exc: # noqa: BLE001 — broad on purpose; dill raises many types
report.skipped_unpicklable.append((name, type(exc).__name__))
logger.info("[checkpoint] skipping unpicklable %s: %s", name, exc)
continue
saved[name] = value
report.saved_names.append(name)

payload = dill.dumps(saved)
store.write(key, payload)
report.bytes_written = len(payload)
return report


def try_restore_checkpoint(store: SnapshotStore, key: str) -> tuple[dict[str, Any], RestoreReport] | None:
"""Attempt to restore from `store`. Returns (globals, report) or None if absent.

Callers MERGE the returned dict into their target namespace — this module
intentionally doesn't mutate any caller state.
"""
payload = store.read(key)
if payload is None:
return None
restored: dict[str, Any] = dill.loads(payload)
return restored, RestoreReport(restored_names=list(restored.keys()), bytes_read=len(payload))
Comment on lines +123 to +127
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether checkpoint payload integrity/authenticity is implemented anywhere.
rg -n -C3 'dill\.loads|dill\.dumps|hmac|signature|sign|verify|sha256|blake2|cryptography|fernet' deepnote_toolkit

Repository: deepnote/deepnote-toolkit

Length of output: 15277


Authenticate checkpoint payloads before dill.loadsdeepnote_toolkit/kernel_checkpoint/checkpoint.py (lines 123-127) restores state via dill.loads(payload) directly from store.read(key). If the stored snapshot can be modified, this enables arbitrary code execution. Add integrity/authentication (e.g., signed payloads or an AEAD envelope) and fail closed before deserialization.

🧰 Tools
🪛 Ruff (0.15.14)

[error] 126-126: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue

(S301)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepnote_toolkit/kernel_checkpoint/checkpoint.py` around lines 123 - 127, The
code currently calls dill.loads(payload) on untrusted data read via
store.read(key); instead add integrity/authentication verification before
deserialization: expect an authenticated envelope (e.g., AEAD ciphertext or a
signed payload with signature field) and verify it using your project crypto
utilities (or add HMAC/verify using a configured secret/public key) before
calling dill.loads; if verification fails, return None or raise and do not call
dill.loads. Update checkpoint.restore logic around payload, dill.loads, and
RestoreReport to parse/verify the envelope, decrypt/unwrap or reject invalid
payloads, and only then set restored and construct RestoreReport.

32 changes: 32 additions & 0 deletions deepnote_toolkit/kernel_checkpoint/key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Stable cache-key composition for kernel checkpoints.

The key MUST change whenever any input that influences post-init kernel state
changes. Today that's:

- project_id — checkpoints are project-scoped
- init_source_hash — re-running an updated init produces different state
- environment_id — pandas/numpy version bumps can change in-memory shapes

Production-shaped: a deterministic composition with all inputs visible in the
key string so debugging stale-restore incidents is one `cat` away.
"""

from __future__ import annotations


def compute_checkpoint_key(*, project_id: str, init_source_hash: str, environment_id: str) -> str:
"""Compose a stable checkpoint key from its inputs.

Each component is required and must be a non-empty string. The format is
`proj:<project_id>:init:<init_source_hash>:env:<environment_id>` so the
key is human-readable when listing the snapshot store.
"""
for name, value in (
("project_id", project_id),
("init_source_hash", init_source_hash),
("environment_id", environment_id),
):
if not isinstance(value, str) or not value:
raise ValueError(f"{name} must be a non-empty string, got: {value!r}")

return f"proj:{project_id}:init:{init_source_hash}:env:{environment_id}"
68 changes: 68 additions & 0 deletions deepnote_toolkit/kernel_checkpoint/store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Snapshot storage backend interface and local-disk implementation.

The S3 production backend will be a sibling implementation. Callers must use
the `SnapshotStore` protocol so swapping is a one-line change.
"""

from __future__ import annotations

import os
import tempfile
from pathlib import Path
from typing import Protocol


class SnapshotStore(Protocol):
"""Read/write opaque byte blobs by string key."""

def read(self, key: str) -> bytes | None:
"""Return the bytes for `key`, or None if absent."""
...

def write(self, key: str, data: bytes) -> None:
"""Write `data` under `key`, overwriting any prior value."""
...


class LocalDiskSnapshotStore:
"""File-system implementation. Keys are mapped to files under `root`.

Atomic write (temp + rename) so a crash mid-write doesn't leave a corrupt
snapshot that would later deserialise to garbage. Forward-slashes in keys
are encoded so the key composition is free to use any separator.

The default root is `/tmp/deepnote-checkpoint/`. Production-shaped use
would point this at a per-project persistent volume, or replace the whole
class with an S3-backed implementation.
"""

def __init__(self, root: str | Path = "/tmp/deepnote-checkpoint") -> None:
self._root = Path(root)
self._root.mkdir(parents=True, exist_ok=True)

def _path_for(self, key: str) -> Path:
# Forward slashes and colons would create unwanted directories; flatten.
safe = key.replace("/", "__").replace(":", "_")
return self._root / safe
Comment on lines +43 to +46
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Unsafe key-to-path mapping can collide and traverse.

Line 45’s replacement scheme is non-injective ("a/b" and "a__b" collide) and still allows ".." as a path component. That can overwrite the wrong snapshot or write outside root.

Suggested fix
+import hashlib
+
     def _path_for(self, key: str) -> Path:
-        # Forward slashes and colons would create unwanted directories; flatten.
-        safe = key.replace("/", "__").replace(":", "_")
-        return self._root / safe
+        digest = hashlib.sha256(key.encode("utf-8")).hexdigest()
+        return self._root / f"{digest}.snapshot"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _path_for(self, key: str) -> Path:
# Forward slashes and colons would create unwanted directories; flatten.
safe = key.replace("/", "__").replace(":", "_")
return self._root / safe
def _path_for(self, key: str) -> Path:
digest = hashlib.sha256(key.encode("utf-8")).hexdigest()
return self._root / f"{digest}.snapshot"
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepnote_toolkit/kernel_checkpoint/store.py` around lines 43 - 46, The
_path_for method currently maps keys to filenames with simple replace, which is
non-injective and allows path-traversal; replace it by producing a
deterministic, filesystem-safe filename derived from a strong hash of the key
(e.g., sha256 hex) and optionally a short sanitized/truncated human-readable
suffix, ensure the filename contains no path separators or "..", join with
self._root, then resolve and assert the resolved path is inside self._root (use
Path.resolve() and Path.is_relative_to or compare parents) before returning to
prevent escapes; update references to _path_for accordingly.


def read(self, key: str) -> bytes | None:
path = self._path_for(key)
if not path.exists():
return None
return path.read_bytes()

def write(self, key: str, data: bytes) -> None:
path = self._path_for(key)
# Atomic write: temp file in the same dir, then rename. Avoids a partial
# file being readable as a "successful" snapshot.
fd, tmp_path = tempfile.mkstemp(dir=path.parent, suffix=".tmp")
try:
with os.fdopen(fd, "wb") as f:
f.write(data)
os.replace(tmp_path, path)
except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
Loading
Loading