Skip to content

Commit 23f1e49

Browse files
committed
sparse fixes
1 parent 75e37cd commit 23f1e49

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

flox/aggregate_sparse.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Unlike the other aggregate_* submodules, this one simply defines a wrapper function
22
# because we run the groupby on the underlying dense data.
33

4+
from collections.abc import Callable
45
from functools import partial
6+
from typing import Any, TypeAlias
57

68
import numpy as np
79
import sparse
@@ -18,19 +20,21 @@ def nanadd(a, b):
1820
1921
From https://stackoverflow.com/a/50642947/1707127
2022
"""
21-
return np.where(np.isnan(a + b), np.where(np.isnan(a), b, a), a + b)
23+
ab = a + b
24+
return np.where(np.isnan(ab), np.where(np.isnan(a), b, a), ab)
2225

2326

24-
BINARY_OPS = {
27+
CallableMap: TypeAlias = dict[str, Callable[[np.ndarray, np.ndarray], np.ndarray]]
28+
BINARY_OPS: CallableMap = {
2529
"sum": np.add,
2630
"nansum": nanadd,
2731
"max": np.maximum,
2832
"nanmax": np.fmax,
2933
"min": np.minimum,
3034
"nanmin": np.fmin,
3135
}
32-
HYPER_OPS = {"sum": np.multiply, "nansum": np.multiply}
33-
IDENTITY = {
36+
HYPER_OPS: CallableMap = {"sum": np.multiply, "nansum": np.multiply}
37+
IDENTITY: dict[str, Any] = {
3438
"sum": 0,
3539
"nansum": 0,
3640
"prod": 1,

0 commit comments

Comments
 (0)