From 54d625c815e4a67b8df1def83f4bd8e2843ebe88 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Sun, 2 Jun 2024 18:28:22 +0200 Subject: [PATCH 01/10] Move json equality logic outside of test_arrow --- tests/__init__.py | 0 tests/json_equals.py | 167 ++++++++++++++++++++++++++++++++++++++++++ tests/test_arrow.py | 168 +------------------------------------------ 3 files changed, 169 insertions(+), 166 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/json_equals.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/json_equals.py b/tests/json_equals.py new file mode 100644 index 0000000..79a3ad5 --- /dev/null +++ b/tests/json_equals.py @@ -0,0 +1,167 @@ +import math +from typing import Any, Dict, Sequence, Union + +from ciso8601 import parse_rfc3339 + + +JsonValue = Union[list, tuple, int, float, dict, str, bool, None] + + +def assert_json_value_equal( + result: JsonValue, + expected: JsonValue, + *, + key_name: str = "root", + precision: float = 0.0001, +) -> None: + """Assert that the JSON value in `result` and `expected` are equal for our purposes. + + We allow these variations between result and expected: + + - We allow numbers to vary up to `precision`. + - We consider `key: None` and a missing key to be equivalent. + - We allow RFC3339 date strings with varying precision levels, as long as they + represent the same parsed datetime. + + Args: + result: The result to assert against. + expected: The expected item to compare against. + key_name: The key name of the current path in the JSON. Used for error messages. + precision: The precision to use for comparing integers and floats. + + Raises: + AssertionError: If the two values are not equal + """ + if isinstance(result, list) and isinstance(expected, list): + assert_sequence_equal(result, expected, key_name=key_name, precision=precision) + + elif isinstance(result, tuple) and isinstance(expected, tuple): + assert_sequence_equal(result, expected, key_name=key_name, precision=precision) + + elif isinstance(result, (int, float)) and isinstance(expected, (int, float)): + assert_number_equal(result, expected, key_name=key_name, precision=precision) + + elif isinstance(result, dict) and isinstance(expected, dict): + assert_dict_equal(result, expected, key_name=key_name, precision=precision) + + elif isinstance(result, str) and isinstance(expected, str): + assert_string_equal(result, expected, key_name=key_name) + + elif isinstance(result, bool) and isinstance(expected, bool): + assert_bool_equal(result, expected, key_name=key_name) + + elif result is None and expected is None: + pass + + else: + raise AssertionError( + f"Mismatched types at {key_name}. {type(result)=}, {type(expected)=}" + ) + + +def assert_sequence_equal( + result: Sequence, expected: Sequence, *, key_name: str, precision: float +) -> None: + """Compare two JSON arrays, recursively""" + assert len(result) == len(expected), ( + f"List at {key_name} has different lengths." f"{len(result)=}, {len(expected)=}" + ) + + for i in range(len(result)): + assert_json_value_equal( + result[i], expected[i], key_name=f"{key_name}.[{i}]", precision=precision + ) + + +def assert_number_equal( + result: Union[int, float], + expected: Union[int, float], + *, + precision: float, + key_name: str, +) -> None: + """Compare two JSON numbers""" + # Allow NaN equality + if math.isnan(result) and math.isnan(expected): + return + + assert abs(result - expected) <= precision, ( + f"Number at {key_name} not within precision. " + f"{result=}, {expected=}, {precision=}." + ) + + +def assert_string_equal( + result: str, + expected: str, + *, + key_name: str, +) -> None: + """Compare two JSON strings. + + We attempt to parse each string to a datetime. If this succeeds, then we compare the + datetime.datetime representations instead of the bare strings. + """ + + # Check if both strings are dates, then assert the parsed datetimes are equal + try: + result_datetime = parse_rfc3339(result) + expected_datetime = parse_rfc3339(expected) + + assert result_datetime == expected_datetime, ( + f"Date string at {key_name} not equal. " + f"{result=}, {expected=}." + f"{result_datetime=}, {expected_datetime=}." + ) + + except ValueError: + assert ( + result == expected + ), f"String at {key_name} not equal. {result=}, {expected=}." + + +def assert_bool_equal( + result: bool, + expected: bool, + *, + key_name: str, +) -> None: + """Compare two JSON booleans.""" + assert result == expected, f"Bool at {key_name} not equal. {result=}, {expected=}." + + +def assert_dict_equal( + result: Dict[str, Any], + expected: Dict[str, Any], + *, + key_name: str, + precision: float, +) -> None: + """ + Assert that two JSON dicts are equal, recursively, allowing missing keys to equal + None. + """ + result_keys = set(result.keys()) + expected_keys = set(expected.keys()) + + # For any keys that exist in result but not expected, assert that the result value + # is None + for key in result_keys - expected_keys: + assert ( + result[key] is None + ), f"Expected key at {key_name} to be None in result. Got {result['key']}" + + # And vice versa + for key in expected_keys - result_keys: + assert ( + expected[key] is None + ), f"Expected key at {key_name} to be None in expected. Got {expected['key']}" + + # For any overlapping keys, assert that their values are equal + for key in result_keys & expected_keys: + assert_json_value_equal( + result[key], + expected[key], + key_name=f"{key_name}.{key}", + precision=precision, + ) diff --git a/tests/test_arrow.py b/tests/test_arrow.py index 2b9bca4..94e5bfa 100644 --- a/tests/test_arrow.py +++ b/tests/test_arrow.py @@ -1,178 +1,14 @@ import json -import math from pathlib import Path -from typing import Any, Dict, Sequence, Union import pyarrow as pa import pytest -from ciso8601 import parse_rfc3339 from stac_geoparquet.arrow import parse_stac_items_to_arrow, stac_table_to_items -HERE = Path(__file__).parent - -JsonValue = Union[list, tuple, int, float, dict, str, bool, None] - - -def assert_json_value_equal( - result: JsonValue, - expected: JsonValue, - *, - key_name: str = "root", - precision: float = 0.0001, -) -> None: - """Assert that the JSON value in `result` and `expected` are equal for our purposes. - - We allow these variations between result and expected: - - - We allow numbers to vary up to `precision`. - - We consider `key: None` and a missing key to be equivalent. - - We allow RFC3339 date strings with varying precision levels, as long as they - represent the same parsed datetime. - - Args: - result: The result to assert against. - expected: The expected item to compare against. - key_name: The key name of the current path in the JSON. Used for error messages. - precision: The precision to use for comparing integers and floats. - - Raises: - AssertionError: If the two values are not equal - """ - if isinstance(result, list) and isinstance(expected, list): - assert_sequence_equal(result, expected, key_name=key_name, precision=precision) - - elif isinstance(result, tuple) and isinstance(expected, tuple): - assert_sequence_equal(result, expected, key_name=key_name, precision=precision) - - elif isinstance(result, (int, float)) and isinstance(expected, (int, float)): - assert_number_equal(result, expected, key_name=key_name, precision=precision) - - elif isinstance(result, dict) and isinstance(expected, dict): - assert_dict_equal(result, expected, key_name=key_name, precision=precision) - - elif isinstance(result, str) and isinstance(expected, str): - assert_string_equal(result, expected, key_name=key_name) - - elif isinstance(result, bool) and isinstance(expected, bool): - assert_bool_equal(result, expected, key_name=key_name) - - elif result is None and expected is None: - pass - - else: - raise AssertionError( - f"Mismatched types at {key_name}. {type(result)=}, {type(expected)=}" - ) - - -def assert_sequence_equal( - result: Sequence, expected: Sequence, *, key_name: str, precision: float -) -> None: - """Compare two JSON arrays, recursively""" - assert len(result) == len(expected), ( - f"List at {key_name} has different lengths." f"{len(result)=}, {len(expected)=}" - ) - - for i in range(len(result)): - assert_json_value_equal( - result[i], expected[i], key_name=f"{key_name}.[{i}]", precision=precision - ) - - -def assert_number_equal( - result: Union[int, float], - expected: Union[int, float], - *, - precision: float, - key_name: str, -) -> None: - """Compare two JSON numbers""" - # Allow NaN equality - if math.isnan(result) and math.isnan(expected): - return - - assert abs(result - expected) <= precision, ( - f"Number at {key_name} not within precision. " - f"{result=}, {expected=}, {precision=}." - ) - - -def assert_string_equal( - result: str, - expected: str, - *, - key_name: str, -) -> None: - """Compare two JSON strings. - - We attempt to parse each string to a datetime. If this succeeds, then we compare the - datetime.datetime representations instead of the bare strings. - """ - - # Check if both strings are dates, then assert the parsed datetimes are equal - try: - result_datetime = parse_rfc3339(result) - expected_datetime = parse_rfc3339(expected) - - assert result_datetime == expected_datetime, ( - f"Date string at {key_name} not equal. " - f"{result=}, {expected=}." - f"{result_datetime=}, {expected_datetime=}." - ) - - except ValueError: - assert ( - result == expected - ), f"String at {key_name} not equal. {result=}, {expected=}." - - -def assert_bool_equal( - result: bool, - expected: bool, - *, - key_name: str, -) -> None: - """Compare two JSON booleans.""" - assert result == expected, f"Bool at {key_name} not equal. {result=}, {expected=}." - - -def assert_dict_equal( - result: Dict[str, Any], - expected: Dict[str, Any], - *, - key_name: str, - precision: float, -) -> None: - """ - Assert that two JSON dicts are equal, recursively, allowing missing keys to equal - None. - """ - result_keys = set(result.keys()) - expected_keys = set(expected.keys()) - - # For any keys that exist in result but not expected, assert that the result value - # is None - for key in result_keys - expected_keys: - assert ( - result[key] is None - ), f"Expected key at {key_name} to be None in result. Got {result['key']}" - - # And vice versa - for key in expected_keys - result_keys: - assert ( - expected[key] is None - ), f"Expected key at {key_name} to be None in expected. Got {expected['key']}" - - # For any overlapping keys, assert that their values are equal - for key in result_keys & expected_keys: - assert_json_value_equal( - result[key], - expected[key], - key_name=f"{key_name}.{key}", - precision=precision, - ) +from .json_equals import assert_json_value_equal +HERE = Path(__file__).parent TEST_COLLECTIONS = [ "3dep-lidar-copc", From a0433c586f4d24d6d5fa1ae1ca4719fda90f074b Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Sun, 2 Jun 2024 19:45:43 +0200 Subject: [PATCH 02/10] Refactor to RawBatch and CleanBatch wrapper types --- stac_geoparquet/arrow/_batch.py | 196 +++++++++++++++++ stac_geoparquet/arrow/_from_arrow.py | 112 +--------- stac_geoparquet/arrow/_to_arrow.py | 193 ++++++++++++++++- stac_geoparquet/arrow/_util.py | 303 ++++----------------------- 4 files changed, 438 insertions(+), 366 deletions(-) create mode 100644 stac_geoparquet/arrow/_batch.py diff --git a/stac_geoparquet/arrow/_batch.py b/stac_geoparquet/arrow/_batch.py new file mode 100644 index 0000000..950c9fe --- /dev/null +++ b/stac_geoparquet/arrow/_batch.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import os +from copy import deepcopy +from typing import ( + Any, + Dict, + Iterable, + List, + Optional, + Union, +) + +import numpy as np +import orjson +import pyarrow as pa +import pyarrow.compute as pc +import shapely +import shapely.geometry +from numpy.typing import NDArray +from typing_extensions import Self + +from stac_geoparquet.arrow._to_arrow import ( + assign_geoarrow_metadata, + bring_properties_to_top_level, + convert_bbox_to_struct, + convert_timestamp_columns, +) +from stac_geoparquet.arrow._util import convert_tuples_to_lists, set_by_path +from stac_geoparquet.from_arrow import ( + convert_bbox_to_array, + convert_timestamp_columns_to_string, + lower_properties_from_top_level, +) + + +class RawBatch: + """ + An Arrow RecordBatch of STAC Items that has been **minimally converted** to Arrow. + That is, it aligns as much as possible to the raw STAC JSON representation. + + The **only** transformations that have already been applied here are those that are + necessary to represent the core STAC items in Arrow. + + - `geometry` has been converted to WKB binary + - `properties.proj:geometry`, if it exists, has been converted to WKB binary + ISO encoding + - The `proj:geometry` in any asset properties, if it exists, has been converted to + WKB binary. + + No other transformations have yet been applied. I.e. all properties are still in a + top-level `properties` struct column. + """ + + inner: pa.RecordBatch + """The underlying pyarrow RecordBatch""" + + def __init__(self, batch: pa.RecordBatch) -> None: + self.inner = batch + + @classmethod + def from_dicts( + cls, items: Iterable[Dict[str, Any]], *, schema: Optional[pa.Schema] = None + ) -> Self: + """Construct a RawBatch from an iterable of dicts representing STAC items. + + All items will be parsed into a single RecordBatch, meaning that each internal + array is fully contiguous in memory for the length of `items`. + + Args: + items: STAC Items to convert to Arrow + + Kwargs: + schema: An optional schema that describes the format of the data. Note that + this must represent the geometry column and any `proj:geometry` columns + as binary type. + + Returns: + _description_ + """ + # Preprocess GeoJSON to WKB in each STAC item + # Otherwise, pyarrow will try to parse coordinates into a native geometry type + # and if you have multiple geometry types pyarrow will error with + # `ArrowInvalid: cannot mix list and non-list, non-null values` + wkb_items = [] + for item in items: + wkb_item = deepcopy(item) + wkb_item["geometry"] = shapely.to_wkb( + shapely.geometry.shape(wkb_item["geometry"]), flavor="iso" + ) + + # If a proj:geometry key exists in top-level properties, convert that to WKB + if "proj:geometry" in wkb_item["properties"]: + wkb_item["properties"]["proj:geometry"] = shapely.to_wkb( + shapely.geometry.shape(wkb_item["properties"]["proj:geometry"]), + flavor="iso", + ) + + # If a proj:geometry key exists in any asset properties, convert that to WKB + for asset_value in wkb_item["assets"].values(): + if "proj:geometry" in asset_value: + asset_value["proj:geometry"] = shapely.to_wkb( + shapely.geometry.shape(asset_value["proj:geometry"]), + flavor="iso", + ) + + wkb_items.append(wkb_item) + + if schema is not None: + array = pa.array(wkb_items, type=pa.struct(schema)) + else: + array = pa.array(wkb_items) + + return cls(pa.RecordBatch.from_struct_array(array)) + + def iter_dicts(self) -> Iterable[dict]: + batch = self.inner + + # Find all paths in the schema that have a WKB geometry + geometry_paths = [["geometry"]] + try: + batch.schema.field("properties").type.field("proj:geometry") + geometry_paths.append(["properties", "proj:geometry"]) + except KeyError: + pass + + assets_struct = batch.schema.field("assets").type + for asset_idx in range(assets_struct.num_fields): + asset_field = assets_struct.field(asset_idx) + if "proj:geometry" in pa.schema(asset_field).names: + geometry_paths.append(["assets", asset_field.name, "proj:geometry"]) + + # Convert each geometry column to a Shapely geometry, and then assign the + # geojson geometry when converting each row to a dictionary. + geometries: List[NDArray[np.object_]] = [] + for geometry_path in geometry_paths: + col = batch + for path_segment in geometry_path: + if isinstance(col, pa.RecordBatch): + col = col[path_segment] + elif pa.types.is_struct(col.type): + col = pc.struct_field(col, path_segment) # type: ignore + else: + raise AssertionError(f"unexpected type {type(col)}") + + geometries.append(shapely.from_wkb(col)) + + struct_batch = batch.to_struct_array() + for row_idx in range(len(struct_batch)): + row_dict = struct_batch[row_idx].as_py() + for geometry_path, geometry_column in zip(geometry_paths, geometries): + geojson_g = geometry_column[row_idx].__geo_interface__ + geojson_g["coordinates"] = convert_tuples_to_lists( + geojson_g["coordinates"] + ) + set_by_path(row_dict, geometry_path, geojson_g) + + yield row_dict + + def to_clean_batch(self) -> CleanBatch: + batch = self.inner + + batch = bring_properties_to_top_level(batch) + batch = convert_timestamp_columns(batch) + batch = convert_bbox_to_struct(batch) + batch = assign_geoarrow_metadata(batch) + + return CleanBatch(batch) + + def to_ndjson(self, dest: Union[str, os.PathLike[bytes]]) -> None: + with open(dest, "ab") as f: + for item_dict in self.iter_dicts(): + f.write(orjson.dumps(item_dict)) + f.write(b"\n") + + +class CleanBatch: + """ + An Arrow RecordBatch of STAC Items that has been processed to match the + STAC-GeoParquet specification. + """ + + inner: pa.RecordBatch + """The underlying pyarrow RecordBatch""" + + def __init__(self, batch: pa.RecordBatch) -> None: + self.inner = batch + + def to_raw_batch(self) -> RawBatch: + batch = self.inner + + batch = convert_timestamp_columns_to_string(batch) + batch = lower_properties_from_top_level(batch) + batch = convert_bbox_to_array(batch) + + return RawBatch(batch) diff --git a/stac_geoparquet/arrow/_from_arrow.py b/stac_geoparquet/arrow/_from_arrow.py index c06b0ad..a92f894 100644 --- a/stac_geoparquet/arrow/_from_arrow.py +++ b/stac_geoparquet/arrow/_from_arrow.py @@ -1,60 +1,17 @@ """Convert STAC Items in Arrow Table format to JSON Lines or Python dicts.""" import orjson -import operator import os -from functools import reduce -from typing import Any, Dict, Iterable, List, Sequence, Tuple, Union +from typing import Iterable, List, Union import numpy as np import pyarrow as pa import pyarrow.compute as pc -import shapely -from numpy.typing import NDArray -import shapely.geometry def stac_batch_to_items(batch: pa.RecordBatch) -> Iterable[dict]: """Convert a stac arrow recordbatch to item dicts.""" batch = _undo_stac_transformations(batch) - # Find all paths in the schema that have a WKB geometry - geometry_paths = [["geometry"]] - try: - batch.schema.field("properties").type.field("proj:geometry") - geometry_paths.append(["properties", "proj:geometry"]) - except KeyError: - pass - - assets_struct = batch.schema.field("assets").type - for asset_idx in range(assets_struct.num_fields): - asset_field = assets_struct.field(asset_idx) - if "proj:geometry" in pa.schema(asset_field).names: - geometry_paths.append(["assets", asset_field.name, "proj:geometry"]) - - # Convert each geometry column to a Shapely geometry, and then assign the - # geojson geometry when converting each row to a dictionary. - geometries: List[NDArray[np.object_]] = [] - for geometry_path in geometry_paths: - col = batch - for path_segment in geometry_path: - if isinstance(col, pa.RecordBatch): - col = col[path_segment] - elif pa.types.is_struct(col.type): - col = pc.struct_field(col, path_segment) - else: - raise AssertionError(f"unexpected type {type(col)}") - - geometries.append(shapely.from_wkb(col)) - - struct_batch = batch.to_struct_array() - for row_idx in range(len(struct_batch)): - row_dict = struct_batch[row_idx].as_py() - for geometry_path, geometry_column in zip(geometry_paths, geometries): - geojson_g = geometry_column[row_idx].__geo_interface__ - geojson_g["coordinates"] = convert_tuples_to_lists(geojson_g["coordinates"]) - set_by_path(row_dict, geometry_path, geojson_g) - - yield row_dict def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[str]]) -> None: @@ -77,13 +34,12 @@ def _undo_stac_transformations(batch: pa.RecordBatch) -> pa.RecordBatch: Note that this function does _not_ undo the GeoJSON -> WKB geometry transformation, as that is easier to do when converting each item in the table to a dict. """ - batch = _convert_timestamp_columns_to_string(batch) - batch = _lower_properties_from_top_level(batch) - batch = _convert_bbox_to_array(batch) + batch = lower_properties_from_top_level(batch) + batch = convert_bbox_to_array(batch) return batch -def _convert_timestamp_columns_to_string(batch: pa.RecordBatch) -> pa.RecordBatch: +def convert_timestamp_columns_to_string(batch: pa.RecordBatch) -> pa.RecordBatch: """Convert any datetime columns in the table to a string representation""" allowed_column_names = { "datetime", # common metadata @@ -102,13 +58,14 @@ def _convert_timestamp_columns_to_string(batch: pa.RecordBatch) -> pa.RecordBatc continue batch = batch.drop_columns((column_name,)).append_column( - column_name, pc.strftime(column, format="%Y-%m-%dT%H:%M:%SZ") + column_name, + pc.strftime(column, format="%Y-%m-%dT%H:%M:%SZ"), # type: ignore ) return batch -def _lower_properties_from_top_level(batch: pa.RecordBatch) -> pa.RecordBatch: +def lower_properties_from_top_level(batch: pa.RecordBatch) -> pa.RecordBatch: """Take properties columns from the top level and wrap them in a struct column""" stac_top_level_keys = { "stac_version", @@ -141,7 +98,7 @@ def _lower_properties_from_top_level(batch: pa.RecordBatch) -> pa.RecordBatch: ) -def _convert_bbox_to_array(batch: pa.RecordBatch) -> pa.RecordBatch: +def convert_bbox_to_array(batch: pa.RecordBatch) -> pa.RecordBatch: """Convert the struct bbox column back to a list column for writing to JSON""" bbox_col_idx = batch.schema.get_field_index("bbox") @@ -191,56 +148,3 @@ def _convert_bbox_to_array(batch: pa.RecordBatch) -> pa.RecordBatch: raise ValueError("Expected 4 or 6 fields in bbox struct.") return batch.set_column(bbox_col_idx, "bbox", list_arr) - - -def convert_tuples_to_lists(t: List | Tuple) -> List[Any]: - """Convert tuples to lists, recursively - - For example, converts: - ``` - ( - ( - (-112.4820566, 38.1261015), - (-112.4816283, 38.1331311), - (-112.4833551, 38.1338897), - (-112.4832919, 38.1307687), - (-112.4855415, 38.1291793), - (-112.4820566, 38.1261015), - ), - ) - ``` - - to - - ```py - [ - [ - [-112.4820566, 38.1261015], - [-112.4816283, 38.1331311], - [-112.4833551, 38.1338897], - [-112.4832919, 38.1307687], - [-112.4855415, 38.1291793], - [-112.4820566, 38.1261015], - ] - ] - ``` - - From https://stackoverflow.com/a/1014669. - """ - return list(map(convert_tuples_to_lists, t)) if isinstance(t, (list, tuple)) else t - - -def get_by_path(root: Dict[str, Any], keys: Sequence[str]) -> Any: - """Access a nested object in root by item sequence. - - From https://stackoverflow.com/a/14692747 - """ - return reduce(operator.getitem, keys, root) - - -def set_by_path(root: Dict[str, Any], keys: Sequence[str], value: Any) -> None: - """Set a value in a nested object in root by item sequence. - - From https://stackoverflow.com/a/14692747 - """ - get_by_path(root, keys[:-1])[keys[-1]] = value # type: ignore diff --git a/stac_geoparquet/arrow/_to_arrow.py b/stac_geoparquet/arrow/_to_arrow.py index b99dcec..8efcdf3 100644 --- a/stac_geoparquet/arrow/_to_arrow.py +++ b/stac_geoparquet/arrow/_to_arrow.py @@ -10,11 +10,15 @@ Union, ) +import ciso8601 +import numpy as np +import orjson import pyarrow as pa +from stac_geoparquet.arrow._crs import WGS84_CRS_JSON from stac_geoparquet.arrow._schema.models import InferredSchema +from stac_geoparquet.arrow._util import batched_iter from stac_geoparquet.json_reader import read_json_chunked -from stac_geoparquet.arrow._util import stac_items_to_arrow, batched_iter def parse_stac_items_to_arrow( @@ -100,3 +104,190 @@ def parse_stac_ndjson_to_arrow( for batch in read_json_chunked(path, chunk_size=chunk_size): yield stac_items_to_arrow(batch, schema=schema) + + +def bring_properties_to_top_level( + batch: pa.RecordBatch, +) -> pa.RecordBatch: + """Bring all the fields inside of the nested "properties" struct to the top level""" + properties_field = batch.schema.field("properties") + properties_column = batch["properties"] + + for field_idx in range(properties_field.type.num_fields): + inner_prop_field = properties_field.type.field(field_idx) + batch = batch.append_column( + inner_prop_field, pc.struct_field(properties_column, field_idx) + ) + + batch = batch.drop_columns( + [ + "properties", + ] + ) + return batch + + +def convert_timestamp_columns( + batch: pa.RecordBatch, +) -> pa.RecordBatch: + """Convert all timestamp columns from a string to an Arrow Timestamp data type""" + allowed_column_names = { + "datetime", # common metadata + "start_datetime", + "end_datetime", + "created", + "updated", + "expires", # timestamps extension + "published", + "unpublished", + } + for column_name in allowed_column_names: + try: + column = batch[column_name] + except KeyError: + continue + + field_index = batch.schema.get_field_index(column_name) + + if pa.types.is_timestamp(column.type): + continue + + # STAC allows datetimes to be null. If all rows are null, the column type may be + # inferred as null. We cast this to a timestamp column. + elif pa.types.is_null(column.type): + batch = batch.set_column( + field_index, column_name, column.cast(pa.timestamp("us")) + ) + + elif pa.types.is_string(column.type): + batch = batch.set_column( + field_index, column_name, _convert_single_timestamp_column(column) + ) + else: + raise ValueError( + f"Inferred time column '{column_name}' was expected to be a string or" + f" timestamp data type but got {column.type}" + ) + + return batch + + +def _convert_single_timestamp_column(column: pa.Array) -> pa.TimestampArray: + """Convert an individual timestamp column from string to a Timestamp type""" + return pa.array( + [ciso8601.parse_rfc3339(str(t)) for t in column], pa.timestamp("us", tz="UTC") + ) + + +def _is_bbox_3d(bbox_col: pa.Array) -> bool: + """Infer whether the bounding box column represents 2d or 3d bounding boxes.""" + offsets_set = set() + offsets = bbox_col.offsets.to_numpy() + offsets_set.update(np.unique(offsets[1:] - offsets[:-1])) + + if len(offsets_set) > 1: + raise ValueError("Mixed 2d-3d bounding boxes not yet supported") + + offset = list(offsets_set)[0] + if offset == 6: + return True + elif offset == 4: + return False + else: + raise ValueError(f"Unexpected bbox offset: {offset=}") + + +def convert_bbox_to_struct(batch: pa.RecordBatch) -> pa.RecordBatch: + """Convert bbox column to a struct representation + + Since the bbox in JSON is stored as an array, pyarrow automatically converts the + bbox column to a ListArray. But according to GeoParquet 1.1, we should save the bbox + column as a StructArray, which allows for Parquet statistics to infer any spatial + partitioning in the dataset. + + Args: + batch: _description_ + + Returns: + New record batch + """ + bbox_col_idx = batch.schema.get_field_index("bbox") + bbox_col = batch.column(bbox_col_idx) + bbox_3d = _is_bbox_3d(bbox_col) + + assert ( + pa.types.is_list(bbox_col.type) + or pa.types.is_large_list(bbox_col.type) + or pa.types.is_fixed_size_list(bbox_col.type) + ) + if bbox_3d: + coords = bbox_col.flatten().to_numpy().reshape(-1, 6) + else: + coords = bbox_col.flatten().to_numpy().reshape(-1, 4) + + if bbox_3d: + xmin = coords[:, 0] + ymin = coords[:, 1] + zmin = coords[:, 2] + xmax = coords[:, 3] + ymax = coords[:, 4] + zmax = coords[:, 5] + + struct_arr = pa.StructArray.from_arrays( + [ + xmin, + ymin, + zmin, + xmax, + ymax, + zmax, + ], + names=[ + "xmin", + "ymin", + "zmin", + "xmax", + "ymax", + "zmax", + ], + ) + + else: + xmin = coords[:, 0] + ymin = coords[:, 1] + xmax = coords[:, 2] + ymax = coords[:, 3] + + struct_arr = pa.StructArray.from_arrays( + [ + xmin, + ymin, + xmax, + ymax, + ], + names=[ + "xmin", + "ymin", + "xmax", + "ymax", + ], + ) + + return batch.set_column(bbox_col_idx, "bbox", struct_arr) + + +def assign_geoarrow_metadata( + batch: pa.RecordBatch, +) -> pa.RecordBatch: + """Tag the primary geometry column with `geoarrow.wkb` on the field metadata.""" + existing_field_idx = batch.schema.get_field_index("geometry") + existing_field = batch.schema.field(existing_field_idx) + ext_metadata = {"crs": WGS84_CRS_JSON} + field_metadata = { + b"ARROW:extension:name": b"geoarrow.wkb", + b"ARROW:extension:metadata": orjson.dumps(ext_metadata), + } + new_field = existing_field.with_metadata(field_metadata) + return batch.set_column( + existing_field_idx, new_field, batch.column(existing_field_idx) + ) diff --git a/stac_geoparquet/arrow/_util.py b/stac_geoparquet/arrow/_util.py index 5390af5..6b73dac 100644 --- a/stac_geoparquet/arrow/_util.py +++ b/stac_geoparquet/arrow/_util.py @@ -1,23 +1,18 @@ -from copy import deepcopy +import operator +from functools import reduce from typing import ( Any, Dict, Iterable, + List, Optional, Sequence, + Union, ) -import ciso8601 -import numpy as np import pyarrow as pa -import pyarrow.compute as pc -import shapely -import shapely.geometry -import orjson from itertools import islice -from stac_geoparquet.arrow._crs import WGS84_CRS_JSON - def update_batch_schema( batch: pa.RecordBatch, @@ -42,268 +37,54 @@ def batched_iter( return -def stac_items_to_arrow( - items: Iterable[Dict[str, Any]], *, schema: Optional[pa.Schema] = None -) -> pa.RecordBatch: - """Convert dicts representing STAC Items to Arrow - - This converts GeoJSON geometries to WKB before Arrow conversion to allow multiple - geometry types. - - All items will be parsed into a single RecordBatch, meaning that each internal array - is fully contiguous in memory for the length of `items`. - - Args: - items: STAC Items to convert to Arrow - - Kwargs: - schema: An optional schema that describes the format of the data. Note that this - must represent the geometry column as binary type. - - Returns: - Arrow RecordBatch with items in Arrow - """ - # Preprocess GeoJSON to WKB in each STAC item - # Otherwise, pyarrow will try to parse coordinates into a native geometry type and - # if you have multiple geometry types pyarrow will error with - # `ArrowInvalid: cannot mix list and non-list, non-null values` - wkb_items = [] - for item in items: - wkb_item = deepcopy(item) - wkb_item["geometry"] = shapely.to_wkb( - shapely.geometry.shape(wkb_item["geometry"]), flavor="iso" - ) - - # If a proj:geometry key exists in top-level properties, convert that to WKB - if "proj:geometry" in wkb_item["properties"]: - wkb_item["properties"]["proj:geometry"] = shapely.to_wkb( - shapely.geometry.shape(wkb_item["properties"]["proj:geometry"]), - flavor="iso", - ) - - # If a proj:geometry key exists in any asset properties, convert that to WKB - for asset_value in wkb_item["assets"].values(): - if "proj:geometry" in asset_value: - asset_value["proj:geometry"] = shapely.to_wkb( - shapely.geometry.shape(asset_value["proj:geometry"]), - flavor="iso", - ) - - wkb_items.append(wkb_item) - - if schema is not None: - array = pa.array(wkb_items, type=pa.struct(schema)) - else: - array = pa.array(wkb_items) - - return _process_arrow_batch(pa.RecordBatch.from_struct_array(array)) - - -def _bring_properties_to_top_level( - batch: pa.RecordBatch, -) -> pa.RecordBatch: - """Bring all the fields inside of the nested "properties" struct to the top level""" - properties_field = batch.schema.field("properties") - properties_column = batch["properties"] - - for field_idx in range(properties_field.type.num_fields): - inner_prop_field = properties_field.type.field(field_idx) - batch = batch.append_column( - inner_prop_field, pc.struct_field(properties_column, field_idx) - ) - - batch = batch.drop_columns( - [ - "properties", - ] +def convert_tuples_to_lists(t: Union[list, tuple]) -> List[Any]: + """Convert tuples to lists, recursively + + For example, converts: + ``` + ( + ( + (-112.4820566, 38.1261015), + (-112.4816283, 38.1331311), + (-112.4833551, 38.1338897), + (-112.4832919, 38.1307687), + (-112.4855415, 38.1291793), + (-112.4820566, 38.1261015), + ), ) - return batch + ``` + to -def _convert_geometry_to_wkb( - batch: pa.RecordBatch, -) -> pa.RecordBatch: - """Convert the geometry column in the table to WKB""" - geoms = shapely.from_geojson( - [orjson.dumps(item) for item in batch["geometry"].to_pylist()] - ) - wkb_geoms = shapely.to_wkb(geoms) - return batch.drop_columns( + ```py + [ [ - "geometry", + [-112.4820566, 38.1261015], + [-112.4816283, 38.1331311], + [-112.4833551, 38.1338897], + [-112.4832919, 38.1307687], + [-112.4855415, 38.1291793], + [-112.4820566, 38.1261015], ] - ).append_column("geometry", pa.array(wkb_geoms)) - - -def _convert_timestamp_columns( - batch: pa.RecordBatch, -) -> pa.RecordBatch: - """Convert all timestamp columns from a string to an Arrow Timestamp data type""" - allowed_column_names = { - "datetime", # common metadata - "start_datetime", - "end_datetime", - "created", - "updated", - "expires", # timestamps extension - "published", - "unpublished", - } - for column_name in allowed_column_names: - try: - column = batch[column_name] - except KeyError: - continue - - field_index = batch.schema.get_field_index(column_name) - - if pa.types.is_timestamp(column.type): - continue - - # STAC allows datetimes to be null. If all rows are null, the column type may be - # inferred as null. We cast this to a timestamp column. - elif pa.types.is_null(column.type): - batch = batch.set_column( - field_index, column_name, column.cast(pa.timestamp("us")) - ) - - elif pa.types.is_string(column.type): - batch = batch.set_column( - field_index, column_name, _convert_timestamp_column(column) - ) - else: - raise ValueError( - f"Inferred time column '{column_name}' was expected to be a string or" - f" timestamp data type but got {column.type}" - ) - - return batch - - -def _convert_timestamp_column(column: pa.Array) -> pa.TimestampArray: - """Convert an individual timestamp column from string to a Timestamp type""" - return pa.array( - [ciso8601.parse_rfc3339(str(t)) for t in column], pa.timestamp("us", tz="UTC") - ) - - -def _is_bbox_3d(bbox_col: pa.Array) -> bool: - """Infer whether the bounding box column represents 2d or 3d bounding boxes.""" - offsets_set = set() - offsets = bbox_col.offsets.to_numpy() - offsets_set.update(np.unique(offsets[1:] - offsets[:-1])) - - if len(offsets_set) > 1: - raise ValueError("Mixed 2d-3d bounding boxes not yet supported") - - offset = list(offsets_set)[0] - if offset == 6: - return True - elif offset == 4: - return False - else: - raise ValueError(f"Unexpected bbox offset: {offset=}") - + ] + ``` -def _convert_bbox_to_struct(batch: pa.RecordBatch) -> pa.RecordBatch: - """Convert bbox column to a struct representation - - Since the bbox in JSON is stored as an array, pyarrow automatically converts the - bbox column to a ListArray. But according to GeoParquet 1.1, we should save the bbox - column as a StructArray, which allows for Parquet statistics to infer any spatial - partitioning in the dataset. - - Args: - batch: _description_ - - Returns: - New record batch + From https://stackoverflow.com/a/1014669. """ - bbox_col_idx = batch.schema.get_field_index("bbox") - bbox_col = batch.column(bbox_col_idx) - bbox_3d = _is_bbox_3d(bbox_col) - - assert ( - pa.types.is_list(bbox_col.type) - or pa.types.is_large_list(bbox_col.type) - or pa.types.is_fixed_size_list(bbox_col.type) - ) - if bbox_3d: - coords = bbox_col.flatten().to_numpy().reshape(-1, 6) - else: - coords = bbox_col.flatten().to_numpy().reshape(-1, 4) - - if bbox_3d: - xmin = coords[:, 0] - ymin = coords[:, 1] - zmin = coords[:, 2] - xmax = coords[:, 3] - ymax = coords[:, 4] - zmax = coords[:, 5] - - struct_arr = pa.StructArray.from_arrays( - [ - xmin, - ymin, - zmin, - xmax, - ymax, - zmax, - ], - names=[ - "xmin", - "ymin", - "zmin", - "xmax", - "ymax", - "zmax", - ], - ) + return list(map(convert_tuples_to_lists, t)) if isinstance(t, (list, tuple)) else t - else: - xmin = coords[:, 0] - ymin = coords[:, 1] - xmax = coords[:, 2] - ymax = coords[:, 3] - struct_arr = pa.StructArray.from_arrays( - [ - xmin, - ymin, - xmax, - ymax, - ], - names=[ - "xmin", - "ymin", - "xmax", - "ymax", - ], - ) - - return batch.set_column(bbox_col_idx, "bbox", struct_arr) +def get_by_path(root: Dict[str, Any], keys: Sequence[str]) -> Any: + """Access a nested object in root by item sequence. + From https://stackoverflow.com/a/14692747 + """ + return reduce(operator.getitem, keys, root) -def _assign_geoarrow_metadata( - batch: pa.RecordBatch, -) -> pa.RecordBatch: - """Tag the primary geometry column with `geoarrow.wkb` on the field metadata.""" - existing_field_idx = batch.schema.get_field_index("geometry") - existing_field = batch.schema.field(existing_field_idx) - ext_metadata = {"crs": WGS84_CRS_JSON} - field_metadata = { - b"ARROW:extension:name": b"geoarrow.wkb", - b"ARROW:extension:metadata": orjson.dumps(ext_metadata), - } - new_field = existing_field.with_metadata(field_metadata) - return batch.set_column( - existing_field_idx, new_field, batch.column(existing_field_idx) - ) +def set_by_path(root: Dict[str, Any], keys: Sequence[str], value: Any) -> None: + """Set a value in a nested object in root by item sequence. -def _process_arrow_batch(batch: pa.RecordBatch) -> pa.RecordBatch: - batch = _bring_properties_to_top_level(batch) - batch = _convert_timestamp_columns(batch) - batch = _convert_bbox_to_struct(batch) - batch = _assign_geoarrow_metadata(batch) - return batch + From https://stackoverflow.com/a/14692747 + """ + get_by_path(root, keys[:-1])[keys[-1]] = value # type: ignore From 7fabc9ae7277c0a13af61d3719a3717142534d23 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Sun, 2 Jun 2024 19:49:52 +0200 Subject: [PATCH 03/10] Move _from_arrow functions to _api --- stac_geoparquet/arrow/__init__.py | 4 +-- stac_geoparquet/arrow/_api.py | 45 ++++++++++++++++++++++++++++ stac_geoparquet/arrow/_batch.py | 2 +- stac_geoparquet/arrow/_from_arrow.py | 34 +-------------------- 4 files changed, 49 insertions(+), 36 deletions(-) create mode 100644 stac_geoparquet/arrow/_api.py diff --git a/stac_geoparquet/arrow/__init__.py b/stac_geoparquet/arrow/__init__.py index ee781a3..31a55ee 100644 --- a/stac_geoparquet/arrow/__init__.py +++ b/stac_geoparquet/arrow/__init__.py @@ -1,3 +1,3 @@ -from ._from_arrow import stac_table_to_items, stac_table_to_ndjson from ._to_arrow import parse_stac_items_to_arrow, parse_stac_ndjson_to_arrow -from ._to_parquet import to_parquet +from ._to_parquet import to_parquet, parse_stac_ndjson_to_parquet +from ._api import stac_table_to_items, stac_table_to_ndjson diff --git a/stac_geoparquet/arrow/_api.py b/stac_geoparquet/arrow/_api.py new file mode 100644 index 0000000..1051e20 --- /dev/null +++ b/stac_geoparquet/arrow/_api.py @@ -0,0 +1,45 @@ +import os +from typing import Any, Dict, Iterable, Optional, Union + +import pyarrow as pa + +from stac_geoparquet.arrow._batch import CleanBatch, RawBatch + + +def stac_items_to_arrow( + items: Iterable[Dict[str, Any]], *, schema: Optional[pa.Schema] = None +) -> pa.RecordBatch: + """Convert dicts representing STAC Items to Arrow + + This converts GeoJSON geometries to WKB before Arrow conversion to allow multiple + geometry types. + + All items will be parsed into a single RecordBatch, meaning that each internal array + is fully contiguous in memory for the length of `items`. + + Args: + items: STAC Items to convert to Arrow + + Kwargs: + schema: An optional schema that describes the format of the data. Note that this + must represent the geometry column as binary type. + + Returns: + Arrow RecordBatch with items in Arrow + """ + raw_batch = RawBatch.from_dicts(items, schema=schema) + return raw_batch.to_clean_batch().inner + + +def stac_table_to_items(table: pa.Table) -> Iterable[dict]: + """Convert a STAC Table to a generator of STAC Item `dict`s""" + for batch in table.to_batches(): + clean_batch = CleanBatch(batch) + yield from clean_batch.to_raw_batch().iter_dicts() + + +def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[bytes]]) -> None: + """Write a STAC Table to a newline-delimited JSON file.""" + for batch in table.to_batches(): + clean_batch = CleanBatch(batch) + clean_batch.to_raw_batch().to_ndjson(dest) diff --git a/stac_geoparquet/arrow/_batch.py b/stac_geoparquet/arrow/_batch.py index 950c9fe..adf900c 100644 --- a/stac_geoparquet/arrow/_batch.py +++ b/stac_geoparquet/arrow/_batch.py @@ -76,7 +76,7 @@ def from_dicts( as binary type. Returns: - _description_ + a new RawBatch of data. """ # Preprocess GeoJSON to WKB in each STAC item # Otherwise, pyarrow will try to parse coordinates into a native geometry type diff --git a/stac_geoparquet/arrow/_from_arrow.py b/stac_geoparquet/arrow/_from_arrow.py index a92f894..a0dbfbe 100644 --- a/stac_geoparquet/arrow/_from_arrow.py +++ b/stac_geoparquet/arrow/_from_arrow.py @@ -1,44 +1,12 @@ """Convert STAC Items in Arrow Table format to JSON Lines or Python dicts.""" -import orjson -import os -from typing import Iterable, List, Union +from typing import List import numpy as np import pyarrow as pa import pyarrow.compute as pc -def stac_batch_to_items(batch: pa.RecordBatch) -> Iterable[dict]: - """Convert a stac arrow recordbatch to item dicts.""" - batch = _undo_stac_transformations(batch) - - -def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[str]]) -> None: - """Write a STAC Table to a newline-delimited JSON file.""" - with open(dest, "wb") as f: - for item_dict in stac_table_to_items(table): - f.write(orjson.dumps(item_dict)) - f.write(b"\n") - - -def stac_table_to_items(table: pa.Table) -> Iterable[dict]: - """Convert a STAC Table to a generator of STAC Item `dict`s""" - for batch in table.to_batches(): - yield from stac_batch_to_items(batch) - - -def _undo_stac_transformations(batch: pa.RecordBatch) -> pa.RecordBatch: - """Undo the transformations done to convert STAC Json into an Arrow Table - - Note that this function does _not_ undo the GeoJSON -> WKB geometry transformation, - as that is easier to do when converting each item in the table to a dict. - """ - batch = lower_properties_from_top_level(batch) - batch = convert_bbox_to_array(batch) - return batch - - def convert_timestamp_columns_to_string(batch: pa.RecordBatch) -> pa.RecordBatch: """Convert any datetime columns in the table to a string representation""" allowed_column_names = { From 46295d0f24be9548f302ee4bb390de9efc78593e Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Sun, 2 Jun 2024 19:58:06 +0200 Subject: [PATCH 04/10] Update imports --- stac_geoparquet/arrow/__init__.py | 10 ++- stac_geoparquet/arrow/_api.py | 101 +++++++++++++++++---- stac_geoparquet/arrow/_schema/models.py | 4 +- stac_geoparquet/arrow/_to_arrow.py | 112 +++++------------------- stac_geoparquet/arrow/_to_parquet.py | 4 +- 5 files changed, 116 insertions(+), 115 deletions(-) diff --git a/stac_geoparquet/arrow/__init__.py b/stac_geoparquet/arrow/__init__.py index 31a55ee..c88deb2 100644 --- a/stac_geoparquet/arrow/__init__.py +++ b/stac_geoparquet/arrow/__init__.py @@ -1,3 +1,7 @@ -from ._to_arrow import parse_stac_items_to_arrow, parse_stac_ndjson_to_arrow -from ._to_parquet import to_parquet, parse_stac_ndjson_to_parquet -from ._api import stac_table_to_items, stac_table_to_ndjson +from ._api import ( + parse_stac_items_to_arrow, + parse_stac_ndjson_to_arrow, + stac_table_to_items, + stac_table_to_ndjson, +) +from ._to_parquet import parse_stac_ndjson_to_parquet, to_parquet diff --git a/stac_geoparquet/arrow/_api.py b/stac_geoparquet/arrow/_api.py index 1051e20..f83cdad 100644 --- a/stac_geoparquet/arrow/_api.py +++ b/stac_geoparquet/arrow/_api.py @@ -1,34 +1,99 @@ import os -from typing import Any, Dict, Iterable, Optional, Union +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, Optional, Union import pyarrow as pa -from stac_geoparquet.arrow._batch import CleanBatch, RawBatch +from stac_geoparquet.arrow._batch import CleanBatch +from stac_geoparquet.arrow._schema.models import InferredSchema +from stac_geoparquet.arrow._to_arrow import stac_items_to_arrow +from stac_geoparquet.arrow._util import batched_iter +from stac_geoparquet.json_reader import read_json_chunked -def stac_items_to_arrow( - items: Iterable[Dict[str, Any]], *, schema: Optional[pa.Schema] = None -) -> pa.RecordBatch: - """Convert dicts representing STAC Items to Arrow +def parse_stac_items_to_arrow( + items: Iterable[Dict[str, Any]], + *, + chunk_size: int = 8192, + schema: Optional[Union[pa.Schema, InferredSchema]] = None, +) -> Iterable[pa.RecordBatch]: + """Parse a collection of STAC Items to an iterable of :class:`pyarrow.RecordBatch`. - This converts GeoJSON geometries to WKB before Arrow conversion to allow multiple - geometry types. + The objects under `properties` are moved up to the top-level of the + Table, similar to :meth:`geopandas.GeoDataFrame.from_features`. + + Args: + items: the STAC Items to convert + chunk_size: The chunk size to use for Arrow record batches. This only takes + effect if `schema` is not None. When `schema` is None, the input will be + parsed into a single contiguous record batch. Defaults to 8192. + schema: The schema of the input data. If provided, can improve memory use; + otherwise all items need to be parsed into a single array for schema + inference. Defaults to None. + + Returns: + an iterable of pyarrow RecordBatches with the STAC-GeoParquet representation of items. + """ + if schema is not None: + if isinstance(schema, InferredSchema): + schema = schema.inner + + # If schema is provided, then for better memory usage we parse input STAC items + # to Arrow batches in chunks. + for chunk in batched_iter(items, chunk_size): + yield stac_items_to_arrow(chunk, schema=schema) + + else: + # If schema is _not_ provided, then we must convert to Arrow all at once, or + # else it would be possible for a STAC item late in the collection (after the + # first chunk) to have a different schema and not match the schema inferred for + # the first chunk. + yield stac_items_to_arrow(items) - All items will be parsed into a single RecordBatch, meaning that each internal array - is fully contiguous in memory for the length of `items`. + +def parse_stac_ndjson_to_arrow( + path: Union[str, Path, Iterable[Union[str, Path]]], + *, + chunk_size: int = 65536, + schema: Optional[pa.Schema] = None, + limit: Optional[int] = None, +) -> Iterator[pa.RecordBatch]: + """ + Convert one or more newline-delimited JSON STAC files to a generator of Arrow + RecordBatches. + + Each RecordBatch in the returned iterator is guaranteed to have an identical schema, + and can be used to write to one or more Parquet files. Args: - items: STAC Items to convert to Arrow + path: One or more paths to files with STAC items. + chunk_size: The chunk size. Defaults to 65536. + schema: The schema to represent the input STAC data. Defaults to None, in which + case the schema will first be inferred via a full pass over the input data. + In this case, there will be two full passes over the input data: one to + infer a common schema across all data and another to read the data. - Kwargs: - schema: An optional schema that describes the format of the data. Note that this - must represent the geometry column as binary type. + Other args: + limit: The maximum number of JSON Items to use for schema inference - Returns: - Arrow RecordBatch with items in Arrow + Yields: + Arrow RecordBatch with a single chunk of Item data. """ - raw_batch = RawBatch.from_dicts(items, schema=schema) - return raw_batch.to_clean_batch().inner + # If the schema was not provided, then we need to load all data into memory at once + # to perform schema resolution. + if schema is None: + inferred_schema = InferredSchema() + inferred_schema.update_from_json(path, chunk_size=chunk_size, limit=limit) + yield from parse_stac_ndjson_to_arrow( + path, chunk_size=chunk_size, schema=inferred_schema + ) + return + + if isinstance(schema, InferredSchema): + schema = schema.inner + + for batch in read_json_chunked(path, chunk_size=chunk_size): + yield stac_items_to_arrow(batch, schema=schema) def stac_table_to_items(table: pa.Table) -> Iterable[dict]: diff --git a/stac_geoparquet/arrow/_schema/models.py b/stac_geoparquet/arrow/_schema/models.py index 06fcbd2..adf5eaf 100644 --- a/stac_geoparquet/arrow/_schema/models.py +++ b/stac_geoparquet/arrow/_schema/models.py @@ -3,7 +3,7 @@ import pyarrow as pa -from stac_geoparquet.arrow._util import stac_items_to_arrow +from stac_geoparquet.arrow._batch import RawBatch from stac_geoparquet.json_reader import read_json_chunked @@ -48,7 +48,7 @@ def update_from_json( def update_from_items(self, items: Sequence[Dict[str, Any]]) -> None: """Update this inferred schema from a sequence of STAC Items.""" self.count += len(items) - current_schema = stac_items_to_arrow(items, schema=None).schema + current_schema = RawBatch.from_dicts(items, schema=None).inner.schema new_schema = pa.unify_schemas( [self.inner, current_schema], promote_options="permissive" ) diff --git a/stac_geoparquet/arrow/_to_arrow.py b/stac_geoparquet/arrow/_to_arrow.py index 8efcdf3..4c24a84 100644 --- a/stac_geoparquet/arrow/_to_arrow.py +++ b/stac_geoparquet/arrow/_to_arrow.py @@ -1,109 +1,40 @@ """Convert STAC data into Arrow tables""" -from pathlib import Path -from typing import ( - Any, - Dict, - Iterable, - Iterator, - Optional, - Union, -) +from typing import Any, Dict, Iterable, Optional import ciso8601 import numpy as np import orjson import pyarrow as pa +import pyarrow.compute as pc +from stac_geoparquet.arrow._batch import RawBatch from stac_geoparquet.arrow._crs import WGS84_CRS_JSON -from stac_geoparquet.arrow._schema.models import InferredSchema -from stac_geoparquet.arrow._util import batched_iter -from stac_geoparquet.json_reader import read_json_chunked -def parse_stac_items_to_arrow( - items: Iterable[Dict[str, Any]], - *, - chunk_size: int = 8192, - schema: Optional[Union[pa.Schema, InferredSchema]] = None, -) -> Iterable[pa.RecordBatch]: - """Parse a collection of STAC Items to an iterable of :class:`pyarrow.RecordBatch`. - - The objects under `properties` are moved up to the top-level of the - Table, similar to :meth:`geopandas.GeoDataFrame.from_features`. - - Args: - items: the STAC Items to convert - chunk_size: The chunk size to use for Arrow record batches. This only takes - effect if `schema` is not None. When `schema` is None, the input will be - parsed into a single contiguous record batch. Defaults to 8192. - schema: The schema of the input data. If provided, can improve memory use; - otherwise all items need to be parsed into a single array for schema - inference. Defaults to None. - - Returns: - an iterable of pyarrow RecordBatches with the STAC-GeoParquet representation of items. - """ - if schema is not None: - if isinstance(schema, InferredSchema): - schema = schema.inner - - # If schema is provided, then for better memory usage we parse input STAC items - # to Arrow batches in chunks. - for chunk in batched_iter(items, chunk_size): - yield stac_items_to_arrow(chunk, schema=schema) +def stac_items_to_arrow( + items: Iterable[Dict[str, Any]], *, schema: Optional[pa.Schema] = None +) -> pa.RecordBatch: + """Convert dicts representing STAC Items to Arrow - else: - # If schema is _not_ provided, then we must convert to Arrow all at once, or - # else it would be possible for a STAC item late in the collection (after the - # first chunk) to have a different schema and not match the schema inferred for - # the first chunk. - yield stac_items_to_arrow(items) - - -def parse_stac_ndjson_to_arrow( - path: Union[str, Path, Iterable[Union[str, Path]]], - *, - chunk_size: int = 65536, - schema: Optional[pa.Schema] = None, - limit: Optional[int] = None, -) -> Iterator[pa.RecordBatch]: - """ - Convert one or more newline-delimited JSON STAC files to a generator of Arrow - RecordBatches. + This converts GeoJSON geometries to WKB before Arrow conversion to allow multiple + geometry types. - Each RecordBatch in the returned iterator is guaranteed to have an identical schema, - and can be used to write to one or more Parquet files. + All items will be parsed into a single RecordBatch, meaning that each internal array + is fully contiguous in memory for the length of `items`. Args: - path: One or more paths to files with STAC items. - chunk_size: The chunk size. Defaults to 65536. - schema: The schema to represent the input STAC data. Defaults to None, in which - case the schema will first be inferred via a full pass over the input data. - In this case, there will be two full passes over the input data: one to - infer a common schema across all data and another to read the data. - - Other args: - limit: The maximum number of JSON Items to use for schema inference - - Yields: - Arrow RecordBatch with a single chunk of Item data. - """ - # If the schema was not provided, then we need to load all data into memory at once - # to perform schema resolution. - if schema is None: - inferred_schema = InferredSchema() - inferred_schema.update_from_json(path, chunk_size=chunk_size, limit=limit) - yield from parse_stac_ndjson_to_arrow( - path, chunk_size=chunk_size, schema=inferred_schema - ) - return + items: STAC Items to convert to Arrow - if isinstance(schema, InferredSchema): - schema = schema.inner + Kwargs: + schema: An optional schema that describes the format of the data. Note that this + must represent the geometry column as binary type. - for batch in read_json_chunked(path, chunk_size=chunk_size): - yield stac_items_to_arrow(batch, schema=schema) + Returns: + Arrow RecordBatch with items in Arrow + """ + raw_batch = RawBatch.from_dicts(items, schema=schema) + return raw_batch.to_clean_batch().inner def bring_properties_to_top_level( @@ -116,7 +47,8 @@ def bring_properties_to_top_level( for field_idx in range(properties_field.type.num_fields): inner_prop_field = properties_field.type.field(field_idx) batch = batch.append_column( - inner_prop_field, pc.struct_field(properties_column, field_idx) + inner_prop_field, + pc.struct_field(properties_column, field_idx), # type: ignore ) batch = batch.drop_columns( diff --git a/stac_geoparquet/arrow/_to_parquet.py b/stac_geoparquet/arrow/_to_parquet.py index 294e216..7197d16 100644 --- a/stac_geoparquet/arrow/_to_parquet.py +++ b/stac_geoparquet/arrow/_to_parquet.py @@ -5,9 +5,9 @@ import pyarrow as pa import pyarrow.parquet as pq -from stac_geoparquet.arrow._schema.models import InferredSchema -from stac_geoparquet.arrow._to_arrow import parse_stac_ndjson_to_arrow +from stac_geoparquet.arrow._api import parse_stac_ndjson_to_arrow from stac_geoparquet.arrow._crs import WGS84_CRS_JSON +from stac_geoparquet.arrow._schema.models import InferredSchema def parse_stac_ndjson_to_parquet( From fa226d453bcfb9be4dfb609b0d233cf4f864c77e Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Mon, 3 Jun 2024 09:43:10 +0200 Subject: [PATCH 05/10] fix circular import --- stac_geoparquet/arrow/_api.py | 28 ++++++++++++++++++++++++++-- stac_geoparquet/arrow/_to_arrow.py | 28 ---------------------------- 2 files changed, 26 insertions(+), 30 deletions(-) diff --git a/stac_geoparquet/arrow/_api.py b/stac_geoparquet/arrow/_api.py index f83cdad..27e3276 100644 --- a/stac_geoparquet/arrow/_api.py +++ b/stac_geoparquet/arrow/_api.py @@ -4,9 +4,8 @@ import pyarrow as pa -from stac_geoparquet.arrow._batch import CleanBatch +from stac_geoparquet.arrow._batch import CleanBatch, RawBatch from stac_geoparquet.arrow._schema.models import InferredSchema -from stac_geoparquet.arrow._to_arrow import stac_items_to_arrow from stac_geoparquet.arrow._util import batched_iter from stac_geoparquet.json_reader import read_json_chunked @@ -108,3 +107,28 @@ def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[bytes]]) for batch in table.to_batches(): clean_batch = CleanBatch(batch) clean_batch.to_raw_batch().to_ndjson(dest) + + +def stac_items_to_arrow( + items: Iterable[Dict[str, Any]], *, schema: Optional[pa.Schema] = None +) -> pa.RecordBatch: + """Convert dicts representing STAC Items to Arrow + + This converts GeoJSON geometries to WKB before Arrow conversion to allow multiple + geometry types. + + All items will be parsed into a single RecordBatch, meaning that each internal array + is fully contiguous in memory for the length of `items`. + + Args: + items: STAC Items to convert to Arrow + + Kwargs: + schema: An optional schema that describes the format of the data. Note that this + must represent the geometry column as binary type. + + Returns: + Arrow RecordBatch with items in Arrow + """ + raw_batch = RawBatch.from_dicts(items, schema=schema) + return raw_batch.to_clean_batch().inner diff --git a/stac_geoparquet/arrow/_to_arrow.py b/stac_geoparquet/arrow/_to_arrow.py index 4c24a84..38e1511 100644 --- a/stac_geoparquet/arrow/_to_arrow.py +++ b/stac_geoparquet/arrow/_to_arrow.py @@ -1,42 +1,14 @@ """Convert STAC data into Arrow tables""" -from typing import Any, Dict, Iterable, Optional - import ciso8601 import numpy as np import orjson import pyarrow as pa import pyarrow.compute as pc -from stac_geoparquet.arrow._batch import RawBatch from stac_geoparquet.arrow._crs import WGS84_CRS_JSON -def stac_items_to_arrow( - items: Iterable[Dict[str, Any]], *, schema: Optional[pa.Schema] = None -) -> pa.RecordBatch: - """Convert dicts representing STAC Items to Arrow - - This converts GeoJSON geometries to WKB before Arrow conversion to allow multiple - geometry types. - - All items will be parsed into a single RecordBatch, meaning that each internal array - is fully contiguous in memory for the length of `items`. - - Args: - items: STAC Items to convert to Arrow - - Kwargs: - schema: An optional schema that describes the format of the data. Note that this - must represent the geometry column as binary type. - - Returns: - Arrow RecordBatch with items in Arrow - """ - raw_batch = RawBatch.from_dicts(items, schema=schema) - return raw_batch.to_clean_batch().inner - - def bring_properties_to_top_level( batch: pa.RecordBatch, ) -> pa.RecordBatch: From cc7beec5e6533e7baee69202675c92eb841dc323 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Mon, 3 Jun 2024 09:56:05 +0200 Subject: [PATCH 06/10] keep deprecated api --- stac_geoparquet/arrow/_batch.py | 2 +- stac_geoparquet/from_arrow.py | 5 ++++- stac_geoparquet/to_arrow.py | 2 +- tests/test_arrow.py | 22 +++++++++++++++++++++- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/stac_geoparquet/arrow/_batch.py b/stac_geoparquet/arrow/_batch.py index adf900c..927ba86 100644 --- a/stac_geoparquet/arrow/_batch.py +++ b/stac_geoparquet/arrow/_batch.py @@ -27,7 +27,7 @@ convert_timestamp_columns, ) from stac_geoparquet.arrow._util import convert_tuples_to_lists, set_by_path -from stac_geoparquet.from_arrow import ( +from stac_geoparquet.arrow._from_arrow import ( convert_bbox_to_array, convert_timestamp_columns_to_string, lower_properties_from_top_level, diff --git a/stac_geoparquet/from_arrow.py b/stac_geoparquet/from_arrow.py index dc19fca..2af5920 100644 --- a/stac_geoparquet/from_arrow.py +++ b/stac_geoparquet/from_arrow.py @@ -5,4 +5,7 @@ FutureWarning, ) -from stac_geoparquet.arrow._from_arrow import * # noqa + +from stac_geoparquet.arrow._api import stac_items_to_arrow # noqa +from stac_geoparquet.arrow._api import stac_table_to_items # noqa +from stac_geoparquet.arrow._api import stac_table_to_ndjson # noqa diff --git a/stac_geoparquet/to_arrow.py b/stac_geoparquet/to_arrow.py index 9b3f81d..923acd8 100644 --- a/stac_geoparquet/to_arrow.py +++ b/stac_geoparquet/to_arrow.py @@ -5,4 +5,4 @@ FutureWarning, ) -from stac_geoparquet.arrow._to_arrow import * # noqa +from stac_geoparquet.arrow import parse_stac_items_to_arrow, parse_stac_ndjson_to_arrow # noqa diff --git a/tests/test_arrow.py b/tests/test_arrow.py index 94e5bfa..4e48efd 100644 --- a/tests/test_arrow.py +++ b/tests/test_arrow.py @@ -4,7 +4,11 @@ import pyarrow as pa import pytest -from stac_geoparquet.arrow import parse_stac_items_to_arrow, stac_table_to_items +from stac_geoparquet.arrow import ( + parse_stac_items_to_arrow, + parse_stac_ndjson_to_arrow, + stac_table_to_items, +) from .json_equals import assert_json_value_equal @@ -55,6 +59,22 @@ def test_table_contains_geoarrow_metadata(): } +@pytest.mark.parametrize( + "collection_id", + TEST_COLLECTIONS, +) +def test_parse_json_to_arrow(collection_id: str): + path = HERE / "data" / f"{collection_id}-pc.json" + table = pa.Table.from_batches(parse_stac_ndjson_to_arrow(path)) + items_result = list(stac_table_to_items(table)) + + with open(HERE / "data" / f"{collection_id}-pc.json") as f: + items = json.load(f) + + for result, expected in zip(items_result, items): + assert_json_value_equal(result, expected, precision=0) + + def test_to_arrow_deprecated(): with pytest.warns(FutureWarning): import stac_geoparquet.to_arrow From 6060644d3c722e3e42a5ea188615ca5ab9407709 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Mon, 3 Jun 2024 10:09:50 +0200 Subject: [PATCH 07/10] Add write-read test and fix typing --- stac_geoparquet/arrow/_api.py | 4 +++- stac_geoparquet/arrow/_batch.py | 13 +++++++------ tests/test_arrow.py | 26 +++++++++++++++++--------- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/stac_geoparquet/arrow/_api.py b/stac_geoparquet/arrow/_api.py index 27e3276..b16dbf0 100644 --- a/stac_geoparquet/arrow/_api.py +++ b/stac_geoparquet/arrow/_api.py @@ -102,7 +102,9 @@ def stac_table_to_items(table: pa.Table) -> Iterable[dict]: yield from clean_batch.to_raw_batch().iter_dicts() -def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[bytes]]) -> None: +def stac_table_to_ndjson( + table: pa.Table, dest: Union[str, Path, os.PathLike[bytes]] +) -> None: """Write a STAC Table to a newline-delimited JSON file.""" for batch in table.to_batches(): clean_batch = CleanBatch(batch) diff --git a/stac_geoparquet/arrow/_batch.py b/stac_geoparquet/arrow/_batch.py index 927ba86..80c4af7 100644 --- a/stac_geoparquet/arrow/_batch.py +++ b/stac_geoparquet/arrow/_batch.py @@ -2,6 +2,7 @@ import os from copy import deepcopy +from pathlib import Path from typing import ( Any, Dict, @@ -20,6 +21,11 @@ from numpy.typing import NDArray from typing_extensions import Self +from stac_geoparquet.arrow._from_arrow import ( + convert_bbox_to_array, + convert_timestamp_columns_to_string, + lower_properties_from_top_level, +) from stac_geoparquet.arrow._to_arrow import ( assign_geoarrow_metadata, bring_properties_to_top_level, @@ -27,11 +33,6 @@ convert_timestamp_columns, ) from stac_geoparquet.arrow._util import convert_tuples_to_lists, set_by_path -from stac_geoparquet.arrow._from_arrow import ( - convert_bbox_to_array, - convert_timestamp_columns_to_string, - lower_properties_from_top_level, -) class RawBatch: @@ -167,7 +168,7 @@ def to_clean_batch(self) -> CleanBatch: return CleanBatch(batch) - def to_ndjson(self, dest: Union[str, os.PathLike[bytes]]) -> None: + def to_ndjson(self, dest: Union[str, Path, os.PathLike[bytes]]) -> None: with open(dest, "ab") as f: for item_dict in self.iter_dicts(): f.write(orjson.dumps(item_dict)) diff --git a/tests/test_arrow.py b/tests/test_arrow.py index 4e48efd..df7f9b7 100644 --- a/tests/test_arrow.py +++ b/tests/test_arrow.py @@ -8,6 +8,7 @@ parse_stac_items_to_arrow, parse_stac_ndjson_to_arrow, stac_table_to_items, + stac_table_to_ndjson, ) from .json_equals import assert_json_value_equal @@ -30,11 +31,8 @@ ] -@pytest.mark.parametrize( - "collection_id", - TEST_COLLECTIONS, -) -def test_round_trip(collection_id: str): +@pytest.mark.parametrize("collection_id", TEST_COLLECTIONS) +def test_round_trip_read_write(collection_id: str): with open(HERE / "data" / f"{collection_id}-pc.json") as f: items = json.load(f) @@ -45,6 +43,19 @@ def test_round_trip(collection_id: str): assert_json_value_equal(result, expected, precision=0) +@pytest.mark.parametrize("collection_id", TEST_COLLECTIONS) +def test_round_trip_write_read_ndjson(collection_id: str, tmp_path: Path): + # First load into a STAC-GeoParquet table + path = HERE / "data" / f"{collection_id}-pc.json" + table = pa.Table.from_batches(parse_stac_ndjson_to_arrow(path)) + + # Then write to disk + stac_table_to_ndjson(table, tmp_path / "tmp.ndjson") + + # Then read back and assert tables match + table = pa.Table.from_batches(parse_stac_ndjson_to_arrow(tmp_path / "tmp.ndjson")) + + def test_table_contains_geoarrow_metadata(): collection_id = "naip" with open(HERE / "data" / f"{collection_id}-pc.json") as f: @@ -59,10 +70,7 @@ def test_table_contains_geoarrow_metadata(): } -@pytest.mark.parametrize( - "collection_id", - TEST_COLLECTIONS, -) +@pytest.mark.parametrize("collection_id", TEST_COLLECTIONS) def test_parse_json_to_arrow(collection_id: str): path = HERE / "data" / f"{collection_id}-pc.json" table = pa.Table.from_batches(parse_stac_ndjson_to_arrow(path)) From 4c5d08b1906ca12b3b7d58e560bc8bd845d03b44 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Mon, 3 Jun 2024 10:54:01 +0200 Subject: [PATCH 08/10] add parquet tests --- tests/test_arrow.py | 7 ------ tests/test_parquet.py | 53 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 tests/test_parquet.py diff --git a/tests/test_arrow.py b/tests/test_arrow.py index df7f9b7..f51787b 100644 --- a/tests/test_arrow.py +++ b/tests/test_arrow.py @@ -89,13 +89,6 @@ def test_to_arrow_deprecated(): stac_geoparquet.to_arrow.parse_stac_items_to_arrow -def test_to_parquet_deprecated(): - with pytest.warns(FutureWarning): - import stac_geoparquet.to_parquet - - stac_geoparquet.to_parquet.to_parquet - - def test_from_arrow_deprecated(): with pytest.warns(FutureWarning): import stac_geoparquet.from_arrow diff --git a/tests/test_parquet.py b/tests/test_parquet.py new file mode 100644 index 0000000..10dd938 --- /dev/null +++ b/tests/test_parquet.py @@ -0,0 +1,53 @@ +import json +from pathlib import Path + +import pyarrow.parquet as pq +import pytest + +from stac_geoparquet.arrow import parse_stac_ndjson_to_parquet, stac_table_to_items + +from .json_equals import assert_json_value_equal + +HERE = Path(__file__).parent + + +def test_to_parquet_deprecated(): + with pytest.warns(FutureWarning): + import stac_geoparquet.to_parquet + + stac_geoparquet.to_parquet.to_parquet + + +TEST_COLLECTIONS = [ + "3dep-lidar-copc", + "3dep-lidar-dsm", + "cop-dem-glo-30", + "io-lulc-annual-v02", + "io-lulc", + "landsat-c2-l1", + "landsat-c2-l2", + "naip", + "planet-nicfi-analytic", + "sentinel-1-rtc", + "sentinel-2-l2a", + "us-census", +] + + +@pytest.mark.parametrize("collection_id", TEST_COLLECTIONS) +def test_round_trip_via_parquet(collection_id: str, tmp_path: Path): + path = HERE / "data" / f"{collection_id}-pc.json" + out_path = tmp_path / "file.parquet" + # Convert to Parquet + parse_stac_ndjson_to_parquet(path, out_path) + + # Read back into table and convert to json + table = pq.read_table(out_path) + items_result = list(stac_table_to_items(table)) + + # Compare with original json + with open(HERE / "data" / f"{collection_id}-pc.json") as f: + items = json.load(f) + + for result, expected in zip(items_result, items): + assert_json_value_equal(result, expected, precision=0) From 14a6bc928902ae730bb0b5222000c1a31339c9f2 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Mon, 3 Jun 2024 11:01:36 +0200 Subject: [PATCH 09/10] fix ci --- stac_geoparquet/arrow/_batch.py | 15 ++++----------- stac_geoparquet/to_arrow.py | 10 ++++++++-- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/stac_geoparquet/arrow/_batch.py b/stac_geoparquet/arrow/_batch.py index 80c4af7..6b01a14 100644 --- a/stac_geoparquet/arrow/_batch.py +++ b/stac_geoparquet/arrow/_batch.py @@ -3,14 +3,7 @@ import os from copy import deepcopy from pathlib import Path -from typing import ( - Any, - Dict, - Iterable, - List, - Optional, - Union, -) +from typing import Any, Iterable import numpy as np import orjson @@ -61,7 +54,7 @@ def __init__(self, batch: pa.RecordBatch) -> None: @classmethod def from_dicts( - cls, items: Iterable[Dict[str, Any]], *, schema: Optional[pa.Schema] = None + cls, items: Iterable[dict[str, Any]], *, schema: pa.Schema | None = None ) -> Self: """Construct a RawBatch from an iterable of dicts representing STAC items. @@ -133,7 +126,7 @@ def iter_dicts(self) -> Iterable[dict]: # Convert each geometry column to a Shapely geometry, and then assign the # geojson geometry when converting each row to a dictionary. - geometries: List[NDArray[np.object_]] = [] + geometries: list[NDArray[np.object_]] = [] for geometry_path in geometry_paths: col = batch for path_segment in geometry_path: @@ -168,7 +161,7 @@ def to_clean_batch(self) -> CleanBatch: return CleanBatch(batch) - def to_ndjson(self, dest: Union[str, Path, os.PathLike[bytes]]) -> None: + def to_ndjson(self, dest: str | Path | os.PathLike[bytes]) -> None: with open(dest, "ab") as f: for item_dict in self.iter_dicts(): f.write(orjson.dumps(item_dict)) diff --git a/stac_geoparquet/to_arrow.py b/stac_geoparquet/to_arrow.py index 923acd8..2802d6e 100644 --- a/stac_geoparquet/to_arrow.py +++ b/stac_geoparquet/to_arrow.py @@ -1,8 +1,14 @@ +# This doesn't work inline on these imports for some reason +# flake8: noqa: F401 + import warnings +from stac_geoparquet.arrow._api import ( + parse_stac_items_to_arrow, + parse_stac_ndjson_to_arrow, +) + warnings.warn( "stac_geoparquet.to_arrow is deprecated. Please use stac_geoparquet.arrow instead.", FutureWarning, ) - -from stac_geoparquet.arrow import parse_stac_items_to_arrow, parse_stac_ndjson_to_arrow # noqa From 7b83081343cd487276d0f5d1e96329349f1c22f5 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Tue, 4 Jun 2024 12:29:19 -0400 Subject: [PATCH 10/10] Rename wrapper types --- stac_geoparquet/arrow/_api.py | 8 ++++---- stac_geoparquet/arrow/_batch.py | 16 ++++++++-------- stac_geoparquet/arrow/_schema/models.py | 4 ++-- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/stac_geoparquet/arrow/_api.py b/stac_geoparquet/arrow/_api.py index b16dbf0..2fee475 100644 --- a/stac_geoparquet/arrow/_api.py +++ b/stac_geoparquet/arrow/_api.py @@ -4,7 +4,7 @@ import pyarrow as pa -from stac_geoparquet.arrow._batch import CleanBatch, RawBatch +from stac_geoparquet.arrow._batch import StacArrowBatch, StacJsonBatch from stac_geoparquet.arrow._schema.models import InferredSchema from stac_geoparquet.arrow._util import batched_iter from stac_geoparquet.json_reader import read_json_chunked @@ -98,7 +98,7 @@ def parse_stac_ndjson_to_arrow( def stac_table_to_items(table: pa.Table) -> Iterable[dict]: """Convert a STAC Table to a generator of STAC Item `dict`s""" for batch in table.to_batches(): - clean_batch = CleanBatch(batch) + clean_batch = StacArrowBatch(batch) yield from clean_batch.to_raw_batch().iter_dicts() @@ -107,7 +107,7 @@ def stac_table_to_ndjson( ) -> None: """Write a STAC Table to a newline-delimited JSON file.""" for batch in table.to_batches(): - clean_batch = CleanBatch(batch) + clean_batch = StacArrowBatch(batch) clean_batch.to_raw_batch().to_ndjson(dest) @@ -132,5 +132,5 @@ def stac_items_to_arrow( Returns: Arrow RecordBatch with items in Arrow """ - raw_batch = RawBatch.from_dicts(items, schema=schema) + raw_batch = StacJsonBatch.from_dicts(items, schema=schema) return raw_batch.to_clean_batch().inner diff --git a/stac_geoparquet/arrow/_batch.py b/stac_geoparquet/arrow/_batch.py index 6b01a14..7130cb2 100644 --- a/stac_geoparquet/arrow/_batch.py +++ b/stac_geoparquet/arrow/_batch.py @@ -28,7 +28,7 @@ from stac_geoparquet.arrow._util import convert_tuples_to_lists, set_by_path -class RawBatch: +class StacJsonBatch: """ An Arrow RecordBatch of STAC Items that has been **minimally converted** to Arrow. That is, it aligns as much as possible to the raw STAC JSON representation. @@ -56,7 +56,7 @@ def __init__(self, batch: pa.RecordBatch) -> None: def from_dicts( cls, items: Iterable[dict[str, Any]], *, schema: pa.Schema | None = None ) -> Self: - """Construct a RawBatch from an iterable of dicts representing STAC items. + """Construct a StacJsonBatch from an iterable of dicts representing STAC items. All items will be parsed into a single RecordBatch, meaning that each internal array is fully contiguous in memory for the length of `items`. @@ -70,7 +70,7 @@ def from_dicts( as binary type. Returns: - a new RawBatch of data. + a new StacJsonBatch of data. """ # Preprocess GeoJSON to WKB in each STAC item # Otherwise, pyarrow will try to parse coordinates into a native geometry type @@ -151,7 +151,7 @@ def iter_dicts(self) -> Iterable[dict]: yield row_dict - def to_clean_batch(self) -> CleanBatch: + def to_clean_batch(self) -> StacArrowBatch: batch = self.inner batch = bring_properties_to_top_level(batch) @@ -159,7 +159,7 @@ def to_clean_batch(self) -> CleanBatch: batch = convert_bbox_to_struct(batch) batch = assign_geoarrow_metadata(batch) - return CleanBatch(batch) + return StacArrowBatch(batch) def to_ndjson(self, dest: str | Path | os.PathLike[bytes]) -> None: with open(dest, "ab") as f: @@ -168,7 +168,7 @@ def to_ndjson(self, dest: str | Path | os.PathLike[bytes]) -> None: f.write(b"\n") -class CleanBatch: +class StacArrowBatch: """ An Arrow RecordBatch of STAC Items that has been processed to match the STAC-GeoParquet specification. @@ -180,11 +180,11 @@ class CleanBatch: def __init__(self, batch: pa.RecordBatch) -> None: self.inner = batch - def to_raw_batch(self) -> RawBatch: + def to_raw_batch(self) -> StacJsonBatch: batch = self.inner batch = convert_timestamp_columns_to_string(batch) batch = lower_properties_from_top_level(batch) batch = convert_bbox_to_array(batch) - return RawBatch(batch) + return StacJsonBatch(batch) diff --git a/stac_geoparquet/arrow/_schema/models.py b/stac_geoparquet/arrow/_schema/models.py index adf5eaf..17fd169 100644 --- a/stac_geoparquet/arrow/_schema/models.py +++ b/stac_geoparquet/arrow/_schema/models.py @@ -3,7 +3,7 @@ import pyarrow as pa -from stac_geoparquet.arrow._batch import RawBatch +from stac_geoparquet.arrow._batch import StacJsonBatch from stac_geoparquet.json_reader import read_json_chunked @@ -48,7 +48,7 @@ def update_from_json( def update_from_items(self, items: Sequence[Dict[str, Any]]) -> None: """Update this inferred schema from a sequence of STAC Items.""" self.count += len(items) - current_schema = RawBatch.from_dicts(items, schema=None).inner.schema + current_schema = StacJsonBatch.from_dicts(items, schema=None).inner.schema new_schema = pa.unify_schemas( [self.inner, current_schema], promote_options="permissive" )