Skip to content

Commit b42c211

Browse files
committed
Add tests
1 parent 5cb4bcf commit b42c211

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

tests/test_core.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,8 @@ def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype):
10661066
@pytest.mark.parametrize("func", ALL_FUNCS)
10671067
@pytest.mark.parametrize("axis", (-1, None))
10681068
@pytest.mark.parametrize("method", ["blockwise", "cohorts", "map-reduce"])
1069-
def test_cohorts_nd_by(func, method, axis, engine):
1069+
@pytest.mark.parametrize("by_is_dask", [True, False])
1070+
def test_cohorts_nd_by(by_is_dask, func, method, axis, engine):
10701071
if (
10711072
("arg" in func and (axis is None or engine in ["flox", "numbagg"]))
10721073
or (method != "blockwise" and func in BLOCKWISE_FUNCS)
@@ -1080,10 +1081,12 @@ def test_cohorts_nd_by(func, method, axis, engine):
10801081
o2 = dask.array.ones((2, 3), chunks=-1)
10811082

10821083
array = dask.array.block([[o, 2 * o], [3 * o2, 4 * o2]])
1083-
by = array.compute().astype(np.int64)
1084+
by = array.astype(np.int64)
10841085
by[0, 1] = 30
10851086
by[2, 1] = 40
10861087
by[0, 4] = 31
1088+
if not by_is_dask:
1089+
by = by.compute()
10871090
array = np.broadcast_to(array, (2, 3) + array.shape)
10881091

10891092
if func in ["any", "all"]:
@@ -1099,10 +1102,19 @@ def test_cohorts_nd_by(func, method, axis, engine):
10991102
assert_equal(groups, sorted_groups)
11001103
assert_equal(actual, expected)
11011104

1102-
actual, groups = groupby_reduce(array, by, sort=False, **kwargs)
1103-
assert_equal(groups, np.array([1, 30, 2, 31, 3, 4, 40], dtype=np.int64))
1104-
reindexed = reindex_(actual, groups, pd.Index(sorted_groups))
1105-
assert_equal(reindexed, expected)
1105+
if isinstance(by, dask.array.Array):
1106+
cache.clear()
1107+
actual_cohorts = find_group_cohorts(by, array.chunks[-by.ndim :])
1108+
expected_cohorts = find_group_cohorts(by.compute(), array.chunks[-by.ndim :])
1109+
assert actual_cohorts == expected_cohorts
1110+
# assert cache.nbytes
1111+
1112+
if not isinstance(by, dask.array.Array):
1113+
# Always sorting groups with cohorts and dask array
1114+
actual, groups = groupby_reduce(array, by, sort=False, **kwargs)
1115+
assert_equal(groups, np.array([1, 30, 2, 31, 3, 4, 40], dtype=np.int64))
1116+
reindexed = reindex_(actual, groups, pd.Index(sorted_groups))
1117+
assert_equal(reindexed, expected)
11061118

11071119

11081120
@pytest.mark.parametrize("func", ["sum", "count"])

0 commit comments

Comments
 (0)