Skip to content

Commit fdb9e17

Browse files
authored
Major speedups in mappers. (#427)
* Major speedups in mappers. * fix test
1 parent 3885b44 commit fdb9e17

File tree

1 file changed

+33
-20
lines changed

1 file changed

+33
-20
lines changed

cf_xarray/accessor.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def _get_custom_criteria(
230230

231231
if isinstance(obj, DataArray):
232232
obj = obj._to_temp_dataset()
233+
variables = obj._variables
233234

234235
if criteria is None:
235236
if not OPTIONS["custom_criteria"]:
@@ -243,8 +244,8 @@ def _get_custom_criteria(
243244
results: set = set()
244245
if key in criteria_map:
245246
for criterion, patterns in criteria_map[key].items():
246-
for var in obj.variables:
247-
if regex_match(patterns, obj[var].attrs.get(criterion, "")):
247+
for var in variables:
248+
if regex_match(patterns, variables[var].attrs.get(criterion, "")):
248249
results.update((var,))
249250
# also check name specifically since not in attributes
250251
elif (
@@ -290,6 +291,8 @@ def _get_axis_coord(obj: DataArray | Dataset, key: str) -> list[str]:
290291
f"cf_xarray did not understand key {key!r}. Expected one of {valid_keys!r}"
291292
)
292293

294+
crds = obj.coords
295+
crd_names = set(crds)
293296
search_in = set()
294297
attrs_or_encoding = ChainMap(obj.attrs, obj.encoding)
295298
coordinates = attrs_or_encoding.get("coordinates", None)
@@ -298,15 +301,16 @@ def _get_axis_coord(obj: DataArray | Dataset, key: str) -> list[str]:
298301
if coordinates:
299302
search_in.update(coordinates.split(" "))
300303
if not search_in:
301-
search_in = set(obj.coords)
304+
search_in = crd_names
302305

303306
# maybe only do this for key in _AXIS_NAMES?
304-
search_in.update(obj.indexes)
307+
if obj._indexes:
308+
search_in.update(obj._indexes)
305309

306-
search_in = search_in & set(obj.coords)
310+
search_in = search_in & crd_names
307311
results: set = set()
308312
for coord in search_in:
309-
var = obj.coords[coord]
313+
var = crds[coord]
310314
if key in coordinate_criteria:
311315
for criterion, expected in coordinate_criteria[key].items():
312316
if var.attrs.get(criterion, None) in expected:
@@ -345,9 +349,8 @@ def _get_measure(obj: DataArray | Dataset, key: str) -> list[str]:
345349
obj = obj._to_temp_dataset()
346350

347351
results = set()
348-
for var in obj.variables:
349-
da = obj[var]
350-
attrs_or_encoding = ChainMap(da.attrs, da.encoding)
352+
for var in obj._variables.values():
353+
attrs_or_encoding = ChainMap(var.attrs, var.encoding)
351354
if "cell_measures" in attrs_or_encoding:
352355
attr = attrs_or_encoding["cell_measures"]
353356
try:
@@ -381,9 +384,13 @@ def _get_bounds(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
381384
List[str], Variable name(s) in parent xarray object that are bounds of `key`
382385
"""
383386

387+
if isinstance(obj, DataArray):
388+
obj = obj._to_temp_dataset()
389+
variables = obj._variables
390+
384391
results = set()
385392
for var in apply_mapper(_get_all, obj, key, error=False, default=[key]):
386-
attrs_or_encoding = ChainMap(obj[var].attrs, obj[var].encoding)
393+
attrs_or_encoding = ChainMap(variables[var].attrs, variables[var].encoding)
387394
if "bounds" in attrs_or_encoding:
388395
results |= {attrs_or_encoding["bounds"]}
389396

@@ -410,17 +417,17 @@ def _get_grid_mapping_name(obj: DataArray | Dataset, key: str) -> list[str]:
410417
if isinstance(obj, DataArray):
411418
obj = obj._to_temp_dataset()
412419

420+
variables = obj._variables
413421
results = set()
414-
for var in obj.variables:
415-
da = obj[var]
416-
attrs_or_encoding = ChainMap(da.attrs, da.encoding)
422+
for var in variables.values():
423+
attrs_or_encoding = ChainMap(var.attrs, var.encoding)
417424
if "grid_mapping" in attrs_or_encoding:
418425
grid_mapping_var_name = attrs_or_encoding["grid_mapping"]
419-
if grid_mapping_var_name not in obj.variables:
426+
if grid_mapping_var_name not in variables:
420427
raise ValueError(
421428
f"{var} defines non-existing grid_mapping variable {grid_mapping_var_name}."
422429
)
423-
if key == obj[grid_mapping_var_name].attrs["grid_mapping_name"]:
430+
if key == variables[grid_mapping_var_name].attrs["grid_mapping_name"]:
424431
results.update([grid_mapping_var_name])
425432
return list(results)
426433

@@ -474,7 +481,7 @@ def _get_indexes(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
474481
One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time',
475482
'area', 'volume'), or arbitrary measures, or standard names present in .indexes
476483
"""
477-
return [k for k in _get_all(obj, key) if k in obj.indexes]
484+
return [k for k in _get_all(obj, key) if k in obj._indexes]
478485

479486

480487
def _get_coords(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
@@ -2251,7 +2258,7 @@ def get_bounds(self, key: Hashable) -> DataArray | Dataset:
22512258
DataArray
22522259
"""
22532260

2254-
results = self.bounds.get(key, [])
2261+
results = self[[key]].cf.bounds.get(key, [])
22552262
if not results:
22562263
raise KeyError(f"No results found for {key!r}.")
22572264

@@ -2270,12 +2277,18 @@ def get_bounds_dim_name(self, key: Hashable) -> Hashable:
22702277
-------
22712278
str
22722279
"""
2273-
crd = self[key]
2274-
bounds = self.get_bounds(key)
2280+
(crd_name,) = apply_mapper(_get_all, self._obj, key, error=False, default=[key])
2281+
variables = self._obj._variables
2282+
crd = variables[crd_name]
2283+
crd_attrs = crd._attrs
2284+
if crd_attrs is None or "bounds" not in crd_attrs:
2285+
raise KeyError(f"No bounds variable found for {key!r}")
2286+
2287+
bounds = variables[crd_attrs["bounds"].strip()]
22752288
bounds_dims = set(bounds.dims) - set(crd.dims)
22762289
assert len(bounds_dims) == 1
22772290
bounds_dim = bounds_dims.pop()
2278-
assert self._obj.sizes[bounds_dim] in [2, 4]
2291+
assert bounds.sizes[bounds_dim] in [2, 4]
22792292
return bounds_dim
22802293

22812294
def add_bounds(

0 commit comments

Comments
 (0)