2
2
3
3
import numpy as np
4
4
import xarray as xr
5
+ from dask .array .core import _check_regular_chunks
5
6
6
7
# TODOs:
7
8
# - api grouping, should the constructors look similar to the DataArray/Dataset constructors
@@ -24,6 +25,9 @@ class DataArraySchema:
24
25
Shape of the DataArray. `None` may be used as a wildcard value. By default None
25
26
dims : Tuple[Union[Hashable, None]], optional
26
27
Dimensions of the DataArray. `None` may be used as a wildcard value. By default None
28
+ chunks : Union[bool, Dict[Hashable, Union[int, None]]], optional
29
+ If bool, specifies whether DataArray is chunked or not, agnostic to chunk sizes.
30
+ If dict, includes the expected chunks for the DataArray, by default None
27
31
name : str, optional
28
32
Name of the DataArray, by default None
29
33
array_type : Any, optional
@@ -37,7 +41,7 @@ def __init__(
37
41
shape : Tuple [Union [int , None ]] = None ,
38
42
dims : Tuple [Union [Hashable , None ]] = None ,
39
43
coords : Dict [Hashable , Any ] = None ,
40
- chunks : Dict [Hashable , Union [int , None ]] = None ,
44
+ chunks : Union [ bool , Dict [Hashable , Union [int , None ] ]] = None ,
41
45
name : str = None ,
42
46
array_type : Any = None ,
43
47
attrs : Dict [Hashable , Any ] = None ,
@@ -72,6 +76,12 @@ def validate(self, da: xr.DataArray) -> xr.DataArray:
72
76
------
73
77
SchemaError
74
78
'''
79
+ if not isinstance (da , xr .DataArray ):
80
+ raise ValueError ('Input must be a xarray.DataArray' )
81
+
82
+ if self .chunks is not None :
83
+ if self .chunks is False and da .chunks :
84
+ raise SchemaError ('Schema expected unchunked DataArray but it is chunked!' )
75
85
76
86
if self .dtype is not None and not np .issubdtype (da .dtype , self .dtype ):
77
87
raise SchemaError (f'dtype { da .dtype } != { self .dtype } ' )
@@ -103,17 +113,30 @@ def validate(self, da: xr.DataArray) -> xr.DataArray:
103
113
raise NotImplementedError ('coords schema not implemented yet' )
104
114
105
115
if self .chunks :
106
- dim_chunks = dict (zip (da .dims , da .chunks ))
107
- for key , ec in self .chunks .items ():
108
- if isinstance (ec , int ):
109
- for ac in dim_chunks [key ][:- 1 ]:
116
+ if self .chunks is True :
117
+ if not da .chunks :
118
+ raise SchemaError ('Schema expected DataArray to be chunked but it is not' )
119
+
120
+ else :
121
+ assert type (self .chunks ) == dict , 'Must pass chunks information as dictionary'
122
+ dim_chunks = dict (zip (da .dims , da .chunks ))
123
+ dim_sizes = dict (zip (da .dims , da .shape ))
124
+ # check whether chunk sizes are regular because we assume the first chunk to be representative below
125
+ assert _check_regular_chunks (da .chunks ), 'Good gracious no! Chunks are not regular!'
126
+ for key , ec in self .chunks .items ():
127
+ if isinstance (ec , int ):
128
+ # handles case of expected chunksize is shorthand of -1 which translates to the full length of dimension
129
+ if ec < 0 :
130
+ ec = dim_sizes [key ]
131
+ # grab the first entry in da's tuple of chunks to be representative (since we've checked above that they're regular)
132
+ ac = dim_chunks [key ][0 ]
110
133
if ac != ec :
111
134
raise SchemaError (f'{ key } chunks did not match: { ac } != { ec } ' )
112
135
113
- else : # assumes ec is an iterable
114
- ac = dim_chunks [key ]
115
- if tuple (ac ) != tuple (ec ):
116
- raise SchemaError (f'{ key } chunks did not match: { ac } != { ec } ' )
136
+ else : # assumes ec is an iterable
137
+ ac = dim_chunks [key ]
138
+ if tuple (ac ) != tuple (ec ):
139
+ raise SchemaError (f'{ key } chunks did not match: { ac } != { ec } ' )
117
140
118
141
if self .attrs :
119
142
raise NotImplementedError ('attrs schema not implemented yet' )
0 commit comments