3
3
import numpy as np
4
4
import xarray as xr
5
5
from dask .array .core import _check_regular_chunks
6
+
6
7
# TODOs:
7
8
# - api grouping, should the constructors look similar to the DataArray/Dataset constructors
8
9
@@ -24,6 +25,9 @@ class DataArraySchema:
24
25
Shape of the DataArray. `None` may be used as a wildcard value. By default None
25
26
dims : Tuple[Union[Hashable, None]], optional
26
27
Dimensions of the DataArray. `None` may be used as a wildcard value. By default None
28
+ chunks : Union[bool, Dict[Hashable, Union[int, None]]], optional
29
+ If bool, specifies whether DataArray is chunked or not, agnostic to chunk sizes.
30
+ If dict, includes the expected chunks for the DataArray, by default None
27
31
name : str, optional
28
32
Name of the DataArray, by default None
29
33
array_type : Any, optional
@@ -37,7 +41,7 @@ def __init__(
37
41
shape : Tuple [Union [int , None ]] = None ,
38
42
dims : Tuple [Union [Hashable , None ]] = None ,
39
43
coords : Dict [Hashable , Any ] = None ,
40
- chunks : Dict [Hashable , Union [int , None ]] = None ,
44
+ chunks : Union [ bool , Dict [Hashable , Union [int , None ] ]] = None ,
41
45
name : str = None ,
42
46
array_type : Any = None ,
43
47
attrs : Dict [Hashable , Any ] = None ,
@@ -74,9 +78,10 @@ def validate(self, da: xr.DataArray) -> xr.DataArray:
74
78
'''
75
79
if not isinstance (da , xr .DataArray ):
76
80
raise ValueError ('Input must be a xarray.DataArray' )
77
-
78
- if da .chunks :
79
- assert _check_regular_chunks (da .chunks ), 'Good gracious no! Chunks are not regular!'
81
+
82
+ if self .chunks is not None :
83
+ if self .chunks is False and da .chunks :
84
+ raise SchemaError ('Schema expected unchunked DataArray but it is chunked!' )
80
85
81
86
if self .dtype is not None and not np .issubdtype (da .dtype , self .dtype ):
82
87
raise SchemaError (f'dtype { da .dtype } != { self .dtype } ' )
@@ -108,22 +113,30 @@ def validate(self, da: xr.DataArray) -> xr.DataArray:
108
113
raise NotImplementedError ('coords schema not implemented yet' )
109
114
110
115
if self .chunks :
111
- dim_chunks = dict (zip (da .dims , da .chunks ))
112
- dim_sizes = dict (zip (da .dims , da .shape ))
113
- for key , ec in self .chunks .items ():
114
- if isinstance (ec , int ):
115
- # handles case of expected chunksize is shorthand of -1 which translates to the full length of dimension
116
- if ec < 0 :
117
- ec = dim_sizes [key ]
118
- # grab the first entry in da's tuple of chunks to be representative (as it should be assuming they're regular)
119
- ac = dim_chunks [key ][0 ]
120
- if ac != ec :
121
- raise SchemaError (f'{ key } chunks did not match: { ac } != { ec } ' )
122
-
123
- else : # assumes ec is an iterable
124
- ac = dim_chunks [key ]
125
- if tuple (ac ) != tuple (ec ):
126
- raise SchemaError (f'{ key } chunks did not match: { ac } != { ec } ' )
116
+ if self .chunks is True :
117
+ if not da .chunks :
118
+ raise SchemaError ('Schema expected DataArray to be chunked but it is not' )
119
+
120
+ else :
121
+ assert type (self .chunks ) == dict , 'Must pass chunks information as dictionary'
122
+ dim_chunks = dict (zip (da .dims , da .chunks ))
123
+ dim_sizes = dict (zip (da .dims , da .shape ))
124
+ # check whether chunk sizes are regular because we assume the first chunk to be representative below
125
+ assert _check_regular_chunks (da .chunks ), 'Good gracious no! Chunks are not regular!'
126
+ for key , ec in self .chunks .items ():
127
+ if isinstance (ec , int ):
128
+ # handles case of expected chunksize is shorthand of -1 which translates to the full length of dimension
129
+ if ec < 0 :
130
+ ec = dim_sizes [key ]
131
+ # grab the first entry in da's tuple of chunks to be representative (since we've checked above that they're regular)
132
+ ac = dim_chunks [key ][0 ]
133
+ if ac != ec :
134
+ raise SchemaError (f'{ key } chunks did not match: { ac } != { ec } ' )
135
+
136
+ else : # assumes ec is an iterable
137
+ ac = dim_chunks [key ]
138
+ if tuple (ac ) != tuple (ec ):
139
+ raise SchemaError (f'{ key } chunks did not match: { ac } != { ec } ' )
127
140
128
141
if self .attrs :
129
142
raise NotImplementedError ('attrs schema not implemented yet' )
@@ -134,7 +147,7 @@ def validate(self, da: xr.DataArray) -> xr.DataArray:
134
147
if self .checks :
135
148
for check in self .checks :
136
149
da = check (da )
137
-
150
+
138
151
return da
139
152
140
153
0 commit comments