Skip to content

Commit d9b42a1

Browse files
Merge branch 'main' into helmeleegy-SNOW-2504821
2 parents 874a3c7 + 85997e1 commit d9b42a1

File tree

7 files changed

+317
-16
lines changed

7 files changed

+317
-16
lines changed

CHANGELOG.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
- `pivot_table()` with `sort=True`, non-string `index` list, non-string `columns` list, non-string `values` list, or `aggfunc` dict with non-string values
3838
- `fillna()` with `downcast` parameter or using `limit` together with `value`
3939
- `dropna()` with `axis=1`
40+
- `groupby()` with `axis=1`, `by!=None and level!=None`, or by containing any non-pandas hashable labels.
41+
- `groupby_fillna()` with `downcast` parameter
42+
- `groupby_first()` with `min_count>1`
43+
- `groupby_last()` with `min_count>1`
44+
- `shift()` with `freq` parameter
4045

4146
#### Bug Fixes
4247

@@ -210,11 +215,6 @@
210215
- `skew()` with `axis=1` or `numeric_only=False` parameters
211216
- `round()` with `decimals` parameter as a Series
212217
- `corr()` with `method!=pearson` parameter
213-
- `df.groupby()` with `axis=1`, `by!=None and level!=None`, or by containing any non-pandas hashable labels.
214-
- `groupby_fillna()` with `downcast` parameter
215-
- `groupby_first()` with `min_count>1`
216-
- `groupby_last()` with `min_count>1`
217-
- `shift()` with `freq` parameter
218218
- Set `cte_optimization_enabled` to True for all Snowpark pandas sessions.
219219
- Add support for the following in faster pandas:
220220
- `isin`

src/snowflake/snowpark/functions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7135,10 +7135,15 @@ def array_contains(
71357135
variant: Column containing the VARIANT to find.
71367136
array: Column containing the ARRAY to search.
71377137

7138+
If this is a semi-structured array, you're required to explicitly cast the following SQL types into a VARIANT:
7139+
7140+
- `String & Binary <https://docs.snowflake.com/en/sql-reference/data-types-text>`_
7141+
- `Date & Time <https://docs.snowflake.com/en/sql-reference/data-types-datetime>`_
7142+
71387143
Example::
71397144
>>> from snowflake.snowpark import Row
7140-
>>> df = session.create_dataframe([Row([1, 2]), Row([1, 3])], schema=["a"])
7141-
>>> df.select(array_contains(lit(2), "a").alias("result")).show()
7145+
>>> df = session.create_dataframe([Row(["apple", "banana"]), Row(["apple", "orange"])], schema=["a"])
7146+
>>> df.select(array_contains(lit("banana").cast("variant"), "a").alias("result")).show()
71427147
------------
71437148
|"RESULT" |
71447149
------------

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5548,7 +5548,7 @@ def _groupby_first_last(
55485548
return result
55495549

55505550
@register_query_compiler_method_not_implemented(
5551-
"DataFrameGroupBy",
5551+
["DataFrameGroupBy", "SeriesGroupBy"],
55525552
"first",
55535553
UnsupportedArgsRule(
55545554
unsupported_conditions=[
@@ -5594,7 +5594,7 @@ def groupby_first(
55945594
)
55955595

55965596
@register_query_compiler_method_not_implemented(
5597-
"DataFrameGroupBy",
5597+
["DataFrameGroupBy", "SeriesGroupBy"],
55985598
"last",
55995599
UnsupportedArgsRule(
56005600
unsupported_conditions=[
@@ -5640,7 +5640,7 @@ def groupby_last(
56405640
)
56415641

56425642
@register_query_compiler_method_not_implemented(
5643-
"DataFrameGroupBy",
5643+
["DataFrameGroupBy", "SeriesGroupBy"],
56445644
"rank",
56455645
UnsupportedArgsRule(
56465646
unsupported_conditions=[
@@ -6102,7 +6102,7 @@ def groupby_rolling(
61026102
return result_qc
61036103

61046104
@register_query_compiler_method_not_implemented(
6105-
"DataFrameGroupBy",
6105+
["DataFrameGroupBy", "SeriesGroupBy"],
61066106
"shift",
61076107
UnsupportedArgsRule(
61086108
unsupported_conditions=[
@@ -7107,7 +7107,7 @@ def groupby_value_counts(
71077107
)
71087108

71097109
@register_query_compiler_method_not_implemented(
7110-
"DataFrameGroupBy",
7110+
["DataFrameGroupBy", "SeriesGroupBy"],
71117111
"fillna",
71127112
UnsupportedArgsRule(
71137113
unsupported_conditions=[

src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@
5959
)
6060
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
6161
HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS,
62+
UnsupportedArgsRule,
63+
_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE,
64+
register_query_compiler_method_not_implemented,
65+
)
66+
from snowflake.snowpark.modin.plugin._internal.groupby_utils import (
67+
check_is_groupby_supported_by_snowflake,
6268
)
6369
from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike
6470
from snowflake.snowpark.modin.plugin.extensions.snow_partition_iterator import (
@@ -1549,6 +1555,22 @@ def fillna(
15491555

15501556
# Snowpark pandas defines a custom GroupBy object
15511557
@register_series_accessor("groupby")
1558+
@register_query_compiler_method_not_implemented(
1559+
"Series",
1560+
"groupby",
1561+
UnsupportedArgsRule(
1562+
unsupported_conditions=[
1563+
(
1564+
lambda args: not check_is_groupby_supported_by_snowflake(
1565+
args.get("by"),
1566+
args.get("level"),
1567+
args.get("axis", 0),
1568+
),
1569+
f"Groupby {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}",
1570+
)
1571+
]
1572+
),
1573+
)
15521574
def groupby(
15531575
self,
15541576
by=None,

tests/integ/modin/groupby/test_groupby_default2pandas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def test_groupby_with_numpy_array(basic_snowpark_pandas_df) -> None:
130130
@sql_count_checker(query_count=0)
131131
def test_groupby_series_with_numpy_array(native_series_multi_numeric, by_list) -> None:
132132
with pytest.raises(
133-
NotImplementedError, match=GROUPBY_UNSUPPORTED_GROUPING_ERROR_PATTERN
133+
NotImplementedError, match=_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE
134134
):
135135
pd.Series(native_series_multi_numeric).groupby(by=by_list).max()
136136

tests/integ/modin/groupby/test_groupby_rolling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,18 @@ def test_groupby_rolling_dropna_false():
102102
)
103103

104104

105-
@sql_count_checker(query_count=1)
105+
@sql_count_checker(query_count=0)
106106
def test_groupby_rolling_series_negative():
107107
date_idx = pd.date_range("1/1/2000", periods=8, freq="min")
108108
date_idx.names = ["grp_col"]
109109
snow_ser = pd.Series([1, 1, np.nan, 2])
110110
with pytest.raises(
111111
NotImplementedError,
112112
match=re.escape(
113-
"Groupby does not yet support axis == 1, by != None and level != None, or by containing any non-pandas hashable labels"
113+
"Snowpark pandas does not yet support the method GroupBy.rolling for Series"
114114
),
115115
):
116-
snow_ser.groupby(snow_ser.index).rolling(2).sum()
116+
snow_ser.groupby(level=0).rolling(2).sum()
117117

118118

119119
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)