Skip to content

Commit 44a32e5

Browse files
authored
Merge pull request #119 from quantifyearth/mwd-cache-limits
Enforce cache limits on GDAL and MLX
2 parents 4886dce + 64fc872 commit 44a32e5

File tree

5 files changed

+32
-1
lines changed

5 files changed

+32
-1
lines changed

CHANGES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
## v1.12.2 (26/1/2026)
1+
## v1.12.2 (27/1/2026)
22

33
### Added
44

55
* Added `yg.sum`, `yg.all`, `yg.any` to build layers from lists of layers.
6+
* Cache limits applied to GDAL and MLX which tend to assume that they are the only thing using memory. You can override these by adjusting `yg.constants.GDAL_CACHE_LIMIT` and `yg.constants.MLX_CACHE_LIMIT`.
67

78
## v1.12.1 (21/1/2026)
89

yirgacheffe/_backends/mlx.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
from typing import Callable
44

5+
from osgeo import gdal
56
import numpy as np
67
import mlx.core as mx # type: ignore
78
import mlx.nn
89

910
from .enumeration import operators as op
1011
from .enumeration import dtype
12+
from .. import constants
1113

1214
array_t = mx.array
1315
float_t = mx.float32
@@ -53,6 +55,10 @@
5355
round_op = mx.round
5456
ceil_op = mx.ceil
5557

58+
def init() -> None:
59+
gdal.SetCacheMax(constants.GDAL_CACHE_LIMIT)
60+
mx.set_cache_limit(constants.MLX_CACHE_LIMIT)
61+
5662
def sum_op(a):
5763
# By default the type promotion rules for sum in MLX are not the same as with Numpy. E.g.,
5864
#

yirgacheffe/_backends/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
from typing import Callable
44

5+
from osgeo import gdal
56
import numpy as np
67
import torch
78

89
from .enumeration import operators as op
910
from .enumeration import dtype
11+
from .. import constants
1012

1113
array_t = np.ndarray
1214
float_t = np.float64
@@ -57,6 +59,9 @@
5759
round_op = np.round
5860
ceil_op = np.ceil
5961

62+
def init() -> None:
63+
gdal.SetCacheMax(constants.GDAL_CACHE_LIMIT)
64+
6065
def conv2d_op(data, weights):
6166
# torch wants to process dimensions of channels of width of height
6267
# Which is why both the data and weights get nested into two arrays here,

yirgacheffe/_operators/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,8 @@ def sum(self) -> float:
811811

812812
cse_cache = CSECacheTable(self, computation_window)
813813

814+
backend.init()
815+
814816
for yoffset in range(0, computation_window.ysize, self.ystep):
815817
cse_cache.reset_cache()
816818
step=self.ystep
@@ -837,6 +839,8 @@ def min(self) -> float:
837839

838840
cse_cache = CSECacheTable(self, computation_window)
839841

842+
backend.init()
843+
840844
for yoffset in range(0, computation_window.ysize, self.ystep):
841845
cse_cache.reset_cache()
842846
step=self.ystep
@@ -863,6 +867,8 @@ def max(self) -> float:
863867

864868
cse_cache = CSECacheTable(self, computation_window)
865869

870+
backend.init()
871+
866872
for yoffset in range(0, computation_window.ysize, self.ystep):
867873
cse_cache.reset_cache()
868874
step=self.ystep
@@ -898,6 +904,8 @@ def unique(self, return_counts:bool=False) -> np.ndarray | tuple[np.ndarray,np.n
898904

899905
cse_cache = CSECacheTable(self, computation_window)
900906

907+
backend.init()
908+
901909
for yoffset in range(0, computation_window.ysize, self.ystep):
902910
cse_cache.reset_cache()
903911
step=self.ystep
@@ -984,6 +992,8 @@ def save(self, destination_layer, and_sum=False, callback=None, band=1) -> float
984992

985993
cse_cache = CSECacheTable(self, computation_window)
986994

995+
backend.init()
996+
987997
for yoffset in range(0, computation_window.ysize, self.ystep):
988998

989999
cse_cache.reset_cache()
@@ -1024,6 +1034,8 @@ def _parallel_worker(
10241034
# the cache build once per worker
10251035
cse_cache = CSECacheTable(self, computation_window)
10261036

1037+
backend.init()
1038+
10271039
arr = np.ndarray((self.ystep, width), dtype=np_dtype, buffer=shared_mem.buf) # type: ignore[var-annotated]
10281040
projection = self.map_projection
10291041
# TODO: the `save` method does more sanity checking that parallel save!

yirgacheffe/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@
33
YSTEP = 512
44
MINIMUM_CHUNKS_PER_THREAD = 1
55

6+
# Both GDAL and MLX assume that there is one instance running and it has the right to use
7+
# all the memory it can. In general Yirgacheffe's chunking and it's own caching is what we shou
8+
# be relying on, so we set some limits here. These are applied before each calculation, and so
9+
# in theory they can be tweaked if necessary on demand.
10+
GDAL_CACHE_LIMIT = 1 * 1024 * 1024 * 1024
11+
MLX_CACHE_LIMIT = 1 * 1024 * 1024 * 1024
12+
613
# I don't really want this here, but it's just too useful having it exposed
714
# This used to be a fixed string, but now it is at least programmatically generated
815
WGS_84_PROJECTION = pyproj.CRS.from_epsg(4326).to_wkt(version='WKT1_GDAL')

0 commit comments

Comments
 (0)