Skip to content

Commit 570d17a

Browse files
authored
Merge pull request #125 from quantifyearth/mwd-import-speed
Load heavier imports lazily
2 parents 228ee2c + 7b3c6ab commit 570d17a

File tree

5 files changed

+25
-9
lines changed

5 files changed

+25
-9
lines changed

CHANGES.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
## v1.12.5 (18/02/2026)
2+
3+
### Changed
4+
5+
* Reduce initial import time.
6+
17
## v1.12.4 (12/02/2026)
28

39
### Added

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
66

77
[project]
88
name = "yirgacheffe"
9-
version = "1.12.4"
9+
version = "1.12.5"
1010
description = "Abstraction of gdal datasets for doing basic math operations"
1111
readme = "README.md"
1212
authors = [{ name = "Michael Dales", email = "mwd24@cam.ac.uk" }]
@@ -31,6 +31,7 @@ dependencies = [
3131
"tomli",
3232
"h3",
3333
"pyproj",
34+
"lazy_loader"
3435
]
3536
requires-python = ">=3.10"
3637

yirgacheffe/_backends/numpy.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from osgeo import gdal
66
import numpy as np
7-
import torch
87

98
from .enumeration import operators as op
109
from .enumeration import dtype
@@ -63,15 +62,18 @@ def init() -> None:
6362
gdal.SetCacheMax(constants.GDAL_CACHE_LIMIT)
6463

6564
def conv2d_op(data, weights):
65+
# Pytorch is very slow to import, and currently a niche use case
66+
# for just this operator. As such we defer import to a local one
67+
# to improve initial load times.
68+
import torch # pylint: disable=C0415
69+
6670
# torch wants to process dimensions of channels of width of height
6771
# Which is why both the data and weights get nested into two arrays here,
6872
# and then we have to unpack it from that nesting.
69-
7073
preped_weights = np.array([[weights]])
7174
conv = torch.nn.Conv2d(1, 1, weights.shape, bias=False)
7275
conv.weight = torch.nn.Parameter(torch.from_numpy(preped_weights))
7376
preped_data = torch.from_numpy(np.array([[data]]))
74-
7577
res = conv(preped_data)
7678
return res.detach().numpy()[0][0]
7779

yirgacheffe/_datatypes/mapprojection.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
import math
33
from functools import lru_cache
44

5-
import pyproj
6-
from pyproj import CRS
5+
import lazy_loader as lazy # type: ignore
76

87
from .pixelscale import PixelScale
98

9+
# Pyproj is relatively slow to import, which was adding reasonably to
10+
# the import time for yirgacheffe, so make it lazy
11+
pyproj = lazy.load('pyproj')
12+
1013
# As per https://xkcd.com/2170/, we need to stop caring about floating point
1114
# accuracy at some point as it becomes problematic.
1215
# The value here is 1 meter, given that geo data that we've been working with
@@ -27,7 +30,7 @@
2730
# that makes 1600 rasters take 72 seconds without this fix, and 2 seconds with it.
2831
@lru_cache(maxsize=128)
2932
def _get_projection_string(provided_name: str) -> str:
30-
crs = CRS.from_string(provided_name)
33+
crs = pyproj.CRS.from_string(provided_name)
3134
epsg = crs.to_epsg()
3235
if epsg is not None:
3336
return f"EPSG:{epsg}"
@@ -64,7 +67,7 @@ class MapProjection:
6467

6568
def __init__(self, projection_string: str, xstep: float, ystep: float) -> None:
6669
try:
67-
self.crs = CRS.from_string(projection_string)
70+
self.crs = pyproj.CRS.from_string(projection_string)
6871
except pyproj.exceptions.CRSError as exc:
6972
raise ValueError(f"Invalid projection: {projection_string}") from exc
7073
self.xstep = xstep

yirgacheffe/layers/h3layer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@
33
from math import ceil, floor
44
from typing import Any
55

6-
import h3
6+
import lazy_loader as lazy # type: ignore
77
import numpy as np
88

99
from .._datatypes import Area, MapProjection, Window
1010
from .base import YirgacheffeLayer
1111
from .._backends import backend
1212
from .._backends.enumeration import dtype as DataType
1313

14+
# H3 is relatively slow to import, which was adding reasonably to
15+
# the import time for yirgacheffe, so make it lazy
16+
h3 = lazy.load('h3')
17+
1418
class H3CellLayer(YirgacheffeLayer):
1519

1620
def __init__(self, cell_id: str, projection: MapProjection):

0 commit comments

Comments
 (0)