Skip to content

Commit a46b52a

Browse files
committed
Use standard name mapper in more places.
1 parent 450f6a4 commit a46b52a

File tree

3 files changed

+59
-21
lines changed

3 files changed

+59
-21
lines changed

cf_xarray/accessor.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -201,21 +201,23 @@ def _apply_single_mapper(mapper):
201201
for mapper in mappers:
202202
results.append(_apply_single_mapper(mapper))
203203

204-
nresults = sum([bool(v) for v in results])
205-
if nresults > 1:
206-
raise KeyError(
207-
f"Multiple mappers succeeded with key {key!r}.\nI was using mappers: {mappers!r}."
208-
f"I received results: {results!r}.\nPlease open an issue."
209-
)
210-
if nresults == 0:
204+
flat = list(itertools.chain(*results))
205+
# de-duplicate
206+
if all(not isinstance(r, DataArray) for r in flat):
207+
results = list(set(flat))
208+
else:
209+
results = flat
210+
211+
nresults = any([bool(v) for v in [results]])
212+
if not nresults:
211213
if error:
212214
raise KeyError(
213215
f"cf-xarray cannot interpret key {key!r}. Perhaps some needed attributes are missing."
214216
)
215217
else:
216218
# none of the mappers worked. Return the default
217219
return default
218-
return list(itertools.chain(*results))
220+
return results
219221

220222

221223
def _get_axis_coord_single(var: Union[DataArray, Dataset], key: str) -> List[str]:
@@ -370,6 +372,21 @@ def _get_measure(obj: Union[DataArray, Dataset], key: str) -> List[str]:
370372
return list(results)
371373

372374

375+
def _get_with_standard_name(
376+
obj: Union[DataArray, Dataset], name: Union[str, List[str]]
377+
) -> List[str]:
378+
""" returns a list of variable names with standard name == name. """
379+
varnames = []
380+
if isinstance(obj, DataArray):
381+
obj = obj._to_temp_dataset()
382+
for vname, var in obj.variables.items():
383+
stdname = var.attrs.get("standard_name", None)
384+
if stdname == name:
385+
varnames.append(str(vname))
386+
387+
return varnames
388+
389+
373390
#: Default mappers for common keys.
374391
_DEFAULT_KEY_MAPPERS: Mapping[str, Tuple[Mapper, ...]] = {
375392
"dim": (_get_axis_coord,),
@@ -385,24 +402,18 @@ def _get_measure(obj: Union[DataArray, Dataset], key: str) -> List[str]:
385402
"dims_or_levels": (_get_axis_coord,), # reset_index
386403
"window": (_get_axis_coord,), # rolling_exp
387404
"coord": (_get_axis_coord_single,), # differentiate, integrate
388-
"group": (_get_axis_coord_single, _get_axis_coord_time_accessor),
405+
"group": (
406+
_get_axis_coord_single,
407+
_get_axis_coord_time_accessor,
408+
_get_with_standard_name,
409+
),
389410
"indexer": (_get_axis_coord_single,), # resample
390-
"variables": (_get_axis_coord,), # sortby
411+
"variables": (_get_axis_coord, _get_with_standard_name), # sortby
391412
"weights": (_get_measure_variable,), # type: ignore
413+
"chunks": (_get_axis_coord,), # chunk
392414
}
393415

394416

395-
def _get_with_standard_name(ds: Dataset, name: Union[str, List[str]]) -> List[str]:
396-
""" returns a list of variable names with standard name == name. """
397-
varnames = []
398-
for vname, var in ds.variables.items():
399-
stdname = var.attrs.get("standard_name", None)
400-
if stdname == name:
401-
varnames.append(str(vname))
402-
403-
return varnames
404-
405-
406417
def _guess_bounds_dim(da):
407418
"""
408419
Guess bounds values given a 1D coordinate variable.

cf_xarray/tests/test_accessor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,3 +610,25 @@ def test_param_vcoord_ocean_s_coord():
610610
copy.s_rho.attrs["formula_terms"] = "s: s_rho C: Cs_r depth: h depth_c: hc"
611611
with pytest.raises(KeyError):
612612
copy.cf.decode_vertical_coords()
613+
614+
615+
def test_standard_name_mapper():
616+
da = xr.DataArray(
617+
np.arange(6),
618+
dims="time",
619+
coords={
620+
"label": (
621+
"time",
622+
["A", "B", "B", "A", "B", "C"],
623+
{"standard_name": "standard_label"},
624+
)
625+
},
626+
)
627+
628+
actual = da.cf.groupby("standard_label").mean()
629+
expected = da.cf.groupby("label").mean()
630+
assert_identical(actual, expected)
631+
632+
actual = da.cf.sortby("standard_label")
633+
expected = da.sortby("label")
634+
assert_identical(actual, expected)

doc/whats-new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
What's New
44
----------
55

6+
v0.4.1 (unreleased)
7+
===================
8+
9+
- Support for using ``standard_name`` in more functions. By `Deepak Cherian`_
10+
611
v0.4.0 (Jan 22, 2021)
712
=====================
813
- Support for arbitrary cell measures indexing. By `Mattia Almansi`_.

0 commit comments

Comments
 (0)