Skip to content

Commit 4cffa5b

Browse files
committed
handle variable chunks individually and test
1 parent 1425e13 commit 4cffa5b

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

src/xarray_regrid/utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,6 @@ def format_for_regrid(
249249
Currently handles padding of spherical geometry if lat/lon coordinates can
250250
be inferred and the domain size requires boundary padding.
251251
"""
252-
orig_chunksizes = obj.chunksizes
253-
254252
# Special-cased coordinates with accepted names and formatting function
255253
coord_handlers: dict[str, CoordHandler] = {
256254
"lat": {"names": ["lat", "latitude"], "func": format_lat},
@@ -270,15 +268,22 @@ def format_for_regrid(
270268
formatted_coords[coord_type] = str(coord)
271269

272270
# Apply formatting
271+
result = obj.copy()
273272
for coord_type, coord in formatted_coords.items():
274273
# Make sure formatted coords are sorted
275-
obj = ensure_monotonic(obj, coord)
274+
result = ensure_monotonic(result, coord)
276275
target = ensure_monotonic(target, coord)
277-
obj = coord_handlers[coord_type]["func"](obj, target, formatted_coords)
276+
result = coord_handlers[coord_type]["func"](result, target, formatted_coords)
277+
278278
# Coerce back to a single chunk if that's what was passed
279-
if len(orig_chunksizes.get(coord, [])) == 1:
280-
obj = obj.chunk({coord: -1})
281-
return obj
279+
if isinstance(obj, xr.DataArray) and len(obj.chunksizes.get(coord, ())) == 1:
280+
result = result.chunk({coord: -1})
281+
elif isinstance(obj, xr.Dataset):
282+
for var in result.data_vars:
283+
if len(obj[var].chunksizes.get(coord, ())) == 1:
284+
result[var] = result[var].chunk({coord: -1})
285+
286+
return result
282287

283288

284289
def format_lat(

tests/test_format.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,40 @@ def test_stats():
216216
# And preserve integer dtypes
217217
assert formatted.data.dtype == source.data.dtype
218218
assert (formatted.longitude.diff("longitude") == 1).all()
219+
220+
221+
def test_maintain_single_chunk():
222+
dx_source = 2
223+
source = xarray_regrid.Grid(
224+
north=90 - dx_source / 2,
225+
east=360 - dx_source / 2,
226+
south=-90 + dx_source / 2,
227+
west=0 + dx_source / 2,
228+
resolution_lat=dx_source,
229+
resolution_lon=dx_source,
230+
).create_regridding_dataset()
231+
source["a"] = xr.DataArray(
232+
np.ones((source.latitude.size, source.longitude.size)),
233+
dims=["latitude", "longitude"],
234+
coords={"latitude": source.latitude, "longitude": source.longitude},
235+
).chunk({"latitude": -1, "longitude": -1})
236+
source["b"] = source.a.copy().chunk({"latitude": 45, "longitude": 90})
237+
238+
dx_target = 1
239+
target = xarray_regrid.Grid(
240+
north=90,
241+
east=360,
242+
south=-90,
243+
west=0,
244+
resolution_lat=dx_target,
245+
resolution_lon=dx_target,
246+
).create_regridding_dataset()
247+
248+
# dataset
249+
formatted = format_for_regrid(source, target)
250+
assert formatted.a.chunks == ((92,), (182,))
251+
assert formatted.b.chunks == ((1, 45, 45, 1), (1, 90, 90, 1))
252+
253+
# dataarray
254+
formatted = format_for_regrid(source.a, target)
255+
assert formatted.chunks == ((92,), (182,))

0 commit comments

Comments
 (0)