Skip to content

Commit 2c53d31

Browse files
make round tripping more thorough
1 parent 598d2d3 commit 2c53d31

File tree

3 files changed

+42
-8
lines changed

3 files changed

+42
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,5 +298,5 @@ convention = "numpy"
298298
[tool.ruff.lint.pylint]
299299
# TODO: refactor to reduce complexity, if possible
300300
max-args = 10
301-
max-branches = 23
301+
max-branches = 25
302302
max-statements = 110

src/esmf_regrid/experimental/io.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
ESMFAreaWeightedRegridder,
2121
ESMFBilinear,
2222
ESMFBilinearRegridder,
23-
ESMFNearest,
2423
ESMFNearestRegridder,
2524
GridRecord,
2625
MeshRecord,
@@ -76,6 +75,7 @@
7675
"extrap_method": _EXTRAP_METHOD_DICT,
7776
"unmapped_action": _UNMAPPED_ACTION_DICT,
7877
}
78+
_ESMF_BOOL_ARGS = ["ignore_degenerate", "large_file"]
7979

8080

8181
def _add_mask_to_cube(mask, cube, name):
@@ -294,7 +294,7 @@ def save_regridder(rg, filename, allow_partial=False):
294294
if tgt_slice is None:
295295
tgt_slice = []
296296
tgt_slice_cube = Cube(
297-
src_slice, long_name=_TGT_SLICE_NAME, var_name=_TGT_SLICE_NAME
297+
tgt_slice, long_name=_TGT_SLICE_NAME, var_name=_TGT_SLICE_NAME
298298
)
299299
extra_cubes = [src_slice_cube, tgt_slice_cube]
300300

@@ -416,11 +416,11 @@ def load_regridder(filename, allow_partial=False):
416416
mdtol = weights_cube.attributes[_MDTOL]
417417

418418
if src_cube.coords(_SOURCE_MASK_NAME):
419-
use_src_mask = src_cube.coord(_SOURCE_MASK_NAME).points
419+
use_src_mask = src_cube.coord(_SOURCE_MASK_NAME).points.astype(bool)
420420
else:
421421
use_src_mask = False
422422
if tgt_cube.coords(_TARGET_MASK_NAME):
423-
use_tgt_mask = tgt_cube.coord(_TARGET_MASK_NAME).points
423+
use_tgt_mask = tgt_cube.coord(_TARGET_MASK_NAME).points.astype(bool)
424424
else:
425425
use_tgt_mask = False
426426

@@ -433,6 +433,9 @@ def load_regridder(filename, allow_partial=False):
433433
for arg, arg_dict in _ESMF_ENUM_ARGS.items():
434434
if arg in esmf_args:
435435
esmf_args[arg] = arg_dict[esmf_args[arg]]
436+
for arg in _ESMF_BOOL_ARGS:
437+
if arg in esmf_args:
438+
esmf_args[arg] = bool(esmf_args[arg])
436439

437440
if scheme is GridToMeshESMFRegridder:
438441
resolution_keyword = _SOURCE_RESOLUTION
@@ -463,16 +466,21 @@ def load_regridder(filename, allow_partial=False):
463466
Constants.Method.BILINEAR: ESMFBilinear,
464467
}[method]
465468
mdtol = kwargs.pop(_MDTOL, None)
466-
mdtol_dict = {}
469+
sub_kwargs = {}
467470
if mdtol is not None:
468-
mdtol_dict[_MDTOL] = mdtol
471+
sub_kwargs[_MDTOL] = mdtol
469472
regridder = scheme(
470473
src_cube,
471474
tgt_cube,
472475
src_slice,
473476
tgt_slice,
474477
weight_matrix,
475-
sub_scheme(**mdtol_dict),
478+
sub_scheme(
479+
use_src_mask=use_src_mask,
480+
use_tgt_mask=use_tgt_mask,
481+
esmf_args=esmf_args,
482+
**sub_kwargs,
483+
),
476484
**kwargs,
477485
)
478486
else:

src/esmf_regrid/tests/unit/experimental/partition/test_PartialRegridder.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""Unit tests for :mod:`esmf_regrid.experimental.partition`."""
22

3+
import numpy as np
4+
35
from esmf_regrid import ESMFAreaWeighted
46
from esmf_regrid.experimental._partial import PartialRegridder
7+
from esmf_regrid.experimental.io import load_regridder, save_regridder
58
from esmf_regrid.tests.unit.schemes.test__cube_to_GridInfo import (
69
_grid_cube,
710
)
@@ -23,3 +26,26 @@ def test_PartialRegridder_repr():
2326
"scheme=ESMFAreaWeighted(mdtol=0.5, use_src_mask=False, use_tgt_mask=False, esmf_args={}))"
2427
)
2528
assert repr(pr) == expected_repr
29+
30+
31+
def test_PartialRegridder_roundtrip(tmp_path):
32+
"""Test load/save for PartialRegridder instance."""
33+
src = _grid_cube(10, 15, (-180, 180), (-90, 90), circular=True)
34+
mask = np.zeros_like(src.data)
35+
mask[0, 0] = 1
36+
src.data = np.ma.array(src.data, mask=mask)
37+
tgt = _grid_cube(5, 10, (-180, 180), (-90, 90), circular=True)
38+
src_slice = [[10, 20], [15, 30]]
39+
tgt_slice = [[0, 5], [0, 10]]
40+
weights = None
41+
scheme = ESMFAreaWeighted(
42+
mdtol=0.5, use_src_mask=src.data.mask, esmf_args={"ignore_degenerate": True}
43+
)
44+
45+
pr = PartialRegridder(src, tgt, src_slice, tgt_slice, weights, scheme)
46+
file = tmp_path / "partial.nc"
47+
48+
save_regridder(pr, file, allow_partial=True)
49+
loaded_pr = load_regridder(file, allow_partial=True)
50+
51+
assert repr(loaded_pr) == repr(pr)

0 commit comments

Comments
 (0)