Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions src/zarr/storage/_local.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import asyncio
import contextlib
import io
import os
import shutil
import sys
import uuid
from pathlib import Path
from typing import TYPE_CHECKING, Self
from typing import TYPE_CHECKING, BinaryIO, Literal, Self

from zarr.abc.store import (
ByteRequest,
Expand All @@ -19,7 +22,7 @@
from zarr.core.common import AccessModeLiteral, concurrent_map

if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterable
from collections.abc import AsyncIterator, Iterable, Iterator

from zarr.core.buffer import BufferPrototype

Expand All @@ -41,27 +44,55 @@ def _get(path: Path, prototype: BufferPrototype, byte_range: ByteRequest | None)
return prototype.buffer.from_bytes(f.read())


if sys.platform == "win32":
# Per the os.rename docs:
# On Windows, if dst exists a FileExistsError is always raised.
_safe_move = os.rename
else:
# On Unix, os.rename silently replace files, so instead we use os.link like
# atomicwrites:
# https://github.com/untitaker/python-atomicwrites/blob/1.4.1/atomicwrites/__init__.py#L59-L60
# This also raises FileExistsError if dst exists.
def _safe_move(src: Path, dst: Path) -> None:
os.link(src, dst)
os.unlink(src)


@contextlib.contextmanager
def _atomic_write(
path: Path,
mode: Literal["r+b", "wb"],
exclusive: bool = False,
) -> Iterator[BinaryIO]:
tmp_path = path.with_suffix(f".{uuid.uuid4().hex}.partial")
try:
with tmp_path.open(mode) as f:
yield f
if exclusive:
_safe_move(tmp_path, path)
else:
tmp_path.replace(path)
except Exception:
tmp_path.unlink(missing_ok=True)
raise


def _put(
path: Path,
value: Buffer,
start: int | None = None,
exclusive: bool = False,
) -> int | None:
path.parent.mkdir(parents=True, exist_ok=True)
# write takes any object supporting the buffer protocol
view = value.as_buffer_like()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes me wonder why Buffer doesn't implement the buffer protocol!

if start is not None:
with path.open("r+b") as f:
f.seek(start)
# write takes any object supporting the buffer protocol
f.write(value.as_buffer_like())
f.write(view)
return None
else:
view = value.as_buffer_like()
if exclusive:
mode = "xb"
else:
mode = "wb"
with path.open(mode=mode) as f:
# write takes any object supporting the buffer protocol
with _atomic_write(path, "wb", exclusive=exclusive) as f:
return f.write(view)


Expand Down
44 changes: 44 additions & 0 deletions tests/test_store/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from zarr import create_array
from zarr.core.buffer import Buffer, cpu
from zarr.storage import LocalStore
from zarr.storage._local import _atomic_write
from zarr.testing.store import StoreTests
from zarr.testing.utils import assert_bytes_equal

Expand Down Expand Up @@ -109,3 +110,46 @@ async def test_move(
FileExistsError, match=re.escape(f"Destination root {destination} already exists")
):
await store2.move(destination)


@pytest.mark.parametrize("exclusive", [True, False])
def test_atomic_write_successful(tmp_path: pathlib.Path, exclusive: bool) -> None:
path = pathlib.Path(tmp_path) / "data"
with _atomic_write(path, "wb", exclusive=exclusive) as f:
f.write(b"abc")
assert path.read_bytes() == b"abc"
assert list(path.parent.iterdir()) == [path] # no temp files


@pytest.mark.parametrize("exclusive", [True, False])
def test_atomic_write_incomplete(tmp_path: pathlib.Path, exclusive: bool) -> None:
path = pathlib.Path(tmp_path) / "data"
with pytest.raises(RuntimeError): # noqa: PT012
with _atomic_write(path, "wb", exclusive=exclusive) as f:
f.write(b"a")
raise RuntimeError
assert not path.exists()
assert list(path.parent.iterdir()) == [] # no temp files


def test_atomic_write_non_exclusive_preexisting(tmp_path: pathlib.Path) -> None:
path = pathlib.Path(tmp_path) / "data"
with path.open("wb") as f:
f.write(b"xyz")
assert path.read_bytes() == b"xyz"
with _atomic_write(path, "wb", exclusive=False) as f:
f.write(b"abc")
assert path.read_bytes() == b"abc"
assert list(path.parent.iterdir()) == [path] # no temp files


def test_atomic_write_exclusive_preexisting(tmp_path: pathlib.Path) -> None:
path = pathlib.Path(tmp_path) / "data"
with path.open("wb") as f:
f.write(b"xyz")
assert path.read_bytes() == b"xyz"
with pytest.raises(FileExistsError):
with _atomic_write(path, "wb", exclusive=True) as f:
f.write(b"abc")
assert path.read_bytes() == b"xyz"
assert list(path.parent.iterdir()) == [path] # no temp files
Loading