Skip to content

Commit 7e51d9e

Browse files
authored
Handle non-uniform chunk sizes in formatting (#57)
* handle variable chunks individually and test * ignore new ruff rule * changelog
1 parent 1425e13 commit 7e51d9e

File tree

4 files changed

+52
-9
lines changed

4 files changed

+52
-9
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/).
88

99
Fixed:
1010
- Attributes are now properly preserved when updating coordinates during pre-formatting for regridding ([#54](https://github.com/xarray-contrib/xarray-regrid/pull/54)).
11+
- Handle datasets with inconsistent chunksizes during pre-formatting ([#57](https://github.com/xarray-contrib/xarray-regrid/pull/57)).
1112

1213

1314
## 0.4.0 (2024-09-26)

src/xarray_regrid/methods/conservative.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def conservative_regrid_dataset(
119119
# Create weights array and coverage mask for each regridding dim
120120
weights = {}
121121
covered = {}
122-
for coord in coords:
122+
for coord in coords: # noqa: PLC0206
123123
covered[coord] = (coords[coord] <= data[coord].max()) & (
124124
coords[coord] >= data[coord].min()
125125
)
@@ -137,7 +137,7 @@ def conservative_regrid_dataset(
137137
weights[coord] = da_weights
138138

139139
# Apply the weights, using a unique set that matches chunking of each array
140-
for array in data_vars.keys():
140+
for array in data_vars.keys(): # noqa: PLC0206
141141
var_weights = {}
142142
for coord, weight_array in weights.items():
143143
var_input_chunks = data_vars[array].chunksizes.get(coord)

src/xarray_regrid/utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,6 @@ def format_for_regrid(
249249
Currently handles padding of spherical geometry if lat/lon coordinates can
250250
be inferred and the domain size requires boundary padding.
251251
"""
252-
orig_chunksizes = obj.chunksizes
253-
254252
# Special-cased coordinates with accepted names and formatting function
255253
coord_handlers: dict[str, CoordHandler] = {
256254
"lat": {"names": ["lat", "latitude"], "func": format_lat},
@@ -270,15 +268,22 @@ def format_for_regrid(
270268
formatted_coords[coord_type] = str(coord)
271269

272270
# Apply formatting
271+
result = obj.copy()
273272
for coord_type, coord in formatted_coords.items():
274273
# Make sure formatted coords are sorted
275-
obj = ensure_monotonic(obj, coord)
274+
result = ensure_monotonic(result, coord)
276275
target = ensure_monotonic(target, coord)
277-
obj = coord_handlers[coord_type]["func"](obj, target, formatted_coords)
276+
result = coord_handlers[coord_type]["func"](result, target, formatted_coords)
277+
278278
# Coerce back to a single chunk if that's what was passed
279-
if len(orig_chunksizes.get(coord, [])) == 1:
280-
obj = obj.chunk({coord: -1})
281-
return obj
279+
if isinstance(obj, xr.DataArray) and len(obj.chunksizes.get(coord, ())) == 1:
280+
result = result.chunk({coord: -1})
281+
elif isinstance(obj, xr.Dataset):
282+
for var in result.data_vars:
283+
if len(obj[var].chunksizes.get(coord, ())) == 1:
284+
result[var] = result[var].chunk({coord: -1})
285+
286+
return result
282287

283288

284289
def format_lat(

tests/test_format.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,40 @@ def test_stats():
216216
# And preserve integer dtypes
217217
assert formatted.data.dtype == source.data.dtype
218218
assert (formatted.longitude.diff("longitude") == 1).all()
219+
220+
221+
def test_maintain_single_chunk():
222+
dx_source = 2
223+
source = xarray_regrid.Grid(
224+
north=90 - dx_source / 2,
225+
east=360 - dx_source / 2,
226+
south=-90 + dx_source / 2,
227+
west=0 + dx_source / 2,
228+
resolution_lat=dx_source,
229+
resolution_lon=dx_source,
230+
).create_regridding_dataset()
231+
source["a"] = xr.DataArray(
232+
np.ones((source.latitude.size, source.longitude.size)),
233+
dims=["latitude", "longitude"],
234+
coords={"latitude": source.latitude, "longitude": source.longitude},
235+
).chunk({"latitude": -1, "longitude": -1})
236+
source["b"] = source.a.copy().chunk({"latitude": 45, "longitude": 90})
237+
238+
dx_target = 1
239+
target = xarray_regrid.Grid(
240+
north=90,
241+
east=360,
242+
south=-90,
243+
west=0,
244+
resolution_lat=dx_target,
245+
resolution_lon=dx_target,
246+
).create_regridding_dataset()
247+
248+
# dataset
249+
formatted = format_for_regrid(source, target)
250+
assert formatted.a.chunks == ((92,), (182,))
251+
assert formatted.b.chunks == ((1, 45, 45, 1), (1, 90, 90, 1))
252+
253+
# dataarray
254+
formatted = format_for_regrid(source.a, target)
255+
assert formatted.chunks == ((92,), (182,))

0 commit comments

Comments
 (0)