|
1 | 1 | import functools
|
2 | 2 | import inspect
|
| 3 | +from collections import ChainMap |
3 | 4 | from typing import Any, List, Optional, Set, Union
|
4 | 5 |
|
5 | 6 | import xarray as xr
|
@@ -95,7 +96,7 @@ def _get_axis_coord_single(var, key, *args):
|
95 | 96 | results = _get_axis_coord(var, key, *args)
|
96 | 97 | if len(results) > 1:
|
97 | 98 | raise ValueError(
|
98 |
| - "Multiple results for {key!r} found: {results!r}. Is this valid CF? Please open an issue." |
| 99 | + f"Multiple results for {key!r} found: {results!r}. Is this valid CF? Please open an issue." |
99 | 100 | )
|
100 | 101 | else:
|
101 | 102 | return results[0]
|
@@ -335,20 +336,34 @@ def _process_signature(self, func, args, kwargs, key_mappers):
|
335 | 336 | def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
|
336 | 337 | """ rewrites 'dim' for example using 'mapper' """
|
337 | 338 | updates: dict = {}
|
338 |
| - key_mappers.update(dict.fromkeys(var_kws, _get_axis_coord_single)) |
| 339 | + |
| 340 | + # allow multiple return values here. |
| 341 | + # these are valid for .sel, .isel, .coarsen |
| 342 | + key_mappers.update(dict.fromkeys(var_kws, _get_axis_coord)) |
| 343 | + |
339 | 344 | for key, mapper in key_mappers.items():
|
340 | 345 | value = kwargs.get(key, None)
|
| 346 | + |
341 | 347 | if value is not None:
|
342 | 348 | if isinstance(value, str):
|
343 | 349 | value = [value]
|
344 | 350 |
|
345 | 351 | if isinstance(value, dict):
|
346 | 352 | # this for things like isel where **kwargs captures things like T=5
|
347 |
| - updates[key] = { |
348 |
| - mapper(self._obj, k, False, k): v for k, v in value.items() |
349 |
| - } |
| 353 | + # .sel, .isel, .rolling |
| 354 | + # Account for multiple names matching the key. |
| 355 | + # e.g. .isel(X=5) → .isel(xi_rho=5, xi_u=5, xi_v=5, xi_psi=5) |
| 356 | + # where xi_* have attrs["axis"] = "X" |
| 357 | + updates[key] = ChainMap( |
| 358 | + *[ |
| 359 | + dict.fromkeys(mapper(self._obj, k, False, k), v) |
| 360 | + for k, v in value.items() |
| 361 | + ] |
| 362 | + ) |
| 363 | + |
350 | 364 | elif value is Ellipsis:
|
351 | 365 | pass
|
| 366 | + |
352 | 367 | else:
|
353 | 368 | # things like sum which have dim
|
354 | 369 | updates[key] = [mapper(self._obj, v, False, v) for v in value]
|
|
0 commit comments