@@ -230,6 +230,7 @@ def _get_custom_criteria(
230
230
231
231
if isinstance (obj , DataArray ):
232
232
obj = obj ._to_temp_dataset ()
233
+ variables = obj ._variables
233
234
234
235
if criteria is None :
235
236
if not OPTIONS ["custom_criteria" ]:
@@ -243,8 +244,8 @@ def _get_custom_criteria(
243
244
results : set = set ()
244
245
if key in criteria_map :
245
246
for criterion , patterns in criteria_map [key ].items ():
246
- for var in obj . variables :
247
- if regex_match (patterns , obj [var ].attrs .get (criterion , "" )):
247
+ for var in variables :
248
+ if regex_match (patterns , variables [var ].attrs .get (criterion , "" )):
248
249
results .update ((var ,))
249
250
# also check name specifically since not in attributes
250
251
elif (
@@ -290,6 +291,8 @@ def _get_axis_coord(obj: DataArray | Dataset, key: str) -> list[str]:
290
291
f"cf_xarray did not understand key { key !r} . Expected one of { valid_keys !r} "
291
292
)
292
293
294
+ crds = obj .coords
295
+ crd_names = set (crds )
293
296
search_in = set ()
294
297
attrs_or_encoding = ChainMap (obj .attrs , obj .encoding )
295
298
coordinates = attrs_or_encoding .get ("coordinates" , None )
@@ -298,15 +301,16 @@ def _get_axis_coord(obj: DataArray | Dataset, key: str) -> list[str]:
298
301
if coordinates :
299
302
search_in .update (coordinates .split (" " ))
300
303
if not search_in :
301
- search_in = set ( obj . coords )
304
+ search_in = crd_names
302
305
303
306
# maybe only do this for key in _AXIS_NAMES?
304
- search_in .update (obj .indexes )
307
+ if obj ._indexes :
308
+ search_in .update (obj ._indexes )
305
309
306
- search_in = search_in & set ( obj . coords )
310
+ search_in = search_in & crd_names
307
311
results : set = set ()
308
312
for coord in search_in :
309
- var = obj . coords [coord ]
313
+ var = crds [coord ]
310
314
if key in coordinate_criteria :
311
315
for criterion , expected in coordinate_criteria [key ].items ():
312
316
if var .attrs .get (criterion , None ) in expected :
@@ -345,9 +349,8 @@ def _get_measure(obj: DataArray | Dataset, key: str) -> list[str]:
345
349
obj = obj ._to_temp_dataset ()
346
350
347
351
results = set ()
348
- for var in obj .variables :
349
- da = obj [var ]
350
- attrs_or_encoding = ChainMap (da .attrs , da .encoding )
352
+ for var in obj ._variables .values ():
353
+ attrs_or_encoding = ChainMap (var .attrs , var .encoding )
351
354
if "cell_measures" in attrs_or_encoding :
352
355
attr = attrs_or_encoding ["cell_measures" ]
353
356
try :
@@ -381,9 +384,13 @@ def _get_bounds(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
381
384
List[str], Variable name(s) in parent xarray object that are bounds of `key`
382
385
"""
383
386
387
+ if isinstance (obj , DataArray ):
388
+ obj = obj ._to_temp_dataset ()
389
+ variables = obj ._variables
390
+
384
391
results = set ()
385
392
for var in apply_mapper (_get_all , obj , key , error = False , default = [key ]):
386
- attrs_or_encoding = ChainMap (obj [var ].attrs , obj [var ].encoding )
393
+ attrs_or_encoding = ChainMap (variables [var ].attrs , variables [var ].encoding )
387
394
if "bounds" in attrs_or_encoding :
388
395
results |= {attrs_or_encoding ["bounds" ]}
389
396
@@ -410,17 +417,17 @@ def _get_grid_mapping_name(obj: DataArray | Dataset, key: str) -> list[str]:
410
417
if isinstance (obj , DataArray ):
411
418
obj = obj ._to_temp_dataset ()
412
419
420
+ variables = obj ._variables
413
421
results = set ()
414
- for var in obj .variables :
415
- da = obj [var ]
416
- attrs_or_encoding = ChainMap (da .attrs , da .encoding )
422
+ for var in variables .values ():
423
+ attrs_or_encoding = ChainMap (var .attrs , var .encoding )
417
424
if "grid_mapping" in attrs_or_encoding :
418
425
grid_mapping_var_name = attrs_or_encoding ["grid_mapping" ]
419
- if grid_mapping_var_name not in obj . variables :
426
+ if grid_mapping_var_name not in variables :
420
427
raise ValueError (
421
428
f"{ var } defines non-existing grid_mapping variable { grid_mapping_var_name } ."
422
429
)
423
- if key == obj [grid_mapping_var_name ].attrs ["grid_mapping_name" ]:
430
+ if key == variables [grid_mapping_var_name ].attrs ["grid_mapping_name" ]:
424
431
results .update ([grid_mapping_var_name ])
425
432
return list (results )
426
433
@@ -474,7 +481,7 @@ def _get_indexes(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
474
481
One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time',
475
482
'area', 'volume'), or arbitrary measures, or standard names present in .indexes
476
483
"""
477
- return [k for k in _get_all (obj , key ) if k in obj .indexes ]
484
+ return [k for k in _get_all (obj , key ) if k in obj ._indexes ]
478
485
479
486
480
487
def _get_coords (obj : DataArray | Dataset , key : Hashable ) -> list [Hashable ]:
@@ -2251,7 +2258,7 @@ def get_bounds(self, key: Hashable) -> DataArray | Dataset:
2251
2258
DataArray
2252
2259
"""
2253
2260
2254
- results = self .bounds .get (key , [])
2261
+ results = self [[ key ]]. cf .bounds .get (key , [])
2255
2262
if not results :
2256
2263
raise KeyError (f"No results found for { key !r} ." )
2257
2264
@@ -2270,12 +2277,18 @@ def get_bounds_dim_name(self, key: Hashable) -> Hashable:
2270
2277
-------
2271
2278
str
2272
2279
"""
2273
- crd = self [key ]
2274
- bounds = self .get_bounds (key )
2280
+ (crd_name ,) = apply_mapper (_get_all , self ._obj , key , error = False , default = [key ])
2281
+ variables = self ._obj ._variables
2282
+ crd = variables [crd_name ]
2283
+ crd_attrs = crd ._attrs
2284
+ if crd_attrs is None or "bounds" not in crd_attrs :
2285
+ raise KeyError (f"No bounds variable found for { key !r} " )
2286
+
2287
+ bounds = variables [crd_attrs ["bounds" ].strip ()]
2275
2288
bounds_dims = set (bounds .dims ) - set (crd .dims )
2276
2289
assert len (bounds_dims ) == 1
2277
2290
bounds_dim = bounds_dims .pop ()
2278
- assert self . _obj .sizes [bounds_dim ] in [2 , 4 ]
2291
+ assert bounds .sizes [bounds_dim ] in [2 , 4 ]
2279
2292
return bounds_dim
2280
2293
2281
2294
def add_bounds (
0 commit comments