Skip to content

Commit 569aef4

Browse files
SNOW-1757443: Implement GroupBy rolling (#3686)
Signed-off-by: Labanya Mukhopadhyay <[email protected]>
1 parent 226b9d9 commit 569aef4

File tree

8 files changed

+865
-57
lines changed

8 files changed

+865
-57
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
- Added support for the `dtypes` parameter of `pd.get_dummies`
8585
- Added support for `nunique` in `df.pivot_table`, `df.agg` and other places where aggregate functions can be used.
8686
- Added support for `DataFrame.interpolate` and `Series.interpolate` with the "linear", "ffill"/"pad", and "backfill"/bfill" methods. These use the SQL `INTERPOLATE_LINEAR`, `INTERPOLATE_FFILL`, and `INTERPOLATE_BFILL` functions (PuPr).
87+
- Added support for `Dataframe.groupby.rolling()`.
8788

8889
#### Improvements
8990

docs/source/modin/supported/groupby_supported.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,10 @@ Computations/descriptive stats
153153
| | | will be lost. ``rule`` frequencies 's', 'min', |
154154
| | | 'h', and 'D' are supported. |
155155
+-----------------------------+---------------------------------+----------------------------------------------------+
156-
| ``rolling`` | N | |
156+
| ``rolling`` | P | Implemented for DataframeGroupby objects. ``N`` for|
157+
| | | ``on``, non-integer ``window``, ``axis = 1``, |
158+
| | | ``method`` != ``single``, ``min_periods = 0``, or |
159+
| | | ``closed`` != ``None``. |
157160
+-----------------------------+---------------------------------+----------------------------------------------------+
158161
| ``sample`` | N | |
159162
+-----------------------------+---------------------------------+----------------------------------------------------+

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 236 additions & 51 deletions
Large diffs are not rendered by default.

src/snowflake/snowpark/modin/plugin/docstrings/groupby.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2424,7 +2424,76 @@ def expanding():
24242424
pass
24252425

24262426
def rolling():
2427-
pass
2427+
"""
2428+
Return a rolling grouper, providing rolling functionality per group.
2429+
2430+
This implementation supports both fixed window-based and time-based rolling operations
2431+
with groupby functionality.
2432+
2433+
Parameters
2434+
----------
2435+
window : int, timedelta, str, offset, or BaseIndexer subclass
2436+
Size of the moving window.
2437+
If an integer, the fixed number of observations used for each window.
2438+
If a timedelta, str, or offset, the time period of each window. Each window will be a variable sized based on the observations included in the time-period. This is only valid for datetimelike indexes. To learn more about the offsets & frequency strings, please see `this link <https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases>`.
2439+
If a BaseIndexer subclass, the window boundaries based on the defined `get_window_bounds` method. Additional rolling keyword arguments, namely `min_periods`, `center`, `closed` and `step` will be passed to `get_window_bounds`.
2440+
2441+
min_periods : int, default None
2442+
Minimum number of observations in window required to have a value; otherwise, result is `np.nan`.
2443+
For a window that is specified by an offset, `min_periods` will default to 1.
2444+
For a window that is specified by an integer, `min_periods` will default to the size of the window.
2445+
2446+
center : bool, default False
2447+
If False, set the window labels as the right edge of the window index.
2448+
If True, set the window labels as the center of the window index.
2449+
2450+
win_type : str, default None
2451+
If `None`, all points are evenly weighted.
2452+
If a string, it must be a valid scipy.signal window function.
2453+
Certain Scipy window types require additional parameters to be passed in the aggregation function. The additional parameters must match the keywords specified in the Scipy window type method signature.
2454+
2455+
on : str, optional
2456+
For a DataFrame, a column label or Index level on which to calculate the rolling window, rather than the DataFrame's index.
2457+
Provided integer column is ignored and excluded from result since an integer index is not used to calculate the rolling window.
2458+
2459+
axis : int or str, default 0
2460+
If `0` or `'index'`, roll across the rows.
2461+
If `1` or `'columns'`, roll across the columns.
2462+
For Series this parameter is unused and defaults to 0.
2463+
2464+
closed : str, default None
2465+
If `'right'`, the first point in the window is excluded from calculations.
2466+
If `'left'`, the last point in the window is excluded from calculations.
2467+
If `'both'`, no points in the window are excluded from calculations.
2468+
If `'neither'`, the first and last points in the window are excluded from calculations.
2469+
Default `None` (`'right'`).
2470+
2471+
method : str {'single', 'table'}, default 'single'
2472+
Execute the rolling operation per single column or row (`'single'`) or over the entire object (`'table'`).
2473+
This argument is only implemented when specifying `engine='numba'` in the method call.
2474+
2475+
Returns
2476+
-------
2477+
Rolling
2478+
A rolling grouper, providing rolling functionality per group.
2479+
2480+
Examples
2481+
--------
2482+
>>> df = pd.DataFrame({'A': [1, 1, 2, 2], 'B': [1, 2, 3, 4], 'C': [0.362, 0.227, 1.267, -0.562]})
2483+
>>> df
2484+
A B C
2485+
0 1 1 0.362
2486+
1 1 2 0.227
2487+
2 2 3 1.267
2488+
3 2 4 -0.562
2489+
>>> df.groupby('A').rolling(2).sum() # doctest: +NORMALIZE_WHITESPACE
2490+
B C
2491+
A
2492+
1 0 NaN NaN
2493+
1 3.0 0.589
2494+
2 2 NaN NaN
2495+
3 7.0 0.705
2496+
"""
24282497

24292498
def hist():
24302499
pass

src/snowflake/snowpark/modin/plugin/extensions/dataframe_groupby_overrides.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -847,9 +847,35 @@ def resample(
847847

848848

849849
@register_df_groupby_override("rolling")
850-
def rolling(self, *args, **kwargs):
851-
# TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions
852-
ErrorMessage.method_not_implemented_error(name="rolling", class_="GroupBy")
850+
def rolling(
851+
self,
852+
window,
853+
min_periods: Union[int, None] = None,
854+
center: bool = False,
855+
win_type: Union[str, None] = None,
856+
on: Union[str, None] = None,
857+
axis: Union[int, str] = 0,
858+
closed: Union[str, None] = None,
859+
method: str = "single",
860+
**kwargs,
861+
):
862+
from snowflake.snowpark.modin.plugin.extensions.rolling_groupby_overrides import (
863+
RollingGroupby,
864+
)
865+
866+
return RollingGroupby(
867+
dataframe=self._df,
868+
by=self._by,
869+
window=window,
870+
min_periods=min_periods,
871+
center=center,
872+
win_type=win_type,
873+
on=on,
874+
axis=axis,
875+
closed=closed,
876+
method=method,
877+
dropna=self._kwargs.get("dropna", True),
878+
)
853879

854880

855881
@register_df_groupby_override("sample")

0 commit comments

Comments
 (0)