Skip to content

Commit e8e7606

Browse files
authored
BUG: generate proper code for custom aggregation func (#428)
1 parent e6b765e commit e8e7606

File tree

2 files changed

+87
-8
lines changed

2 files changed

+87
-8
lines changed

python/xorbits/_mars/dataframe/groupby/tests/test_groupby_execution.py

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,8 +1674,7 @@ def test_gpu_groupby_size(data_type, chunked, as_index, sort, setup_gpu):
16741674
pd.testing.assert_series_equal(expected, actual)
16751675

16761676

1677-
# TODO: support cuda
1678-
# @support_cuda
1677+
@support_cuda
16791678
@pytest.mark.parametrize(
16801679
"as_index",
16811680
[True, False],
@@ -1706,16 +1705,78 @@ def g3(x):
17061705
df.groupby("a", as_index=False).agg((g1, g2, g3)),
17071706
mdf.groupby("a", as_index=False).agg((g1, g2, g3)).execute().fetch(),
17081707
)
1709-
pd.testing.assert_frame_equal(
1710-
df.groupby("a", as_index=as_index).agg((g1, g1)),
1711-
mdf.groupby("a", as_index=as_index).agg((g1, g1)).execute().fetch(),
1712-
)
1708+
if not gpu:
1709+
# cuDF doesn't support having multiple columns with same names yet.
1710+
pd.testing.assert_frame_equal(
1711+
df.groupby("a", as_index=as_index).agg((g1, g1)),
1712+
mdf.groupby("a", as_index=as_index).agg((g1, g1)).execute().fetch(),
1713+
)
17131714

17141715
pd.testing.assert_frame_equal(
17151716
df.groupby("a", as_index=as_index)["b"].agg((g1, g2, g3)),
17161717
mdf.groupby("a", as_index=as_index)["b"].agg((g1, g2, g3)).execute().fetch(),
17171718
)
1719+
if not gpu:
1720+
# cuDF doesn't support having multiple columns with same names yet.
1721+
pd.testing.assert_frame_equal(
1722+
df.groupby("a", as_index=as_index)["b"].agg((g1, g1)),
1723+
mdf.groupby("a", as_index=as_index)["b"].agg((g1, g1)).execute().fetch(),
1724+
)
1725+
1726+
1727+
@support_cuda
1728+
def test_groupby_agg_on_custom_funcs(setup_gpu, gpu):
1729+
rs = np.random.RandomState(0)
1730+
df = pd.DataFrame(
1731+
{
1732+
"a": rs.choice(["foo", "bar", "baz"], size=100),
1733+
"b": rs.choice(["foo", "bar", "baz"], size=100),
1734+
"c": rs.choice(["foo", "bar", "baz"], size=100),
1735+
},
1736+
)
1737+
1738+
mdf = md.DataFrame(df, chunk_size=34, gpu=gpu)
1739+
1740+
def g1(x):
1741+
return ("foo" == x).sum()
1742+
1743+
def g2(x):
1744+
return ("foo" != x).sum()
1745+
1746+
def g3(x):
1747+
return (x > "bar").sum()
1748+
1749+
def g4(x):
1750+
return (x >= "bar").sum()
1751+
1752+
def g5(x):
1753+
return (x < "baz").sum()
1754+
1755+
def g6(x):
1756+
return (x <= "baz").sum()
1757+
17181758
pd.testing.assert_frame_equal(
1719-
df.groupby("a", as_index=as_index)["b"].agg((g1, g1)),
1720-
mdf.groupby("a", as_index=as_index)["b"].agg((g1, g1)).execute().fetch(),
1759+
df.groupby("a", as_index=False).agg(
1760+
(
1761+
g1,
1762+
g2,
1763+
g3,
1764+
g4,
1765+
g5,
1766+
g6,
1767+
)
1768+
),
1769+
mdf.groupby("a", as_index=False)
1770+
.agg(
1771+
(
1772+
g1,
1773+
g2,
1774+
g3,
1775+
g4,
1776+
g5,
1777+
g6,
1778+
)
1779+
)
1780+
.execute()
1781+
.fetch(),
17211782
)

python/xorbits/_mars/dataframe/reduction/core.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,24 @@ def _interpret_var(v):
11641164
axis_expr = f"axis={op_axis!r}, " if op_axis is not None else ""
11651165
op_str = _func_name_to_op[func_name]
11661166
if t.op.lhs is t.inputs[0]:
1167+
if (
1168+
(
1169+
func_name
1170+
in (
1171+
"gt",
1172+
"ge",
1173+
"lt",
1174+
"le",
1175+
"eq",
1176+
"ne",
1177+
)
1178+
)
1179+
and isinstance(t.op.lhs, DATAFRAME_TYPE)
1180+
and isinstance(t.op.rhs, str)
1181+
):
1182+
# for a cudf dataframe, df == 'foo' doesn't work, so we convert the rhs
1183+
# to a tuple.
1184+
rhs = f"({rhs},) * len({lhs}.columns)"
11671185
statements = [
11681186
f"try:",
11691187
f" {var_name} = {lhs}.{func_name}({rhs}, {axis_expr})",

0 commit comments

Comments
 (0)