@@ -216,3 +216,40 @@ def test_stats():
216216 # And preserve integer dtypes
217217 assert formatted .data .dtype == source .data .dtype
218218 assert (formatted .longitude .diff ("longitude" ) == 1 ).all ()
219+
220+
221+ def test_maintain_single_chunk ():
222+ dx_source = 2
223+ source = xarray_regrid .Grid (
224+ north = 90 - dx_source / 2 ,
225+ east = 360 - dx_source / 2 ,
226+ south = - 90 + dx_source / 2 ,
227+ west = 0 + dx_source / 2 ,
228+ resolution_lat = dx_source ,
229+ resolution_lon = dx_source ,
230+ ).create_regridding_dataset ()
231+ source ["a" ] = xr .DataArray (
232+ np .ones ((source .latitude .size , source .longitude .size )),
233+ dims = ["latitude" , "longitude" ],
234+ coords = {"latitude" : source .latitude , "longitude" : source .longitude },
235+ ).chunk ({"latitude" : - 1 , "longitude" : - 1 })
236+ source ["b" ] = source .a .copy ().chunk ({"latitude" : 45 , "longitude" : 90 })
237+
238+ dx_target = 1
239+ target = xarray_regrid .Grid (
240+ north = 90 ,
241+ east = 360 ,
242+ south = - 90 ,
243+ west = 0 ,
244+ resolution_lat = dx_target ,
245+ resolution_lon = dx_target ,
246+ ).create_regridding_dataset ()
247+
248+ # dataset
249+ formatted = format_for_regrid (source , target )
250+ assert formatted .a .chunks == ((92 ,), (182 ,))
251+ assert formatted .b .chunks == ((1 , 45 , 45 , 1 ), (1 , 90 , 90 , 1 ))
252+
253+ # dataarray
254+ formatted = format_for_regrid (source .a , target )
255+ assert formatted .chunks == ((92 ,), (182 ,))
0 commit comments