11import argparse
2- import math
32import os
43import logging
5- import shutil
6- import tempfile
74from functools import partial
85from multiprocessing import Pool , cpu_count
96from pathlib import Path
107from typing import Optional , Set
118
129import numpy as np
1310import psutil
11+ import yirgacheffe as yg
1412from osgeo import gdal # type: ignore
15- from yirgacheffe .layers import RasterLayer # type: ignore
1613
1714logger = logging .getLogger (__name__ )
1815logging .basicConfig (level = logging .INFO , format = '%(asctime)s %(levelname)-8s %(message)s' )
@@ -24,7 +21,7 @@ def _enumerate_subset(
2421 offset : int ,
2522) -> Set [int ]:
2623 gdal .SetCacheMax (1 * 1024 * 1024 * 1024 )
27- with RasterLayer . layer_from_file (habitat_path ) as habitat_map :
24+ with yg . read_raster (habitat_path ) as habitat_map :
2825 blocksize = min (BLOCKSIZE , habitat_map .window .ysize - offset )
2926 data = habitat_map .read_array (0 , offset , habitat_map .window .xsize , blocksize )
3027 values = np .unique (data )
@@ -36,7 +33,7 @@ def enumerate_terrain_types(
3633 habitat_path : Path
3734) -> Set [int ]:
3835 gdal .SetCacheMax (1 * 1024 * 1024 * 1024 )
39- with RasterLayer . layer_from_file (habitat_path ) as habitat_map :
36+ with yg . read_raster (habitat_path ) as habitat_map :
4037 ysize = habitat_map .window .ysize
4138 blocks = range (0 , ysize , BLOCKSIZE )
4239 logger .info ("Enumerating habitat classes in raster..." )
@@ -51,44 +48,65 @@ def enumerate_terrain_types(
5148 pass
5249 return superset
5350
54- def _make_single_type_map (
51+ class VsimemFile :
52+ def __init__ (self , path ):
53+ self .path = path
54+
55+ def __enter__ (self ):
56+ return self .path
57+
58+ def __exit__ (self , * args ):
59+ try :
60+ gdal .Unlink (self .path )
61+ except RuntimeError :
62+ pass
63+
64+ def make_single_type_map (
5565 habitat_path : Path ,
5666 pixel_scale : Optional [float ],
5767 target_projection : Optional [str ],
5868 output_directory_path : Path ,
69+ max_threads : int ,
5970 habitat_value : int | float ,
6071) -> None :
6172 logger .info ("Building layer for %s..." , habitat_value )
6273
63- # We could do this via yirgacheffe if it wasn't for the need to
64- # both rescale and reproject. So we do the initial filtering
65- # in that, but then bounce it to a temporary file for the
66- # warping
67- with tempfile .TemporaryDirectory () as tmpdir :
68- with RasterLayer .layer_from_file (habitat_path ) as habitat_map :
69- logger .info ("Filtering for %s..." , habitat_value )
70- calc = habitat_map == habitat_value
71- with RasterLayer .empty_raster_layer_like (habitat_map , datatype = gdal .GDT_Byte ) as filtered_map :
72- calc .save (filtered_map )
74+ mem_stats = psutil .virtual_memory ()
75+ available_mem = mem_stats .available
76+ gdal .SetCacheMax (available_mem )
77+ gdal .SetConfigOption ('GDAL_NUM_THREADS' , str (max_threads ))
7378
74- filename = f"lcc_ { habitat_value } .tif"
75- tempname = os . path . join ( tmpdir , filename )
79+ with yg . read_raster ( habitat_path ) as habitat_map :
80+ logger . info ( "Filtering for %s..." , habitat_value )
7681
77- dataset = filtered_map ._dataset # pylint: disable=W0212
82+ # We use the GDAL in memory file system for all this
83+ with VsimemFile (f"/vsimem/filtered_{ habitat_value } .tif" ) as filter_map_path :
84+ filtered_map = habitat_map == habitat_value
85+ filtered_map .to_geotiff (filter_map_path , parallelism = max_threads )
86+
87+ with VsimemFile (f"/vsimem/warped_{ habitat_value } .tif" ) as warped_map_path :
7888 logger .info ("Projecting %s..." , habitat_value )
79- gdal .Warp (tempname , dataset , options = gdal .WarpOptions (
80- creationOptions = ['COMPRESS=LZW' , 'NUM_THREADS=16' ],
81- multithread = True ,
82- dstSRS = target_projection ,
83- outputType = gdal .GDT_Float32 ,
84- xRes = pixel_scale ,
85- yRes = ((0.0 - pixel_scale ) if pixel_scale else pixel_scale ),
86- resampleAlg = "average" ,
87- workingType = gdal .GDT_Float32
88- ))
89-
90- logger .info ("Saving %s..." , habitat_value )
91- shutil .move (tempname , output_directory_path / filename )
89+ gdal .Warp (
90+ warped_map_path ,
91+ filter_map_path ,
92+ options = gdal .WarpOptions (
93+ creationOptions = [],
94+ multithread = True ,
95+ dstSRS = target_projection ,
96+ outputType = gdal .GDT_Float32 ,
97+ xRes = pixel_scale ,
98+ yRes = ((0.0 - pixel_scale ) if pixel_scale else pixel_scale ),
99+ resampleAlg = "average" ,
100+ warpOptions = [f'NUM_THREADS={ max_threads } ' ],
101+ warpMemoryLimit = available_mem ,
102+ workingType = gdal .GDT_Float32
103+ )
104+ )
105+
106+ logger .info ("Saving %s..." , habitat_value )
107+ filename = f"lcc_{ habitat_value } .tif"
108+ with yg .read_raster (warped_map_path ) as result :
109+ result .to_geotiff (output_directory_path / filename )
92110
93111def habitat_process (
94112 habitat_path : Path ,
@@ -99,48 +117,20 @@ def habitat_process(
99117) -> None :
100118 os .makedirs (output_directory_path , exist_ok = True )
101119
102- with RasterLayer .layer_from_file (habitat_path ) as habitat_map :
103- # The processing stage uses GDAL warp directly, with no chunking, so we should
104- # take a guess at how much memory we need based on the dimensions of the base map
105- pixels = habitat_map .window .xsize * habitat_map .window .ysize
106- # I really tried not to write this statement and use introspection, but nothing
107- # I tried gave a sensible answer. Normally I'd be more paranoid due to numpy bloat,
108- # but we're calling GDALwarp and passing it filenames, so everything should be done
109- # in the C++ world of GDAL, so I have more confidence that we won't see the usual
110- # 4x plus memory bloat of loading raster data into the python world.
111- match habitat_map .datatype :
112- case gdal .GDT_Byte | gdal .GDT_Int8 :
113- pixel_size = 1
114- case gdal .GDT_CInt16 | gdal .GDT_Int16 | gdal .GDT_UInt16 :
115- pixel_size = 2
116- case gdal .GDT_CFloat32 | gdal .GDT_CInt32 | gdal .GDT_Float32 | gdal .GDT_Int32 :
117- pixel_size = 4
118- case _:
119- pixel_size = 8
120- estimated_memory = pixel_size * pixels
121-
122- mem_stats = psutil .virtual_memory ()
123- max_copies = math .floor ((mem_stats .available * 0.8 ) / estimated_memory )
124- if max_copies == 0 :
125- logger .warning ("Low memory" )
126- max_copies = 1
127- process_count = min (max_copies , process_count )
128- logger .info ("Estimating we can run %s concurrent tasks" , process_count )
129-
130120 # We need to know how many terrains there are. We could get this from the crosswalk
131121 # table, but we can also work out the unique values ourselves. In practice this is
132122 # worth the effort, otherwise we generate a lot of empty maps potentially.
133123 habitats = enumerate_terrain_types (habitat_path )
134124
135- if max_copies > 1 :
136- with Pool ( processes = process_count ) as pool :
137- pool . map (
138- partial ( _make_single_type_map , habitat_path , pixel_scale , target_projection , output_directory_path ) ,
139- habitats
140- )
141- else :
142- for habitat in habitats :
143- _make_single_type_map ( habitat_path , pixel_scale , target_projection , output_directory_path , habitat )
125+ for habitat in habitats :
126+ make_single_type_map (
127+ habitat_path ,
128+ pixel_scale ,
129+ target_projection ,
130+ output_directory_path ,
131+ process_count ,
132+ habitat ,
133+ )
144134
145135def main () -> None :
146136 parser = argparse .ArgumentParser (description = "Downsample habitat map to raster per terrain type." )
@@ -177,7 +167,7 @@ def main() -> None:
177167 "-j" ,
178168 type = int ,
179169 required = False ,
180- default = round ( cpu_count () / 2 ),
170+ default = cpu_count (),
181171 dest = "processes_count" ,
182172 help = "Optional number of concurrent threads to use."
183173 )
0 commit comments