@@ -1974,18 +1974,56 @@ def test_nanlen_string(dtype, engine) -> None:
1974
1974
assert_equal (expected , actual )
1975
1975
1976
1976
1977
- def test_cumsum () -> None :
1978
- array = np .array ([1 , 1 , 1 ], dtype = np .uint64 )
1977
+ @pytest .mark .parametrize (
1978
+ "array" ,
1979
+ [
1980
+ np .array ([1 , 1 , 1 , 2 , 3 , 4 , 5 ], dtype = np .uint64 ),
1981
+ np .array ([1 , 1 , 1 , 2 , np .nan , 4 , 5 ], dtype = np .float64 ),
1982
+ ],
1983
+ )
1984
+ @pytest .mark .parametrize ("func" , ["cumsum" , "nancumsum" ])
1985
+ def test_cumsum_simple (array , func ) -> None :
1979
1986
by = np .array ([0 ] * array .shape [- 1 ])
1980
- expected = np . nancumsum (array , axis = - 1 )
1987
+ expected = getattr ( np , func ) (array , axis = - 1 )
1981
1988
1982
- actual = groupby_scan (array , by , func = "nancumsum" , axis = - 1 )
1983
- assert_equal (expected , actual )
1989
+ actual = groupby_scan (array , by , func = func , axis = - 1 )
1990
+ assert_equal (actual , expected )
1991
+
1992
+ if has_dask :
1993
+ da = dask .array .from_array (array , chunks = 2 )
1994
+ actual = groupby_scan (da , by , func = func , axis = - 1 )
1995
+ assert_equal (actual , expected )
1996
+
1997
+
1998
+ def test_cumsum () -> None :
1999
+ array = np .array (
2000
+ [
2001
+ [1 , 2 , np .nan , 4 , 5 ],
2002
+ [3 , np .nan , 4 , 6 , 6 ],
2003
+ ]
2004
+ )
2005
+ by = [0 , 1 , 1 , 0 , 1 ]
2006
+
2007
+ expected = np .array (
2008
+ [
2009
+ [1 , 2 , np .nan , 5 , np .nan ],
2010
+ [3 , np .nan , np .nan , 9 , np .nan ],
2011
+ ]
2012
+ )
2013
+ actual = groupby_scan (array , by , func = "cumsum" , axis = - 1 )
2014
+ assert_equal (actual , expected )
2015
+ if has_dask :
2016
+ da = dask .array .from_array (array , chunks = 2 )
2017
+ actual = groupby_scan (da , by , func = "cumsum" , axis = - 1 )
2018
+ assert_equal (actual , expected )
1984
2019
2020
+ expected = np .array ([[1 , 2 , 2 , 5 , 7 ], [3 , 0 , 4 , 9 , 10 ]], dtype = np .float64 )
2021
+ actual = groupby_scan (array , by , func = "nancumsum" , axis = - 1 )
2022
+ assert_equal (actual , expected )
1985
2023
if has_dask :
1986
2024
da = dask .array .from_array (array , chunks = 2 )
1987
2025
actual = groupby_scan (da , by , func = "nancumsum" , axis = - 1 )
1988
- assert_equal (expected , actual )
2026
+ assert_equal (actual , expected )
1989
2027
1990
2028
1991
2029
@pytest .mark .parametrize (
0 commit comments