@@ -867,7 +867,8 @@ def test_groupby_bins(chunk_labels, kwargs, chunks, engine, method) -> None:
867867 if chunk_labels :
868868 labels = dask .array .from_array (labels , chunks = chunks )
869869
870- with raise_if_dask_computes ():
870+ max_computes = 1 if method == "cohorts" else 0
871+ with raise_if_dask_computes (max_computes ):
871872 actual , * groups = groupby_reduce (
872873 array , labels , func = "count" , fill_value = 0 , engine = engine , method = method , ** kwargs
873874 )
@@ -1072,7 +1073,9 @@ def test_cohorts_nd_by(by_is_dask, func, method, axis, engine):
10721073 ):
10731074 pytest .skip ()
10741075 if axis is not None and method != "map-reduce" :
1075- pytest .xfail ()
1076+ pytest .skip ()
1077+ if by_is_dask and method == "blockwise" :
1078+ pytest .skip ()
10761079
10771080 o = dask .array .ones ((3 ,), chunks = - 1 )
10781081 o2 = dask .array .ones ((2 , 3 ), chunks = - 1 )
@@ -1092,6 +1095,9 @@ def test_cohorts_nd_by(by_is_dask, func, method, axis, engine):
10921095 fill_value = - 123
10931096
10941097 kwargs = dict (func = func , engine = engine , method = method , axis = axis , fill_value = fill_value )
1098+ if by_is_dask and axis is not None and method == "map-reduce" :
1099+ kwargs ["expected_groups" ] = pd .Index ([1 , 2 , 3 , 4 , 30 , 31 , 40 ])
1100+
10951101 if "quantile" in func :
10961102 kwargs ["finalize_kwargs" ] = {"q" : DEFAULT_QUANTILE }
10971103 actual , groups = groupby_reduce (array , by , ** kwargs )
@@ -1102,6 +1108,7 @@ def test_cohorts_nd_by(by_is_dask, func, method, axis, engine):
11021108 if isinstance (by , dask .array .Array ):
11031109 cache .clear ()
11041110 actual_cohorts = find_group_cohorts (by , array .chunks [- by .ndim :])
1111+ cache .clear ()
11051112 expected_cohorts = find_group_cohorts (by .compute (), array .chunks [- by .ndim :])
11061113 assert actual_cohorts == expected_cohorts
11071114 # assert cache.nbytes
0 commit comments