Skip to content

Commit 1d13dae

Browse files
committed
feat: experimental support for Subarray dtypes and backported support for nested Structured dtypes in V2 (#3582, #3583)
1 parent bff778b commit 1d13dae

File tree

8 files changed

+706
-34
lines changed

8 files changed

+706
-34
lines changed

src/zarr/core/dtype/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from zarr.core.dtype.npy.float import Float16, Float32, Float64
2323
from zarr.core.dtype.npy.int import Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64
2424
from zarr.core.dtype.npy.structured import Structured, StructuredJSON_V2, StructuredJSON_V3
25+
from zarr.core.dtype.npy.subarray import Subarray, SubarrayJSON_V3
2526
from zarr.core.dtype.npy.time import (
2627
DateTime64,
2728
DateTime64JSON_V2,
@@ -78,6 +79,8 @@
7879
"Structured",
7980
"StructuredJSON_V2",
8081
"StructuredJSON_V3",
82+
"Subarray",
83+
"SubarrayJSON_V3",
8184
"TBaseDType",
8285
"TBaseScalar",
8386
"TimeDelta64",
@@ -126,6 +129,7 @@
126129
| StringDType
127130
| BytesDType
128131
| Structured
132+
| Subarray
129133
| TimeDType
130134
| VariableLengthBytes
131135
)
@@ -139,6 +143,7 @@
139143
*STRING_DTYPE,
140144
*BYTES_DTYPE,
141145
Structured,
146+
Subarray,
142147
*TIME_DTYPE,
143148
VariableLengthBytes,
144149
)

src/zarr/core/dtype/common.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
# classes can perform a very specific type check.
4949

5050
# This is the JSON representation of a structured dtype in zarr v2
51-
StructuredName_V2 = Sequence["str | StructuredName_V2"]
51+
StructuredName_V2 = Sequence[Sequence["str | StructuredName_V2 | Sequence[int]"]]
5252

5353
# This models the type of the name a dtype might have in zarr v2 array metadata
5454
DTypeName_V2 = StructuredName_V2 | str
@@ -70,23 +70,39 @@ def check_structured_dtype_v2_inner(data: object) -> TypeGuard[StructuredName_V2
7070
A type guard for the inner elements of a structured dtype. This is a recursive check because
7171
the type is itself recursive.
7272
73-
This check ensures that all the elements are 2-element sequences beginning with a string
74-
and ending with either another string or another 2-element sequence beginning with a string and
75-
ending with another instance of that type.
73+
This check ensures that all the elements are either 2-element or 3-element sequences that:
74+
1. Begin with a string (name)
75+
2. Have as their second element either a string (dtype) or another sequence (structured dtype)
76+
3. If they have a third element, it is a sequence representing the shape of the field.
7677
"""
7778
if isinstance(data, (str, Mapping)):
7879
return False
7980
if not isinstance(data, Sequence):
8081
return False
81-
if len(data) != 2:
82+
if len(data) != 2 and len(data) != 3:
8283
return False
83-
if not (isinstance(data[0], str)):
84+
85+
name, dtype = data[0], data[1]
86+
87+
# check name element
88+
if not (isinstance(name, str)):
8489
return False
85-
if isinstance(data[-1], str):
90+
91+
# check shape element
92+
if len(data) == 3:
93+
shape = data[2]
94+
if not isinstance(shape, Sequence):
95+
return False
96+
if not all(isinstance(dim, int) for dim in shape):
97+
return False
98+
99+
# (recursively) check dtype element
100+
if isinstance(dtype, str):
86101
return True
87-
elif isinstance(data[-1], Sequence):
88-
return check_structured_dtype_name_v2(data[-1])
89-
return False
102+
elif isinstance(dtype, Sequence):
103+
return check_structured_dtype_name_v2(dtype)
104+
else:
105+
return False
90106

91107

92108
def check_structured_dtype_name_v2(data: Sequence[object]) -> TypeGuard[StructuredName_V2]:

src/zarr/core/dtype/npy/bytes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def _check_native_dtype(
606606
Bool
607607
True if the dtype is an instance of np.dtypes.VoidDType with no fields, False otherwise.
608608
"""
609-
return cls.dtype_cls is type(dtype) and dtype.fields is None
609+
return cls.dtype_cls is type(dtype) and dtype.fields is None and dtype.subdtype is None
610610

611611
@classmethod
612612
def from_native_dtype(cls, dtype: TBaseDType) -> Self:

src/zarr/core/dtype/npy/structured.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
bytes_to_json,
2323
check_json_str,
2424
)
25+
from zarr.core.dtype.npy.subarray import Subarray
2526
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType
2627

2728
if TYPE_CHECKING:
@@ -34,9 +35,11 @@ class StructuredJSON_V2(DTypeConfig_V2[StructuredName_V2, None]):
3435
"""
3536
A wrapper around the JSON representation of the ``Structured`` data type in Zarr V2.
3637
37-
The ``name`` field is a sequence of sequences, where each inner sequence has two values:
38-
the field name and the data type name for that field (which could be another sequence).
39-
The data type names are strings, and the object codec ID is always None.
38+
The ``name`` field is a sequence of sequences, where each inner sequence has 2 or 3 values:
39+
- First value: field name
40+
- Second value: data type name (which could be another sequence for nested structured dtypes)
41+
- Third value (optional): shape of the field (for subarray dtypes)
42+
The object codec ID is always None.
4043
4144
References
4245
----------
@@ -49,7 +52,7 @@ class StructuredJSON_V2(DTypeConfig_V2[StructuredName_V2, None]):
4952
{
5053
"name": [
5154
["f0", "<m8[10s]"],
52-
["f1", "<m8[10s]"],
55+
["f1", "int32", [2, 2]],
5356
],
5457
"object_codec_id": None
5558
}
@@ -252,17 +255,33 @@ def _from_json_v2(cls, data: DTypeJSON) -> Self:
252255
# structured dtypes are constructed directly from a list of lists
253256
# note that we do not handle the object codec here! this will prevent structured
254257
# dtypes from containing object dtypes.
255-
return cls(
256-
fields=tuple( # type: ignore[misc]
257-
( # type: ignore[misc]
258-
f_name,
259-
get_data_type_from_json(
260-
{"name": f_dtype, "object_codec_id": None}, zarr_format=2
261-
),
262-
)
263-
for f_name, f_dtype in data["name"]
258+
fields = []
259+
name = data["name"]
260+
for tpl in name:
261+
f_name = tpl[0]
262+
if not isinstance(f_name, str):
263+
msg = f"Invalid field name. Got {f_name!r}, expected a string."
264+
raise DataTypeValidationError(msg)
265+
266+
f_dtype = tpl[1]
267+
subdtype = get_data_type_from_json(
268+
{"name": f_dtype, "object_codec_id": None}, zarr_format=2
264269
)
265-
)
270+
271+
if len(tpl) == 3:
272+
f_shape = cast("tuple[int]", tuple(tpl[2]))
273+
if not all(isinstance(dim, int) for dim in f_shape):
274+
msg = f"Invalid shape for field {f_name!r}. Got {f_shape!r}, expected a sequence of integers."
275+
raise DataTypeValidationError(msg)
276+
subdtype = Subarray(
277+
subdtype=subdtype,
278+
shape=f_shape,
279+
)
280+
281+
fields.append((f_name, subdtype))
282+
283+
return cls(fields=tuple(fields))
284+
266285
msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected a JSON array of arrays"
267286
raise DataTypeValidationError(msg)
268287

@@ -309,11 +328,23 @@ def to_json(self, zarr_format: ZarrFormat) -> StructuredJSON_V2 | StructuredJSON
309328
If the zarr_format is not 2 or 3.
310329
"""
311330
if zarr_format == 2:
312-
fields = [
313-
[f_name, f_dtype.to_json(zarr_format=zarr_format)["name"]]
314-
for f_name, f_dtype in self.fields
315-
]
316-
return {"name": fields, "object_codec_id": None}
331+
fields = []
332+
for f_name, f_dtype in self.fields:
333+
if isinstance(f_dtype, Subarray):
334+
fields.append(
335+
[
336+
f_name,
337+
f_dtype.subdtype.to_json(zarr_format=zarr_format)["name"],
338+
list(f_dtype.shape),
339+
]
340+
)
341+
else:
342+
fields.append([f_name, f_dtype.to_json(zarr_format=zarr_format)["name"]])
343+
dct = {
344+
"name": fields,
345+
"object_codec_id": None,
346+
}
347+
return cast("StructuredJSON_V2", dct)
317348
elif zarr_format == 3:
318349
v3_unstable_dtype_warning(self)
319350
fields = [
@@ -415,7 +446,6 @@ def default_scalar(self) -> np.void:
415446
The default scalar value, which is the scalar representation of 0
416447
cast to this structured data type.
417448
"""
418-
419449
return self._cast_scalar_unchecked(0)
420450

421451
def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> np.void:

0 commit comments

Comments
 (0)