@@ -101,25 +101,34 @@ def assert_equal(a, b, tolerance=None):
101
101
else :
102
102
tolerance = {}
103
103
104
- if has_dask and isinstance (a , dask_array_type ) or isinstance (b , dask_array_type ):
105
- # sometimes it's nice to see values and shapes
106
- # rather than being dropped into some file in dask
107
- np .testing .assert_allclose (a , b , ** tolerance )
108
- # does some validation of the dask graph
109
- da .utils .assert_eq (a , b , equal_nan = True )
104
+ # Always run the numpy comparison first, so that we get nice error messages with dask.
105
+ # sometimes it's nice to see values and shapes
106
+ # rather than being dropped into some file in dask
107
+ if a .dtype != b .dtype :
108
+ raise AssertionError (f"a and b have different dtypes: (a: { a .dtype } , b: { b .dtype } )" )
109
+
110
+ if has_dask :
111
+ a_eager = a .compute () if isinstance (a , dask_array_type ) else a
112
+ b_eager = b .compute () if isinstance (b , dask_array_type ) else b
110
113
else :
111
- if a .dtype != b .dtype :
112
- raise AssertionError (f"a and b have different dtypes: (a: { a .dtype } , b: { b .dtype } )" )
114
+ a_eager , b_eager = a , b
113
115
114
- np .testing .assert_allclose (a , b , equal_nan = True , ** tolerance )
116
+ if a .dtype .kind in "SUMm" :
117
+ np .testing .assert_equal (a_eager , b_eager )
118
+ else :
119
+ np .testing .assert_allclose (a_eager , b_eager , equal_nan = True , ** tolerance )
120
+
121
+ if has_dask and isinstance (a , dask_array_type ) or isinstance (b , dask_array_type ):
122
+ # does some validation of the dask graph
123
+ dask_assert_eq (a , b , equal_nan = True )
115
124
116
125
117
126
def assert_equal_tuple (a , b ):
118
127
"""assert_equal for .blocks indexing tuples"""
119
128
assert len (a ) == len (b )
120
129
121
130
for a_ , b_ in zip (a , b ):
122
- assert type (a_ ) == type (b_ )
131
+ assert type (a_ ) is type (b_ )
123
132
if isinstance (a_ , np .ndarray ):
124
133
np .testing .assert_array_equal (a_ , b_ )
125
134
else :
@@ -156,3 +165,91 @@ def assert_equal_tuple(a, b):
156
165
"quantile" ,
157
166
"nanquantile" ,
158
167
) + tuple (SCIPY_STATS_FUNCS )
168
+
169
+
170
+ def dask_assert_eq (
171
+ a ,
172
+ b ,
173
+ check_shape = True ,
174
+ check_graph = True ,
175
+ check_meta = True ,
176
+ check_chunks = True ,
177
+ check_ndim = True ,
178
+ check_type = True ,
179
+ check_dtype = True ,
180
+ equal_nan = True ,
181
+ scheduler = "sync" ,
182
+ ** kwargs ,
183
+ ):
184
+ """dask.array.utils.assert_eq modified to skip value checks. Their code is buggy for some dtypes.
185
+ We just check values through numpy and care about validating the graph in this function."""
186
+ from dask .array .utils import _get_dt_meta_computed
187
+
188
+ a_original = a
189
+ b_original = b
190
+
191
+ if isinstance (a , (list , int , float )):
192
+ a = np .array (a )
193
+ if isinstance (b , (list , int , float )):
194
+ b = np .array (b )
195
+
196
+ a , adt , a_meta , a_computed = _get_dt_meta_computed (
197
+ a ,
198
+ check_shape = check_shape ,
199
+ check_graph = check_graph ,
200
+ check_chunks = check_chunks ,
201
+ check_ndim = check_ndim ,
202
+ scheduler = scheduler ,
203
+ )
204
+ b , bdt , b_meta , b_computed = _get_dt_meta_computed (
205
+ b ,
206
+ check_shape = check_shape ,
207
+ check_graph = check_graph ,
208
+ check_chunks = check_chunks ,
209
+ check_ndim = check_ndim ,
210
+ scheduler = scheduler ,
211
+ )
212
+
213
+ if check_type :
214
+ _a = a if a .shape else a .item ()
215
+ _b = b if b .shape else b .item ()
216
+ assert type (_a ) is type (_b ), f"a and b have different types (a: { type (_a )} , b: { type (_b )} )"
217
+ if check_meta :
218
+ if hasattr (a , "_meta" ) and hasattr (b , "_meta" ):
219
+ dask_assert_eq (a ._meta , b ._meta )
220
+ if hasattr (a_original , "_meta" ):
221
+ msg = (
222
+ f"compute()-ing 'a' changes its number of dimensions "
223
+ f"(before: { a_original ._meta .ndim } , after: { a .ndim } )"
224
+ )
225
+ assert a_original ._meta .ndim == a .ndim , msg
226
+ if a_meta is not None :
227
+ msg = (
228
+ f"compute()-ing 'a' changes its type "
229
+ f"(before: { type (a_original ._meta )} , after: { type (a_meta )} )"
230
+ )
231
+ assert type (a_original ._meta ) is type (a_meta ), msg
232
+ if not (np .isscalar (a_meta ) or np .isscalar (a_computed )):
233
+ msg = (
234
+ f"compute()-ing 'a' results in a different type than implied by its metadata "
235
+ f"(meta: { type (a_meta )} , computed: { type (a_computed )} )"
236
+ )
237
+ assert type (a_meta ) is type (a_computed ), msg
238
+ if hasattr (b_original , "_meta" ):
239
+ msg = (
240
+ f"compute()-ing 'b' changes its number of dimensions "
241
+ f"(before: { b_original ._meta .ndim } , after: { b .ndim } )"
242
+ )
243
+ assert b_original ._meta .ndim == b .ndim , msg
244
+ if b_meta is not None :
245
+ msg = (
246
+ f"compute()-ing 'b' changes its type "
247
+ f"(before: { type (b_original ._meta )} , after: { type (b_meta )} )"
248
+ )
249
+ assert type (b_original ._meta ) is type (b_meta ), msg
250
+ if not (np .isscalar (b_meta ) or np .isscalar (b_computed )):
251
+ msg = (
252
+ f"compute()-ing 'b' results in a different type than implied by its metadata "
253
+ f"(meta: { type (b_meta )} , computed: { type (b_computed )} )"
254
+ )
255
+ assert type (b_meta ) is type (b_computed ), msg
0 commit comments