Skip to content

Commit 6def605

Browse files
SNOW-2432963: Add support for duplicated in faster pandas
1 parent 04218cc commit 6def605

File tree

3 files changed

+50
-0
lines changed

3 files changed

+50
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
- `loc` (setting columns)
113113
- `to_datetime`
114114
- `drop`
115+
- `duplicated`
115116
- Reuse row count from the relaxed query compiler in `get_axis_len`.
116117

117118
#### Bug Fixes

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17024,6 +17024,26 @@ def duplicated(
1702417024
self,
1702517025
subset: Union[Hashable, Sequence[Hashable]] = None,
1702617026
keep: DropKeep = "first",
17027+
) -> "SnowflakeQueryCompiler":
17028+
"""
17029+
Wrapper around _duplicated_internal to be supported in faster pandas.
17030+
"""
17031+
relaxed_query_compiler = None
17032+
if self._relaxed_query_compiler is not None:
17033+
relaxed_query_compiler = self._relaxed_query_compiler._duplicated_internal(
17034+
subset=subset,
17035+
keep=keep,
17036+
)
17037+
qc = self._duplicated_internal(
17038+
subset=subset,
17039+
keep=keep,
17040+
)
17041+
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
17042+
17043+
def _duplicated_internal(
17044+
self,
17045+
subset: Union[Hashable, Sequence[Hashable]] = None,
17046+
keep: DropKeep = "first",
1702717047
) -> "SnowflakeQueryCompiler":
1702817048
"""
1702917049
Return boolean Series denoting duplicate rows.

tests/integ/modin/test_faster_pandas.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,35 @@ def test_drop(session):
223223
assert_frame_equal(snow_result, native_result)
224224

225225

226+
@sql_count_checker(query_count=3, join_count=1)
227+
def test_duplicated(session, func):
228+
# create tables
229+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
230+
session.create_dataframe(
231+
native_pd.DataFrame([[2, 12], [2, 12], [3, 13]], columns=["A", "B"])
232+
).write.save_as_table(table_name, table_type="temp")
233+
234+
# create snow dataframes
235+
df = pd.read_snowflake(table_name)
236+
snow_result = df.duplicated()
237+
238+
# verify that the input dataframe has a populated relaxed query compiler
239+
assert df._query_compiler._relaxed_query_compiler is not None
240+
assert df._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
241+
# verify that the output dataframe also has a populated relaxed query compiler
242+
assert snow_result._query_compiler._relaxed_query_compiler is not None
243+
assert (
244+
snow_result._query_compiler._relaxed_query_compiler._dummy_row_pos_mode is True
245+
)
246+
247+
# create pandas dataframes
248+
native_df = df.to_pandas()
249+
native_result = native_df.duplicated()
250+
251+
# compare results
252+
assert_series_equal(snow_result, native_result)
253+
254+
226255
@pytest.mark.parametrize("func", ["isna", "isnull", "notna", "notnull"])
227256
@sql_count_checker(query_count=3)
228257
def test_isna_notna(session, func):

0 commit comments

Comments
 (0)