@@ -35,7 +35,7 @@ def open(cls, filepath: str | Path) -> "RasterioLibImage":
3535 raise InvalidInputError ({"filepath" : str (filepath )}, f"Failed to open raster file: { e } " ) from e
3636
3737 if isinstance (raster , list ):
38- raise InvalidInputError ({"file_type" : type (raster ).__name__ }, "Expected DataArray, got Dataset " )
38+ raise InvalidInputError ({"file_type" : type (raster ).__name__ }, "Expected DataArray or Dataset , got list " )
3939
4040 return cls (raster )
4141
@@ -49,44 +49,41 @@ def filepath(self) -> Path:
4949
5050 @property
5151 def metadata (self ) -> dict [str , Any ]:
52- return dict ( self .file .attrs )
52+ return self .file .attrs
5353
5454 @property
5555 def shape (self ) -> tuple [int , int , int ]:
5656 # rioxarray uses (band, y, x) ordering
5757 return (self .file .y .size , self .file .x .size , self .file .band .size )
5858
59+ @property
60+ def rows (self ) -> int :
61+ return self .file .y .size
62+
63+ @property
64+ def cols (self ) -> int :
65+ return self .file .x .size
66+
5967 @property
6068 def bands (self ) -> int :
6169 return self .file .band .size
6270
6371 @property
6472 def default_bands (self ) -> list [int ]:
6573 # Most common RGB band combination for satellite imagery
66- if self .bands >= 3 :
67- return [0 , 1 , 2 ]
68- return list (range (min (3 , self .bands )))
74+ return list (range (1 , min (3 , self .bands ) + 1 ))
6975
7076 @property
7177 def wavelengths (self ) -> list [float ]:
72- # Try to get wavelengths from band attributes
73- wavelengths = []
74- for band_idx in range (self .bands ):
75- band_data = self .file .sel (band = band_idx + 1 )
76- wave = band_data .attrs .get ("wavelength" )
77- if wave :
78- wavelengths .append (float (wave ))
79- else :
80- wavelengths .append (float (band_idx + 1 ))
81- return wavelengths
78+ return self .file .band .values
8279
8380 @property
8481 def camera_id (self ) -> str :
82+ # Todo: camera_id is not a standard metadata field, should be updated
8583 return self .metadata .get ("camera_id" , "" )
8684
8785 def to_display (self , equalize : bool = True ) -> Image .Image :
88- selected_bands = [i + 1 for i in self .default_bands ] # Adjust for 1-indexed bands
89- bands_data = self .file .sel (band = selected_bands )
86+ bands_data = self .file .sel (band = self .default_bands )
9087 image_3ch = bands_data .transpose ("y" , "x" , "band" ).values
9188 image_3ch_clean = np .nan_to_num (np .asarray (image_3ch ))
9289 min_val = np .nanmin (image_3ch_clean )
@@ -100,7 +97,10 @@ def to_display(self, equalize: bool = True) -> Image.Image:
10097 return image
10198
10299 def to_numpy (self , nan_value : float | None = None ) -> np .ndarray :
103- image = np .moveaxis ( np . asarray (self .file .values ), 0 , - 1 )
100+ image = np .asarray (self .file .transpose ( "y" , "x" , "band" ). values )
104101 if nan_value is not None :
105102 image = np .nan_to_num (image , nan = nan_value )
106103 return image
104+
105+ def to_xarray (self ) -> "XarrayType" :
106+ return self .file
0 commit comments