Skip to content

Commit 149f3c8

Browse files
make stats and histogram optional #466 (#467)
* make stats and histogram optional #466 * update docstring and changelog * update comment * fix(docs): use PR for changelog * Update tests/core/test_add_raster.py * Update tests/core/test_add_raster.py --------- Co-authored-by: Pete Gadomski <[email protected]>
1 parent 53b99ad commit 149f3c8

File tree

3 files changed

+204
-11
lines changed

3 files changed

+204
-11
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
- Make computation of statistics and histogram optional for `core.add_raster.add_raster_to_item` ([#467](https://github.com/stac-utils/stactools/pull/467))
11+
1012
## [0.5.2] - 2023-09-20
1113

1214
### Fixed

src/stactools/core/add_raster.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@
1818
BINS = 256
1919

2020

21-
def add_raster_to_item(item: Item) -> Item:
21+
def add_raster_to_item(
22+
item: Item, statistics: bool = True, histogram: bool = True
23+
) -> Item:
2224
"""Adds the raster extension to an item.
2325
2426
Args:
2527
item (Item): The PySTAC Item to extend.
28+
statistics (bool): Compute band statistics (min/max). Defaults to True
29+
histogram (bool): Compute band histogram. Defaults to True
2630
2731
Returns:
2832
Item:
@@ -34,27 +38,38 @@ def add_raster_to_item(item: Item) -> Item:
3438
if asset.roles and "data" in asset.roles:
3539
raster = RasterExtension.ext(asset)
3640
href = make_absolute_href(asset.href, item.get_self_href())
37-
bands = _read_bands(href)
41+
bands = _read_bands(href, statistics, histogram)
3842
if bands:
3943
raster.apply(bands)
4044
return item
4145

4246

43-
def _read_bands(href: str) -> List[RasterBand]:
47+
def _read_bands(href: str, statistics: bool, histogram: bool) -> List[RasterBand]:
4448
bands = []
4549
with rasterio.open(href) as dataset:
4650
for i, index in enumerate(dataset.indexes):
47-
data = dataset.read(index, masked=True)
4851
band = RasterBand.create()
4952
band.nodata = dataset.nodatavals[i]
5053
band.spatial_resolution = dataset.transform[0]
5154
band.data_type = DataType(dataset.dtypes[i])
52-
minimum = float(numpy.min(data))
53-
maximum = float(numpy.max(data))
54-
band.statistics = Statistics.create(minimum=minimum, maximum=maximum)
55-
hist_data, _ = numpy.histogram(data, range=(minimum, maximum), bins=BINS)
56-
band.histogram = Histogram.create(
57-
BINS, minimum, maximum, hist_data.tolist()
58-
)
55+
56+
if statistics or histogram:
57+
data = dataset.read(index, masked=True)
58+
minimum = float(numpy.nanmin(data))
59+
maximum = float(numpy.nanmax(data))
60+
if statistics:
61+
band.statistics = Statistics.create(minimum=minimum, maximum=maximum)
62+
if histogram:
63+
# the entire array is masked, or all values are NAN.
64+
# won't be able to compute histogram and will return empty array.
65+
if numpy.isnan(minimum):
66+
band.histogram = Histogram.create(0, minimum, maximum, [])
67+
else:
68+
hist_data, _ = numpy.histogram(
69+
data, range=(minimum, maximum), bins=BINS
70+
)
71+
band.histogram = Histogram.create(
72+
BINS, minimum, maximum, hist_data.tolist()
73+
)
5974
bands.append(band)
6075
return bands

tests/core/test_add_raster.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import tempfile
2+
from typing import Callable, List, Optional
3+
4+
import numpy as np
5+
import pystac
6+
import pytest
7+
import rasterio
8+
from rasterio.crs import CRS
9+
from rasterio.transform import Affine
10+
from stactools.core import create
11+
from stactools.core.add_raster import add_raster_to_item
12+
13+
14+
def random_data(count: int) -> np.ndarray:
15+
return np.random.rand(count, 10, 10) * 10
16+
17+
18+
def nan_data(count: int) -> np.ndarray:
19+
data = np.empty((count, 10, 10))
20+
data[:] = np.nan
21+
return data
22+
23+
24+
def data_with_nan(count: int) -> np.ndarray:
25+
data = np.random.rand(count, 10, 10) * 10
26+
data[0][1][1] = np.nan
27+
return data
28+
29+
30+
def zero_data(count: int) -> np.ndarray:
31+
return np.zeros((count, 10, 10))
32+
33+
34+
def test_add_raster(tmp_asset_path) -> None:
35+
item = create.item(tmp_asset_path)
36+
add_raster_to_item(item)
37+
38+
asset: pystac.Asset = item.assets["data"]
39+
40+
_assert_asset(
41+
asset,
42+
expected_count=1,
43+
expected_nodata=None,
44+
expected_spatial_resolution=60.0,
45+
expected_dtype=np.dtype("uint8"),
46+
expected_min=[74.0],
47+
expected_max=[255.0],
48+
)
49+
50+
51+
@pytest.mark.parametrize(
52+
"count,nodata,dtype,datafunc,hist_count",
53+
[
54+
(1, 0, np.dtype("int8"), random_data, 256),
55+
(1, None, np.dtype("float64"), random_data, 256),
56+
(1, np.nan, np.dtype("float64"), random_data, 256),
57+
(2, 0, np.dtype("int8"), random_data, 256),
58+
(2, None, np.dtype("float64"), random_data, 256),
59+
(2, np.nan, np.dtype("float64"), random_data, 256),
60+
(1, 0, np.dtype("uint8"), zero_data, 0),
61+
(1, None, np.dtype("uint8"), zero_data, 256),
62+
(1, None, np.dtype("float64"), nan_data, 0),
63+
(1, np.nan, np.dtype("float64"), nan_data, 0),
64+
(1, None, np.dtype("float64"), data_with_nan, 256),
65+
(1, np.nan, np.dtype("float64"), data_with_nan, 256),
66+
],
67+
)
68+
def test_add_raster_with_nodata(
69+
count: int, nodata: float, dtype: np.dtype, datafunc: Callable, hist_count: int
70+
) -> None:
71+
with tempfile.NamedTemporaryFile(suffix=".tif") as tmpfile:
72+
with rasterio.open(
73+
tmpfile.name,
74+
mode="w",
75+
driver="GTiff",
76+
count=count,
77+
nodata=nodata,
78+
dtype=dtype,
79+
transform=Affine(0.1, 0.0, 1.0, 0.0, -0.1, 1.0),
80+
width=10,
81+
height=10,
82+
crs=CRS.from_epsg(4326),
83+
) as dst:
84+
data = datafunc(count)
85+
data.astype(dtype)
86+
dst.write(data)
87+
88+
with rasterio.open(tmpfile.name) as src:
89+
data = src.read(masked=True)
90+
minimum = []
91+
maximum = []
92+
for i, _ in enumerate(src.indexes):
93+
minimum.append(float(np.nanmin(data[i])))
94+
maximum.append(float(np.nanmax(data[i])))
95+
96+
item = create.item(tmpfile.name)
97+
98+
add_raster_to_item(item)
99+
100+
asset: pystac.Asset = item.assets["data"]
101+
_assert_asset(
102+
asset,
103+
expected_count=count,
104+
expected_nodata=nodata,
105+
expected_spatial_resolution=0.1,
106+
expected_dtype=dtype,
107+
expected_min=minimum,
108+
expected_max=maximum,
109+
expected_hist_count=hist_count,
110+
)
111+
112+
113+
def test_add_raster_without_stats(tmp_asset_path) -> None:
114+
item = create.item(tmp_asset_path)
115+
add_raster_to_item(item, statistics=False)
116+
117+
asset: pystac.Asset = item.assets["data"]
118+
bands = asset.extra_fields.get("raster:bands")
119+
120+
assert bands[0].get("statistics") is None
121+
assert bands[0].get("histogram")
122+
123+
124+
def test_add_raster_without_histogram(tmp_asset_path) -> None:
125+
item = create.item(tmp_asset_path)
126+
add_raster_to_item(item, histogram=False)
127+
128+
asset: pystac.Asset = item.assets["data"]
129+
bands = asset.extra_fields.get("raster:bands")
130+
131+
assert bands[0].get("statistics")
132+
assert bands[0].get("histogram") is None
133+
134+
135+
def _assert_asset(
136+
asset: pystac.Asset,
137+
expected_count: int,
138+
expected_nodata: Optional[float],
139+
expected_dtype: np.dtype,
140+
expected_spatial_resolution: float,
141+
expected_min: List[float],
142+
expected_max: List[float],
143+
expected_hist_count=256,
144+
) -> None:
145+
bands = asset.extra_fields.get("raster:bands")
146+
assert bands
147+
assert len(bands) == expected_count
148+
149+
for i, band in enumerate(bands):
150+
nodata = band.get("nodata")
151+
dtype = band["data_type"].value
152+
spatial_resolution = band["spatial_resolution"]
153+
statistics = band["statistics"]
154+
histogram = band["histogram"]
155+
assert nodata == expected_nodata or (
156+
np.isnan(nodata) and np.isnan(expected_nodata)
157+
)
158+
assert dtype == expected_dtype.name
159+
assert spatial_resolution == expected_spatial_resolution
160+
assert statistics == {
161+
"minimum": expected_min[i],
162+
"maximum": expected_max[i],
163+
} or (
164+
np.isnan(statistics["maximum"])
165+
and np.isnan(expected_max[i])
166+
and np.isnan(statistics["minimum"])
167+
and np.isnan(expected_min[i])
168+
)
169+
assert histogram["count"] == expected_hist_count
170+
assert histogram["max"] == band["statistics"]["maximum"] or (
171+
np.isnan(histogram["max"]) and np.isnan(statistics["maximum"])
172+
)
173+
assert histogram["min"] == band["statistics"]["minimum"] or (
174+
np.isnan(histogram["min"]) and np.isnan(statistics["minimum"])
175+
)
176+
assert len(histogram["buckets"]) == histogram["count"]

0 commit comments

Comments
 (0)