Skip to content

Commit e657795

Browse files
dcherianclaude
andcommitted
Move _postprocess_numbagg to aggregate_numbagg.py
- Relocates _postprocess_numbagg function from core.py to aggregate_numbagg.py - Updates import in core.py to use the new location - Groups numbagg-specific functionality together for better organization 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 13b50f8 commit e657795

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

flox/aggregate_numbagg.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,23 @@ def nanlen(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None)
136136
any = partial(_numbagg_wrapper, func="nanany")
137137
all = partial(_numbagg_wrapper, func="nanall")
138138

139+
140+
def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
141+
"""Account for numbagg not providing a fill_value kwarg."""
142+
if not isinstance(func, str) or func not in DEFAULT_FILL_VALUE:
143+
return result
144+
# The condition needs to be
145+
# len(found_groups) < size; if so we mask with fill_value (?)
146+
default_fv = DEFAULT_FILL_VALUE[func]
147+
needs_masking = fill_value is not None and not np.array_equal(fill_value, default_fv, equal_nan=True)
148+
groups = np.arange(size)
149+
if needs_masking:
150+
mask = np.isin(groups, seen_groups, assume_unique=True, invert=True)
151+
if mask.any():
152+
result[..., groups[mask]] = fill_value
153+
return result
154+
155+
139156
# sum = nansum
140157
# mean = nanmean
141158
# sum_of_squares = nansum_of_squares

flox/core.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -143,24 +143,6 @@ def get_dask_meta(self, other, *, fill_value, dtype) -> Any:
143143
return sparse.COO.from_numpy(np.ones(shape=(0,) * other.ndim, dtype=dtype), fill_value=fill_value)
144144

145145

146-
def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
147-
"""Account for numbagg not providing a fill_value kwarg."""
148-
from .aggregate_numbagg import DEFAULT_FILL_VALUE
149-
150-
if not isinstance(func, str) or func not in DEFAULT_FILL_VALUE:
151-
return result
152-
# The condition needs to be
153-
# len(found_groups) < size; if so we mask with fill_value (?)
154-
default_fv = DEFAULT_FILL_VALUE[func]
155-
needs_masking = fill_value is not None and not np.array_equal(fill_value, default_fv, equal_nan=True)
156-
groups = np.arange(size)
157-
if needs_masking:
158-
mask = np.isin(groups, seen_groups, assume_unique=True, invert=True)
159-
if mask.any():
160-
result[..., groups[mask]] = fill_value
161-
return result
162-
163-
164146
def identity(x: T) -> T:
165147
return x
166148

@@ -901,6 +883,8 @@ def chunk_reduce(
901883
group_idx, array, axis=-1, engine=engine, func=reduction, **kw_func
902884
).astype(dt, copy=False)
903885
if engine == "numbagg":
886+
from .aggregate_numbagg import _postprocess_numbagg
887+
904888
result = _postprocess_numbagg(
905889
result,
906890
func=reduction,

0 commit comments

Comments
 (0)