Skip to content

Commit 12e1f1b

Browse files
SNOW-243798: Add support for agg/min/max/count/sum/mean/median/std/var in faster pandas
1 parent 04218cc commit 12e1f1b

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2314,6 +2314,20 @@ def cache_result(self) -> "SnowflakeQueryCompiler":
23142314

23152315
@snowpark_pandas_type_immutable_check
23162316
def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler":
2317+
"""
2318+
Wrapper around _set_columns_internal to be supported in faster pandas.
2319+
"""
2320+
relaxed_query_compiler = None
2321+
if self._relaxed_query_compiler is not None:
2322+
relaxed_query_compiler = self._relaxed_query_compiler._set_columns_internal(
2323+
new_pandas_labels=new_pandas_labels
2324+
)
2325+
qc = self._set_columns_internal(new_pandas_labels=new_pandas_labels)
2326+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
2327+
2328+
def _set_columns_internal(
2329+
self, new_pandas_labels: Axes
2330+
) -> "SnowflakeQueryCompiler":
23172331
"""
23182332
Set pandas column labels with the new column labels
23192333

@@ -7103,6 +7117,33 @@ def agg(
71037117
axis: int,
71047118
args: Any,
71057119
kwargs: dict[str, Any],
7120+
) -> "SnowflakeQueryCompiler":
7121+
"""
7122+
Wrapper around _agg_internal to be supported in faster pandas.
7123+
"""
7124+
relaxed_query_compiler = None
7125+
if self._relaxed_query_compiler is not None:
7126+
relaxed_query_compiler = self._relaxed_query_compiler._agg_internal(
7127+
func=func,
7128+
axis=axis,
7129+
args=args,
7130+
kwargs=kwargs,
7131+
)
7132+
qc = self._agg_internal(
7133+
func=func,
7134+
axis=axis,
7135+
args=args,
7136+
kwargs=kwargs,
7137+
)
7138+
qc = self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
7139+
return qc
7140+
7141+
def _agg_internal(
7142+
self,
7143+
func: AggFuncType,
7144+
axis: int,
7145+
args: Any,
7146+
kwargs: dict[str, Any],
71067147
) -> "SnowflakeQueryCompiler":
71077148
"""
71087149
Aggregate using one or more operations over the specified axis.

tests/integ/modin/test_faster_pandas.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,61 @@ def test_read_filter_join_flag_disabled(session):
194194
assert_frame_equal(snow_result, native_result)
195195

196196

197+
@pytest.mark.parametrize(
198+
"func",
199+
[
200+
"min",
201+
"max",
202+
"count",
203+
"sum",
204+
"mean",
205+
"median",
206+
"std",
207+
"var",
208+
],
209+
)
210+
@sql_count_checker(query_count=6)
211+
def test_agg(session, func):
212+
# create tables
213+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
214+
session.create_dataframe(
215+
native_pd.DataFrame([[2, 12], [1, 11], [3, 13]], columns=["A", "B"])
216+
).write.save_as_table(table_name, table_type="temp")
217+
218+
# create snow dataframes
219+
df = pd.read_snowflake(table_name)
220+
snow_result1 = getattr(df, func)()
221+
snow_result2 = df.agg([func])
222+
snow_result3 = getattr(df["B"], func)()
223+
snow_result4 = df["B"].agg([func])
224+
225+
# verify that the input dataframe has a populated relaxed query compiler
226+
assert df._query_compiler._relaxed_query_compiler is not None
227+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
228+
# verify that the output dataframe also has a populated relaxed query compiler
229+
assert snow_result1._query_compiler._relaxed_query_compiler is not None
230+
assert (
231+
snow_result1._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
232+
)
233+
assert snow_result2._query_compiler._relaxed_query_compiler is not None
234+
assert (
235+
snow_result2._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
236+
)
237+
238+
# create pandas dataframes
239+
native_df = df.to_pandas()
240+
native_result1 = getattr(native_df, func)()
241+
native_result2 = native_df.agg([func])
242+
native_result3 = getattr(native_df["B"], func)()
243+
native_result4 = native_df["B"].agg([func])
244+
245+
# compare results
246+
assert_series_equal(snow_result1, native_result1, check_dtype=False)
247+
assert_frame_equal(snow_result2, native_result2, check_dtype=False)
248+
assert snow_result3 == native_result3
249+
assert_series_equal(snow_result4, native_result4, check_dtype=False)
250+
251+
197252
@sql_count_checker(query_count=3)
198253
def test_drop(session):
199254
# create tables
@@ -604,6 +659,37 @@ def test_set_2d_labels_from_different_df(session, input_df2):
604659
assert_frame_equal(snow_result, native_result)
605660

606661

662+
@sql_count_checker(query_count=3)
663+
def test_set_columns(session):
664+
# create tables
665+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
666+
session.create_dataframe(
667+
native_pd.DataFrame([[2, 12], [1, 11], [3, 13]], columns=["A", "B"])
668+
).write.save_as_table(table_name, table_type="temp")
669+
670+
# create snow dataframes
671+
df = pd.read_snowflake(table_name)
672+
snow_result = df
673+
snow_result.columns = ["X", "Y"]
674+
675+
# verify that the input dataframe has a populated relaxed query compiler
676+
assert df._query_compiler._relaxed_query_compiler is not None
677+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
678+
# verify that the output dataframe also has a populated relaxed query compiler
679+
assert snow_result._query_compiler._relaxed_query_compiler is not None
680+
assert (
681+
snow_result._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
682+
)
683+
684+
# create pandas dataframes
685+
native_df = df.to_pandas()
686+
native_result = native_df
687+
native_result.columns = ["X", "Y"]
688+
689+
# compare results
690+
assert_frame_equal(snow_result, native_result)
691+
692+
607693
@sql_count_checker(query_count=3)
608694
def test_dataframe_to_datetime(session):
609695
# create tables

0 commit comments

Comments
 (0)