11# rasterio wrappers
22from __future__ import annotations
33
4- from collections .abc import Sequence
4+ import functools
5+ from collections .abc import Callable , Sequence
56from functools import partial
6- from typing import TYPE_CHECKING , Any
7+ from typing import TYPE_CHECKING , Any , TypeVar
78
89import geopandas as gpd
910import numpy as np
10- import odc .geo .xr # noqa
1111import rasterio as rio
1212import xarray as xr
13+ from affine import Affine
1314from rasterio .features import MergeAlg , geometry_mask
1415from rasterio .features import rasterize as rasterize_rio
1516
16- from .utils import is_in_memory , prepare_for_dask
17+ from .utils import XAXIS , YAXIS , get_affine , is_in_memory , prepare_for_dask
18+
19+ F = TypeVar ("F" , bound = Callable [..., Any ])
1720
1821if TYPE_CHECKING :
1922 import dask_geopandas
2023
2124
25+ def with_rio_env (func : F ) -> F :
26+ """
27+ Decorator that handles the 'env' and 'clear_cache' kwargs.
28+ """
29+
30+ @functools .wraps (func )
31+ def wrapper (* args , ** kwargs ):
32+ env = kwargs .pop ("env" , None )
33+ clear_cache = kwargs .pop ("clear_cache" , False )
34+
35+ if env is None :
36+ env = rio .Env ()
37+
38+ with env :
39+ # Remove env and clear_cache from kwargs before calling the wrapped function
40+ # since the function shouldn't handle the context management
41+ result = func (* args , ** kwargs )
42+
43+ if clear_cache :
44+ with rio .Env (GDAL_CACHEMAX = 0 ):
45+ # attempt to force-clear the GDAL cache
46+ pass
47+
48+ return result
49+
50+ return wrapper
51+
52+
2253def dask_rasterize_wrapper (
2354 geom_array : np .ndarray ,
24- tile_array : np .ndarray ,
55+ x_offsets : np .ndarray ,
56+ y_offsets : np .ndarray ,
57+ x_sizes : np .ndarray ,
58+ y_sizes : np .ndarray ,
2559 offset_array : np .ndarray ,
2660 * ,
2761 fill : Any ,
62+ affine : Affine ,
2863 all_touched : bool ,
2964 merge_alg : MergeAlg ,
3065 dtype_ : np .dtype ,
3166 env : rio .Env | None = None ,
3267) -> np .ndarray :
33- tile = tile_array .item ()
3468 offset = offset_array .item ()
3569
3670 return rasterize_geometries (
3771 geom_array [:, 0 , 0 ].tolist (),
38- tile = tile ,
72+ affine = affine * affine .translation (x_offsets .item (), y_offsets .item ()),
73+ shape = (y_sizes .item (), x_sizes .item ()),
3974 offset = offset ,
4075 all_touched = all_touched ,
4176 merge_alg = merge_alg ,
@@ -45,44 +80,25 @@ def dask_rasterize_wrapper(
4580 )[np .newaxis , :, :]
4681
4782
83+ @with_rio_env
4884def rasterize_geometries (
4985 geometries : Sequence [Any ],
5086 * ,
5187 dtype : np .dtype ,
52- tile ,
53- offset ,
88+ shape : tuple [int , int ],
89+ affine : Affine ,
90+ offset : int ,
5491 env : rio .Env | None = None ,
5592 clear_cache : bool = False ,
5693 ** kwargs ,
5794):
58- # From https://rasterio.readthedocs.io/en/latest/api/rasterio.features.html#rasterio.features.rasterize
59- # The out array will be copied and additional temporary raster memory equal to 2x the smaller of out data
60- # or GDAL’s max cache size (controlled by GDAL_CACHEMAX, default is 5% of the computer’s physical memory) is required.
61- # If GDAL max cache size is smaller than the output data, the array of shapes will be iterated multiple times.
62- # Performance is thus a linear function of buffer size. For maximum speed, ensure that GDAL_CACHEMAX
63- # is larger than the size of out or out_shape.
64- if env is None :
65- # out_size = dtype.itemsize * math.prod(tile.shape)
66- # env = rio.Env(GDAL_CACHEMAX=1.2 * out_size)
67- # FIXME: figure out a good default
68- env = rio .Env ()
69- with env :
70- res = rasterize_rio (
71- zip (geometries , range (offset , offset + len (geometries )), strict = True ),
72- out_shape = tile .shape ,
73- transform = tile .affine ,
74- ** kwargs ,
75- )
76- if clear_cache :
77- with rio .Env (GDAL_CACHEMAX = 0 ):
78- try :
79- from osgeo import gdal
80-
81- # attempt to force-clear the GDAL cache
82- assert gdal .GetCacheMax () == 0
83- except ImportError :
84- pass
85- assert res .shape == tile .shape
95+ res = rasterize_rio (
96+ zip (geometries , range (offset , offset + len (geometries )), strict = True ),
97+ out_shape = shape ,
98+ transform = affine ,
99+ ** kwargs ,
100+ )
101+ assert res .shape == shape
86102 return res
87103
88104
@@ -129,25 +145,30 @@ def rasterize(
129145 """
130146 if xdim not in obj .dims or ydim not in obj .dims :
131147 raise ValueError (f"Received { xdim = !r} , { ydim = !r} but obj.dims={ tuple (obj .dims )} " )
132- box = obj .odc .geobox
133- rasterize_kwargs = dict (all_touched = all_touched , merge_alg = merge_alg )
148+
149+ rasterize_kwargs = dict (
150+ all_touched = all_touched , merge_alg = merge_alg , affine = get_affine (obj , xdim = xdim , ydim = ydim ), env = env
151+ )
134152 # FIXME: box.crs == geometries.crs
135153 if is_in_memory (obj = obj , geometries = geometries ):
136154 geom_array = geometries .to_numpy ().squeeze (axis = 1 )
137155 rasterized = rasterize_geometries (
138156 geom_array .tolist (),
139- tile = box ,
157+ shape = ( obj . sizes [ ydim ], obj . sizes [ xdim ]) ,
140158 offset = 0 ,
141159 dtype = np .min_scalar_type (len (geometries )),
142160 fill = len (geometries ),
143- env = env ,
144161 ** rasterize_kwargs ,
145162 )
146163 else :
147164 from dask .array import from_array , map_blocks
148165
149- chunks , tiles_array , geom_array = prepare_for_dask (
150- obj , geometries , xdim = xdim , ydim = ydim , geoms_rechunk_size = geoms_rechunk_size
166+ map_blocks_args , chunks , geom_array = prepare_for_dask (
167+ obj ,
168+ geometries ,
169+ xdim = xdim ,
170+ ydim = ydim ,
171+ geoms_rechunk_size = geoms_rechunk_size ,
151172 )
152173 # DaskGeoDataFrame.len() computes!
153174 num_geoms = geom_array .size
@@ -159,10 +180,9 @@ def rasterize(
159180
160181 rasterized = map_blocks (
161182 dask_rasterize_wrapper ,
162- geom_array [:, np .newaxis , np .newaxis ],
163- tiles_array [np .newaxis , :, :],
183+ * map_blocks_args ,
164184 offsets [:, np .newaxis , np .newaxis ],
165- chunks = ((1 ,) * geom_array .numblocks [0 ], chunks [0 ], chunks [1 ]),
185+ chunks = ((1 ,) * geom_array .numblocks [0 ], chunks [YAXIS ], chunks [XAXIS ]),
166186 meta = np .array ([], dtype = dtype ),
167187 fill = 0 , # good identity value for both sum & replace.
168188 ** rasterize_kwargs ,
@@ -205,54 +225,39 @@ def replace_values(array: np.ndarray, to, *, from_=0) -> np.ndarray:
205225
206226def dask_mask_wrapper (
207227 geom_array : np .ndarray ,
208- tile_array : np .ndarray ,
228+ x_offsets : np .ndarray ,
229+ y_offsets : np .ndarray ,
230+ x_sizes : np .ndarray ,
231+ y_sizes : np .ndarray ,
209232 * ,
233+ affine : Affine ,
210234 all_touched : bool ,
211235 invert : bool ,
212236 env : rio .Env | None = None ,
213237) -> np .ndarray [Any , np .dtype [np .bool_ ]]:
214- tile = tile_array .item ()
215-
216238 return np_geometry_mask (
217239 geom_array [:, 0 , 0 ].tolist (),
218- tile = tile ,
219- all_touched = all_touched ,
240+ shape = ( y_sizes . item (), x_sizes . item ()) ,
241+ affine = affine * affine . translation ( x_offsets . item (), y_offsets . item ()) ,
220242 invert = invert ,
221243 env = env ,
222244 )[np .newaxis , :, :]
223245
224246
247+ @with_rio_env
225248def np_geometry_mask (
226249 geometries : Sequence [Any ],
227250 * ,
228- tile ,
251+ x_offset : int ,
252+ y_offset : int ,
253+ shape : tuple [int , int ],
254+ affine : Affine ,
229255 env : rio .Env | None = None ,
230256 clear_cache : bool = False ,
231257 ** kwargs ,
232258) -> np .ndarray [Any , np .dtype [np .bool_ ]]:
233- # From https://rasterio.readthedocs.io/en/latest/api/rasterio.features.html#rasterio.features.rasterize
234- # The out array will be copied and additional temporary raster memory equal to 2x the smaller of out data
235- # or GDAL’s max cache size (controlled by GDAL_CACHEMAX, default is 5% of the computer’s physical memory) is required.
236- # If GDAL max cache size is smaller than the output data, the array of shapes will be iterated multiple times.
237- # Performance is thus a linear function of buffer size. For maximum speed, ensure that GDAL_CACHEMAX
238- # is larger than the size of out or out_shape.
239- if env is None :
240- # out_size = np.bool_.itemsize * math.prod(tile.shape)
241- # env = rio.Env(GDAL_CACHEMAX=1.2 * out_size)
242- # FIXME: figure out a good default
243- env = rio .Env ()
244- with env :
245- res = geometry_mask (geometries , out_shape = tile .shape , transform = tile .affine , ** kwargs )
246- if clear_cache :
247- with rio .Env (GDAL_CACHEMAX = 0 ):
248- try :
249- from osgeo import gdal
250-
251- # attempt to force-clear the GDAL cache
252- assert gdal .GetCacheMax () == 0
253- except ImportError :
254- pass
255- assert res .shape == tile .shape
259+ res = geometry_mask (geometries , out_shape = shape , transform = affine , ** kwargs )
260+ assert res .shape == shape
256261 return res
257262
258263
@@ -298,23 +303,31 @@ def geometry_clip(
298303 invert = not invert # rioxarray clip convention -> rasterio geometry_mask convention
299304 if xdim not in obj .dims or ydim not in obj .dims :
300305 raise ValueError (f"Received { xdim = !r} , { ydim = !r} but obj.dims={ tuple (obj .dims )} " )
301- box = obj .odc .geobox
302- geometry_mask_kwargs = dict (all_touched = all_touched , invert = invert )
306+ geometry_mask_kwargs = dict (
307+ all_touched = all_touched , invert = invert , affine = get_affine (obj , xdim = xdim , ydim = ydim ), env = env
308+ )
303309
304310 if is_in_memory (obj = obj , geometries = geometries ):
305311 geom_array = geometries .to_numpy ().squeeze (axis = 1 )
306- mask = np_geometry_mask (geom_array .tolist (), tile = box , env = env , ** geometry_mask_kwargs )
312+ mask = np_geometry_mask (
313+ geom_array .tolist (),
314+ shape = (obj .sizes [ydim ], obj .sizes [xdim ]),
315+ ** geometry_mask_kwargs ,
316+ )
307317 else :
308318 from dask .array import map_blocks
309319
310- chunks , tiles_array , geom_array = prepare_for_dask (
311- obj , geometries , xdim = xdim , ydim = ydim , geoms_rechunk_size = geoms_rechunk_size
320+ map_blocks_args , chunks , geom_array = prepare_for_dask (
321+ obj ,
322+ geometries ,
323+ xdim = xdim ,
324+ ydim = ydim ,
325+ geoms_rechunk_size = geoms_rechunk_size ,
312326 )
313327 mask = map_blocks (
314328 dask_mask_wrapper ,
315- geom_array [:, np .newaxis , np .newaxis ],
316- tiles_array [np .newaxis , :, :],
317- chunks = ((1 ,) * geom_array .numblocks [0 ], chunks [0 ], chunks [1 ]),
329+ * map_blocks_args ,
330+ chunks = ((1 ,) * geom_array .numblocks [0 ], chunks [YAXIS ], chunks [XAXIS ]),
318331 meta = np .array ([], dtype = bool ),
319332 ** geometry_mask_kwargs ,
320333 )
0 commit comments