Skip to content

Commit c387edc

Browse files
authored
Fix factorizing some more. (#115)
1 parent 5646179 commit c387edc

File tree

2 files changed

+69
-11
lines changed

2 files changed

+69
-11
lines changed

flox/core.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ def factorize_(
421421
factorized = []
422422
found_groups = []
423423
for groupvar, expect in zip(by, expected_groups):
424+
flat = groupvar.ravel()
424425
if isinstance(expect, pd.IntervalIndex):
425426
# when binning we change expected groups to integers marking the interval
426427
# this makes the reindexing logic simpler.
@@ -432,21 +433,19 @@ def factorize_(
432433
if groupvar.dtype.kind == "M":
433434
expect = np.concatenate([expect.left.to_numpy(), [expect.right[-1].to_numpy()]])
434435
# code is -1 for values outside the bounds of all intervals
435-
idx = pd.cut(groupvar.ravel(), bins=expect).codes.copy()
436+
idx = pd.cut(flat, bins=expect).codes.copy()
436437
else:
437438
if expect is not None and reindex:
438-
groups = expect
439+
sorter = np.argsort(expect)
440+
groups = expect[(sorter,)] if sort else expect
441+
idx = np.searchsorted(expect, flat, sorter=sorter)
442+
mask = ~np.isin(flat, expect) | isnull(flat) | (idx == len(expect))
439443
if not sort:
440-
sorter = np.argsort(expect)
441-
else:
442-
sorter = None
443-
idx = np.searchsorted(expect, groupvar.ravel(), sorter=sorter)
444-
mask = isnull(groupvar.ravel()) | (idx == len(expect))
445-
# TODO: optimize?
444+
# idx is the index in to the sorted array.
445+
# if we didn't want sorting, unsort it back
446+
idx[(idx == len(expect),)] = -1
447+
idx = sorter[(idx,)]
446448
idx[mask] = -1
447-
if not sort:
448-
idx = sorter[idx]
449-
idx[mask] = -1
450449
else:
451450
idx, groups = pd.factorize(groupvar.ravel(), sort=sort)
452451

tests/test_core.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,3 +905,62 @@ def test_factorize_values_outside_bins():
905905
actual = vals[0]
906906
expected = np.array([[-1, -1], [-1, 0], [6, 12], [18, 24], [-1, -1]])
907907
assert_equal(expected, actual)
908+
909+
910+
def test_multiple_groupers():
911+
actual, *_ = groupby_reduce(
912+
np.ones((5, 2)),
913+
np.arange(10).reshape(5, 2),
914+
np.arange(10).reshape(5, 2),
915+
axis=(0, 1),
916+
expected_groups=(
917+
pd.IntervalIndex.from_breaks(np.arange(2, 8, 1)),
918+
pd.IntervalIndex.from_breaks(np.arange(2, 8, 1)),
919+
),
920+
reindex=True,
921+
func="count",
922+
)
923+
expected = np.eye(5, 5)
924+
assert_equal(expected, actual)
925+
926+
927+
def test_factorize_reindex_sorting_strings():
928+
kwargs = dict(
929+
by=(np.array(["El-Nino", "La-Nina", "boo", "Neutral"]),),
930+
axis=-1,
931+
expected_groups=(np.array(["El-Nino", "Neutral", "foo", "La-Nina"]),),
932+
)
933+
934+
expected = factorize_(**kwargs, reindex=True, sort=True)[0]
935+
assert_equal(expected, [0, 1, 4, 2])
936+
937+
expected = factorize_(**kwargs, reindex=True, sort=False)[0]
938+
assert_equal(expected, [0, 3, 4, 1])
939+
940+
expected = factorize_(**kwargs, reindex=False, sort=False)[0]
941+
assert_equal(expected, [0, 1, 2, 3])
942+
943+
expected = factorize_(**kwargs, reindex=False, sort=True)[0]
944+
assert_equal(expected, [0, 1, 3, 2])
945+
946+
947+
def test_factorize_reindex_sorting_ints():
948+
kwargs = dict(
949+
by=(np.array([-10, 1, 10, 2, 3, 5]),),
950+
axis=-1,
951+
expected_groups=(np.array([0, 1, 2, 3, 4, 5]),),
952+
)
953+
954+
expected = factorize_(**kwargs, reindex=True, sort=True)[0]
955+
assert_equal(expected, [6, 1, 6, 2, 3, 5])
956+
957+
expected = factorize_(**kwargs, reindex=True, sort=False)[0]
958+
assert_equal(expected, [6, 1, 6, 2, 3, 5])
959+
960+
kwargs["expected_groups"] = (np.arange(5, -1, -1),)
961+
962+
expected = factorize_(**kwargs, reindex=True, sort=True)[0]
963+
assert_equal(expected, [6, 1, 6, 2, 3, 5])
964+
965+
expected = factorize_(**kwargs, reindex=True, sort=False)[0]
966+
assert_equal(expected, [6, 4, 6, 3, 2, 0])

0 commit comments

Comments
 (0)