Skip to content

Commit 81309c5

Browse files
Merge branch 'main' into helmeleegy-SNOW-2435290
2 parents f9a6075 + 1d779c2 commit 81309c5

File tree

3 files changed

+75
-0
lines changed

3 files changed

+75
-0
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@
122122
- `drop`
123123
- `invert`
124124
- `duplicated`
125+
- `iloc`
126+
- `head`
125127
- `columns` (e.g., df.columns = ["A", "B"])
126128
- `agg`
127129
- `min`

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10806,6 +10806,36 @@ def take_2d_positional(
1080610806
self,
1080710807
index: Union["SnowflakeQueryCompiler", slice],
1080810808
columns: Union["SnowflakeQueryCompiler", slice, int, bool, list, AnyArrayLike],
10809+
) -> "SnowflakeQueryCompiler":
10810+
"""
10811+
Wrapper around _take_2d_positional_internal to be supported in faster pandas.
10812+
"""
10813+
relaxed_query_compiler = None
10814+
if self._relaxed_query_compiler is not None and (
10815+
index == slice(None, None, None)
10816+
or (
10817+
isinstance(index, slice)
10818+
and (index.start is None or index.start == 0)
10819+
and (index.step is None or index.step == 1)
10820+
)
10821+
):
10822+
relaxed_query_compiler = (
10823+
self._relaxed_query_compiler._take_2d_positional_internal(
10824+
index=index,
10825+
columns=columns,
10826+
)
10827+
)
10828+
10829+
qc = self._take_2d_positional_internal(
10830+
index=index,
10831+
columns=columns,
10832+
)
10833+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
10834+
10835+
def _take_2d_positional_internal(
10836+
self,
10837+
index: Union["SnowflakeQueryCompiler", slice],
10838+
columns: Union["SnowflakeQueryCompiler", slice, int, bool, list, AnyArrayLike],
1080910839
) -> "SnowflakeQueryCompiler":
1081010840
"""
1081110841
Index QueryCompiler with passed keys.

tests/integ/modin/test_faster_pandas.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,49 @@ def test_groupby_agg(session, func):
363363
assert_frame_equal(snow_result4, native_result4, check_dtype=False)
364364

365365

366+
@sql_count_checker(query_count=5)
367+
def test_iloc_head(session):
368+
# create tables
369+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
370+
session.create_dataframe(
371+
native_pd.DataFrame([[1, 11], [2, 12], [3, 13]], columns=["A", "B"])
372+
).write.save_as_table(table_name, table_type="temp")
373+
374+
# create snow dataframes
375+
df = pd.read_snowflake(table_name)
376+
snow_result1 = df.iloc[:, [1]]
377+
snow_result2 = df.iloc[0:2:1, [1]]
378+
snow_result3 = df.head()
379+
380+
# verify that the input dataframe has a populated relaxed query compiler
381+
assert df._query_compiler._relaxed_query_compiler is not None
382+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
383+
# verify that the output dataframe also has a populated relaxed query compiler
384+
assert snow_result1._query_compiler._relaxed_query_compiler is not None
385+
assert (
386+
snow_result1._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
387+
)
388+
assert snow_result2._query_compiler._relaxed_query_compiler is not None
389+
assert (
390+
snow_result2._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
391+
)
392+
assert snow_result3._query_compiler._relaxed_query_compiler is not None
393+
assert (
394+
snow_result3._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
395+
)
396+
397+
# create pandas dataframes
398+
native_df = df.to_pandas()
399+
native_result1 = native_df.iloc[:, [1]]
400+
native_result2 = native_df.iloc[0:2:1, [1]]
401+
native_result3 = native_df.head()
402+
403+
# compare results
404+
assert_frame_equal(snow_result1, native_result1)
405+
assert_frame_equal(snow_result2, native_result2)
406+
assert_frame_equal(snow_result3, native_result3)
407+
408+
366409
@sql_count_checker(query_count=3)
367410
def test_invert(session):
368411
# create tables

0 commit comments

Comments
 (0)