diff --git a/stac_geoparquet/arrow/_schema/__init__.py b/stac_geoparquet/arrow/_schema/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stac_geoparquet/arrow/_schema/default_schemas.py b/stac_geoparquet/arrow/_schema/default_schemas.py new file mode 100644 index 0000000..cc0de82 --- /dev/null +++ b/stac_geoparquet/arrow/_schema/default_schemas.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +import pyarrow as pa +from typing import List, Literal, Optional, Sequence, Tuple + +TimestampResolution = Literal["s", "ms", "us", "ns"] + +# Unsure +# - "license relation": https://github.com/radiantearth/stac-spec/blob/v1.0.0/item-spec/common-metadata.md#relation-types + + +def item_core( + *, + timestamp_resolution: TimestampResolution = "us", + asset_keys: Optional[List[str]] = None, +) -> pa.Schema: + provider_object = pa.struct( + [ + ("name", pa.utf8()), + ("description", pa.utf8()), + ("roles", pa.list_(pa.utf8())), + ("url", pa.utf8()), + ] + ) + + core_properties = [ + ("title", pa.utf8()), + ("description", pa.utf8()), + ("datetime", pa.timestamp(timestamp_resolution, "UTC")), + ("created", pa.timestamp(timestamp_resolution, "UTC")), + ("updated", pa.timestamp(timestamp_resolution, "UTC")), + ("start_datetime", pa.timestamp(timestamp_resolution, "UTC")), + ("end_datetime", pa.timestamp(timestamp_resolution, "UTC")), + ("license", pa.utf8()), + ("providers", pa.list_(provider_object)), + ("platform", pa.utf8()), + ("instruments", pa.list_(pa.utf8())), + ("constellation", pa.utf8()), + ("mission", pa.utf8()), + ("gsd", pa.float64()), + ] + + link_object = pa.struct( + [ + ("href", pa.utf8()), + ("rel", pa.utf8()), + ("type", pa.utf8()), + ("title", pa.utf8()), + ] + ) + + asset_object = pa.struct( + [ + ("href", pa.utf8()), + ("title", pa.utf8()), + ("description", pa.utf8()), + ("type", pa.utf8()), + ("roles", pa.list_(pa.utf8())), + ] + ) + + # If asset_keys was not provided, we use a Map type + if asset_keys is not None: + assets_type = pa.struct([(key, asset_object) for key in asset_keys]) + else: + assets_type = pa.map_(pa.utf8(), asset_object) + + return pa.schema( + [ + ("type", pa.dictionary(index_type=pa.int8(), value_type=pa.utf8())), + ( + "stac_version", + pa.dictionary(index_type=pa.int8(), value_type=pa.utf8()), + ), + ( + "stac_extensions", + pa.list_(pa.dictionary(index_type=pa.int16(), value_type=pa.utf8())), + ), + ("id", pa.utf8()), + ("geometry", pa.binary()), + ("bbox", bbox_type(dim=3)), + *core_properties, + ("links", pa.list_(link_object)), + ("assets", assets_type), + ("collection", pa.utf8()), + ] + ) + + +def bbox_type(dim: int) -> pa.StructType: + fields = [ + ("xmin", pa.float64()), + ("ymin", pa.float64()), + ] + if dim == 3: + fields.append(("zmin", pa.float64())) + + fields.extend( + [ + ("xmax", pa.float64()), + ("ymax", pa.float64()), + ] + ) + + if dim == 3: + fields.append(("zmax", pa.float64())) + + return pa.struct(fields) + + +def item_eo( + *, properties: bool, asset_keys: Optional[Sequence[str]] = None +) -> pa.Schema: + """Construct the partial schema for the STAC EO extension + + The EO extension allows information to be assigned either at the top-level properties or within assets. + + Args: + properties: Set to `True` if EO information is set on properties. + asset_keys: Pass a sequence of string asset keys that contain EO information. + + Returns: + Partial EO extension Arrow schema + """ + band_object = pa.struct( + [ + ("name", pa.utf8()), + ("common_name", pa.utf8()), + ("description", pa.utf8()), + ("center_wavelength", pa.float64()), + ("full_width_half_max", pa.float64()), + ("solar_illumination", pa.float64()), + ] + ) + eo_fields = [ + ("eo:bands", pa.list_(band_object)), + ("eo:cloud_cover", pa.float64()), + ("eo:snow_cover", pa.float64()), + ] + + eo_object = pa.struct(eo_fields) + if asset_keys is not None: + assets_type = pa.struct([(key, eo_object) for key in asset_keys]) + else: + assets_type = pa.map_(pa.utf8(), eo_object) + fields: List[Tuple[str, pa.Field]] = [ + ("assets", assets_type), + ] + + if properties: + fields.extend(eo_fields) + + return pa.schema(fields) + + +def item_proj( + *, properties: bool, asset_keys: Optional[Sequence[str]] = None +) -> pa.Schema: + centroid_object = pa.struct( + [ + ("lat", pa.float64()), + ("lon", pa.float64()), + ] + ) + + proj_fields = [ + ("proj:epsg", pa.uint16()), + ("proj:wkt2", pa.utf8()), + # TODO: this arbitrary JSON will need special handling + ("proj:projjson", pa.utf8()), + # TODO: this arbitrary JSON will need special handling + ("proj:geometry", pa.binary()), + # TODO: this bbox will need special handling + # TODO: should this use list or struct encoding? + # ("proj:bbox", bbox_type(dim=3)), + ("proj:bbox", pa.list_(pa.float64())), + ("proj:centroid", centroid_object), + ("proj:shape", pa.list_(pa.uint32(), 2)), + # TODO: switch this to a fixed size list of 6 or 9 elements + ("proj:transform", pa.list_(pa.float64())), + ] + + proj_object = pa.struct(proj_fields) + if asset_keys is not None: + assets_type = pa.struct([(key, proj_object) for key in asset_keys]) + else: + assets_type = pa.map_(pa.utf8(), proj_object) + fields: List[Tuple[str, pa.Field]] = [ + ("assets", assets_type), + ] + + if properties: + fields.extend(proj_fields) + + return pa.schema(fields) + + +def item_sci() -> pa.Schema: + publication_object = pa.struct( + [ + ("doi", pa.utf8()), + ("citation", pa.utf8()), + ] + ) + + sci_fields = [ + ("sci:doi", pa.utf8()), + ("sci:citation", pa.utf8()), + ("sci:publications", pa.list_(publication_object)), + ] + + return pa.schema(sci_fields) + + +def item_view(*, properties: bool, asset_keys: Optional[Sequence[str]] = None): + view_fields = [ + ("view:off_nadir", pa.float64()), + ("view:incidence_angle", pa.float64()), + ("view:azimuth", pa.float64()), + ("view:sun_azimuth", pa.float64()), + ("view:sun_elevation", pa.float64()), + ] + + view_object = pa.struct(view_fields) + if asset_keys is not None: + assets_type = pa.struct([(key, view_object) for key in asset_keys]) + else: + assets_type = pa.map_(pa.utf8(), view_object) + fields: List[Tuple[str, pa.Field]] = [ + ("assets", assets_type), + ] + + if properties: + fields.extend(view_fields) + + return pa.schema(fields) diff --git a/stac_geoparquet/arrow/_schema/ingest.py b/stac_geoparquet/arrow/_schema/ingest.py new file mode 100644 index 0000000..9421c6b --- /dev/null +++ b/stac_geoparquet/arrow/_schema/ingest.py @@ -0,0 +1,43 @@ +"""Schema-aware ingestion""" + +import pyarrow as pa +from copy import deepcopy +from typing import Sequence + +import json + +from stac_geoparquet.arrow._schema.models import PartialSchema + + +# TODO: convert `items` to an Iterable to allow generator input. Then this function +# should return a generator of Arrow RecordBatches for output. +def ingest(items: Sequence[dict], schema_fragments: Sequence[PartialSchema]): + """_summary_""" + # Preprocess items + new_items = [] + for item in items: + new_item = deepcopy(item) + for schema_fragment in schema_fragments: + schema_fragment.preprocess_item(new_item) + + new_items.append(new_item) + + # Combine Arrow schemas across fragments + arrow_schema_fragments = [fragment.to_dict_input() for fragment in schema_fragments] + unified_arrow_schema = pa.unify_schemas( + arrow_schema_fragments, promote_options="permissive" + ) + + struct_array = pa.array(new_items, pa.struct(unified_arrow_schema)) + return pa.RecordBatch.from_struct_array(struct_array) + + +def _example(): + path = "/Users/kyle/github/stac-utils/stac-geoparquet/tests/data/naip-pc.json" + with open(path) as f: + items = json.load(f) + + schema_fragments = [Core(), EO(), Proj()] + + batch = ingest(items, schema_fragments) + # Works! diff --git a/stac_geoparquet/arrow/_schema/models.py b/stac_geoparquet/arrow/_schema/models.py new file mode 100644 index 0000000..78a765b --- /dev/null +++ b/stac_geoparquet/arrow/_schema/models.py @@ -0,0 +1,162 @@ +import json +from typing import List, Optional + +import pyarrow as pa +import shapely +import shapely.geometry + +from stac_geoparquet.arrow._schema.default_schemas import ( + TimestampResolution, + item_core, + item_eo, + item_proj, +) + + +class PartialSchema: + inner: pa.Schema + + def to_dict_input(self) -> pa.Schema: + """Convert this partial schema to one that works on input STAC data""" + return self.inner + + def preprocess_item(self, item: dict) -> dict: + """ + Any pre-processing steps to be applied to the input STAC dict before converting + with Arrow. + + Note: this pre-processing is allowed to mutate input. + """ + return item + + +class Core(PartialSchema): + def __init__( + self, + *, + timestamp_resolution: TimestampResolution = "us", + asset_keys: Optional[List[str]] = None, + ) -> None: + schema = item_core( + timestamp_resolution=timestamp_resolution, asset_keys=asset_keys + ) + self.inner = schema + super().__init__() + + def to_dict_input(self) -> pa.Schema: + schema = self.inner + schema = _timestamp_to_string(schema) + schema = _lower_properties(schema) + schema = _bbox_struct_to_list(schema) + return schema + + def preprocess_item(self, item: dict) -> dict: + item["geometry"] = shapely.to_wkb( + shapely.geometry.shape(item["geometry"]), flavor="iso" + ) + return item + + +class EO(PartialSchema): + def __init__( + self, + *, + properties: bool = True, + asset_keys: Optional[List[str]] = None, + ) -> None: + schema = item_eo(properties=properties, asset_keys=asset_keys) + self.inner = schema + super().__init__() + + +class Proj(PartialSchema): + properties: bool + asset_keys: Optional[List[str]] + + def __init__( + self, + *, + properties: bool = True, + asset_keys: Optional[List[str]] = None, + ) -> None: + schema = item_proj(properties=properties, asset_keys=asset_keys) + self.inner = schema + self.properties = properties + self.asset_keys = asset_keys + super().__init__() + + def to_dict_input(self) -> pa.Schema: + schema = self.inner + schema = _lower_properties(schema) + return schema + + def preprocess_item(self, item: dict) -> dict: + projjson = item["properties"].get("proj:projjson") + if projjson is not None: + item["properties"]["proj:projjson"] = json.dumps( + projjson, separators=(",", ":") + ) + + geometry = item["properties"].get("proj:geometry") + if geometry is not None: + item["properties"]["proj:geometry"] = shapely.to_wkb( + shapely.geometry.shape(geometry), flavor="iso" + ) + + # TODO: handle projjson and geometry inside asset keys + + return super().preprocess_item(item) + + +STAC_TOP_LEVEL_KEYS = { + "stac_version", + "stac_extensions", + "type", + "id", + "bbox", + "geometry", + "collection", + "links", + "assets", +} + + +def _lower_properties(schema: pa.Schema) -> pa.Schema: + """Take properties fields from the top level and wrap them in a struct column""" + + properties_fields: List[pa.Field] = [] + top_level_fields: List[pa.Field] = [] + + for field_idx in range(len(schema)): + field = schema.field(field_idx) + if field.name in STAC_TOP_LEVEL_KEYS: + # Add to top-level fields + top_level_fields.append(field) + else: + # Put inside properties struct + properties_fields.append(field) + + top_level_fields.append(pa.field("properties", pa.struct(properties_fields))) + return pa.schema(top_level_fields) + + +def _bbox_struct_to_list(schema: pa.Schema) -> pa.Schema: + """Convert the bbox struct field to a variable-sized list""" + bbox_idx = schema.get_field_index("bbox") + bbox_field = schema.field(bbox_idx) + return schema.set(bbox_idx, bbox_field.with_type(pa.list_(pa.float64()))) + + +def _timestamp_to_string(schema: pa.Schema) -> pa.Schema: + new_fields = [] + for field_idx in range(len(schema)): + field = schema.field(field_idx) + if pa.types.is_timestamp(field.type): + new_fields.append(field.with_type(pa.utf8())) + # elif pa.types.is_struct(field.type): + # field.type + # pa.struct([]).field(0) + else: + new_fields.append(field) + + return pa.schema(new_fields)