|
| 1 | +from pathlib import Path |
| 2 | +from typing import TYPE_CHECKING, Any |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import xarray as xr |
| 6 | +from numpy.typing import NDArray |
| 7 | +from PIL import Image |
| 8 | + |
| 9 | +from siapy.core.exceptions import InvalidInputError |
| 10 | +from siapy.entities.images.interfaces import ImageBase |
| 11 | + |
| 12 | +if TYPE_CHECKING: |
| 13 | + from siapy.core.types import XarrayType |
| 14 | + |
| 15 | + |
| 16 | +class MockImage(ImageBase): |
| 17 | + def __init__( |
| 18 | + self, |
| 19 | + array: NDArray[np.floating[Any]], |
| 20 | + ) -> None: |
| 21 | + if len(array.shape) != 3: |
| 22 | + raise InvalidInputError( |
| 23 | + input_value=array.shape, |
| 24 | + message="Input array must be 3-dimensional (height, width, bands)", |
| 25 | + ) |
| 26 | + |
| 27 | + self._array = array.astype(np.float32) |
| 28 | + |
| 29 | + @classmethod |
| 30 | + def open(cls, array: NDArray[np.floating[Any]]) -> "MockImage": |
| 31 | + return cls(array=array) |
| 32 | + |
| 33 | + @property |
| 34 | + def filepath(self) -> Path: |
| 35 | + return Path() |
| 36 | + |
| 37 | + @property |
| 38 | + def metadata(self) -> dict[str, Any]: |
| 39 | + return {} |
| 40 | + |
| 41 | + @property |
| 42 | + def shape(self) -> tuple[int, int, int]: |
| 43 | + x = self._array.shape[1] |
| 44 | + y = self._array.shape[0] |
| 45 | + bands = self._array.shape[2] |
| 46 | + return (y, x, bands) |
| 47 | + |
| 48 | + @property |
| 49 | + def bands(self) -> int: |
| 50 | + return self._array.shape[2] |
| 51 | + |
| 52 | + @property |
| 53 | + def default_bands(self) -> list[int]: |
| 54 | + if self.bands >= 3: |
| 55 | + return [0, 1, 2] |
| 56 | + return list(range(min(3, self.bands))) |
| 57 | + |
| 58 | + @property |
| 59 | + def wavelengths(self) -> list[float]: |
| 60 | + return list(range(self.bands)) |
| 61 | + |
| 62 | + @property |
| 63 | + def camera_id(self) -> str: |
| 64 | + return "" |
| 65 | + |
| 66 | + def to_display(self, equalize: bool = True) -> Image.Image: |
| 67 | + if self.bands >= 3: |
| 68 | + display_bands = self._array[:, :, self.default_bands] |
| 69 | + else: |
| 70 | + display_bands = np.stack([self._array[:, :, 0]] * 3, axis=2) |
| 71 | + |
| 72 | + if equalize: |
| 73 | + for i in range(display_bands.shape[2]): |
| 74 | + band = display_bands[:, :, i] |
| 75 | + non_nan = ~np.isnan(band) |
| 76 | + if np.any(non_nan): |
| 77 | + min_val = np.nanmin(band) |
| 78 | + max_val = np.nanmax(band) |
| 79 | + if max_val > min_val: |
| 80 | + band = (band - min_val) / (max_val - min_val) * 255 |
| 81 | + display_bands[:, :, i] = band |
| 82 | + |
| 83 | + display_array = np.nan_to_num(display_bands).astype(np.uint8) |
| 84 | + return Image.fromarray(display_array) |
| 85 | + |
| 86 | + def to_numpy(self, nan_value: float | None = None) -> NDArray[np.floating[Any]]: |
| 87 | + if nan_value is not None: |
| 88 | + return np.nan_to_num(self._array, nan=nan_value) |
| 89 | + return self._array.copy() |
| 90 | + |
| 91 | + def to_xarray(self) -> "XarrayType": |
| 92 | + return xr.DataArray( |
| 93 | + self._array, |
| 94 | + dims=["y", "x", "band"], |
| 95 | + coords={ |
| 96 | + "band": self.wavelengths, |
| 97 | + "x": np.arange(self.shape[1]), |
| 98 | + "y": np.arange(self.shape[0]), |
| 99 | + }, |
| 100 | + attrs={ |
| 101 | + "camera_id": self.camera_id, |
| 102 | + }, |
| 103 | + ) |
0 commit comments