diff --git a/src/traffic/core/flight.py b/src/traffic/core/flight.py index 7513b08f..88e10a9b 100644 --- a/src/traffic/core/flight.py +++ b/src/traffic/core/flight.py @@ -57,6 +57,7 @@ if TYPE_CHECKING: import altair as alt import plotly.graph_objects as go + import pyarrow as pa from cartopy import crs from cartopy.mpl.geoaxes import GeoAxes from ipyleaflet import Map as LeafletMap @@ -679,13 +680,13 @@ def label( .. code:: python - flight.label(holding_pattern, holding=True) + flight.label('holding_pattern', holding=True) - Add a column `index` to enumerate holding patterns: .. code:: python - flight.label(holding_pattern, index="{i}") + flight.label('holding_pattern', index="{i}") - More complicated enriching: @@ -1903,6 +1904,69 @@ def forward( ).ffill() ) + def pyarrow_table( + self, + columns: None | list[str] = None, + geo: None | Literal["line2d", "line3d", "point2d", "point3d"] = None, + ) -> "pa.Table": + import pyarrow as pa + + len_ = len(self) + + default_columns = ["track", "groundspeed", "altitude", "vertical_rate"] + if columns is None: + columns = default_columns + else: + columns += default_columns + + entry = dict( + icao24=pa.array([self.icao24]), + timestamp=pa.ListArray.from_arrays([0, len_], self.data.timestamp), + ) + + if (callsign := self.callsign) is not None: + entry["callsign"] = pa.array([callsign]) + + if (flight_id := self.flight_id) is not None: + entry["flight_id"] = pa.array([flight_id]) + + # The usual features + for col in columns: + if col in self.data.columns: + entry[col] = pa.ListArray.from_arrays([0, len_], self.data[col]) + + pa_table = pa.table(entry) + + # The geometric feature + if geo is not None: + coords = [ + self.data.longitude, + self.data.latitude, + ] + if geo in ["line3d", "point3d"]: + coords.append(self.data.altitude) + + extension_name = dict( + line2d="geoarrow.linestring", + line3d="geoarrow.linestring", + point2d="geoarrow.multipoint", + point3d="geoarrow.multipoint", + ) + + np_coords = np.column_stack(coords).ravel("C") + pa_coords = pa.FixedSizeListArray.from_arrays( + np_coords, list_size=len(coords) + ) + pa_linestrings = pa.ListArray.from_arrays([0, len_], pa_coords) + pa_field = pa.field( + "geometry", + pa_linestrings.type, + metadata={"ARROW:extension:name": extension_name[geo]}, + ) + pa_table = pa_table.append_column(pa_field, pa_linestrings) + + return pa_table + # -- Air traffic management -- def assign_id( diff --git a/src/traffic/core/traffic.py b/src/traffic/core/traffic.py index 6220ca86..1c0a8428 100644 --- a/src/traffic/core/traffic.py +++ b/src/traffic/core/traffic.py @@ -3,6 +3,7 @@ import logging import warnings from datetime import timedelta +from itertools import accumulate from pathlib import Path from typing import ( TYPE_CHECKING, @@ -45,6 +46,7 @@ if TYPE_CHECKING: import plotly.graph_objects as go + import pyarrow as pa from cartopy import crs from cartopy.mpl.geoaxes import GeoAxesSubplot from ipyleaflet import Map as LeafletMap @@ -753,6 +755,10 @@ def all(self, *args, **kwargs): # type: ignore def next(self, *args, **kwargs): # type: ignore ... + @lazy_evaluation() + def label(self, *args, **kwargs): # type: ignore + ... + @lazy_evaluation() def final(self, *args, **kwargs): # type: ignore ... @@ -1047,6 +1053,69 @@ def basic_stats(self) -> pd.DataFrame: @lazy_evaluation() def summary(self, attributes: list[str]) -> pd.DataFrame: ... + def pyarrow_table( + self, + columns: None | list[str] = None, + geo: None | Literal["line2d", "line3d", "point2d", "point3d"] = None, + ) -> "pa.Table": + import pyarrow as pa + + summary = cast( + pd.DataFrame, + self.summary(["icao24", "callsign", "flight_id", "count"]).eval(), + ) + offsets = [0, *accumulate(summary["count"])] + + default_columns = ["track", "groundspeed", "altitude", "vertical_rate"] + if columns is None: + columns = default_columns + else: + columns += default_columns + + entry = dict( + icao24=pa.array(list(summary["icao24"])), + callsign=pa.array(list(summary["callsign"])), + flight_id=pa.array(list(summary["flight_id"])), + timestamp=pa.ListArray.from_arrays(offsets, self.data.timestamp), + ) + + # The usual features + for col in columns: + if col in self.data.columns: + entry[col] = pa.ListArray.from_arrays(offsets, self.data[col]) + + pa_table = pa.table(entry) + + # The geometric feature + if geo is not None: + coords = [ + self.data.longitude, + self.data.latitude, + ] + if geo in ["line3d", "point3d"]: + coords.append(self.data.altitude) + + extension_name = dict( + line2d="geoarrow.linestring", + line3d="geoarrow.linestring", + point2d="geoarrow.multipoint", + point3d="geoarrow.multipoint", + ) + + np_coords = np.column_stack(coords).ravel("C") + pa_coords = pa.FixedSizeListArray.from_arrays( + np_coords, list_size=len(coords) + ) + pa_linestrings = pa.ListArray.from_arrays(offsets, pa_coords) + pa_field = pa.field( + "geometry", + pa_linestrings.type, + metadata={"ARROW:extension:name": extension_name[geo]}, + ) + pa_table = pa_table.append_column(pa_field, pa_linestrings) + + return pa_table + def geoencode(self, *args: Any, **kwargs: Any) -> NoReturn: """ .. danger::