Skip to content

Commit d73d9f5

Browse files
authored
Add standard name mapper in more places. (#151)
1 parent 9d32c18 commit d73d9f5

File tree

3 files changed

+131
-21
lines changed

3 files changed

+131
-21
lines changed

cf_xarray/accessor.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,7 @@ def _get_axis_coord_single(var: Union[DataArray, Dataset], key: str) -> List[str
213213
return results
214214

215215

216-
def _get_axis_coord_time_accessor(
217-
var: Union[DataArray, Dataset], key: str
218-
) -> List[str]:
216+
def _get_groupby_time_accessor(var: Union[DataArray, Dataset], key: str) -> List[str]:
219217
"""
220218
Helper method for when our key name is of the nature "T.month" and we want to
221219
isolate the "T" for coordinate mapping
@@ -238,7 +236,11 @@ def _get_axis_coord_time_accessor(
238236
if "." in key:
239237
key, ext = key.split(".", 1)
240238

241-
results = _get_axis_coord_single(var, key)
239+
results = apply_mapper(
240+
(_get_axis_coord, _get_with_standard_name), var, key, error=False
241+
)
242+
if len(results) > 1:
243+
raise KeyError(f"Multiple results received for {key}.")
242244
return [v + "." + ext for v in results]
243245

244246
else:
@@ -370,34 +372,34 @@ def _get_with_standard_name(
370372

371373
#: Default mappers for common keys.
372374
_DEFAULT_KEY_MAPPERS: Mapping[str, Tuple[Mapper, ...]] = {
373-
"dim": (_get_axis_coord,),
374-
"dims": (_get_axis_coord,), # transpose
375-
"drop_dims": (_get_axis_coord,), # drop_dims
376-
"dimensions": (_get_axis_coord,), # stack
377-
"dims_dict": (_get_axis_coord,), # swap_dims, rename_dims
378-
"shifts": (_get_axis_coord,), # shift, roll
379-
"pad_width": (_get_axis_coord,), # shift, roll
375+
"dim": (_get_axis_coord, _get_with_standard_name),
376+
"dims": (_get_axis_coord, _get_with_standard_name), # transpose
377+
"drop_dims": (_get_axis_coord, _get_with_standard_name), # drop_dims
378+
"dimensions": (_get_axis_coord, _get_with_standard_name), # stack
379+
"dims_dict": (_get_axis_coord, _get_with_standard_name), # swap_dims, rename_dims
380+
"shifts": (_get_axis_coord, _get_with_standard_name), # shift, roll
381+
"pad_width": (_get_axis_coord, _get_with_standard_name), # shift, roll
380382
"names": (
381383
_get_axis_coord,
382384
_get_measure,
383385
_get_with_standard_name,
384386
), # set_coords, reset_coords, drop_vars
385387
"labels": (_get_axis_coord, _get_measure, _get_with_standard_name), # drop
386-
"coords": (_get_axis_coord,), # interp
387-
"indexers": (_get_axis_coord,), # sel, isel, reindex
388+
"coords": (_get_axis_coord, _get_with_standard_name), # interp
389+
"indexers": (_get_axis_coord, _get_with_standard_name), # sel, isel, reindex
388390
# "indexes": (_get_axis_coord,), # set_index
389-
"dims_or_levels": (_get_axis_coord,), # reset_index
390-
"window": (_get_axis_coord,), # rolling_exp
391+
"dims_or_levels": (_get_axis_coord, _get_with_standard_name), # reset_index
392+
"window": (_get_axis_coord, _get_with_standard_name), # rolling_exp
391393
"coord": (_get_axis_coord_single,), # differentiate, integrate
392394
"group": (
393395
_get_axis_coord_single,
394-
_get_axis_coord_time_accessor,
396+
_get_groupby_time_accessor,
395397
_get_with_standard_name,
396398
),
397399
"indexer": (_get_axis_coord_single,), # resample
398400
"variables": (_get_axis_coord, _get_with_standard_name), # sortby
399401
"weights": (_get_measure_variable,), # type: ignore
400-
"chunks": (_get_axis_coord,), # chunk
402+
"chunks": (_get_axis_coord, _get_with_standard_name), # chunk
401403
}
402404

403405

@@ -430,7 +432,7 @@ def _build_docstring(func):
430432
mapper_docstrings = {
431433
_get_axis_coord: f"One or more of {(_AXIS_NAMES + _COORD_NAMES)!r}",
432434
_get_axis_coord_single: f"One of {(_AXIS_NAMES + _COORD_NAMES)!r}",
433-
_get_axis_coord_time_accessor: "Time variable accessor e.g. 'T.month'",
435+
_get_groupby_time_accessor: "Time variable accessor e.g. 'T.month'",
434436
_get_with_standard_name: "Standard names",
435437
_get_measure_variable: f"One of {_CELL_MEASURES!r}",
436438
}
@@ -900,7 +902,10 @@ def _rewrite_values(
900902

901903
# allow multiple return values here.
902904
# these are valid for .sel, .isel, .coarsen
903-
all_mappers = ChainMap(key_mappers, dict.fromkeys(var_kws, (_get_axis_coord,)))
905+
all_mappers = ChainMap(
906+
key_mappers,
907+
dict.fromkeys(var_kws, (_get_axis_coord, _get_with_standard_name)),
908+
)
904909

905910
for key in set(all_mappers) & set(kwargs):
906911
value = kwargs[key]

cf_xarray/datasets.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,95 @@
188188
lat_vertices=xr.DataArray(lat_vertices, dims=("x_vertices", "y_vertices")),
189189
),
190190
)
191+
192+
forecast = xr.decode_cf(
193+
xr.Dataset.from_dict(
194+
{
195+
"coords": {
196+
"L": {
197+
"dims": ("L",),
198+
"attrs": {
199+
"long_name": "Lead",
200+
"standard_name": "forecast_period",
201+
"pointwidth": 1.0,
202+
"gridtype": 0,
203+
"units": "months",
204+
},
205+
"data": [0, 1],
206+
},
207+
"M": {
208+
"dims": ("M",),
209+
"attrs": {
210+
"standard_name": "realization",
211+
"long_name": "Ensemble Member",
212+
"pointwidth": 1.0,
213+
"gridtype": 0,
214+
"units": "unitless",
215+
},
216+
"data": [0, 1, 2],
217+
},
218+
"S": {
219+
"dims": ("S",),
220+
"attrs": {
221+
"calendar": "360_day",
222+
"long_name": "Forecast Start Time",
223+
"standard_name": "forecast_reference_time",
224+
"pointwidth": 0,
225+
"gridtype": 0,
226+
"units": "months since 1960-01-01",
227+
},
228+
"data": [0, 1, 2, 3],
229+
},
230+
"X": {
231+
"dims": ("X",),
232+
"attrs": {
233+
"standard_name": "longitude",
234+
"pointwidth": 1.0,
235+
"gridtype": 1,
236+
"units": "degree_east",
237+
},
238+
"data": [0, 1, 2, 3, 4],
239+
},
240+
"Y": {
241+
"dims": ("Y",),
242+
"attrs": {
243+
"standard_name": "latitude",
244+
"pointwidth": 1.0,
245+
"gridtype": 0,
246+
"units": "degree_north",
247+
},
248+
"data": [0, 1, 2, 3, 4, 5],
249+
},
250+
},
251+
"attrs": {"Conventions": "IRIDL"},
252+
"dims": {"L": 2, "M": 3, "S": 4, "X": 5, "Y": 6},
253+
"data_vars": {
254+
"sst": {
255+
"dims": ("S", "L", "M", "Y", "X"),
256+
"attrs": {
257+
"pointwidth": 0,
258+
"PDS_TimeRange": 3,
259+
"center": "US Weather Service - National Met. Center",
260+
"grib_name": "TMP",
261+
"gribNumBits": 21,
262+
"gribcenter": 7,
263+
"gribparam": 11,
264+
"gribleveltype": 1,
265+
"GRIBgridcode": 3,
266+
"process": 'Spectral Statistical Interpolation (SSI) analysis from "Final" run.',
267+
"PTVersion": 2,
268+
"gribfield": 1,
269+
"units": "Celsius_scale",
270+
"scale_min": -69.97389221191406,
271+
"scale_max": 43.039306640625,
272+
"long_name": "Sea Surface Temperature",
273+
"standard_name": "sea_surface_temperature",
274+
},
275+
"data": np.arange(np.prod((4, 2, 3, 6, 5))).reshape(
276+
(4, 2, 3, 6, 5)
277+
),
278+
}
279+
},
280+
}
281+
)
282+
)

cf_xarray/tests/test_accessor.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import cf_xarray # noqa
1111

12-
from ..datasets import airds, anc, ds_no_attrs, multiple, popds, romsds
12+
from ..datasets import airds, anc, ds_no_attrs, forecast, multiple, popds, romsds
1313
from . import raise_if_dask_computes
1414

1515
mpl.use("Agg")
@@ -163,7 +163,6 @@ def test_rename_like():
163163
reason="xarray GH4120. any test after this will fail since attrs are lost"
164164
),
165165
),
166-
# groupby("time.day")?
167166
),
168167
)
169168
def test_wrapped_classes(obj, attr, xrkwargs, cfkwargs):
@@ -744,6 +743,20 @@ def test_drop_dims(ds):
744743
assert_identical(ds.drop_dims("lon"), ds.cf.drop_dims(cf_name))
745744

746745

746+
def test_new_standard_name_mappers():
747+
assert_identical(forecast.cf.mean("realization"), forecast.mean("M"))
748+
assert_identical(
749+
forecast.cf.mean(["realization", "forecast_period"]), forecast.mean(["M", "L"])
750+
)
751+
assert_identical(forecast.cf.chunk({"realization": 1}), forecast.chunk({"M": 1}))
752+
assert_identical(forecast.cf.isel({"realization": 1}), forecast.isel({"M": 1}))
753+
assert_identical(forecast.cf.isel(**{"realization": 1}), forecast.isel(**{"M": 1}))
754+
assert_identical(
755+
forecast.cf.groupby("forecast_reference_time.month").mean(),
756+
forecast.groupby("S.month").mean(),
757+
)
758+
759+
747760
def test_possible_x_y_plot():
748761
from ..accessor import _possible_x_y_plot
749762

0 commit comments

Comments
 (0)