Skip to content

Commit 2437edc

Browse files
SNOW-243798: Add support for agg/min/max/count/sum/mean/median/std/var and set_columns in faster pandas (#3902)
1 parent 7aa9bcf commit 2437edc

File tree

4 files changed

+141
-3
lines changed

4 files changed

+141
-3
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@
122122
- `drop`
123123
- `invert`
124124
- `duplicated`
125+
- `columns` (e.g., df.columns = ["A", "B"])
126+
- `agg`
127+
- `min`
128+
- `max`
129+
- `count`
130+
- `sum`
131+
- `mean`
132+
- `median`
133+
- `std`
134+
- `var`
125135
- Reuse row count from the relaxed query compiler in `get_axis_len`.
126136

127137
#### Bug Fixes

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2314,6 +2314,20 @@ def cache_result(self) -> "SnowflakeQueryCompiler":
23142314

23152315
@snowpark_pandas_type_immutable_check
23162316
def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler":
2317+
"""
2318+
Wrapper around _set_columns_internal to be supported in faster pandas.
2319+
"""
2320+
relaxed_query_compiler = None
2321+
if self._relaxed_query_compiler is not None:
2322+
relaxed_query_compiler = self._relaxed_query_compiler._set_columns_internal(
2323+
new_pandas_labels=new_pandas_labels
2324+
)
2325+
qc = self._set_columns_internal(new_pandas_labels=new_pandas_labels)
2326+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
2327+
2328+
def _set_columns_internal(
2329+
self, new_pandas_labels: Axes
2330+
) -> "SnowflakeQueryCompiler":
23172331
"""
23182332
Set pandas column labels with the new column labels
23192333

@@ -7103,6 +7117,33 @@ def agg(
71037117
axis: int,
71047118
args: Any,
71057119
kwargs: dict[str, Any],
7120+
) -> "SnowflakeQueryCompiler":
7121+
"""
7122+
Wrapper around _agg_internal to be supported in faster pandas.
7123+
"""
7124+
relaxed_query_compiler = None
7125+
if self._relaxed_query_compiler is not None:
7126+
relaxed_query_compiler = self._relaxed_query_compiler._agg_internal(
7127+
func=func,
7128+
axis=axis,
7129+
args=args,
7130+
kwargs=kwargs,
7131+
)
7132+
qc = self._agg_internal(
7133+
func=func,
7134+
axis=axis,
7135+
args=args,
7136+
kwargs=kwargs,
7137+
)
7138+
qc = self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
7139+
return qc
7140+
7141+
def _agg_internal(
7142+
self,
7143+
func: AggFuncType,
7144+
axis: int,
7145+
args: Any,
7146+
kwargs: dict[str, Any],
71067147
) -> "SnowflakeQueryCompiler":
71077148
"""
71087149
Aggregate using one or more operations over the specified axis.

src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2642,7 +2642,7 @@ def seconds():
26422642
0 1
26432643
1 2
26442644
2 3
2645-
dtype: int64
2645+
dtype: int8
26462646
26472647
For TimedeltaIndex:
26482648
@@ -2702,7 +2702,7 @@ def microseconds():
27022702
0 1
27032703
1 2
27042704
2 3
2705-
dtype: int64
2705+
dtype: int8
27062706
27072707
For TimedeltaIndex:
27082708
@@ -2734,7 +2734,7 @@ def nanoseconds():
27342734
0 1
27352735
1 2
27362736
2 3
2737-
dtype: int64
2737+
dtype: int8
27382738
27392739
For TimedeltaIndex:
27402740

tests/integ/modin/test_faster_pandas.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import modin.pandas as pd
88
import pandas as native_pd
99
import pytest
10+
from pandas._testing import assert_almost_equal
1011

1112
from snowflake.snowpark._internal.utils import TempObjectType
1213
import snowflake.snowpark.modin.plugin # noqa: F401
@@ -194,6 +195,61 @@ def test_read_filter_join_flag_disabled(session):
194195
assert_frame_equal(snow_result, native_result)
195196

196197

198+
@pytest.mark.parametrize(
199+
"func",
200+
[
201+
"min",
202+
"max",
203+
"count",
204+
"sum",
205+
"mean",
206+
"median",
207+
"std",
208+
"var",
209+
],
210+
)
211+
@sql_count_checker(query_count=6)
212+
def test_agg(session, func):
213+
# create tables
214+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
215+
session.create_dataframe(
216+
native_pd.DataFrame([[2, 12], [1, 11], [3, 13]], columns=["A", "B"])
217+
).write.save_as_table(table_name, table_type="temp")
218+
219+
# create snow dataframes
220+
df = pd.read_snowflake(table_name)
221+
snow_result1 = getattr(df, func)()
222+
snow_result2 = df.agg([func])
223+
snow_result3 = getattr(df["B"], func)()
224+
snow_result4 = df["B"].agg([func])
225+
226+
# verify that the input dataframe has a populated relaxed query compiler
227+
assert df._query_compiler._relaxed_query_compiler is not None
228+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
229+
# verify that the output dataframe also has a populated relaxed query compiler
230+
assert snow_result1._query_compiler._relaxed_query_compiler is not None
231+
assert (
232+
snow_result1._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
233+
)
234+
assert snow_result2._query_compiler._relaxed_query_compiler is not None
235+
assert (
236+
snow_result2._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
237+
)
238+
239+
# create pandas dataframes
240+
native_df = df.to_pandas()
241+
native_result1 = getattr(native_df, func)()
242+
native_result2 = native_df.agg([func])
243+
native_result3 = getattr(native_df["B"], func)()
244+
native_result4 = native_df["B"].agg([func])
245+
246+
# compare results
247+
assert_series_equal(snow_result1, native_result1, check_dtype=False)
248+
assert_frame_equal(snow_result2, native_result2, check_dtype=False)
249+
assert_almost_equal(snow_result3, native_result3)
250+
assert_series_equal(snow_result4, native_result4, check_dtype=False)
251+
252+
197253
@sql_count_checker(query_count=3)
198254
def test_drop(session):
199255
# create tables
@@ -662,6 +718,37 @@ def test_set_2d_labels_from_different_df(session, input_df2):
662718
assert_frame_equal(snow_result, native_result)
663719

664720

721+
@sql_count_checker(query_count=3)
722+
def test_set_columns(session):
723+
# create tables
724+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
725+
session.create_dataframe(
726+
native_pd.DataFrame([[2, 12], [1, 11], [3, 13]], columns=["A", "B"])
727+
).write.save_as_table(table_name, table_type="temp")
728+
729+
# create snow dataframes
730+
df = pd.read_snowflake(table_name)
731+
snow_result = df
732+
snow_result.columns = ["X", "Y"]
733+
734+
# verify that the input dataframe has a populated relaxed query compiler
735+
assert df._query_compiler._relaxed_query_compiler is not None
736+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
737+
# verify that the output dataframe also has a populated relaxed query compiler
738+
assert snow_result._query_compiler._relaxed_query_compiler is not None
739+
assert (
740+
snow_result._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
741+
)
742+
743+
# create pandas dataframes
744+
native_df = df.to_pandas()
745+
native_result = native_df
746+
native_result.columns = ["X", "Y"]
747+
748+
# compare results
749+
assert_frame_equal(snow_result, native_result)
750+
751+
665752
@sql_count_checker(query_count=3)
666753
def test_dataframe_to_datetime(session):
667754
# create tables

0 commit comments

Comments
 (0)