Skip to content

Commit eb1b2ab

Browse files
authored
Merge pull request #30 from quantifyearth/mwd-habitat-process
Simplify and improve performance of habitat processing
2 parents 6fc2c22 + 7df8b70 commit eb1b2ab

File tree

4 files changed

+89
-75
lines changed

4 files changed

+89
-75
lines changed

CHANGES.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
## v1.1.0 (4/11/2025)
2+
3+
### Added
4+
5+
* Implementation of point validation based on [Dahal et al](https://gmd.copernicus.org/articles/15/5093/2022/).
6+
7+
### Changed
8+
9+
* Performance improvements and simplification to habitat processing.
10+
* Store more analysis data from model validation.
11+
12+
## v1.0.1 (19/10/2025)
13+
14+
### Fixed
15+
16+
* Fixed github action for publishing to pip.
17+
18+
## v1.0.0 (19/10/2025)
19+
20+
### Added
21+
22+
* Initial release as stand alone package.

aoh/habitat_process.py

Lines changed: 61 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
import argparse
2-
import math
32
import os
43
import logging
5-
import shutil
6-
import tempfile
74
from functools import partial
85
from multiprocessing import Pool, cpu_count
96
from pathlib import Path
107
from typing import Optional, Set
118

129
import numpy as np
1310
import psutil
11+
import yirgacheffe as yg
1412
from osgeo import gdal # type: ignore
15-
from yirgacheffe.layers import RasterLayer # type: ignore
1613

1714
logger = logging.getLogger(__name__)
1815
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s')
@@ -24,7 +21,7 @@ def _enumerate_subset(
2421
offset: int,
2522
) -> Set[int]:
2623
gdal.SetCacheMax(1 * 1024 * 1024 * 1024)
27-
with RasterLayer.layer_from_file(habitat_path) as habitat_map:
24+
with yg.read_raster(habitat_path) as habitat_map:
2825
blocksize = min(BLOCKSIZE, habitat_map.window.ysize - offset)
2926
data = habitat_map.read_array(0, offset, habitat_map.window.xsize, blocksize)
3027
values = np.unique(data)
@@ -36,7 +33,7 @@ def enumerate_terrain_types(
3633
habitat_path: Path
3734
) -> Set[int]:
3835
gdal.SetCacheMax(1 * 1024 * 1024 * 1024)
39-
with RasterLayer.layer_from_file(habitat_path) as habitat_map:
36+
with yg.read_raster(habitat_path) as habitat_map:
4037
ysize = habitat_map.window.ysize
4138
blocks = range(0, ysize, BLOCKSIZE)
4239
logger.info("Enumerating habitat classes in raster...")
@@ -51,44 +48,65 @@ def enumerate_terrain_types(
5148
pass
5249
return superset
5350

54-
def _make_single_type_map(
51+
class VsimemFile:
52+
def __init__(self, path):
53+
self.path = path
54+
55+
def __enter__(self):
56+
return self.path
57+
58+
def __exit__(self, *args):
59+
try:
60+
gdal.Unlink(self.path)
61+
except RuntimeError:
62+
pass
63+
64+
def make_single_type_map(
5565
habitat_path: Path,
5666
pixel_scale: Optional[float],
5767
target_projection: Optional[str],
5868
output_directory_path: Path,
69+
max_threads: int,
5970
habitat_value: int | float,
6071
) -> None:
6172
logger.info("Building layer for %s...", habitat_value)
6273

63-
# We could do this via yirgacheffe if it wasn't for the need to
64-
# both rescale and reproject. So we do the initial filtering
65-
# in that, but then bounce it to a temporary file for the
66-
# warping
67-
with tempfile.TemporaryDirectory() as tmpdir:
68-
with RasterLayer.layer_from_file(habitat_path) as habitat_map:
69-
logger.info("Filtering for %s...", habitat_value)
70-
calc = habitat_map == habitat_value
71-
with RasterLayer.empty_raster_layer_like(habitat_map, datatype=gdal.GDT_Byte) as filtered_map:
72-
calc.save(filtered_map)
74+
mem_stats = psutil.virtual_memory()
75+
available_mem = mem_stats.available
76+
gdal.SetCacheMax(available_mem)
77+
gdal.SetConfigOption('GDAL_NUM_THREADS', str(max_threads))
7378

74-
filename = f"lcc_{habitat_value}.tif"
75-
tempname = os.path.join(tmpdir, filename)
79+
with yg.read_raster(habitat_path) as habitat_map:
80+
logger.info("Filtering for %s...", habitat_value)
7681

77-
dataset = filtered_map._dataset # pylint: disable=W0212
82+
# We use the GDAL in memory file system for all this
83+
with VsimemFile(f"/vsimem/filtered_{habitat_value}.tif") as filter_map_path:
84+
filtered_map = habitat_map == habitat_value
85+
filtered_map.to_geotiff(filter_map_path, parallelism=max_threads)
86+
87+
with VsimemFile(f"/vsimem/warped_{habitat_value}.tif") as warped_map_path:
7888
logger.info("Projecting %s...", habitat_value)
79-
gdal.Warp(tempname, dataset, options=gdal.WarpOptions(
80-
creationOptions=['COMPRESS=LZW', 'NUM_THREADS=16'],
81-
multithread=True,
82-
dstSRS=target_projection,
83-
outputType=gdal.GDT_Float32,
84-
xRes=pixel_scale,
85-
yRes=((0.0 - pixel_scale) if pixel_scale else pixel_scale),
86-
resampleAlg="average",
87-
workingType=gdal.GDT_Float32
88-
))
89-
90-
logger.info("Saving %s...", habitat_value)
91-
shutil.move(tempname, output_directory_path / filename)
89+
gdal.Warp(
90+
warped_map_path,
91+
filter_map_path,
92+
options=gdal.WarpOptions(
93+
creationOptions=[],
94+
multithread=True,
95+
dstSRS=target_projection,
96+
outputType=gdal.GDT_Float32,
97+
xRes=pixel_scale,
98+
yRes=((0.0 - pixel_scale) if pixel_scale else pixel_scale),
99+
resampleAlg="average",
100+
warpOptions=[f'NUM_THREADS={max_threads}'],
101+
warpMemoryLimit=available_mem,
102+
workingType=gdal.GDT_Float32
103+
)
104+
)
105+
106+
logger.info("Saving %s...", habitat_value)
107+
filename = f"lcc_{habitat_value}.tif"
108+
with yg.read_raster(warped_map_path) as result:
109+
result.to_geotiff(output_directory_path / filename)
92110

93111
def habitat_process(
94112
habitat_path: Path,
@@ -99,48 +117,20 @@ def habitat_process(
99117
) -> None:
100118
os.makedirs(output_directory_path, exist_ok=True)
101119

102-
with RasterLayer.layer_from_file(habitat_path) as habitat_map:
103-
# The processing stage uses GDAL warp directly, with no chunking, so we should
104-
# take a guess at how much memory we need based on the dimensions of the base map
105-
pixels = habitat_map.window.xsize * habitat_map.window.ysize
106-
# I really tried not to write this statement and use introspection, but nothing
107-
# I tried gave a sensible answer. Normally I'd be more paranoid due to numpy bloat,
108-
# but we're calling GDALwarp and passing it filenames, so everything should be done
109-
# in the C++ world of GDAL, so I have more confidence that we won't see the usual
110-
# 4x plus memory bloat of loading raster data into the python world.
111-
match habitat_map.datatype:
112-
case gdal.GDT_Byte | gdal.GDT_Int8:
113-
pixel_size = 1
114-
case gdal.GDT_CInt16 | gdal.GDT_Int16 | gdal.GDT_UInt16:
115-
pixel_size = 2
116-
case gdal.GDT_CFloat32 | gdal.GDT_CInt32 | gdal.GDT_Float32 | gdal.GDT_Int32:
117-
pixel_size = 4
118-
case _:
119-
pixel_size = 8
120-
estimated_memory = pixel_size * pixels
121-
122-
mem_stats = psutil.virtual_memory()
123-
max_copies = math.floor((mem_stats.available * 0.8) / estimated_memory)
124-
if max_copies == 0:
125-
logger.warning("Low memory")
126-
max_copies = 1
127-
process_count = min(max_copies, process_count)
128-
logger.info("Estimating we can run %s concurrent tasks", process_count)
129-
130120
# We need to know how many terrains there are. We could get this from the crosswalk
131121
# table, but we can also work out the unique values ourselves. In practice this is
132122
# worth the effort, otherwise we generate a lot of empty maps potentially.
133123
habitats = enumerate_terrain_types(habitat_path)
134124

135-
if max_copies > 1:
136-
with Pool(processes=process_count) as pool:
137-
pool.map(
138-
partial(_make_single_type_map, habitat_path, pixel_scale, target_projection, output_directory_path),
139-
habitats
140-
)
141-
else:
142-
for habitat in habitats:
143-
_make_single_type_map(habitat_path, pixel_scale, target_projection, output_directory_path, habitat)
125+
for habitat in habitats:
126+
make_single_type_map(
127+
habitat_path,
128+
pixel_scale,
129+
target_projection,
130+
output_directory_path,
131+
process_count,
132+
habitat,
133+
)
144134

145135
def main() -> None:
146136
parser = argparse.ArgumentParser(description="Downsample habitat map to raster per terrain type.")
@@ -177,7 +167,7 @@ def main() -> None:
177167
"-j",
178168
type=int,
179169
required=False,
180-
default=round(cpu_count() / 2),
170+
default=cpu_count(),
181171
dest="processes_count",
182172
help="Optional number of concurrent threads to use."
183173
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ dependencies = [
2929
"psutil",
3030
"pyproj>=3.4,<4.0",
3131
"scikit-image>=0.20,<1.0",
32-
"yirgacheffe>=1.9.1,<2.0",
32+
"yirgacheffe>=1.10.2,<2.0",
3333
"zenodo_search",
3434
"pandas>=2.0,<3.0",
3535
"gdal[numpy]>=3.8,<3.12",

tests/test_habitat_process.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import yirgacheffe as yg
77
from osgeo import gdal # type: ignore
88

9-
from aoh.habitat_process import enumerate_terrain_types, _make_single_type_map
9+
from aoh.habitat_process import enumerate_terrain_types, make_single_type_map
1010

1111
def generate_habitat_map(
1212
output_path: Path,
@@ -47,11 +47,12 @@ def test_simple_make_single_map() -> None:
4747
generate_habitat_map(habitat_path, (20, 10), options)
4848
assert habitat_path.exists()
4949

50-
_make_single_type_map(
50+
make_single_type_map(
5151
habitat_path,
5252
None,
5353
None,
5454
tmp,
55+
1,
5556
100,
5657
)
5758
expected_result_path = tmp / "lcc_100.tif"
@@ -76,11 +77,12 @@ def test_rescale_make_single_map() -> None:
7677
generate_habitat_map(habitat_path, (20, 10), options)
7778
assert habitat_path.exists()
7879

79-
_make_single_type_map(
80+
make_single_type_map(
8081
habitat_path,
8182
180.0 / 5.0, # Scale down by half
8283
None,
8384
tmp,
85+
1,
8486
100,
8587
)
8688
expected_result_path = tmp / "lcc_100.tif"

0 commit comments

Comments
 (0)