Skip to content

Commit ed51c19

Browse files
Some typing updates (#208)
* Some typing updates * Little more typing * Introduce TypedDict for Aggregation.dtype * Cleanup * Upgrade types. * Try with typing_extensions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert "Try with typing_extensions" This reverts commit 21983a5. * Guard with TYPE_CHECKING Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c1358e7 commit ed51c19

File tree

5 files changed

+81
-50
lines changed

5 files changed

+81
-50
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,9 @@ repos:
4646
hooks:
4747
- id: nbstripout
4848
args: [--extra-keys=metadata.kernelspec metadata.language_info.version]
49+
- repo: https://github.com/asottile/pyupgrade
50+
rev: v3.3.1
51+
hooks:
52+
- id: pyupgrade
53+
args:
54+
- "--py38-plus"

docs/source/conf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
#
32
# complexity documentation build configuration file, created by
43
# sphinx-quickstart on Tue Jul 9 22:26:36 2013.

flox/aggregations.py

Lines changed: 61 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,17 @@
33
import copy
44
import warnings
55
from functools import partial
6+
from typing import TYPE_CHECKING, Any, Callable, TypedDict
67

78
import numpy as np
89
import numpy_groupies as npg
10+
from numpy.typing import DTypeLike
911

1012
from . import aggregate_flox, aggregate_npg, xrdtypes as dtypes, xrutils
1113

14+
if TYPE_CHECKING:
15+
FuncTuple = tuple[Callable | str, ...]
16+
1217

1318
def _is_arg_reduction(func: str | Aggregation) -> bool:
1419
if isinstance(func, str) and func in ["argmin", "argmax", "nanargmax", "nanargmin"]:
@@ -18,6 +23,17 @@ def _is_arg_reduction(func: str | Aggregation) -> bool:
1823
return False
1924

2025

26+
class AggDtypeInit(TypedDict):
27+
final: DTypeLike | None
28+
intermediate: tuple[DTypeLike, ...]
29+
30+
31+
class AggDtype(TypedDict):
32+
final: np.dtype
33+
numpy: tuple[np.dtype | type[np.intp], ...]
34+
intermediate: tuple[np.dtype | type[np.intp], ...]
35+
36+
2137
def generic_aggregate(
2238
group_idx,
2339
array,
@@ -57,7 +73,7 @@ def generic_aggregate(
5773
return result
5874

5975

60-
def _normalize_dtype(dtype, array_dtype, fill_value=None):
76+
def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
6177
if dtype is None:
6278
dtype = array_dtype
6379
if dtype is np.floating:
@@ -103,16 +119,16 @@ def __init__(
103119
self,
104120
name,
105121
*,
106-
numpy=None,
107-
chunk,
108-
combine,
109-
preprocess=None,
110-
aggregate=None,
111-
finalize=None,
122+
numpy: str | FuncTuple | None = None,
123+
chunk: str | FuncTuple | None,
124+
combine: str | FuncTuple | None,
125+
preprocess: Callable | None = None,
126+
aggregate: Callable | None = None,
127+
finalize: Callable | None = None,
112128
fill_value=None,
113129
final_fill_value=dtypes.NA,
114130
dtypes=None,
115-
final_dtype=None,
131+
final_dtype: DTypeLike | None = None,
116132
reduction_type="reduce",
117133
):
118134
"""
@@ -162,15 +178,15 @@ def __init__(
162178
self.preprocess = preprocess
163179
# Use "chunk_reduce" or "chunk_argreduce"
164180
self.reduction_type = reduction_type
165-
self.numpy = (numpy,) if numpy else (self.name,)
181+
self.numpy: FuncTuple = (numpy,) if numpy else (self.name,)
166182
# initialize blockwise reduction
167-
self.chunk = _atleast_1d(chunk)
183+
self.chunk: FuncTuple = _atleast_1d(chunk)
168184
# how to aggregate results after first round of reduction
169-
self.combine = _atleast_1d(combine)
185+
self.combine: FuncTuple = _atleast_1d(combine)
170186
# final aggregation
171-
self.aggregate = aggregate if aggregate else self.combine[0]
187+
self.aggregate: Callable | str = aggregate if aggregate else self.combine[0]
172188
# finalize results (see mean)
173-
self.finalize = finalize if finalize else lambda x: x
189+
self.finalize: Callable | None = finalize
174190

175191
self.fill_value = {}
176192
# This is used for the final reindexing
@@ -180,13 +196,15 @@ def __init__(
180196
# They should make sense when aggregated together with results from other blocks
181197
self.fill_value["intermediate"] = self._normalize_dtype_fill_value(fill_value, "fill_value")
182198

183-
self.dtype = {}
184-
self.dtype[name] = final_dtype
185-
self.dtype["intermediate"] = self._normalize_dtype_fill_value(dtypes, "dtype")
199+
self.dtype_init: AggDtypeInit = {
200+
"final": final_dtype,
201+
"intermediate": self._normalize_dtype_fill_value(dtypes, "dtype"),
202+
}
203+
self.dtype: AggDtype = None # type: ignore
186204

187205
# The following are set by _initialize_aggregation
188-
self.finalize_kwargs = {}
189-
self.min_count = None
206+
self.finalize_kwargs: dict[Any, Any] = {}
207+
self.min_count: int | None = None
190208

191209
def _normalize_dtype_fill_value(self, value, name):
192210
value = _atleast_1d(value)
@@ -211,15 +229,15 @@ def __dask_tokenize__(self):
211229
self.dtype,
212230
)
213231

214-
def __repr__(self):
232+
def __repr__(self) -> str:
215233
return "\n".join(
216234
(
217-
f"{self.name}, fill: {np.unique(self.fill_value.values())}, dtype: {self.dtype}",
218-
f"chunk: {self.chunk}",
219-
f"combine: {self.combine}",
220-
f"aggregate: {self.aggregate}",
221-
f"finalize: {self.finalize}",
222-
f"min_count: {self.min_count}",
235+
f"{self.name!r}, fill: {self.fill_value.values()!r}, dtype: {self.dtype}",
236+
f"chunk: {self.chunk!r}",
237+
f"combine: {self.combine!r}",
238+
f"aggregate: {self.aggregate!r}",
239+
f"finalize: {self.finalize!r}",
240+
f"min_count: {self.min_count!r}",
223241
)
224242
)
225243

@@ -484,7 +502,7 @@ def _initialize_aggregation(
484502
array_dtype,
485503
fill_value,
486504
min_count: int | None,
487-
finalize_kwargs,
505+
finalize_kwargs: dict[Any, Any] | None,
488506
) -> Aggregation:
489507
if not isinstance(func, Aggregation):
490508
try:
@@ -502,24 +520,30 @@ def _initialize_aggregation(
502520

503521
# np.dtype(None) == np.dtype("float64")!!!
504522
# so check for not None
505-
if dtype is not None and not isinstance(dtype, np.dtype):
506-
dtype = np.dtype(dtype)
523+
dtype_: np.dtype | None = (
524+
np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype
525+
)
507526

508-
agg.dtype[func] = _normalize_dtype(dtype or agg.dtype[func], array_dtype, fill_value)
509-
agg.dtype["numpy"] = (agg.dtype[func],)
510-
agg.dtype["intermediate"] = [
511-
_normalize_dtype(int_dtype, np.result_type(array_dtype, agg.dtype[func]), int_fv)
512-
if int_dtype is None
513-
else int_dtype
514-
for int_dtype, int_fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"])
515-
]
527+
final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value)
528+
agg.dtype = {
529+
"final": final_dtype,
530+
"numpy": (final_dtype,),
531+
"intermediate": tuple(
532+
_normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv)
533+
if int_dtype is None
534+
else np.dtype(int_dtype)
535+
for int_dtype, int_fv in zip(
536+
agg.dtype_init["intermediate"], agg.fill_value["intermediate"]
537+
)
538+
),
539+
}
516540

517541
# Replace sentinel fill values according to dtype
518542
agg.fill_value["intermediate"] = tuple(
519543
_get_fill_value(dt, fv)
520544
for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"])
521545
)
522-
agg.fill_value[func] = _get_fill_value(agg.dtype[func], agg.fill_value[func])
546+
agg.fill_value[func] = _get_fill_value(agg.dtype["final"], agg.fill_value[func])
523547

524548
fv = fill_value if fill_value is not None else agg.fill_value[agg.name]
525549
if _is_arg_reduction(agg):

flox/core.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ def _finalize_results(
807807
else:
808808
finalized["groups"] = squeezed["groups"]
809809

810-
finalized[agg.name] = finalized[agg.name].astype(agg.dtype[agg.name], copy=False)
810+
finalized[agg.name] = finalized[agg.name].astype(agg.dtype["final"], copy=False)
811811
return finalized
812812

813813

@@ -884,6 +884,7 @@ def _simple_combine(
884884
assert array.ndim >= 2
885885
with warnings.catch_warnings():
886886
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
887+
assert isinstance(combine, str)
887888
result = getattr(np, combine)(array, axis=axis_, keepdims=True)
888889
if is_aggregate:
889890
# squeeze out DUMMY_AXIS if this is the last step i.e. called from _aggregate
@@ -1015,7 +1016,7 @@ def _grouped_combine(
10151016
if array.shape[-1] == 0:
10161017
# all empty when combined
10171018
results["intermediates"].append(
1018-
np.empty(shape=(1,) * (len(axis) - 1) + (0,), dtype=agg.dtype)
1019+
np.empty(shape=(1,) * (len(axis) - 1) + (0,), dtype=dtype)
10191020
)
10201021
results["groups"] = np.empty(
10211022
shape=(1,) * (len(neg_axis) - 1) + (0,), dtype=groups.dtype
@@ -1059,10 +1060,11 @@ def _reduce_blockwise(
10591060
agg.finalize = None
10601061

10611062
assert agg.finalize_kwargs is not None
1062-
finalize_kwargs = agg.finalize_kwargs
1063-
if isinstance(finalize_kwargs, Mapping):
1064-
finalize_kwargs = (finalize_kwargs,)
1065-
finalize_kwargs = finalize_kwargs + ({},) + ({},)
1063+
if isinstance(agg.finalize_kwargs, Mapping):
1064+
finalize_kwargs_: tuple[dict[Any, Any], ...] = (agg.finalize_kwargs,)
1065+
else:
1066+
finalize_kwargs_ = agg.finalize_kwargs
1067+
finalize_kwargs_ += ({},) + ({},)
10661068

10671069
results = chunk_reduce(
10681070
array,
@@ -1075,7 +1077,7 @@ def _reduce_blockwise(
10751077
# (see below)
10761078
fill_value=agg.fill_value["numpy"],
10771079
dtype=agg.dtype["numpy"],
1078-
kwargs=finalize_kwargs,
1080+
kwargs=finalize_kwargs_,
10791081
engine=engine,
10801082
sort=sort,
10811083
reindex=reindex,
@@ -1102,7 +1104,7 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple:
11021104
"""
11031105
unraveled = np.unravel_index(flatblocks, blkshape)
11041106

1105-
normalized: list[Union[int, slice, list[int]]] = []
1107+
normalized: list[int | slice | list[int]] = []
11061108
for ax, idx in enumerate(unraveled):
11071109
i = _unique(idx).squeeze()
11081110
if i.ndim == 0:
@@ -1303,7 +1305,7 @@ def dask_groupby_agg(
13031305
name=f"{name}-chunk-{token}",
13041306
)
13051307

1306-
group_chunks: tuple[tuple[Union[int, float], ...]]
1308+
group_chunks: tuple[tuple[int | float, ...]]
13071309

13081310
if method in ["map-reduce", "cohorts"]:
13091311
combine: Callable[..., IntermediateDict]
@@ -1402,7 +1404,7 @@ def dask_groupby_agg(
14021404
reduced,
14031405
inds,
14041406
adjust_chunks=dict(zip(out_inds, output_chunks)),
1405-
dtype=agg.dtype[agg.name],
1407+
dtype=agg.dtype["final"],
14061408
key=agg.name,
14071409
name=f"{name}-{token}",
14081410
concatenate=False,
@@ -1600,7 +1602,7 @@ def groupby_reduce(
16001602
method: T_Method = "map-reduce",
16011603
engine: T_Engine = "numpy",
16021604
reindex: bool | None = None,
1603-
finalize_kwargs: Mapping | None = None,
1605+
finalize_kwargs: dict[Any, Any] | None = None,
16041606
) -> tuple[DaskArray, np.ndarray | DaskArray]:
16051607
"""
16061608
GroupBy reductions using tree reductions for dask.array

flox/xarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def xarray_reduce(
223223
raise NotImplementedError("sort must be True for xarray_reduce")
224224

225225
# eventually drop the variables we are grouping by
226-
maybe_drop = set(b for b in by if isinstance(b, Hashable))
226+
maybe_drop = {b for b in by if isinstance(b, Hashable)}
227227
unindexed_dims = tuple(
228228
b
229229
for b, isbin_ in zip(by, isbins)

0 commit comments

Comments
 (0)