Skip to content

Commit cd33d6d

Browse files
authored
TYP: eliminate some typing issues (#59)
* eliminate some mypy errors * REF: eliminate some typing issues * avoid xarray internals * mypy in ci * more types * try different mypy command * fix errors * coverage fix
1 parent 17a1308 commit cd33d6d

File tree

7 files changed

+121
-93
lines changed

7 files changed

+121
-93
lines changed

.github/workflows/tests.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ jobs:
4848
id: status
4949
run: pytest -v . --cov=xvec --cov-append --cov-report term-missing --cov-report xml --color=yes --report-log pytest-log.jsonl
5050

51+
- name: run mypy
52+
if: contains(matrix.environment-file, 'ci/312.yaml') && contains(matrix.os, 'ubuntu')
53+
run: mypy xvec/ --install-types --ignore-missing-imports --non-interactive
54+
5155
- uses: codecov/codecov-action@v3
5256

5357
- name: Generate and publish the report

ci/312.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ dependencies:
1919
- geopandas-base
2020
- geodatasets
2121
- pyogrio
22+
- mypy
2223

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ omit = ["xvec/tests/*"]
4949
exclude_lines = [
5050
"except ImportError",
5151
"except PackageNotFoundError",
52+
"if TYPE_CHECKING:"
5253
]
5354

5455
[tool.ruff]

xvec/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from importlib.metadata import PackageNotFoundError, version
22

3-
from .accessor import XvecAccessor # noqa
4-
from .index import GeometryIndex # noqa
3+
from .accessor import XvecAccessor # noqa: F401
4+
from .index import GeometryIndex # noqa: F401
55

66
try:
77
__version__ = version("xvec")

xvec/accessor.py

Lines changed: 69 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from collections.abc import Hashable, Mapping, Sequence
5-
from typing import Any, Callable
5+
from typing import TYPE_CHECKING, Any, Callable, cast
66

77
import numpy as np
88
import pandas as pd
@@ -13,6 +13,9 @@
1313
from .index import GeometryIndex
1414
from .zonal import _zonal_stats_iterative, _zonal_stats_rasterize
1515

16+
if TYPE_CHECKING:
17+
from geopandas import GeoDataFrame
18+
1619

1720
@xr.register_dataarray_accessor("xvec")
1821
@xr.register_dataset_accessor("xvec")
@@ -22,7 +25,7 @@ class XvecAccessor:
2225
Currently works on coordinates with :class:`xvec.GeometryIndex`.
2326
"""
2427

25-
def __init__(self, xarray_obj: xr.Dataset | xr.DataArray):
28+
def __init__(self, xarray_obj: xr.Dataset | xr.DataArray) -> None:
2629
"""xvec init, nothing to be done here."""
2730
self._obj = xarray_obj
2831
self._geom_coords_all = [
@@ -36,7 +39,9 @@ def __init__(self, xarray_obj: xr.Dataset | xr.DataArray):
3639
if self.is_geom_variable(name, has_index=True)
3740
]
3841

39-
def is_geom_variable(self, name: Hashable, has_index: bool = True):
42+
def is_geom_variable(
43+
self, name: Hashable, has_index: bool = True
44+
) -> bool | np.bool_:
4045
"""Check if coordinate variable is composed of :class:`shapely.Geometry`.
4146
4247
Can return all such variables or only those using :class:`~xvec.GeometryIndex`.
@@ -208,7 +213,7 @@ def to_crs(
208213
self,
209214
variable_crs: Mapping[Any, Any] | None = None,
210215
**variable_crs_kwargs: Any,
211-
):
216+
) -> xr.DataArray | xr.Dataset:
212217
"""
213218
Transform :class:`shapely.Geometry` objects of a variable to a new coordinate
214219
reference system.
@@ -313,20 +318,15 @@ def to_crs(
313318
currently wraps :meth:`Dataset.assign_coords <xarray.Dataset.assign_coords>`
314319
or :meth:`DataArray.assign_coords <xarray.DataArray.assign_coords>`.
315320
"""
316-
if variable_crs and variable_crs_kwargs:
317-
raise ValueError(
318-
"Cannot specify both keyword and positional arguments to "
319-
"'.xvec.to_crs'."
320-
)
321+
variable_crs_solved = _resolve_input(
322+
variable_crs, variable_crs_kwargs, "to_crs"
323+
)
321324

322325
_obj = self._obj.copy(deep=False)
323326

324-
if variable_crs_kwargs:
325-
variable_crs = variable_crs_kwargs
326-
327327
transformed = {}
328328

329-
for key, crs in variable_crs.items():
329+
for key, crs in variable_crs_solved.items():
330330
if not isinstance(self._obj.xindexes[key], GeometryIndex):
331331
raise ValueError(
332332
f"The index '{key}' is not an xvec.GeometryIndex. "
@@ -335,7 +335,7 @@ def to_crs(
335335
)
336336

337337
data = _obj[key]
338-
data_crs = self._obj.xindexes[key].crs
338+
data_crs = self._obj.xindexes[key].crs # type: ignore
339339

340340
# transformation code taken from geopandas (BSD 3-clause license)
341341
if data_crs is None:
@@ -374,21 +374,21 @@ def to_crs(
374374
for key, (result, _crs) in transformed.items():
375375
_obj = _obj.assign_coords({key: result})
376376

377-
_obj = _obj.drop_indexes(variable_crs.keys())
377+
_obj = _obj.drop_indexes(variable_crs_solved.keys())
378378

379-
for key, crs in variable_crs.items():
379+
for key, crs in variable_crs_solved.items():
380380
if crs:
381381
_obj[key].attrs["crs"] = CRS.from_user_input(crs)
382-
_obj = _obj.set_xindex(key, GeometryIndex, crs=crs)
382+
_obj = _obj.set_xindex([key], GeometryIndex, crs=crs)
383383

384384
return _obj
385385

386386
def set_crs(
387387
self,
388388
variable_crs: Mapping[Any, Any] | None = None,
389-
allow_override=False,
389+
allow_override: bool = False,
390390
**variable_crs_kwargs: Any,
391-
):
391+
) -> xr.DataArray | xr.Dataset:
392392
"""Set the Coordinate Reference System (CRS) of coordinates backed by
393393
:class:`~xvec.GeometryIndex`.
394394
@@ -480,27 +480,21 @@ def set_crs(
480480
transform the geometries to a new CRS, use the :meth:`to_crs`
481481
method.
482482
"""
483-
484-
if variable_crs and variable_crs_kwargs:
485-
raise ValueError(
486-
"Cannot specify both keyword and positional arguments to "
487-
".xvec.set_crs."
488-
)
483+
variable_crs_solved = _resolve_input(
484+
variable_crs, variable_crs_kwargs, "set_crs"
485+
)
489486

490487
_obj = self._obj.copy(deep=False)
491488

492-
if variable_crs_kwargs:
493-
variable_crs = variable_crs_kwargs
494-
495-
for key, crs in variable_crs.items():
489+
for key, crs in variable_crs_solved.items():
496490
if not isinstance(self._obj.xindexes[key], GeometryIndex):
497491
raise ValueError(
498492
f"The index '{key}' is not an xvec.GeometryIndex. "
499493
"Set the xvec.GeometryIndex using '.xvec.set_geom_indexes' before "
500494
"handling projection information."
501495
)
502496

503-
data_crs = self._obj.xindexes[key].crs
497+
data_crs = self._obj.xindexes[key].crs # type: ignore
504498

505499
if not allow_override and data_crs is not None and not data_crs == crs:
506500
raise ValueError(
@@ -510,23 +504,23 @@ def set_crs(
510504
"want to transform the geometries, use '.xvec.to_crs' instead."
511505
)
512506

513-
_obj = _obj.drop_indexes(variable_crs.keys())
507+
_obj = _obj.drop_indexes(variable_crs_solved.keys())
514508

515-
for key, crs in variable_crs.items():
509+
for key, crs in variable_crs_solved.items():
516510
if crs:
517511
_obj[key].attrs["crs"] = CRS.from_user_input(crs)
518-
_obj = _obj.set_xindex(key, GeometryIndex, crs=crs)
512+
_obj = _obj.set_xindex([key], GeometryIndex, crs=crs)
519513

520514
return _obj
521515

522516
def query(
523517
self,
524518
coord_name: str,
525519
geometry: shapely.Geometry | Sequence[shapely.Geometry],
526-
predicate: str = None,
527-
distance: float | Sequence[float] = None,
528-
unique=False,
529-
):
520+
predicate: str | None = None,
521+
distance: float | Sequence[float] | None = None,
522+
unique: bool = False,
523+
) -> xr.DataArray | xr.Dataset:
530524
"""Return a subset of a DataArray/Dataset filtered using a spatial query on
531525
:class:`~xvec.GeometryIndex`.
532526
@@ -619,12 +613,12 @@ def query(
619613
620614
"""
621615
if isinstance(geometry, shapely.Geometry):
622-
ilocs = self._obj.xindexes[coord_name].sindex.query(
616+
ilocs = self._obj.xindexes[coord_name].sindex.query( # type: ignore
623617
geometry, predicate=predicate, distance=distance
624618
)
625619

626620
else:
627-
_, ilocs = self._obj.xindexes[coord_name].sindex.query(
621+
_, ilocs = self._obj.xindexes[coord_name].sindex.query( # type: ignore
628622
geometry, predicate=predicate, distance=distance
629623
)
630624
if unique:
@@ -634,11 +628,11 @@ def query(
634628

635629
def set_geom_indexes(
636630
self,
637-
coord_names: str | Sequence[Hashable],
631+
coord_names: str | Sequence[str],
638632
crs: Any = None,
639633
allow_override: bool = False,
640-
**kwargs,
641-
):
634+
**kwargs: dict[str, Any],
635+
) -> xr.DataArray | xr.Dataset:
642636
"""Set a new :class:`~xvec.GeometryIndex` for one or more existing
643637
coordinate(s). One :class:`~xvec.GeometryIndex` is set per coordinate. Only
644638
1-dimensional coordinates are supported.
@@ -691,7 +685,7 @@ def set_geom_indexes(
691685

692686
for coord in coord_names:
693687
if isinstance(self._obj.xindexes[coord], GeometryIndex):
694-
data_crs = self._obj.xindexes[coord].crs
688+
data_crs = self._obj.xindexes[coord].crs # type: ignore
695689

696690
if not allow_override and data_crs is not None and not data_crs == crs:
697691
raise ValueError(
@@ -710,7 +704,7 @@ def set_geom_indexes(
710704

711705
return _obj
712706

713-
def to_geopandas(self):
707+
def to_geopandas(self) -> GeoDataFrame | pd.DataFrame:
714708
"""Convert this array into a GeoPandas :class:`~geopandas.GeoDataFrame`
715709
716710
Returns a :class:`~geopandas.GeoDataFrame` with coordinates based on a
@@ -762,11 +756,11 @@ def to_geopandas(self):
762756
if len(self._geom_indexes):
763757
if self._obj.ndim == 1:
764758
gdf = self._obj.to_pandas()
765-
elif self._obj.ndim == 2:
759+
else:
766760
gdf = self._obj.to_pandas()
767761
if gdf.columns.name == self._geom_indexes[0]:
768762
gdf = gdf.T
769-
return gdf.reset_index().set_geometry(
763+
return gdf.reset_index().set_geometry( # type: ignore
770764
self._geom_indexes[0],
771765
crs=self._obj.xindexes[self._geom_indexes[0]].crs,
772766
)
@@ -790,7 +784,7 @@ def to_geopandas(self):
790784
if index_name in self._geom_coords_all:
791785
return gdf.reset_index().set_geometry(
792786
index_name, crs=self._obj[index_name].attrs.get("crs", None)
793-
)
787+
) # type: ignore
794788

795789
warnings.warn(
796790
"No active geometry column to be set. The resulting object "
@@ -810,7 +804,7 @@ def to_geodataframe(
810804
dim_order: Sequence[Hashable] | None = None,
811805
geometry: Hashable | None = None,
812806
long: bool = True,
813-
):
807+
) -> GeoDataFrame | pd.DataFrame:
814808
"""Convert this array and its coordinates into a tidy geopandas.GeoDataFrame.
815809
816810
The GeoDataFrame is indexed by the Cartesian product of index coordinates
@@ -884,7 +878,7 @@ def to_geodataframe(
884878
level
885879
for level in df.index.names
886880
if level not in self._geom_coords_all
887-
]
881+
] # type: ignore
888882
)
889883

890884
if isinstance(df.index, pd.MultiIndex):
@@ -907,7 +901,7 @@ def to_geodataframe(
907901
if geometry is not None:
908902
return df.set_geometry(
909903
geometry, crs=self._obj[geometry].attrs.get("crs", None)
910-
)
904+
) # type: ignore
911905

912906
warnings.warn(
913907
"No active geometry column to be set. The resulting object "
@@ -926,12 +920,12 @@ def zonal_stats(
926920
y_coords: Hashable,
927921
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
928922
name: Hashable = "geometry",
929-
index: bool = None,
923+
index: bool | None = None,
930924
method: str = "rasterize",
931925
all_touched: bool = False,
932926
n_jobs: int = -1,
933-
**kwargs,
934-
):
927+
**kwargs: dict[str, Any],
928+
) -> xr.DataArray | xr.Dataset:
935929
"""Extract the values from a dataset indexed by a set of geometries
936930
937931
Given an object indexed by x and y coordinates (or latitude and longitude), such
@@ -1121,9 +1115,9 @@ def extract_points(
11211115
y_coords: Hashable,
11221116
tolerance: float | None = None,
11231117
name: str = "geometry",
1124-
crs: Any = None,
1125-
index: bool = None,
1126-
):
1118+
crs: Any | None = None,
1119+
index: bool | None = None,
1120+
) -> xr.DataArray | xr.Dataset:
11271121
"""Extract points from a DataArray or a Dataset indexed by spatial coordinates
11281122
11291123
Given an object indexed by x and y coordinates (or latitude and longitude), such
@@ -1263,3 +1257,22 @@ def extract_points(
12631257
}
12641258
)
12651259
return result
1260+
1261+
1262+
def _resolve_input(
1263+
positional: Mapping[Any, Any] | None,
1264+
keyword: Mapping[str, Any],
1265+
func_name: str,
1266+
) -> Mapping[Hashable, Any]:
1267+
"""Resolve combination of positional and keyword arguments.
1268+
1269+
Based on xarray's ``either_dict_or_kwargs``.
1270+
"""
1271+
if positional and keyword:
1272+
raise ValueError(
1273+
"Cannot specify both keyword and positional arguments to "
1274+
f"'.xvec.{func_name}'."
1275+
)
1276+
if positional is None or positional == {}:
1277+
return cast(Mapping[Hashable, Any], keyword)
1278+
return positional

0 commit comments

Comments
 (0)