Skip to content

Commit 1d779c2

Browse files
SNOW-2435494: Add support for iloc and head in faster pandas (#3910)
1 parent 2437edc commit 1d779c2

File tree

3 files changed

+75
-0
lines changed

3 files changed

+75
-0
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@
122122
- `drop`
123123
- `invert`
124124
- `duplicated`
125+
- `iloc`
126+
- `head`
125127
- `columns` (e.g., df.columns = ["A", "B"])
126128
- `agg`
127129
- `min`

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10762,6 +10762,36 @@ def take_2d_positional(
1076210762
self,
1076310763
index: Union["SnowflakeQueryCompiler", slice],
1076410764
columns: Union["SnowflakeQueryCompiler", slice, int, bool, list, AnyArrayLike],
10765+
) -> "SnowflakeQueryCompiler":
10766+
"""
10767+
Wrapper around _take_2d_positional_internal to be supported in faster pandas.
10768+
"""
10769+
relaxed_query_compiler = None
10770+
if self._relaxed_query_compiler is not None and (
10771+
index == slice(None, None, None)
10772+
or (
10773+
isinstance(index, slice)
10774+
and (index.start is None or index.start == 0)
10775+
and (index.step is None or index.step == 1)
10776+
)
10777+
):
10778+
relaxed_query_compiler = (
10779+
self._relaxed_query_compiler._take_2d_positional_internal(
10780+
index=index,
10781+
columns=columns,
10782+
)
10783+
)
10784+
10785+
qc = self._take_2d_positional_internal(
10786+
index=index,
10787+
columns=columns,
10788+
)
10789+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
10790+
10791+
def _take_2d_positional_internal(
10792+
self,
10793+
index: Union["SnowflakeQueryCompiler", slice],
10794+
columns: Union["SnowflakeQueryCompiler", slice, int, bool, list, AnyArrayLike],
1076510795
) -> "SnowflakeQueryCompiler":
1076610796
"""
1076710797
Index QueryCompiler with passed keys.

tests/integ/modin/test_faster_pandas.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,49 @@ def test_duplicated(session):
308308
assert_series_equal(snow_result, native_result)
309309

310310

311+
@sql_count_checker(query_count=5)
312+
def test_iloc_head(session):
313+
# create tables
314+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
315+
session.create_dataframe(
316+
native_pd.DataFrame([[1, 11], [2, 12], [3, 13]], columns=["A", "B"])
317+
).write.save_as_table(table_name, table_type="temp")
318+
319+
# create snow dataframes
320+
df = pd.read_snowflake(table_name)
321+
snow_result1 = df.iloc[:, [1]]
322+
snow_result2 = df.iloc[0:2:1, [1]]
323+
snow_result3 = df.head()
324+
325+
# verify that the input dataframe has a populated relaxed query compiler
326+
assert df._query_compiler._relaxed_query_compiler is not None
327+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
328+
# verify that the output dataframe also has a populated relaxed query compiler
329+
assert snow_result1._query_compiler._relaxed_query_compiler is not None
330+
assert (
331+
snow_result1._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
332+
)
333+
assert snow_result2._query_compiler._relaxed_query_compiler is not None
334+
assert (
335+
snow_result2._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
336+
)
337+
assert snow_result3._query_compiler._relaxed_query_compiler is not None
338+
assert (
339+
snow_result3._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
340+
)
341+
342+
# create pandas dataframes
343+
native_df = df.to_pandas()
344+
native_result1 = native_df.iloc[:, [1]]
345+
native_result2 = native_df.iloc[0:2:1, [1]]
346+
native_result3 = native_df.head()
347+
348+
# compare results
349+
assert_frame_equal(snow_result1, native_result1)
350+
assert_frame_equal(snow_result2, native_result2)
351+
assert_frame_equal(snow_result3, native_result3)
352+
353+
311354
@sql_count_checker(query_count=3)
312355
def test_invert(session):
313356
# create tables

0 commit comments

Comments
 (0)