@@ -1674,8 +1674,7 @@ def test_gpu_groupby_size(data_type, chunked, as_index, sort, setup_gpu):
16741674 pd .testing .assert_series_equal (expected , actual )
16751675
16761676
1677- # TODO: support cuda
1678- # @support_cuda
1677+ @support_cuda
16791678@pytest .mark .parametrize (
16801679 "as_index" ,
16811680 [True , False ],
@@ -1706,16 +1705,78 @@ def g3(x):
17061705 df .groupby ("a" , as_index = False ).agg ((g1 , g2 , g3 )),
17071706 mdf .groupby ("a" , as_index = False ).agg ((g1 , g2 , g3 )).execute ().fetch (),
17081707 )
1709- pd .testing .assert_frame_equal (
1710- df .groupby ("a" , as_index = as_index ).agg ((g1 , g1 )),
1711- mdf .groupby ("a" , as_index = as_index ).agg ((g1 , g1 )).execute ().fetch (),
1712- )
1708+ if not gpu :
1709+ # cuDF doesn't support having multiple columns with same names yet.
1710+ pd .testing .assert_frame_equal (
1711+ df .groupby ("a" , as_index = as_index ).agg ((g1 , g1 )),
1712+ mdf .groupby ("a" , as_index = as_index ).agg ((g1 , g1 )).execute ().fetch (),
1713+ )
17131714
17141715 pd .testing .assert_frame_equal (
17151716 df .groupby ("a" , as_index = as_index )["b" ].agg ((g1 , g2 , g3 )),
17161717 mdf .groupby ("a" , as_index = as_index )["b" ].agg ((g1 , g2 , g3 )).execute ().fetch (),
17171718 )
1719+ if not gpu :
1720+ # cuDF doesn't support having multiple columns with same names yet.
1721+ pd .testing .assert_frame_equal (
1722+ df .groupby ("a" , as_index = as_index )["b" ].agg ((g1 , g1 )),
1723+ mdf .groupby ("a" , as_index = as_index )["b" ].agg ((g1 , g1 )).execute ().fetch (),
1724+ )
1725+
1726+
1727+ @support_cuda
1728+ def test_groupby_agg_on_custom_funcs (setup_gpu , gpu ):
1729+ rs = np .random .RandomState (0 )
1730+ df = pd .DataFrame (
1731+ {
1732+ "a" : rs .choice (["foo" , "bar" , "baz" ], size = 100 ),
1733+ "b" : rs .choice (["foo" , "bar" , "baz" ], size = 100 ),
1734+ "c" : rs .choice (["foo" , "bar" , "baz" ], size = 100 ),
1735+ },
1736+ )
1737+
1738+ mdf = md .DataFrame (df , chunk_size = 34 , gpu = gpu )
1739+
1740+ def g1 (x ):
1741+ return ("foo" == x ).sum ()
1742+
1743+ def g2 (x ):
1744+ return ("foo" != x ).sum ()
1745+
1746+ def g3 (x ):
1747+ return (x > "bar" ).sum ()
1748+
1749+ def g4 (x ):
1750+ return (x >= "bar" ).sum ()
1751+
1752+ def g5 (x ):
1753+ return (x < "baz" ).sum ()
1754+
1755+ def g6 (x ):
1756+ return (x <= "baz" ).sum ()
1757+
17181758 pd .testing .assert_frame_equal (
1719- df .groupby ("a" , as_index = as_index )["b" ].agg ((g1 , g1 )),
1720- mdf .groupby ("a" , as_index = as_index )["b" ].agg ((g1 , g1 )).execute ().fetch (),
1759+ df .groupby ("a" , as_index = False ).agg (
1760+ (
1761+ g1 ,
1762+ g2 ,
1763+ g3 ,
1764+ g4 ,
1765+ g5 ,
1766+ g6 ,
1767+ )
1768+ ),
1769+ mdf .groupby ("a" , as_index = False )
1770+ .agg (
1771+ (
1772+ g1 ,
1773+ g2 ,
1774+ g3 ,
1775+ g4 ,
1776+ g5 ,
1777+ g6 ,
1778+ )
1779+ )
1780+ .execute ()
1781+ .fetch (),
17211782 )
0 commit comments