Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
94b2f90
SNOW-1757443: Groupby rolling functionality and test
sfc-gh-lmukhopadhyay Aug 21, 2025
9bc3b00
update for dropna
sfc-gh-lmukhopadhyay Aug 21, 2025
e1c4882
Merge branch 'main' into lmukhopadhyay-SNOW-1757443-groupby-rolling
sfc-gh-lmukhopadhyay Aug 21, 2025
48f7d2e
fix tests
sfc-gh-lmukhopadhyay Aug 25, 2025
a7aed41
Merge branch 'main' into lmukhopadhyay-SNOW-1757443-groupby-rolling
sfc-gh-lmukhopadhyay Aug 25, 2025
5ddfa3d
update changelog doc
sfc-gh-lmukhopadhyay Aug 25, 2025
0d55f3e
update series error and type fix
sfc-gh-lmukhopadhyay Aug 26, 2025
e0c4cfa
overrides type fix
sfc-gh-lmukhopadhyay Aug 26, 2025
1977500
Merge branch 'main' into lmukhopadhyay-SNOW-1757443-groupby-rolling
sfc-gh-lmukhopadhyay Aug 28, 2025
801b1be
fix override file
sfc-gh-lmukhopadhyay Aug 28, 2025
5878c65
add doctests
sfc-gh-lmukhopadhyay Aug 28, 2025
3e58be5
rm inherit docstr
sfc-gh-lmukhopadhyay Aug 28, 2025
bceab0a
Merge branch 'main' into lmukhopadhyay-SNOW-1757443-groupby-rolling
sfc-gh-lmukhopadhyay Aug 28, 2025
6b0ffdb
Merge branch 'main' into lmukhopadhyay-SNOW-1757443-groupby-rolling
sfc-gh-lmukhopadhyay Aug 29, 2025
e751775
Merge branch 'main' into lmukhopadhyay-SNOW-1757443-groupby-rolling
sfc-gh-lmukhopadhyay Aug 29, 2025
fabf678
Merge branch 'main' into lmukhopadhyay-SNOW-1757443-groupby-rolling
sfc-gh-lmukhopadhyay Sep 2, 2025
bff385b
Merge branch 'main' into lmukhopadhyay-SNOW-1757443-groupby-rolling
sfc-gh-lmukhopadhyay Sep 12, 2025
b0303c1
review changes and add neg unsupported tests
sfc-gh-lmukhopadhyay Sep 29, 2025
ba43f46
Merge branch 'main' into lmukhopadhyay-SNOW-1757443-groupby-rolling
sfc-gh-lmukhopadhyay Sep 29, 2025
55b601c
update changelog
sfc-gh-lmukhopadhyay Sep 29, 2025
ca730bd
Merge branch 'main' into lmukhopadhyay-SNOW-1757443-groupby-rolling
sfc-gh-lmukhopadhyay Sep 30, 2025
c361e58
update doctests
sfc-gh-lmukhopadhyay Sep 30, 2025
84ab734
Merge branch 'main' into lmukhopadhyay-SNOW-1757443-groupby-rolling
sfc-gh-lmukhopadhyay Sep 30, 2025
8bab816
update doctest
sfc-gh-lmukhopadhyay Sep 30, 2025
d7a8098
Merge branch 'main' into lmukhopadhyay-SNOW-1757443-groupby-rolling
sfc-gh-lmukhopadhyay Oct 1, 2025
4682324
Merge branch 'main' into lmukhopadhyay-SNOW-1757443-groupby-rolling
sfc-gh-lmukhopadhyay Oct 16, 2025
1b674e5
review changes, multiindex and dropna and sort support
sfc-gh-lmukhopadhyay Oct 17, 2025
54f1018
fix doctest
sfc-gh-lmukhopadhyay Oct 17, 2025
32ca8b1
update changelog
sfc-gh-lmukhopadhyay Oct 17, 2025
ced6106
change from review
sfc-gh-lmukhopadhyay Oct 17, 2025
5805a5d
review change using kwargs
sfc-gh-lmukhopadhyay Oct 17, 2025
a8319bc
review changes
sfc-gh-lmukhopadhyay Oct 21, 2025
39b41ec
resolve conf
sfc-gh-lmukhopadhyay Oct 21, 2025
78c3689
fix changelog format
sfc-gh-lmukhopadhyay Oct 21, 2025
30053fb
fix warning msg
sfc-gh-lmukhopadhyay Oct 21, 2025
151e04b
Merge branch 'main' into lmukhopadhyay-SNOW-1757443-groupby-rolling
sfc-gh-lmukhopadhyay Oct 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
backends. Previously, only some of these functions and methods were supported
on the Pandas backend.
- Added support for `Index.get_level_values()`.
- Added support for `Dataframe.groupby.rolling()`.

#### Improvements
- Set the default transfer limit in hybrid execution for data leaving Snowflake to 100k, which can be overridden with the SnowflakePandasTransferThreshold environment variable. This configuration is appropriate for scenarios with two available engines, "Pandas" and "Snowflake" on relational workloads.
Expand Down
4 changes: 3 additions & 1 deletion docs/source/modin/supported/groupby_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ Computations/descriptive stats
| | | will be lost. ``rule`` frequencies 's', 'min', |
| | | 'h', and 'D' are supported. |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``rolling`` | N | |
| ``rolling`` | P | Implemented for DataframeGroupby objects. ``N`` for|
| | | non-integer ``window``, ``axis = 1``, or |
| | | ``min_periods = 0``. |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``sample`` | N | |
+-----------------------------+---------------------------------+----------------------------------------------------+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5317,6 +5317,113 @@ def groupby_resample(
)
return SnowflakeQueryCompiler(resampled_frame_all_bins)

def groupby_rolling(
self,
rolling_kwargs: dict[str, Any],
rolling_method: AggFuncType,
groupby_kwargs: dict[str, Any],
is_series: bool,
agg_args: Any,
agg_kwargs: dict[str, Any],
) -> "SnowflakeQueryCompiler":
"""
Return a rolling grouper, providing rolling functionality per group.

This implementation supports both fixed window-based and time-based rolling operations
with groupby functionality.

Args:
rolling_kwargs: Dictionary containing rolling window parameters
rolling_method: The aggregation method to apply (e.g., 'mean', 'sum')
groupby_kwargs: Dictionary containing groupby parameters
is_series: Whether the operation is on a Series
agg_args: Additional arguments for aggregation
agg_kwargs: Additional keyword arguments for aggregation

Returns:
SnowflakeQueryCompiler: A new query compiler with the rolling operation applied
"""

dropna = groupby_kwargs.get("dropna", True)

# Validate parameters
if rolling_kwargs.get("axis", 0) != 0:
raise ErrorMessage.not_implemented(
"GroupBy rolling with axis != 0 is not supported yet in Snowpark pandas."
)

if rolling_kwargs.get("win_type") is not None:
raise ErrorMessage.not_implemented(
"GroupBy rolling with win_type parameter is not supported yet in Snowpark pandas."
)

if rolling_kwargs.get("method", "single") != "single":
raise ErrorMessage.not_implemented(
"GroupBy rolling with method != 'single' is not supported yet in Snowpark pandas."
)

window = rolling_kwargs.get("window")
min_periods = rolling_kwargs.get(
"min_periods", (1 if isinstance(window, str) else window)
)
window_kwargs = {
"window": window,
"min_periods": min_periods,
"center": rolling_kwargs.get("center", False),
"on": rolling_kwargs.get("on", None),
"axis": 0,
"closed": rolling_kwargs.get("closed", None),
}

if not isinstance(window, (int, float)):
raise ErrorMessage.not_implemented(
"GroupBy rolling only supports numeric window sizes in Snowpark pandas."
)
if min_periods and isinstance(window, int) and min_periods > window:
raise ValueError(f"min_periods {min_periods} must be <= window {window}")

# Extract groupby columns
by_labels = extract_groupby_column_pandas_labels(
self,
groupby_kwargs.get("by", None),
groupby_kwargs.get("level", None),
)

extended_qc = self
if dropna:
extended_qc = extended_qc.dropna(axis=0, how="any", subset=by_labels)

result_qc = extended_qc._window_agg(
window_func=WindowFunction.ROLLING,
agg_func=rolling_method,
window_kwargs=window_kwargs,
agg_kwargs=agg_kwargs,
partition_cols=by_labels,
)

if by_labels:
result_qc = result_qc.reset_index()

# Set index with group columns then original index
index_cols = list(by_labels) + ["index"]
result_qc = result_qc.set_index(keys=index_cols, drop=True)

frame = result_qc._modin_frame
expected_names = list(by_labels) + [None]
new_frame = InternalFrame.create(
ordered_dataframe=frame.ordered_dataframe,
data_column_pandas_labels=frame.data_column_pandas_labels,
data_column_pandas_index_names=frame.data_column_pandas_index_names,
data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers,
index_column_pandas_labels=expected_names, # This is the key change
index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers,
data_column_types=frame.cached_data_column_snowpark_pandas_types,
index_column_types=frame.cached_index_column_snowpark_pandas_types,
)

result_qc = SnowflakeQueryCompiler(new_frame)
return result_qc

def groupby_shift(
self,
by: Any,
Expand Down Expand Up @@ -14932,6 +15039,7 @@ def _window_agg(
agg_func: AggFuncType,
window_kwargs: dict[str, Any],
agg_kwargs: dict[str, Any],
partition_cols: Optional[list[str]] = None,
) -> "SnowflakeQueryCompiler":
"""
Compute rolling window with given aggregation.
Expand All @@ -14940,6 +15048,7 @@ def _window_agg(
agg_func: callable, str, list or dict. the aggregation function used.
rolling_kwargs: keyword arguments passed to rolling.
agg_kwargs: keyword arguments passed for the aggregation function.
partition_cols: list of columns to partition by, if any.
Returns:
SnowflakeQueryCompiler: with a newly constructed internal dataframe
"""
Expand Down Expand Up @@ -14980,7 +15089,20 @@ def _window_agg(
window_expr = Window.orderBy(
col(row_position_quoted_identifier)
).rows_between(rows_between_start, rows_between_end)
if partition_cols:
# Get the actual Snowflake quoted identifiers for the partition columns
partition_identifiers = []
for col_label in partition_cols:
if col_label in frame.data_column_pandas_labels:
idx = frame.data_column_pandas_labels.index(col_label)
partition_identifiers.append(
frame.data_column_snowflake_quoted_identifiers[idx]
)

if partition_identifiers:
window_expr = window_expr.partitionBy(
*[col(pid) for pid in partition_identifiers]
)
if window_func == WindowFunction.ROLLING:
# min_periods defaults to the size of the window if window is specified by an integer
min_periods = window if min_periods is None else min_periods
Expand Down Expand Up @@ -15010,6 +15132,32 @@ def _window_agg(
for t in frame.cached_data_column_snowpark_pandas_types
)

# Determine which columns to apply the window function to
# For groupby rolling, we want to exclude the partition columns from aggregation
# but keep them in the result with their original values
if partition_cols:
# Only apply window functions to non-partition columns
agg_column_identifiers = [
quoted_identifier
for i, quoted_identifier in enumerate(
frame.data_column_snowflake_quoted_identifiers
)
if frame.data_column_pandas_labels[i] not in partition_cols
]

# Keep partition columns with original values
partition_column_expressions = {}
for col_label in partition_cols:
if col_label in frame.data_column_pandas_labels:
idx = frame.data_column_pandas_labels.index(col_label)
partition_column_expressions[
frame.data_column_snowflake_quoted_identifiers[idx]
] = col(frame.data_column_snowflake_quoted_identifiers[idx])
else:
# Regular rolling
agg_column_identifiers = frame.data_column_snowflake_quoted_identifiers
partition_column_expressions = {}

# Perform Aggregation over the window_expr
if agg_func == "sem":
if input_contains_timedelta:
Expand All @@ -15018,39 +15166,38 @@ def _window_agg(
# Standard error of mean (SEM) does not have native Snowflake engine support
# so calculate as STDDEV/SQRT(N-ddof)
ddof = agg_kwargs.get("ddof", 1)
new_frame = frame.update_snowflake_quoted_identifiers_with_expressions(
{
quoted_identifier: iff(
count(col(quoted_identifier)).over(window_expr) >= min_periods,
when(
# If STDDEV is Null (like when the window has 1 element), return NaN
# Note that in Python, np.nan / np.inf results in np.nan, so this check must come first
builtin("stddev")(col(quoted_identifier))
.over(window_expr)
.is_null(),
pandas_lit(None),
)
.when(
# Elif (N-ddof) is negative number, return NaN to mimic pandas sqrt of a negative number
count(col(quoted_identifier)).over(window_expr) - ddof < 0,
pandas_lit(None),
)
.when(
# Elif (N-ddof) is 0, return np.inf to mimic pandas division by 0
count(col(quoted_identifier)).over(window_expr) - ddof == 0,
pandas_lit(np.inf),
)
.otherwise(
# Else compute STDDEV/SQRT(N-ddof)
builtin("stddev")(col(quoted_identifier)).over(window_expr)
/ builtin("sqrt")(
count(col(quoted_identifier)).over(window_expr) - ddof
),
),
agg_expressions = {
quoted_identifier: iff(
count(col(quoted_identifier)).over(window_expr) >= min_periods,
when(
builtin("stddev")(col(quoted_identifier))
.over(window_expr)
.is_null(),
pandas_lit(None),
)
for quoted_identifier in frame.data_column_snowflake_quoted_identifiers
}
.when(
count(col(quoted_identifier)).over(window_expr) - ddof < 0,
pandas_lit(None),
)
.when(
count(col(quoted_identifier)).over(window_expr) - ddof == 0,
pandas_lit(np.inf),
)
.otherwise(
builtin("stddev")(col(quoted_identifier)).over(window_expr)
/ builtin("sqrt")(
count(col(quoted_identifier)).over(window_expr) - ddof
),
),
pandas_lit(None),
)
for quoted_identifier in agg_column_identifiers
}
# Combine partition column expressions (unchanged) with aggregated expressions
all_expressions = {**partition_column_expressions, **agg_expressions}

new_frame = frame.update_snowflake_quoted_identifiers_with_expressions(
all_expressions
).frame
elif agg_func == "corr":
if input_contains_timedelta:
Expand Down Expand Up @@ -15139,28 +15286,28 @@ def _window_agg(
and input_contains_timedelta
):
raise DataError(_TIMEDELTA_ROLLING_AGGREGATION_NOT_SUPPORTED)
# Build expressions for aggregated columns
agg_expressions = {
quoted_identifier: iff(
count(col(row_position_quoted_identifier)).over(window_expr)
>= min_periods
if agg_func == "count"
else count(col(quoted_identifier)).over(window_expr) >= min_periods,
snowflake_agg_func.snowpark_aggregation(
builtin("zeroifnull")(col(quoted_identifier))
if window_func == WindowFunction.EXPANDING and agg_func == "sum"
else col(quoted_identifier)
).over(window_expr),
pandas_lit(None),
)
for quoted_identifier in agg_column_identifiers
}

# Combine partition column expressions (unchanged) with aggregated expressions
all_expressions = {**partition_column_expressions, **agg_expressions}

new_frame = frame.update_snowflake_quoted_identifiers_with_expressions(
{
# If aggregation is count use count on row_position_quoted_identifier
# to include NULL values for min_periods comparison
quoted_identifier: iff(
count(col(row_position_quoted_identifier)).over(window_expr)
>= min_periods
if agg_func == "count"
else count(col(quoted_identifier)).over(window_expr)
>= min_periods,
snowflake_agg_func.snowpark_aggregation(
# Expanding is cumulative so replace NULL with 0 for sum aggregation
builtin("zeroifnull")(col(quoted_identifier))
if window_func == WindowFunction.EXPANDING
and agg_func == "sum"
else col(quoted_identifier)
).over(window_expr),
pandas_lit(None),
)
for quoted_identifier in frame.data_column_snowflake_quoted_identifiers
}
all_expressions
).frame
return self.__constructor__(new_frame)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -847,9 +847,36 @@ def resample(


@register_df_groupby_override("rolling")
def rolling(self, *args, **kwargs):
def rolling(
self,
window,
min_periods: Union[int, None] = None,
center: bool = False,
win_type: Union[str, None] = None,
on: Union[str, None] = None,
axis: Union[int, str] = 0,
closed: Union[str, None] = None,
method: str = "single",
**kwargs,
):
# TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions
ErrorMessage.method_not_implemented_error(name="rolling", class_="GroupBy")
from snowflake.snowpark.modin.plugin.extensions.rolling_groupby_overrides import (
RollingGroupby,
)

# ErrorMessage.method_not_implemented_error(name="rolling", class_="GroupBy")
return RollingGroupby(
dataframe=self._df,
by=self._by,
window=window,
min_periods=min_periods,
center=center,
win_type=win_type,
on=on,
axis=axis,
closed=closed,
method=method,
)


@register_df_groupby_override("sample")
Expand Down
Loading
Loading