11from pathlib import Path
2+ import warnings
23import rasterio as rio
34from rasterio.windows import Window
45import numpy as np
56import pickle
67import multiprocessing
78import time
9+ from tqdm import tqdm
810
911
1012class ImageChip:
@@ -36,6 +38,7 @@ def __init__(
3638 output_format="tif",
3739 max_batch_size=1000,
3840 ):
41+
3942 self.input_image_path = Path(input_image_path)
4043 self.output_path = Path(output_path) if output_path else Path(input_image_path)
4144 self.output_name = output_name if output_name else Path(input_image_path).stem
@@ -46,6 +49,41 @@ def __init__(
4649 self.use_multiprocessing = use_multiprocessing
4750 self.output_format = output_format
4851 self.max_batch_size = max_batch_size
52+ if not self.input_image_path.exists():
53+ raise FileNotFoundError(f"Input image not found: {self.input_image_path}")
54+ self._read_image_metadata()
55+ if self.pixel_dimensions <= 0:
56+ raise ValueError("pixel_dimensions must be a positive integer")
57+ if self.offset <= 0:
58+ raise ValueError("offset must be a positive integer")
59+ if self.output_format not in ("tif", "npz"):
60+ raise ValueError(
61+ f"output_format must be 'tif' or 'npz', got '{self.output_format}'"
62+ )
63+
64+ def _read_image_metadata(self) -> None:
65+ """
66+ Read image profile metadata on initialisation.
67+
68+ Sets self.nodata_val from the image nodata property. Set to 0 with a warning
69+ if no value set.
70+ Sets self._band_count for use in band validation.
71+ """
72+ with rio.open(self.input_image_path) as src:
73+ self._band_count = src.count
74+ nodata = src.nodata
75+
76+ if nodata is not None:
77+ self.nodata_val = nodata
78+ else:
79+ self.nodata_val = 0
80+ warnings.warn(
81+ f"No nodata value found in {self.input_image_path.name}. "
82+ "Defaulting to 0. If 0 is a valid pixel value in your image, "
83+ "scaling and normalisation will incorrectly treat those pixels as nodata.",
84+ UserWarning,
85+ stacklevel=2,
86+ )
4987
5088 def _generate_windows(self, src):
5189 """
@@ -91,7 +129,7 @@ def _save_chip(self, chip, transform, output_file_path, d_type, src) -> None:
91129 dtype=d_type,
92130 crs=src.crs,
93131 transform=transform,
94- nodata=None ,
132+ nodata=self.nodata_val ,
95133 ) as dst:
96134 dst.write(chip)
97135
@@ -122,11 +160,9 @@ def _output_file(self, x: int, y: int) -> Path:
122160 Returns:
123161 Path: The full path (as a `Path` object) where the chip will be saved, including the generated file name.
124162 """
125- if self.output_name is None:
126- output_file_name = f"{self.input_image_path.stem}_{x}_{y}.tif"
127- else:
128- output_name = self.output_name.replace(".tif", "")
129- output_file_name = f"{output_name}_{x}_{y}.tif"
163+
164+ output_name = self.output_name.replace(".tif", "")
165+ output_file_name = f"{output_name}_{x}_{y}.tif"
130166 return self.output_path / output_file_name
131167
132168 def set_scaler(self, sample_size=10000, write_file=True, write_path=None):
@@ -147,7 +183,7 @@ def set_scaler(self, sample_size=10000, write_file=True, write_path=None):
147183 <image_file>_scaler_<sample_size>.pkl.
148184 """
149185 if self.normaliser is not None:
150- print ("normaliser will be set to None")
186+ warnings.warn ("normaliser will be set to None")
151187 self.normaliser = None
152188 self.standard_scaler = self.sample_to_scaler(sample_size=sample_size)
153189 if write_file:
@@ -176,8 +212,7 @@ def _validate_normaliser_inputs(self, value, name):
176212 Raises:
177213 ValueError: If value is not valid.
178214 """
179- with rio.open(self.input_image_path) as f:
180- bands = f.profile["count"]
215+ bands = self._band_count
181216 if isinstance(value, list):
182217 if len(value) != bands:
183218 raise ValueError(
@@ -212,7 +247,7 @@ def set_normaliser(
212247 write_file (bool): If True a pickle file is written containing the normaliser dictionary.
213248 write_path (string): The directory and filename (.pkl) where the scaler will be written if `write_file`.
214249 None by default when the file path is written to the same dir as ImageChip.input_image_path and file name
215- <image_file>_scaler_<sample_size> .pkl.
250+ <image_file>_normaliser .pkl.
216251
217252 """
218253 if min_val is None or max_val is None:
@@ -227,7 +262,7 @@ def set_normaliser(
227262 max_val = self._validate_normaliser_inputs(max_val, "max_val")
228263
229264 if self.standard_scaler is not None:
230- print ("standard_scaler will be set to None")
265+ warnings.warn ("standard_scaler will be set to None")
231266 self.standard_scaler = None
232267
233268 self.normaliser = {"min_val": min_val, "max_val": max_val}
@@ -243,13 +278,12 @@ def set_normaliser(
243278 pickle_file_path = output_dir / pickle_file_name
244279 else:
245280 pickle_file_path = write_path
246- # Save the dictionary to a pickle file
281+ # Save the dictionary to a pickle file
247282 with open(pickle_file_path, "wb") as f:
248283 pickle.dump(self.normaliser, f)
249284 print(f"Written normaliser to {pickle_file_path}")
250285
251- @staticmethod
252- def apply_normaliser(array: np.ndarray, normaliser_dict: dict) -> np.ndarray:
286+ def apply_normaliser(self, array: np.ndarray, normaliser_dict: dict) -> np.ndarray:
253287 """Normalises a numpy array based on min and max values created by `set_normaliser`.
254288
255289 Args:
@@ -273,8 +307,7 @@ def apply_normaliser(array: np.ndarray, normaliser_dict: dict) -> np.ndarray:
273307 for i in range(array.shape[0]):
274308 min_val = min_vals[i]
275309 max_val = max_vals[i]
276- # Apply normalising only to non-zero values (assuming 0 is nodata)
277- mask = array[i, :, :] != 0
310+ mask = array[i, :, :] != self.nodata_val
278311 clipped = np.clip(array[i, :, :], min_val, max_val)
279312 normalised_array[i, :, :] = np.where(
280313 mask, (clipped - min_val) / (max_val - min_val), 0
@@ -318,6 +351,7 @@ def sample_to_scaler(self, sample_size: int) -> dict:
318351 band_pixel_values = pixel_values[:, band_index]
319352 valid_band_pixel_values = band_pixel_values[
320353 ~np.isnan(band_pixel_values)
354+ & (band_pixel_values != self.nodata_val)
321355 ]
322356 band_vals = {
323357 "band_name": band_names[band_index],
@@ -328,9 +362,8 @@ def sample_to_scaler(self, sample_size: int) -> dict:
328362
329363 return stats_dict
330364
331- @staticmethod
332365 def apply_scaler(
333- array: np.ndarray, scaler_dict: dict[int, dict[str, float]]
366+ self, array: np.ndarray, scaler_dict: dict[int, dict[str, float]]
334367 ) -> np.ndarray:
335368 """Standard scales a numpy array based on mean and std values from a dictionary.
336369
@@ -353,8 +386,7 @@ def apply_scaler(
353386 band_info = scaler_dict.get(i)
354387 mean = band_info["mean"]
355388 std = band_info["std"]
356- # Apply scaling only to non-zero values (assuming 0 is nodata)
357- mask = array[i, :, :] != 0
389+ mask = array[i, :, :] != self.nodata_val
358390 scaled_array[i, :, :] = np.where(mask, (array[i, :, :] - mean) / std, 0)
359391 return scaled_array
360392
@@ -400,7 +432,9 @@ def _process_batch(self, batch_vals):
400432 out = {}
401433 with rio.open(self.input_image_path) as src:
402434 for x, y, window in batch:
403- chip = src.read(window=window, boundless=True, fill_value=0)
435+ chip = src.read(
436+ window=window, boundless=True, fill_value=self.nodata_val
437+ )
404438 if self.standard_scaler:
405439 chip = self.apply_scaler(chip, self.standard_scaler)
406440 if self.normaliser:
@@ -471,16 +505,19 @@ def chip_image(self) -> None:
471505 batches = self._calculate_batches(windows)
472506
473507 if self.use_multiprocessing:
474- print(f"Processing {len(batches)} batches in parallel.")
475508 num_cores = multiprocessing.cpu_count() - 1 # leave a core free?
476- print(f"Using {num_cores} cores.")
509+ print(
510+ f"Processing {len(batches)} batches in parallel using {num_cores} cores."
511+ )
477512 with multiprocessing.Pool(processes=num_cores) as pool:
478- pool.map(self._process_batch, batches)
513+ with tqdm(
514+ total=len(batches), desc="Chipping (parallel)", unit="batch"
515+ ) as pbar:
516+ for _ in pool.imap_unordered(self._process_batch, batches):
517+ pbar.update()
479518 else:
480- print(f"Processing in {len(batches)} batches")
481- for i, batch in enumerate(batches):
519+ for batch in tqdm(batches, desc="Chipping", unit="batch"):
482520 self._process_batch(batch)
483- print(f"Processed batch {i + 1} of {len(batches)}.")
484521
485522 elapsed_time = time.time() - start_time
486523 print(f"Chipping completed in {elapsed_time:.2f} seconds.")
0 commit comments