Skip to content

Commit d9db680

Browse files
authored
Use np.digitize instead of pd.cut (#217)
* Use np.digitize instead of pd.cut * Update core.py * Update core.py * fix datetime issues * handle dtype issue * Update core.py * handle nans by reversing the check * test without to_numeric, dtype checks should fail now * ignore dtype * Add benchmark
1 parent 2c98760 commit d9db680

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

flox/core.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,8 +468,26 @@ def factorize_(
468468
bins = np.concatenate([expect.left.to_numpy(), [expect.right[-1].to_numpy()]])
469469
else:
470470
bins = np.concatenate([expect.left.to_numpy(), [expect.right[-1]]])
471+
471472
# code is -1 for values outside the bounds of all intervals
472-
idx = pd.cut(flat, bins=bins, right=expect.closed_right).codes.copy()
473+
# idx = pd.cut(flat, bins=bins, right=expect.closed_right).codes.copy()
474+
475+
# digitize is 0 or idx.max() for values outside the bounds of all intervals
476+
# make it behave like pd.cut:
477+
if len(bins) > 1:
478+
right = expect.closed_right
479+
idx = np.digitize(
480+
flat,
481+
bins=bins.view(np.intp) if bins.dtype.kind == "M" else bins,
482+
right=right,
483+
)
484+
# idx = pd.to_numeric(idx, downcast="integer")
485+
idx -= 1
486+
within_bins = flat <= bins.max() if right else flat < bins.max()
487+
idx[~within_bins] = -1
488+
else:
489+
idx = np.zeros_like(flat, dtype=np.intp) - 1
490+
473491
found_groups.append(expect)
474492
else:
475493
if expect is not None and reindex:

tests/test_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,8 @@ def test_datetime_binning():
884884

885885
ret = factorize_((by.to_numpy(),), axis=0, expected_groups=(actual,))
886886
group_idx = ret[0]
887-
expected = pd.cut(by, time_bins).codes.copy()
887+
# Ignore pd.cut's dtype as it won't match np.digitize:
888+
expected = pd.cut(by, time_bins).codes.copy().astype(group_idx.dtype)
888889
expected[0] = 14 # factorize doesn't return -1 for nans
889890
assert_equal(group_idx, expected)
890891

0 commit comments

Comments
 (0)