Skip to content

Commit 4c36548

Browse files
Merge branch 'main' into feature/aherrera/SNOW-2443512-StringAndBinaryPart2
2 parents c376c45 + b2696ad commit 4c36548

File tree

11 files changed

+394
-44
lines changed

11 files changed

+394
-44
lines changed

CHANGELOG.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@
4949
- `pivot_table()` with `sort=True`, non-string `index` list, non-string `columns` list, non-string `values` list, or `aggfunc` dict with non-string values
5050
- `fillna()` with `downcast` parameter or using `limit` together with `value`
5151
- `dropna()` with `axis=1`
52+
- `groupby()` with `axis=1`, `by!=None and level!=None`, or by containing any non-pandas hashable labels.
53+
- `groupby_fillna()` with `downcast` parameter
54+
- `groupby_first()` with `min_count>1`
55+
- `groupby_last()` with `min_count>1`
56+
- `shift()` with `freq` parameter
57+
- Slightly improved the performance of `agg`, `nunique`, `describe`, and related methods on 1-column DataFrame and Series objects.
5258

5359
#### Bug Fixes
5460

@@ -219,11 +225,6 @@
219225
- `skew()` with `axis=1` or `numeric_only=False` parameters
220226
- `round()` with `decimals` parameter as a Series
221227
- `corr()` with `method!=pearson` parameter
222-
- `df.groupby()` with `axis=1`, `by!=None and level!=None`, or by containing any non-pandas hashable labels.
223-
- `groupby_fillna()` with `downcast` parameter
224-
- `groupby_first()` with `min_count>1`
225-
- `groupby_last()` with `min_count>1`
226-
- `shift()` with `freq` parameter
227228
- Set `cte_optimization_enabled` to True for all Snowpark pandas sessions.
228229
- Add support for the following in faster pandas:
229230
- `isin`

src/snowflake/snowpark/functions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7135,10 +7135,15 @@ def array_contains(
71357135
variant: Column containing the VARIANT to find.
71367136
array: Column containing the ARRAY to search.
71377137

7138+
If this is a semi-structured array, you're required to explicitly cast the following SQL types into a VARIANT:
7139+
7140+
- `String & Binary <https://docs.snowflake.com/en/sql-reference/data-types-text>`_
7141+
- `Date & Time <https://docs.snowflake.com/en/sql-reference/data-types-datetime>`_
7142+
71387143
Example::
71397144
>>> from snowflake.snowpark import Row
7140-
>>> df = session.create_dataframe([Row([1, 2]), Row([1, 3])], schema=["a"])
7141-
>>> df.select(array_contains(lit(2), "a").alias("result")).show()
7145+
>>> df = session.create_dataframe([Row(["apple", "banana"]), Row(["apple", "orange"])], schema=["a"])
7146+
>>> df.select(array_contains(lit("banana").cast("variant"), "a").alias("result")).show()
71427147
------------
71437148
|"RESULT" |
71447149
------------

src/snowflake/snowpark/modin/plugin/_internal/frame.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,8 @@ def update_snowflake_quoted_identifiers_with_expressions(
11441144
self,
11451145
quoted_identifier_to_column_map: dict[str, SnowparkColumn],
11461146
snowpark_pandas_types: Optional[list[Optional[SnowparkPandasType]]] = None,
1147+
*,
1148+
new_index_column_pandas_labels: Optional[list[Hashable]] = None,
11471149
) -> UpdatedInternalFrameResult:
11481150
"""
11491151
Points Snowflake quoted identifiers to column expression given by `quoted_identifier_to_column_map`.
@@ -1171,6 +1173,8 @@ def update_snowflake_quoted_identifiers_with_expressions(
11711173
must be index columns and data columns in the original internal frame.
11721174
data_column_snowpark_pandas_types: The optional Snowpark pandas types for the new
11731175
expressions, in the order of the keys of quoted_identifier_to_column_map.
1176+
new_index_column_pandas_labels: The optional list of labels to be used as
1177+
index_column_pandas_labels for the result.
11741178
11751179
Returns:
11761180
UpdatedInternalFrameResult: A tuple containing the new InternalFrame with updated column references, and a mapping
@@ -1252,7 +1256,9 @@ def update_snowflake_quoted_identifiers_with_expressions(
12521256
data_column_pandas_labels=self.data_column_pandas_labels,
12531257
data_column_snowflake_quoted_identifiers=new_data_column_snowflake_quoted_identifiers,
12541258
data_column_pandas_index_names=self.data_column_pandas_index_names,
1255-
index_column_pandas_labels=self.index_column_pandas_labels,
1259+
index_column_pandas_labels=self.index_column_pandas_labels
1260+
if new_index_column_pandas_labels is None
1261+
else new_index_column_pandas_labels,
12561262
index_column_snowflake_quoted_identifiers=new_index_column_snowflake_quoted_identifiers,
12571263
data_column_types=[
12581264
new_type_mapping[k]

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5548,7 +5548,7 @@ def _groupby_first_last(
55485548
return result
55495549

55505550
@register_query_compiler_method_not_implemented(
5551-
"DataFrameGroupBy",
5551+
["DataFrameGroupBy", "SeriesGroupBy"],
55525552
"first",
55535553
UnsupportedArgsRule(
55545554
unsupported_conditions=[
@@ -5594,7 +5594,7 @@ def groupby_first(
55945594
)
55955595

55965596
@register_query_compiler_method_not_implemented(
5597-
"DataFrameGroupBy",
5597+
["DataFrameGroupBy", "SeriesGroupBy"],
55985598
"last",
55995599
UnsupportedArgsRule(
56005600
unsupported_conditions=[
@@ -5640,7 +5640,7 @@ def groupby_last(
56405640
)
56415641

56425642
@register_query_compiler_method_not_implemented(
5643-
"DataFrameGroupBy",
5643+
["DataFrameGroupBy", "SeriesGroupBy"],
56445644
"rank",
56455645
UnsupportedArgsRule(
56465646
unsupported_conditions=[
@@ -6102,7 +6102,7 @@ def groupby_rolling(
61026102
return result_qc
61036103

61046104
@register_query_compiler_method_not_implemented(
6105-
"DataFrameGroupBy",
6105+
["DataFrameGroupBy", "SeriesGroupBy"],
61066106
"shift",
61076107
UnsupportedArgsRule(
61086108
unsupported_conditions=[
@@ -7107,7 +7107,7 @@ def groupby_value_counts(
71077107
)
71087108

71097109
@register_query_compiler_method_not_implemented(
7110-
"DataFrameGroupBy",
7110+
["DataFrameGroupBy", "SeriesGroupBy"],
71117111
"fillna",
71127112
UnsupportedArgsRule(
71137113
unsupported_conditions=[
@@ -11875,10 +11875,54 @@ def transpose_single_row(self) -> "SnowflakeQueryCompiler":
1187511875
self._raise_not_implemented_error_for_timedelta()
1187611876

1187711877
frame = self._modin_frame
11878-
11878+
input_column_count = len(frame.data_columns_index)
1187911879
# Handle case where the dataframe has empty columns.
11880-
if len(frame.data_columns_index) == 0:
11880+
if input_column_count == 0:
1188111881
return transpose_empty_df(frame)
11882+
if input_column_count == 1:
11883+
# If the frame is 1x1, then the datatype is already preserved; we need only set the entry
11884+
# in the index columns to match the original index labels.
11885+
if len(frame.data_column_index_names) > 1:
11886+
# If the columns object has a multi-index name, we need to project new columns for
11887+
# the extra labels.
11888+
data_odf = frame.ordered_dataframe.select(
11889+
frame.data_column_snowflake_quoted_identifiers
11890+
)
11891+
new_index_column_identifiers = (
11892+
data_odf.generate_snowflake_quoted_identifiers(
11893+
pandas_labels=frame.data_column_pandas_index_names
11894+
)
11895+
)
11896+
new_odf = append_columns(
11897+
data_odf,
11898+
new_index_column_identifiers,
11899+
list(map(pandas_lit, frame.data_column_pandas_labels[0])),
11900+
)
11901+
new_odf.row_count = 1
11902+
return SnowflakeQueryCompiler(
11903+
InternalFrame.create(
11904+
ordered_dataframe=new_odf,
11905+
data_column_pandas_labels=[None],
11906+
data_column_pandas_index_names=[None],
11907+
data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers,
11908+
index_column_pandas_labels=frame.data_column_pandas_index_names,
11909+
index_column_snowflake_quoted_identifiers=new_index_column_identifiers,
11910+
data_column_types=frame.cached_data_column_snowpark_pandas_types,
11911+
index_column_types=None,
11912+
)
11913+
)
11914+
else:
11915+
return SnowflakeQueryCompiler(
11916+
frame.update_snowflake_quoted_identifiers_with_expressions(
11917+
{
11918+
frame.index_column_snowflake_quoted_identifiers[
11919+
0
11920+
]: pandas_lit(frame.data_column_pandas_labels[0]),
11921+
},
11922+
# Swap the name of the index/columns objects
11923+
new_index_column_pandas_labels=frame.data_column_pandas_index_names,
11924+
)[0]
11925+
).set_columns([None])
1188211926

1188311927
# This follows the same approach used in SnowflakeQueryCompiler.transpose().
1188411928
# However, as an optimization, only steps (1), (2), and (4) from the four steps described in
@@ -11909,6 +11953,7 @@ def transpose_single_row(self) -> "SnowflakeQueryCompiler":
1190911953
unpivot_result.variable_name_quoted_snowflake_identifier,
1191011954
unpivot_result.object_name_quoted_snowflake_identifier,
1191111955
)
11956+
new_internal_frame.ordered_dataframe.row_count = input_column_count
1191211957

1191311958
return SnowflakeQueryCompiler(new_internal_frame)
1191411959

@@ -11922,8 +11967,9 @@ def transpose(self) -> "SnowflakeQueryCompiler":
1192211967
"""
1192311968
frame = self._modin_frame
1192411969

11970+
original_col_count = len(frame.data_columns_index)
1192511971
# Handle case where the dataframe has empty columns.
11926-
if len(frame.data_columns_index) == 0:
11972+
if original_col_count == 0:
1192711973
return transpose_empty_df(frame)
1192811974

1192911975
# The following approach to implementing transpose relies on combining unpivot and pivot operations to flip
@@ -12061,6 +12107,7 @@ def transpose(self) -> "SnowflakeQueryCompiler":
1206112107
unpivot_result.variable_name_quoted_snowflake_identifier,
1206212108
unpivot_result.object_name_quoted_snowflake_identifier,
1206312109
)
12110+
new_internal_frame.ordered_dataframe.row_count = original_col_count
1206412111

1206512112
return SnowflakeQueryCompiler(new_internal_frame)
1206612113

src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@
5959
)
6060
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
6161
HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS,
62+
UnsupportedArgsRule,
63+
_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE,
64+
register_query_compiler_method_not_implemented,
65+
)
66+
from snowflake.snowpark.modin.plugin._internal.groupby_utils import (
67+
check_is_groupby_supported_by_snowflake,
6268
)
6369
from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike
6470
from snowflake.snowpark.modin.plugin.extensions.snow_partition_iterator import (
@@ -1549,6 +1555,22 @@ def fillna(
15491555

15501556
# Snowpark pandas defines a custom GroupBy object
15511557
@register_series_accessor("groupby")
1558+
@register_query_compiler_method_not_implemented(
1559+
"Series",
1560+
"groupby",
1561+
UnsupportedArgsRule(
1562+
unsupported_conditions=[
1563+
(
1564+
lambda args: not check_is_groupby_supported_by_snowflake(
1565+
args.get("by"),
1566+
args.get("level"),
1567+
args.get("axis", 0),
1568+
),
1569+
f"Groupby {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}",
1570+
)
1571+
]
1572+
),
1573+
)
15521574
def groupby(
15531575
self,
15541576
by=None,

tests/integ/modin/frame/test_aggregate.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def test_string_sum_with_nulls():
193193
with pytest.raises(TypeError):
194194
pandas_df.sum(numeric_only=False)
195195
snow_result = snow_df.sum(numeric_only=False)
196-
assert_series_equal(snow_result.to_pandas(), native_pd.Series(["ab"]))
196+
assert_series_equal(
197+
snow_result.to_pandas(), native_pd.Series(["ab"]), check_index_type=False
198+
)
197199

198200

199201
class TestTimedelta:
@@ -628,6 +630,16 @@ def test_agg_with_multiindex(native_df_multiindex, func, expected_union_count):
628630
eval_snowpark_pandas_result(snow_df, native_df_multiindex, func)
629631

630632

633+
def test_agg_with_one_column_multiindex(native_df_multiindex):
634+
# Triggers the special 1x1 transpose code path
635+
native_df_multiindex = native_df_multiindex.iloc[:, 0:1]
636+
snow_df = pd.DataFrame(native_df_multiindex)
637+
with SqlCounter(query_count=1):
638+
eval_snowpark_pandas_result(
639+
snow_df, native_df_multiindex, lambda df: df.agg("count")
640+
)
641+
642+
631643
@pytest.mark.parametrize(
632644
"func",
633645
[

tests/integ/modin/frame/test_iloc.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,21 +1077,9 @@ def iloc_helper(df):
10771077
else:
10781078
return native_pd.Series([]) if axis == "row" else df.iloc[:, []]
10791079

1080-
def determine_query_count():
1081-
# Multiple queries because of squeeze() - in range is 2, out-of-bounds is 1.
1082-
if axis == "col":
1083-
num_queries = 1
1084-
else:
1085-
if not -8 < key < 7: # key is out of bound
1086-
num_queries = 2
1087-
else:
1088-
num_queries = 1
1089-
return num_queries
1090-
1091-
query_count = determine_query_count()
10921080
# test df with default index
10931081
num_cols = 7
1094-
with SqlCounter(query_count=query_count):
1082+
with SqlCounter(query_count=1):
10951083
eval_snowpark_pandas_result(
10961084
default_index_snowpark_pandas_df,
10971085
default_index_native_df,
@@ -1101,21 +1089,20 @@ def determine_query_count():
11011089

11021090
# test df with non-default index
11031091
num_cols = 6 # set_index() makes the number of columns 6
1104-
with SqlCounter(query_count=query_count):
1092+
with SqlCounter(query_count=1):
11051093
eval_snowpark_pandas_result(
11061094
default_index_snowpark_pandas_df.set_index("D"),
11071095
default_index_native_df.set_index("D"),
11081096
iloc_helper,
11091097
test_attrs=False,
11101098
)
11111099

1112-
query_count = determine_query_count()
11131100
# test df with MultiIndex
11141101
# Index dtype is different between Snowpark and native pandas if key produces empty df.
11151102
num_cols = 7
11161103
native_df = default_index_native_df.set_index(multiindex_native)
11171104
snowpark_df = pd.DataFrame(native_df)
1118-
with SqlCounter(query_count=query_count):
1105+
with SqlCounter(query_count=1):
11191106
eval_snowpark_pandas_result(
11201107
snowpark_df,
11211108
native_df,
@@ -1129,7 +1116,7 @@ def determine_query_count():
11291116
native_df_with_multiindex_columns
11301117
)
11311118
in_range = True if (-8 < key < 7) else False
1132-
with SqlCounter(query_count=query_count):
1119+
with SqlCounter(query_count=1):
11331120
if axis == "row" or in_range: # series result
11341121
eval_snowpark_pandas_result(
11351122
snowpark_df_with_multiindex_columns,
@@ -1151,7 +1138,7 @@ def determine_query_count():
11511138
# test df with MultiIndex on both index and columns
11521139
native_df = native_df_with_multiindex_columns.set_index(multiindex_native)
11531140
snowpark_df = pd.DataFrame(native_df)
1154-
with SqlCounter(query_count=query_count):
1141+
with SqlCounter(query_count=1):
11551142
if axis == "row" or in_range: # series result
11561143
eval_snowpark_pandas_result(
11571144
snowpark_df,

tests/integ/modin/frame/test_squeeze.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,7 @@ def test_n_by_1(axis, dtype):
3131

3232
@pytest.mark.parametrize("dtype", ["int", "timedelta64[ns]"])
3333
def test_1_by_n(axis, dtype):
34-
if axis is None:
35-
expected_query_count = 2
36-
else:
37-
expected_query_count = 1
38-
with SqlCounter(query_count=expected_query_count):
34+
with SqlCounter(query_count=1):
3935
eval_snowpark_pandas_result(
4036
*create_test_dfs({"a": [1], "b": [2], "c": [3]}, dtype=dtype),
4137
lambda df: df.squeeze(axis=axis),

tests/integ/modin/groupby/test_groupby_default2pandas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def test_groupby_with_numpy_array(basic_snowpark_pandas_df) -> None:
130130
@sql_count_checker(query_count=0)
131131
def test_groupby_series_with_numpy_array(native_series_multi_numeric, by_list) -> None:
132132
with pytest.raises(
133-
NotImplementedError, match=GROUPBY_UNSUPPORTED_GROUPING_ERROR_PATTERN
133+
NotImplementedError, match=_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE
134134
):
135135
pd.Series(native_series_multi_numeric).groupby(by=by_list).max()
136136

tests/integ/modin/groupby/test_groupby_rolling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,18 @@ def test_groupby_rolling_dropna_false():
102102
)
103103

104104

105-
@sql_count_checker(query_count=1)
105+
@sql_count_checker(query_count=0)
106106
def test_groupby_rolling_series_negative():
107107
date_idx = pd.date_range("1/1/2000", periods=8, freq="min")
108108
date_idx.names = ["grp_col"]
109109
snow_ser = pd.Series([1, 1, np.nan, 2])
110110
with pytest.raises(
111111
NotImplementedError,
112112
match=re.escape(
113-
"Groupby does not yet support axis == 1, by != None and level != None, or by containing any non-pandas hashable labels"
113+
"Snowpark pandas does not yet support the method GroupBy.rolling for Series"
114114
),
115115
):
116-
snow_ser.groupby(snow_ser.index).rolling(2).sum()
116+
snow_ser.groupby(level=0).rolling(2).sum()
117117

118118

119119
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)