Skip to content

Commit 9d37b40

Browse files
committed
feat(collection): add bands,bounds,__repr__, fix metadata roundtrip, clean API surface
- Add property (band codes from _metadata columns) - Add property (spatial extent via pyarrow.compute min/max) - Add showing name, source, bands, records, date range - Add return type annotations: get_xarray -> xr.Dataset, get_gdf -> gpd.GeoDataFrame, to_torchgeo_dataset -> RasteretGeoDataset - Restore metadata (name, data_source, description, date_range) from parquet schema on load -- fixes empty data_source after roundtrip - Make from_parquet() Hive-aware (tries Hive first, falls back) - Rename from_local() -> _load_cached() to clarify internal vs public API - Add xr.Dataset return type to get_collection_xarray() in execution.py Signed-off-by: print-sid8 sidsub94@gmail.com
1 parent f6098cc commit 9d37b40

File tree

8 files changed

+150
-31
lines changed

8 files changed

+150
-31
lines changed

src/rasteret/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def build_from_stac(
115115
collection_path = workspace_dir_path / f"{collection_name}_stac"
116116

117117
if collection_path.exists() and not force:
118-
return Collection.from_local(collection_path)
118+
return Collection._load_cached(collection_path)
119119

120120
from rasteret.cloud import CloudConfig, backend_config_from_cloud_config
121121
from rasteret.ingest.stac_indexer import StacCollectionBuilder
@@ -666,7 +666,7 @@ def build_from_table(
666666
if resolved_workspace is not None:
667667
rw = Path(resolved_workspace)
668668
if rw.exists() and not force:
669-
return Collection.from_local(rw)
669+
return Collection._load_cached(rw)
670670

671671
# Arrow-native path: accept an in-memory Arrow table / dataset.
672672
import pyarrow as pa

src/rasteret/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def _handle_cache_list(args: argparse.Namespace) -> int:
181181
def _handle_cache_info(args: argparse.Namespace) -> int:
182182
workspace_dir = _workspace_dir(args.workspace_dir)
183183
collection_path = _resolve_collection_path(args.name, workspace_dir)
184-
collection = Collection.from_local(collection_path)
184+
collection = Collection._load_cached(collection_path)
185185
summary = _collection_summary(collection, collection_path)
186186

187187
if args.json:
@@ -220,7 +220,7 @@ def _handle_cache_import(args: argparse.Namespace) -> int:
220220
collection_path = workspace_dir / f"{args.name}_records"
221221
if collection_path.exists():
222222
if not args.force:
223-
collection = Collection.from_local(collection_path)
223+
collection = Collection._load_cached(collection_path)
224224
summary = _collection_summary(collection, collection_path)
225225
if args.json:
226226
print(json.dumps(summary, indent=2))

src/rasteret/core/collection.py

Lines changed: 133 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import Sequence
1212
from datetime import datetime
1313
from pathlib import Path
14-
from typing import Any, AsyncIterator
14+
from typing import TYPE_CHECKING, Any, AsyncIterator
1515

1616
import pandas as pd
1717
import pyarrow as pa
@@ -23,6 +23,12 @@
2323
from rasteret.core.raster_accessor import RasterAccessor
2424
from rasteret.types import RasterInfo
2525

26+
if TYPE_CHECKING:
27+
import geopandas as gpd
28+
import xarray as xr
29+
30+
from rasteret.integrations.torchgeo import RasteretGeoDataset
31+
2632
logger = logging.getLogger(__name__)
2733

2834
# WKB geometry type id → GeoParquet type name (OGC Simple Features).
@@ -83,7 +89,7 @@ class Collection:
8389
Examples
8490
--------
8591
# From partitioned dataset
86-
>>> collection = Collection.from_local("path/to/dataset")
92+
>>> collection = Collection.from_parquet("path/to/dataset")
8793
8894
# Filter and process
8995
>>> filtered = collection.subset(cloud_cover_lt=20)
@@ -136,12 +142,31 @@ def _view(self, dataset: ds.Dataset) -> Collection:
136142
end_date=self.end_date,
137143
)
138144

145+
@staticmethod
146+
def _metadata_from_schema(dataset: ds.Dataset) -> dict[str, str]:
147+
"""Extract Rasteret metadata stored by ``export()``."""
148+
raw = dataset.schema.metadata or {}
149+
out: dict[str, str] = {}
150+
for key in (b"name", b"data_source", b"description", b"date_range"):
151+
val = raw.get(key)
152+
if val:
153+
try:
154+
out[key.decode()] = val.decode("utf-8")
155+
except (UnicodeDecodeError, AttributeError):
156+
pass
157+
return out
158+
139159
@classmethod
140-
def from_local(cls, path: str | Path) -> Collection:
141-
"""Create collection from a local Parquet dataset.
160+
def _load_cached(cls, path: str | Path) -> Collection:
161+
"""Load a Collection from a workspace cache directory.
142162
143-
Tries Hive-style partitioning first (year/month), falls back to
144-
plain Parquet if the directory isn't Hive-partitioned.
163+
Internal fast-path for ``build()`` / ``build_from_table()`` cache
164+
hits. Trusts the data (no schema validation), strips workspace
165+
suffixes (``_stac``, ``_records``) from the name, and detects
166+
Hive partitioning.
167+
168+
For user-facing loading, use :meth:`from_parquet` or
169+
:func:`rasteret.load` instead.
145170
"""
146171
path = Path(path)
147172
if not path.exists():
@@ -161,22 +186,53 @@ def from_local(cls, path: str | Path) -> Collection:
161186
exclude_invalid_files=True,
162187
)
163188

164-
name = path.stem.removesuffix("_stac").removesuffix("_records")
165-
return cls(dataset=dataset, name=name)
189+
meta = cls._metadata_from_schema(dataset)
190+
name = meta.get("name") or path.stem.removesuffix("_stac").removesuffix(
191+
"_records"
192+
)
193+
194+
start_date = None
195+
end_date = None
196+
dr = meta.get("date_range", "")
197+
if "," in dr:
198+
start_date, end_date = dr.split(",", 1)
199+
200+
return cls(
201+
dataset=dataset,
202+
name=name,
203+
data_source=meta.get("data_source", ""),
204+
description=meta.get("description", ""),
205+
start_date=start_date,
206+
end_date=end_date,
207+
)
166208

167209
@classmethod
168210
def from_parquet(cls, path: str | Path, name: str = "") -> Collection:
169211
"""Load a Collection from any Parquet file or directory.
170212
171-
The Parquet must contain the core columns:
172-
``id``, ``datetime``, ``geometry``, ``assets``, ``scene_bbox``.
213+
Tries Hive-style partitioning first (year/month), falls back to
214+
plain Parquet. Validates that the core contract columns are present.
215+
173216
See the `Schema Contract <../explanation/schema-contract/>`_ docs page.
174217
"""
175218
path = Path(path)
176219
if not path.exists():
177220
raise FileNotFoundError(f"Parquet not found at {path}")
178221

179-
dataset = ds.dataset(str(path), format="parquet")
222+
try:
223+
dataset = ds.dataset(
224+
str(path),
225+
format="parquet",
226+
partitioning="hive",
227+
exclude_invalid_files=True,
228+
)
229+
except pa.ArrowInvalid:
230+
dataset = ds.dataset(
231+
str(path),
232+
format="parquet",
233+
exclude_invalid_files=True,
234+
)
235+
180236
required = {"id", "datetime", "geometry", "assets", "scene_bbox"}
181237
missing = required - set(dataset.schema.names)
182238
if missing:
@@ -185,8 +241,23 @@ def from_parquet(cls, path: str | Path, name: str = "") -> Collection:
185241
"See the Schema Contract page in docs for the expected schema."
186242
)
187243

188-
name = name or path.stem
189-
return cls(dataset=dataset, name=name)
244+
meta = cls._metadata_from_schema(dataset)
245+
resolved_name = name or meta.get("name") or path.stem
246+
247+
start_date = None
248+
end_date = None
249+
dr = meta.get("date_range", "")
250+
if "," in dr:
251+
start_date, end_date = dr.split(",", 1)
252+
253+
return cls(
254+
dataset=dataset,
255+
name=resolved_name,
256+
data_source=meta.get("data_source", ""),
257+
description=meta.get("description", ""),
258+
start_date=start_date,
259+
end_date=end_date,
260+
)
190261

191262
def subset(
192263
self,
@@ -589,6 +660,52 @@ async def get_first_raster(self) -> RasterAccessor:
589660
return raster
590661
raise ValueError("No raster records found in collection")
591662

663+
@property
664+
def bands(self) -> list[str]:
665+
"""Available band codes in this collection."""
666+
if self.dataset is None:
667+
return []
668+
return [
669+
c.removesuffix("_metadata")
670+
for c in self.dataset.schema.names
671+
if c.endswith("_metadata")
672+
]
673+
674+
@property
675+
def bounds(self) -> tuple[float, float, float, float] | None:
676+
"""Spatial extent as ``(minx, miny, maxx, maxy)`` or ``None``."""
677+
if self.dataset is None:
678+
return None
679+
names = set(self.dataset.schema.names)
680+
cols = ("bbox_minx", "bbox_miny", "bbox_maxx", "bbox_maxy")
681+
if not all(c in names for c in cols):
682+
return None
683+
t = self.dataset.to_table(columns=list(cols))
684+
return (
685+
pc.min(t["bbox_minx"]).as_py(),
686+
pc.min(t["bbox_miny"]).as_py(),
687+
pc.max(t["bbox_maxx"]).as_py(),
688+
pc.max(t["bbox_maxy"]).as_py(),
689+
)
690+
691+
def __repr__(self) -> str:
692+
n_bands = len(self.bands)
693+
try:
694+
n_rows = self.dataset.count_rows() if self.dataset is not None else 0
695+
except Exception:
696+
n_rows = "?"
697+
698+
parts = [f"Collection({self.name!r}"]
699+
if self.data_source:
700+
parts.append(f"source={self.data_source!r}")
701+
parts.append(f"bands={n_bands}")
702+
parts.append(f"records={n_rows}")
703+
if self.start_date and self.end_date:
704+
s = str(self.start_date)[:10]
705+
e = str(self.end_date)[:10]
706+
parts.append(f"{s}..{e}")
707+
return ", ".join(parts) + ")"
708+
592709
def _validate_parquet_dataset(self) -> None:
593710
"""Basic dataset validation."""
594711
if not isinstance(self.dataset, ds.Dataset):
@@ -724,7 +841,7 @@ def to_torchgeo_dataset(
724841
backend: Any = None,
725842
time_series: bool = False,
726843
target_crs: int | None = None,
727-
) -> Any:
844+
) -> RasteretGeoDataset:
728845
"""Create a TorchGeo GeoDataset backed by this Collection.
729846
730847
This integration is optional and requires ``torchgeo`` and its
@@ -830,7 +947,7 @@ def get_xarray(
830947
backend: Any = None,
831948
target_crs: int | None = None,
832949
**filters: Any,
833-
) -> Any:
950+
) -> xr.Dataset:
834951
"""Load selected bands into an xarray Dataset.
835952
836953
Parameters
@@ -886,7 +1003,7 @@ def get_gdf(
8861003
backend: Any = None,
8871004
target_crs: int | None = None,
8881005
**filters: Any,
889-
) -> Any:
1006+
) -> gpd.GeoDataFrame:
8901007
"""Load selected bands into a GeoDataFrame.
8911008
8921009
Parameters

src/rasteret/core/execution.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from rasteret.core.utils import infer_data_source, run_sync
2828

2929
if TYPE_CHECKING: # pragma: no cover
30+
import xarray as xr
31+
3032
from rasteret.core.collection import Collection
3133

3234
logger = logging.getLogger(__name__)
@@ -223,7 +225,7 @@ def get_collection_xarray(
223225
backend: object | None = None,
224226
target_crs: int | None = None,
225227
**filters: Any,
226-
):
228+
) -> xr.Dataset:
227229
"""Load selected bands as an ``xarray.Dataset``.
228230
229231
Parameters

src/rasteret/tests/test_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_cache_build_passes_args_and_returns_summary(
9191
def fake_build_from_stac(**kwargs):
9292
captured.update(kwargs)
9393
_write_cached_collection(cache_dir)
94-
return Collection.from_local(cache_dir)
94+
return Collection._load_cached(cache_dir)
9595

9696
monkeypatch.setattr("rasteret.cli.build_from_stac", fake_build_from_stac)
9797

src/rasteret/tests/test_execution.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _make_collection_and_infer(
6969
pq.write_to_dataset(
7070
table, root_path=str(path), partition_cols=["year", "month"]
7171
)
72-
c = Collection.from_local(path)
72+
c = Collection._load_cached(path)
7373
c.data_source = data_source
7474
return infer_data_source(c)
7575

@@ -161,7 +161,7 @@ def test_single_crs_returns_none(self):
161161
path = Path(tmp) / "single_crs"
162162
path.mkdir()
163163
pq.write_table(table, str(path / "data.parquet"))
164-
c = Collection.from_local(path)
164+
c = Collection._load_cached(path)
165165
assert _detect_target_crs(c, {}) is None
166166

167167
def test_multi_crs_returns_most_common(self):
@@ -171,7 +171,7 @@ def test_multi_crs_returns_most_common(self):
171171
path = Path(tmp) / "multi_crs"
172172
path.mkdir()
173173
pq.write_table(table, str(path / "data.parquet"))
174-
c = Collection.from_local(path)
174+
c = Collection._load_cached(path)
175175
result = _detect_target_crs(c, {})
176176
assert result == 32632
177177

@@ -182,6 +182,6 @@ def test_multi_crs_equal_counts_picks_one(self):
182182
path = Path(tmp) / "equal_crs"
183183
path.mkdir()
184184
pq.write_table(table, str(path / "data.parquet"))
185-
c = Collection.from_local(path)
185+
c = Collection._load_cached(path)
186186
result = _detect_target_crs(c, {})
187187
assert result in (32632, 32633)

src/rasteret/tests/test_ingest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def test_build_from_table_workspace_dir_persists(self, tmp_path):
385385
# workspace_dir gets _records suffix for discoverability
386386
expected = out_dir / "demo_records"
387387
assert expected.exists()
388-
reloaded = Collection.from_local(expected)
388+
reloaded = Collection._load_cached(expected)
389389
assert reloaded.dataset is not None
390390
assert reloaded.dataset.count_rows() == collection.dataset.count_rows()
391391

src/rasteret/tests/test_public_api_surface.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_collection_analysis_methods_delegate_to_execution_layer() -> None:
5959
with TemporaryDirectory() as tmp_dir:
6060
dataset_path = Path(tmp_dir) / "example_stac"
6161
_write_minimal_partitioned_collection(dataset_path)
62-
collection = Collection.from_local(dataset_path)
62+
collection = Collection._load_cached(dataset_path)
6363

6464
with (
6565
patch(
@@ -134,10 +134,10 @@ def test_load_rejects_missing_file() -> None:
134134
rasteret.load("/nonexistent/path.parquet")
135135

136136

137-
def test_from_local_fallback_to_non_hive() -> None:
138-
"""from_local should work on non-Hive partitioned parquet."""
137+
def test_from_parquet_fallback_to_non_hive() -> None:
138+
"""from_parquet should work on non-Hive partitioned parquet."""
139139
with TemporaryDirectory() as tmp_dir:
140140
path = Path(tmp_dir) / "flat.parquet"
141141
_write_minimal_flat_collection(path)
142-
collection = Collection.from_local(path)
142+
collection = Collection.from_parquet(path)
143143
assert isinstance(collection, Collection)

0 commit comments

Comments
 (0)