@@ -1066,7 +1066,8 @@ def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype):
1066
1066
@pytest .mark .parametrize ("func" , ALL_FUNCS )
1067
1067
@pytest .mark .parametrize ("axis" , (- 1 , None ))
1068
1068
@pytest .mark .parametrize ("method" , ["blockwise" , "cohorts" , "map-reduce" ])
1069
- def test_cohorts_nd_by (func , method , axis , engine ):
1069
+ @pytest .mark .parametrize ("by_is_dask" , [True , False ])
1070
+ def test_cohorts_nd_by (by_is_dask , func , method , axis , engine ):
1070
1071
if (
1071
1072
("arg" in func and (axis is None or engine in ["flox" , "numbagg" ]))
1072
1073
or (method != "blockwise" and func in BLOCKWISE_FUNCS )
@@ -1080,10 +1081,12 @@ def test_cohorts_nd_by(func, method, axis, engine):
1080
1081
o2 = dask .array .ones ((2 , 3 ), chunks = - 1 )
1081
1082
1082
1083
array = dask .array .block ([[o , 2 * o ], [3 * o2 , 4 * o2 ]])
1083
- by = array .compute (). astype (np .int64 )
1084
+ by = array .astype (np .int64 )
1084
1085
by [0 , 1 ] = 30
1085
1086
by [2 , 1 ] = 40
1086
1087
by [0 , 4 ] = 31
1088
+ if not by_is_dask :
1089
+ by = by .compute ()
1087
1090
array = np .broadcast_to (array , (2 , 3 ) + array .shape )
1088
1091
1089
1092
if func in ["any" , "all" ]:
@@ -1099,10 +1102,19 @@ def test_cohorts_nd_by(func, method, axis, engine):
1099
1102
assert_equal (groups , sorted_groups )
1100
1103
assert_equal (actual , expected )
1101
1104
1102
- actual , groups = groupby_reduce (array , by , sort = False , ** kwargs )
1103
- assert_equal (groups , np .array ([1 , 30 , 2 , 31 , 3 , 4 , 40 ], dtype = np .int64 ))
1104
- reindexed = reindex_ (actual , groups , pd .Index (sorted_groups ))
1105
- assert_equal (reindexed , expected )
1105
+ if isinstance (by , dask .array .Array ):
1106
+ cache .clear ()
1107
+ actual_cohorts = find_group_cohorts (by , array .chunks [- by .ndim :])
1108
+ expected_cohorts = find_group_cohorts (by .compute (), array .chunks [- by .ndim :])
1109
+ assert actual_cohorts == expected_cohorts
1110
+ # assert cache.nbytes
1111
+
1112
+ if not isinstance (by , dask .array .Array ):
1113
+ # Always sorting groups with cohorts and dask array
1114
+ actual , groups = groupby_reduce (array , by , sort = False , ** kwargs )
1115
+ assert_equal (groups , np .array ([1 , 30 , 2 , 31 , 3 , 4 , 40 ], dtype = np .int64 ))
1116
+ reindexed = reindex_ (actual , groups , pd .Index (sorted_groups ))
1117
+ assert_equal (reindexed , expected )
1106
1118
1107
1119
1108
1120
@pytest .mark .parametrize ("func" , ["sum" , "count" ])
0 commit comments