Skip to content

Commit 7deb84a

Browse files
committed
cleanup
1 parent 177b8de commit 7deb84a

File tree

5 files changed

+101
-120
lines changed

5 files changed

+101
-120
lines changed

flox/aggregate_flox.py

Lines changed: 0 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,110 +1,10 @@
11
from functools import partial
2-
from typing import Self
32

43
import numpy as np
54

65
from . import xrdtypes as dtypes
76
from .xrutils import is_scalar, isnull, notnull
87

9-
MULTIARRAY_HANDLED_FUNCTIONS = {}
10-
11-
12-
class MultiArray:
13-
arrays: tuple[np.ndarray, ...]
14-
15-
def __init__(self, arrays):
16-
self.arrays = arrays # something else needed here to be more careful about types (not sure what)
17-
# Do we want to co-erce arrays into a tuple and make sure it's immutable? Do we want it to be immutable?
18-
assert all(arrays[0].shape == a.shape for a in arrays), "Expect all arrays to have the same shape"
19-
20-
def astype(self, dt, **kwargs):
21-
return MultiArray(tuple(array.astype(dt, **kwargs) for array in self.arrays))
22-
23-
def reshape(self, shape, **kwargs):
24-
return MultiArray(tuple(array.reshape(shape, **kwargs) for array in self.arrays))
25-
26-
def squeeze(self, axis=None):
27-
return MultiArray(tuple(array.squeeze(axis) for array in self.arrays))
28-
29-
def __setitem__(self, key, value):
30-
assert len(value) == len(self.arrays)
31-
for array, val in zip(self.arrays, value):
32-
array[key] = val
33-
34-
def __array_function__(self, func, types, args, kwargs):
35-
if func not in MULTIARRAY_HANDLED_FUNCTIONS:
36-
return NotImplemented
37-
# Note: this allows subclasses that don't override
38-
# __array_function__ to handle MyArray objects
39-
# if not all(issubclass(t, MyArray) for t in types): # I can't see this being relevant at all for this code, but maybe it's safer to leave it in?
40-
# return NotImplemented
41-
return MULTIARRAY_HANDLED_FUNCTIONS[func](*args, **kwargs)
42-
43-
# Shape is needed, seems likely that the other two might be
44-
# Making some strong assumptions here that all the arrays are the same shape, and I don't really like this
45-
@property
46-
def dtype(self) -> np.dtype:
47-
return self.arrays[0].dtype
48-
49-
@property
50-
def shape(self) -> tuple[int, ...]:
51-
return self.arrays[0].shape
52-
53-
@property
54-
def ndim(self) -> int:
55-
return self.arrays[0].ndim
56-
57-
def __getitem__(self, key) -> Self:
58-
return type(self)([array[key] for array in self.arrays])
59-
60-
61-
def implements(numpy_function):
62-
"""Register an __array_function__ implementation for MyArray objects."""
63-
64-
def decorator(func):
65-
MULTIARRAY_HANDLED_FUNCTIONS[numpy_function] = func
66-
return func
67-
68-
return decorator
69-
70-
71-
@implements(np.expand_dims)
72-
def expand_dims_MultiArray(multiarray, axis):
73-
return MultiArray(tuple(np.expand_dims(a, axis) for a in multiarray.arrays))
74-
75-
76-
@implements(np.concatenate)
77-
def concatenate_MultiArray(multiarrays, axis):
78-
n_arrays = len(multiarrays[0].arrays)
79-
for ma in multiarrays[1:]:
80-
assert len(ma.arrays) == n_arrays
81-
return MultiArray(
82-
tuple(np.concatenate(tuple(ma.arrays[i] for ma in multiarrays), axis) for i in range(n_arrays))
83-
)
84-
85-
86-
@implements(np.transpose)
87-
def transpose_MultiArray(multiarray, axes):
88-
return MultiArray(tuple(np.transpose(a, axes) for a in multiarray.arrays))
89-
90-
91-
@implements(np.full)
92-
def full_MultiArray(
93-
shape, fill_values, *args, **kwargs
94-
): # I've used *args, **kwargs instead of the full argument list to give us more flexibility if numpy changes stuff https://numpy.org/doc/stable/reference/generated/numpy.full.html
95-
"""All arguments except fill_value are shared by each array
96-
in the MultiArray.
97-
Iterate over fill_values to create arrays
98-
"""
99-
return MultiArray(
100-
tuple(
101-
np.full(
102-
shape, fv, *args, **kwargs
103-
) # I'm 90% sure I've used *args, **kwargs correctly here -- could you double-check?
104-
for fv in fill_values
105-
)
106-
)
107-
1088

1099
def _prepare_for_flox(group_idx, array):
11010
"""

flox/aggregations.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from . import aggregate_flox, aggregate_npg, xrutils
1616
from . import xrdtypes as dtypes
1717
from .lib import dask_array_type, sparse_array_type
18+
from .multiarray import MultiArray
1819

1920
if TYPE_CHECKING:
2021
FuncTuple = tuple[Callable | str, ...]
@@ -346,8 +347,6 @@ def _mean_finalize(sum_, count):
346347
def var_chunk(
347348
group_idx, array, *, skipna: bool, engine: str, axis=-1, size=None, fill_value=None, dtype=None
348349
):
349-
from .aggregate_flox import MultiArray
350-
351350
# Calculate length and sum - important for the adjustment terms to sum squared deviations
352351
array_lens = generic_aggregate(
353352
group_idx,
@@ -432,22 +431,14 @@ def clip_first(array, n=1):
432431
"Instances where we add something to the denominator must come out to zero"
433432
)
434433

435-
return aggregate_flox.MultiArray(
434+
return MultiArray(
436435
(
437436
np.sum(sum_deviations, axis=axis, keepdims=keepdims)
438437
+ np.sum(adj_terms, axis=axis, keepdims=keepdims), # sum of squared deviations
439438
np.sum(sum_X, axis=axis, keepdims=keepdims), # sum of array items
440439
np.sum(sum_len, axis=axis, keepdims=keepdims), # sum of array lengths
441440
)
442-
) # I'm not even pretending calling this class from there is a good idea, I think it wants to be somewhere else though
443-
444-
445-
# TODO: fix this for complex numbers
446-
# def _var_finalize(sumsq, sum_, count, ddof=0):
447-
# with np.errstate(invalid="ignore", divide="ignore"):
448-
# result = (sumsq - (sum_**2 / count)) / (count - ddof)
449-
# result[count <= ddof] = np.nan
450-
# return result
441+
)
451442

452443

453444
def is_var_chunk_reduction(agg: Callable) -> bool:

flox/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2506,7 +2506,7 @@ def _choose_engine(by, agg: Aggregation):
25062506

25072507
not_arg_reduce = not _is_arg_reduction(agg)
25082508

2509-
if agg.name in ["quantile", "nanquantile", "median", "nanmedian"]:
2509+
if agg.name in ["quantile", "nanquantile", "median", "nanmedian", "var", "nanvar", "std", "nanstd"]:
25102510
logger.debug(f"_choose_engine: Choosing 'flox' since {agg.name}")
25112511
return "flox"
25122512

flox/multiarray.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from typing import Self
2+
3+
import numpy as np
4+
5+
MULTIARRAY_HANDLED_FUNCTIONS = {}
6+
7+
8+
class MultiArray:
9+
arrays: tuple[np.ndarray, ...]
10+
11+
def __init__(self, arrays):
12+
self.arrays = arrays
13+
assert all(arrays[0].shape == a.shape for a in arrays), "Expect all arrays to have the same shape"
14+
15+
def astype(self, dt, **kwargs) -> Self:
16+
return type(self)(tuple(array.astype(dt, **kwargs) for array in self.arrays))
17+
18+
def reshape(self, shape, **kwargs) -> Self:
19+
return type(self)(tuple(array.reshape(shape, **kwargs) for array in self.arrays))
20+
21+
def squeeze(self, axis=None) -> Self:
22+
return type(self)(tuple(array.squeeze(axis) for array in self.arrays))
23+
24+
def __setitem__(self, key, value) -> None:
25+
assert len(value) == len(self.arrays)
26+
for array, val in zip(self.arrays, value):
27+
array[key] = val
28+
29+
def __array_function__(self, func, types, args, kwargs):
30+
if func not in MULTIARRAY_HANDLED_FUNCTIONS:
31+
return NotImplemented
32+
# Note: this allows subclasses that don't override
33+
# __array_function__ to handle MyArray objects
34+
# if not all(issubclass(t, MyArray) for t in types): # I can't see this being relevant at all for this code, but maybe it's safer to leave it in?
35+
# return NotImplemented
36+
return MULTIARRAY_HANDLED_FUNCTIONS[func](*args, **kwargs)
37+
38+
# Shape is needed, seems likely that the other two might be
39+
# Making some strong assumptions here that all the arrays are the same shape, and I don't really like this
40+
@property
41+
def dtype(self) -> np.dtype:
42+
return self.arrays[0].dtype
43+
44+
@property
45+
def shape(self) -> tuple[int, ...]:
46+
return self.arrays[0].shape
47+
48+
@property
49+
def ndim(self) -> int:
50+
return self.arrays[0].ndim
51+
52+
def __getitem__(self, key) -> Self:
53+
return type(self)([array[key] for array in self.arrays])
54+
55+
56+
def implements(numpy_function):
57+
"""Register an __array_function__ implementation for MyArray objects."""
58+
59+
def decorator(func):
60+
MULTIARRAY_HANDLED_FUNCTIONS[numpy_function] = func
61+
return func
62+
63+
return decorator
64+
65+
66+
@implements(np.expand_dims)
67+
def expand_dims(multiarray, axis) -> MultiArray:
68+
return MultiArray(tuple(np.expand_dims(a, axis) for a in multiarray.arrays))
69+
70+
71+
@implements(np.concatenate)
72+
def concatenate(multiarrays, axis) -> MultiArray:
73+
n_arrays = len(multiarrays[0].arrays)
74+
for ma in multiarrays[1:]:
75+
assert len(ma.arrays) == n_arrays
76+
return MultiArray(
77+
tuple(np.concatenate(tuple(ma.arrays[i] for ma in multiarrays), axis) for i in range(n_arrays))
78+
)
79+
80+
81+
@implements(np.transpose)
82+
def transpose(multiarray, axes) -> MultiArray:
83+
return MultiArray(tuple(np.transpose(a, axes) for a in multiarray.arrays))
84+
85+
86+
@implements(np.full)
87+
def full(shape, fill_values, *args, **kwargs) -> MultiArray:
88+
"""All arguments except fill_value are shared by each array in the MultiArray.
89+
Iterate over fill_values to create arrays
90+
"""
91+
return MultiArray(tuple(np.full(shape, fv, *args, **kwargs) for fv in fill_values))

tests/test_core.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def gen_array_by(size, func):
236236
@pytest.mark.parametrize("size", [(1, 12), (12,), (12, 9)])
237237
@pytest.mark.parametrize("nby", [1, 2, 3])
238238
@pytest.mark.parametrize("add_nan_by", [True, False])
239-
@pytest.mark.parametrize("func", ["var", "nanvar", "std", "nanstd"])
239+
@pytest.mark.parametrize("func", ALL_FUNCS)
240240
def test_groupby_reduce_all(to_sparse, nby, size, chunks, func, add_nan_by, engine):
241241
if ("arg" in func and engine in ["flox", "numbagg"]) or (func in BLOCKWISE_FUNCS and chunks != -1):
242242
pytest.skip()
@@ -2242,13 +2242,12 @@ def test_sparse_nan_fill_value_reductions(chunks, fill_value, shape, func):
22422242
assert_equal(actual, expected)
22432243

22442244

2245+
@pytest.mark.parametrize("func", ("nanvar", "var"))
22452246
@pytest.mark.parametrize(
2246-
"func", ("nanvar", "var")
2247-
) # Expect to expand this to other functions once written. "nanvar" has updated chunk, combine functions. "var", for the moment, still uses the old algorithm
2248-
@pytest.mark.parametrize("engine", ("flox",)) # Expect to expand this to other engines once written
2249-
@pytest.mark.parametrize(
2250-
"exponent", (2, 4, 6, 8, 10, 12)
2251-
) # Should fail at 10e8 for old algorithm, and survive 10e12 for current
2247+
# Should fail at 10e8 for old algorithm, and survive 10e12 for current
2248+
"exponent",
2249+
(2, 4, 6, 8, 10, 12),
2250+
)
22522251
def test_std_var_precision(func, exponent, engine):
22532252
# Generate a dataset with small variance and big mean
22542253
# Check that func with engine gives you the same answer as numpy

0 commit comments

Comments
 (0)