Skip to content

Commit eeacab2

Browse files
authored
Merge pull request #194 from siapy/fix
Fix
2 parents a6c93cc + b2f9226 commit eeacab2

19 files changed

+763
-461
lines changed

scripts/install-dev.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ pdm install
1111

1212
# Install pre-commit
1313
pdm run pre-commit uninstall
14-
pdm run pre-commit install --hook-type commit-msg
14+
pdm run pre-commit install # --hook-type commit-msg

siapy/entities/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from .images import SpectralImage
22
from .imagesets import SpectralImageSet
33
from .pixels import Pixels
4-
from .shapefiles import Shapefile
5-
from .shapes import Shape
4+
from .shapes import Shapefile
65
from .signatures import Signatures
76

87
__all__ = [
98
"SpectralImage",
109
"SpectralImageSet",
1110
"Pixels",
1211
"Signatures",
13-
"Shape",
1412
"Shapefile",
1513
]

siapy/entities/images/interfaces.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from abc import ABC, abstractmethod
22
from pathlib import Path
3-
from typing import Any
3+
from typing import TYPE_CHECKING, Any
44

55
import numpy as np
66
from PIL import Image
77

8+
if TYPE_CHECKING:
9+
from siapy.core.types import XarrayType
10+
811
__all__ = [
912
"ImageBase",
1013
]
@@ -58,3 +61,7 @@ def to_display(self, equalize: bool = True) -> Image.Image:
5861
@abstractmethod
5962
def to_numpy(self, nan_value: float | None = None) -> np.ndarray:
6063
pass
64+
65+
@abstractmethod
66+
def to_xarray(self) -> "XarrayType":
67+
pass

siapy/entities/images/rasterio_lib.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def open(cls, filepath: str | Path) -> "RasterioLibImage":
3535
raise InvalidInputError({"filepath": str(filepath)}, f"Failed to open raster file: {e}") from e
3636

3737
if isinstance(raster, list):
38-
raise InvalidInputError({"file_type": type(raster).__name__}, "Expected DataArray, got Dataset")
38+
raise InvalidInputError({"file_type": type(raster).__name__}, "Expected DataArray or Dataset, got list")
3939

4040
return cls(raster)
4141

@@ -49,44 +49,41 @@ def filepath(self) -> Path:
4949

5050
@property
5151
def metadata(self) -> dict[str, Any]:
52-
return dict(self.file.attrs)
52+
return self.file.attrs
5353

5454
@property
5555
def shape(self) -> tuple[int, int, int]:
5656
# rioxarray uses (band, y, x) ordering
5757
return (self.file.y.size, self.file.x.size, self.file.band.size)
5858

59+
@property
60+
def rows(self) -> int:
61+
return self.file.y.size
62+
63+
@property
64+
def cols(self) -> int:
65+
return self.file.x.size
66+
5967
@property
6068
def bands(self) -> int:
6169
return self.file.band.size
6270

6371
@property
6472
def default_bands(self) -> list[int]:
6573
# Most common RGB band combination for satellite imagery
66-
if self.bands >= 3:
67-
return [0, 1, 2]
68-
return list(range(min(3, self.bands)))
74+
return list(range(1, min(3, self.bands) + 1))
6975

7076
@property
7177
def wavelengths(self) -> list[float]:
72-
# Try to get wavelengths from band attributes
73-
wavelengths = []
74-
for band_idx in range(self.bands):
75-
band_data = self.file.sel(band=band_idx + 1)
76-
wave = band_data.attrs.get("wavelength")
77-
if wave:
78-
wavelengths.append(float(wave))
79-
else:
80-
wavelengths.append(float(band_idx + 1))
81-
return wavelengths
78+
return self.file.band.values
8279

8380
@property
8481
def camera_id(self) -> str:
82+
# Todo: camera_id is not a standard metadata field, should be updated
8583
return self.metadata.get("camera_id", "")
8684

8785
def to_display(self, equalize: bool = True) -> Image.Image:
88-
selected_bands = [i + 1 for i in self.default_bands] # Adjust for 1-indexed bands
89-
bands_data = self.file.sel(band=selected_bands)
86+
bands_data = self.file.sel(band=self.default_bands)
9087
image_3ch = bands_data.transpose("y", "x", "band").values
9188
image_3ch_clean = np.nan_to_num(np.asarray(image_3ch))
9289
min_val = np.nanmin(image_3ch_clean)
@@ -100,7 +97,10 @@ def to_display(self, equalize: bool = True) -> Image.Image:
10097
return image
10198

10299
def to_numpy(self, nan_value: float | None = None) -> np.ndarray:
103-
image = np.moveaxis(np.asarray(self.file.values), 0, -1)
100+
image = np.asarray(self.file.transpose("y", "x", "band").values)
104101
if nan_value is not None:
105102
image = np.nan_to_num(image, nan=nan_value)
106103
return image
104+
105+
def to_xarray(self) -> "XarrayType":
106+
return self.file

siapy/entities/images/spectral_lib.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
import numpy as np
77
import spectral as sp
8+
import xarray as xr
89
from PIL import Image, ImageOps
910

1011
from siapy.core.exceptions import InvalidFilepathError, InvalidInputError
1112

1213
from .interfaces import ImageBase
1314

1415
if TYPE_CHECKING:
15-
from siapy.core.types import SpectralLibType
16+
from siapy.core.types import SpectralLibType, XarrayType
1617

1718
__all__ = [
1819
"SpectralLibImage",
@@ -117,6 +118,20 @@ def _remove_nan(self, image: np.ndarray, nan_value: float = 0.0) -> np.ndarray:
117118
image[~image_mask] = nan_value
118119
return image
119120

121+
def to_xarray(self) -> "XarrayType":
122+
data = self._file[:, :, :]
123+
xarray = xr.DataArray(
124+
data,
125+
dims=["y", "x", "band"],
126+
coords={
127+
"y": np.arange(self.rows),
128+
"x": np.arange(self.cols),
129+
"band": self.wavelengths,
130+
},
131+
attrs=self._file.metadata,
132+
)
133+
return xarray
134+
120135

121136
def _parse_description(description: str) -> dict[str, Any]:
122137
def _parse():

siapy/entities/images/spimage.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
import numpy as np
66
from PIL import Image
77

8-
from ..shapes import GeometricShapes, Shape
8+
from ..shapes import GeometricShapes, ShapeBase
99
from ..signatures import Signatures
1010
from .interfaces import ImageBase
11+
from .rasterio_lib import RasterioLibImage
1112
from .spectral_lib import SpectralLibImage
1213

1314
if TYPE_CHECKING:
15+
from siapy.core.types import XarrayType
16+
1417
from ..pixels import Pixels
1518

1619

@@ -26,7 +29,7 @@ class SpectralImage(Generic[T]):
2629
def __init__(
2730
self,
2831
image: T,
29-
geometric_shapes: list["Shape"] | None = None,
32+
geometric_shapes: list["ShapeBase"] | None = None,
3033
):
3134
self._image = image
3235
self._geometric_shapes = GeometricShapes(self, geometric_shapes)
@@ -52,10 +55,10 @@ def spy_open(
5255
image = SpectralLibImage.open(header_path=header_path, image_path=image_path)
5356
return SpectralImage(image)
5457

55-
# @classmethod
56-
# def rasterio_open(cls, filepath: str | Path) -> "SpectralImage":
57-
# image = RasterioLib.open(filepath)
58-
# return cls(image)
58+
@classmethod
59+
def rasterio_open(cls, filepath: str | Path) -> "SpectralImage[RasterioLibImage]":
60+
image = RasterioLibImage.open(filepath)
61+
return SpectralImage(image)
5962

6063
@property
6164
def image(self) -> T:
@@ -99,6 +102,9 @@ def to_display(self, equalize: bool = True) -> Image.Image:
99102
def to_numpy(self, nan_value: float | None = None) -> np.ndarray:
100103
return self.image.to_numpy(nan_value)
101104

105+
def to_xarray(self) -> "XarrayType":
106+
return self.image.to_xarray()
107+
102108
def to_signatures(self, pixels: "Pixels") -> Signatures:
103109
image_arr = self.to_numpy()
104110
signatures = Signatures.from_array_and_pixels(image_arr, pixels)

0 commit comments

Comments
 (0)