Skip to content

Commit 7bb9c53

Browse files
committed
add mode argument to zarr.save
1 parent 8a33df7 commit 7bb9c53

File tree

3 files changed

+54
-5
lines changed

3 files changed

+54
-5
lines changed

src/zarr/api/asynchronous.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ async def save(
332332
*args: NDArrayLike,
333333
zarr_version: ZarrFormat | None = None, # deprecated
334334
zarr_format: ZarrFormat | None = None,
335+
mode: AccessModeLiteral | None = None,
335336
path: str | None = None,
336337
**kwargs: Any, # TODO: type kwargs as valid args to save
337338
) -> None:
@@ -345,19 +346,31 @@ async def save(
345346
NumPy arrays with data to save.
346347
zarr_format : {2, 3, None}, optional
347348
The zarr format to use when saving.
349+
mode: {'r', 'r+', 'a', 'w', 'w-'}, optional
350+
Persistence mode: 'r' means read only (must exist); 'r+' means
351+
read/write (must exist); 'a' means read/write (create if doesn't
352+
exist); 'w' means create (overwrite if exists); 'w-' means create
353+
(fail if exists).
348354
path : str or None, optional
349355
The path within the group where the arrays will be saved.
350356
**kwargs
351357
NumPy arrays with data to save.
352358
"""
353359
zarr_format = _handle_zarr_version_or_format(zarr_version=zarr_version, zarr_format=zarr_format)
354360

361+
for arg in args:
362+
if not isinstance(arg, np.ndarray):
363+
raise TypeError("All arguments must be numpy arrays")
364+
for k, v in kwargs.items():
365+
if not isinstance(v, np.ndarray):
366+
raise TypeError(f"Keyword argument '{k}' must be a numpy array")
367+
355368
if len(args) == 0 and len(kwargs) == 0:
356369
raise ValueError("at least one array must be provided")
357370
if len(args) == 1 and len(kwargs) == 0:
358-
await save_array(store, args[0], zarr_format=zarr_format, path=path)
371+
await save_array(store, args[0], zarr_format=zarr_format, mode=mode, path=path)
359372
else:
360-
await save_group(store, *args, zarr_format=zarr_format, path=path, **kwargs)
373+
await save_group(store, *args, zarr_format=zarr_format, mode=mode, path=path, **kwargs)
361374

362375

363376
async def save_array(
@@ -366,6 +379,7 @@ async def save_array(
366379
*,
367380
zarr_version: ZarrFormat | None = None, # deprecated
368381
zarr_format: ZarrFormat | None = None,
382+
mode: AccessModeLiteral | None = None,
369383
path: str | None = None,
370384
storage_options: dict[str, Any] | None = None,
371385
**kwargs: Any, # TODO: type kwargs as valid args to create
@@ -381,6 +395,11 @@ async def save_array(
381395
NumPy array with data to save.
382396
zarr_format : {2, 3, None}, optional
383397
The zarr format to use when saving.
398+
mode: {'r', 'r+', 'a', 'w', 'w-'}, optional
399+
Persistence mode: 'r' means read only (must exist); 'r+' means
400+
read/write (must exist); 'a' means read/write (create if doesn't
401+
exist); 'w' means create (overwrite if exists); 'w-' means create
402+
(fail if exists).
384403
path : str or None, optional
385404
The path within the store where the array will be saved.
386405
storage_options : dict
@@ -394,7 +413,6 @@ async def save_array(
394413
or _default_zarr_version()
395414
)
396415

397-
mode = kwargs.pop("mode", None)
398416
store_path = await make_store_path(store, path=path, mode=mode, storage_options=storage_options)
399417
new = await AsyncArray.create(
400418
store_path,
@@ -412,6 +430,7 @@ async def save_group(
412430
*args: NDArrayLike,
413431
zarr_version: ZarrFormat | None = None, # deprecated
414432
zarr_format: ZarrFormat | None = None,
433+
mode: AccessModeLiteral | None = None,
415434
path: str | None = None,
416435
storage_options: dict[str, Any] | None = None,
417436
**kwargs: NDArrayLike,
@@ -427,6 +446,11 @@ async def save_group(
427446
NumPy arrays with data to save.
428447
zarr_format : {2, 3, None}, optional
429448
The zarr format to use when saving.
449+
mode: {'r', 'r+', 'a', 'w', 'w-'}, optional
450+
Persistence mode: 'r' means read only (must exist); 'r+' means
451+
read/write (must exist); 'a' means read/write (create if doesn't
452+
exist); 'w' means create (overwrite if exists); 'w-' means create
453+
(fail if exists).
430454
path : str or None, optional
431455
Path within the store where the group will be saved.
432456
storage_options : dict
@@ -452,6 +476,7 @@ async def save_group(
452476
store,
453477
arr,
454478
zarr_format=zarr_format,
479+
mode=mode,
455480
path=f"{path}/arr_{i}",
456481
storage_options=storage_options,
457482
)
@@ -460,7 +485,12 @@ async def save_group(
460485
_path = f"{path}/{k}" if path is not None else k
461486
aws.append(
462487
save_array(
463-
store, arr, zarr_format=zarr_format, path=_path, storage_options=storage_options
488+
store,
489+
arr,
490+
zarr_format=zarr_format,
491+
mode=mode,
492+
path=_path,
493+
storage_options=storage_options,
464494
)
465495
)
466496
await asyncio.gather(*aws)

src/zarr/api/synchronous.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,19 @@ def save(
101101
*args: NDArrayLike,
102102
zarr_version: ZarrFormat | None = None, # deprecated
103103
zarr_format: ZarrFormat | None = None,
104+
mode: AccessModeLiteral | None = None,
104105
path: str | None = None,
105106
**kwargs: Any, # TODO: type kwargs as valid args to async_api.save
106107
) -> None:
107108
return sync(
108109
async_api.save(
109-
store, *args, zarr_version=zarr_version, zarr_format=zarr_format, path=path, **kwargs
110+
store,
111+
*args,
112+
zarr_version=zarr_version,
113+
zarr_format=zarr_format,
114+
mode=mode,
115+
path=path,
116+
**kwargs,
110117
)
111118
)
112119

@@ -118,6 +125,7 @@ def save_array(
118125
*,
119126
zarr_version: ZarrFormat | None = None, # deprecated
120127
zarr_format: ZarrFormat | None = None,
128+
mode: AccessModeLiteral | None = None,
121129
path: str | None = None,
122130
**kwargs: Any, # TODO: type kwargs as valid args to async_api.save_array
123131
) -> None:
@@ -127,6 +135,7 @@ def save_array(
127135
arr=arr,
128136
zarr_version=zarr_version,
129137
zarr_format=zarr_format,
138+
mode=mode,
130139
path=path,
131140
**kwargs,
132141
)
@@ -138,6 +147,7 @@ def save_group(
138147
*args: NDArrayLike,
139148
zarr_version: ZarrFormat | None = None, # deprecated
140149
zarr_format: ZarrFormat | None = None,
150+
mode: AccessModeLiteral | None = None,
141151
path: str | None = None,
142152
storage_options: dict[str, Any] | None = None,
143153
**kwargs: NDArrayLike,
@@ -148,6 +158,7 @@ def save_group(
148158
*args,
149159
zarr_version=zarr_version,
150160
zarr_format=zarr_format,
161+
mode=mode,
151162
path=path,
152163
storage_options=storage_options,
153164
**kwargs,

tests/test_api.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from zarr.core.common import MemoryOrder, ZarrFormat
2525
from zarr.errors import MetadataValidationError
26+
from zarr.storage import StorePath
2627
from zarr.storage._utils import normalize_path
2728
from zarr.storage.memory import MemoryStore
2829

@@ -999,3 +1000,10 @@ async def test_metadata_validation_error() -> None:
9991000
match="Invalid value for 'zarr_format'. Expected '2, 3, or None'. Got '3.0'.",
10001001
):
10011002
await zarr.api.asynchronous.open_array(shape=(1,), zarr_format="3.0") # type: ignore[arg-type]
1003+
1004+
1005+
@pytest.mark.parametrize("store", ["local"], indirect=["store"])
1006+
def test_zarr_save(store: Store) -> None:
1007+
a = np.arange(1000).reshape(10, 10, 10)
1008+
zarr.save(StorePath(store), a, mode="w")
1009+
assert_array_equal(zarr.load(store), a)

0 commit comments

Comments
 (0)