Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

#### Improvements

- Enhanced autoswitching functionality from Snowflake to native Pandas for methods with unsupported argument combinations:
- Enhanced autoswitching functionality from Snowflake to native pandas for methods with unsupported argument combinations:
- `shift()` with `suffix` or non-integer `periods` parameters
- `sort_index()` with `axis=1` or `key` parameters
- `sort_values()` with `axis=1`
Expand All @@ -42,6 +42,7 @@

- Fixed a bug in `DataFrameGroupBy.agg` where func is a list of tuples used to set the names of the output columns.
- Fixed a bug where converting a modin datetime index with a timezone to a numpy array with `np.asarray` would cause a `TypeError`.
- Fixed a bug where `Series.isin` with a Series argument matched index labels instead of the row position.

#### Improvements

Expand Down
54 changes: 48 additions & 6 deletions src/snowflake/snowpark/modin/plugin/_internal/isin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,21 @@
SnowparkPandasType,
)
from snowflake.snowpark.modin.plugin._internal.type_utils import infer_series_type
from snowflake.snowpark.modin.plugin._internal.join_utils import join
from snowflake.snowpark.modin.plugin._internal.utils import (
append_columns,
generate_new_labels,
is_duplicate_free,
pandas_lit,
)
from snowflake.snowpark.modin.plugin._typing import ListLike
from snowflake.snowpark.types import DataType, DoubleType, VariantType, _IntegralType
from snowflake.snowpark.types import (
DataType,
DoubleType,
VariantType,
_IntegralType,
BooleanType,
)


def convert_values_to_list_of_literals_and_return_type(
Expand Down Expand Up @@ -120,6 +127,7 @@ def scalar_isin_expression(
def compute_isin_with_series(
frame: InternalFrame,
values_series: InternalFrame,
lhs_is_series: bool,
dummy_row_pos_mode: bool,
) -> InternalFrame:
"""
Expand All @@ -135,6 +143,36 @@ def compute_isin_with_series(
Returns:
InternalFrame
"""
# local import to avoid circular import
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
SnowflakeQueryCompiler,
)

if lhs_is_series:
# If the LHS is a Series, directly compute distinct elements of the RHS, which will be used as
# the argument to ARRAY_CONTAINS for every element in the LHS. The only necessary join is
# between the original data column and the 1-element aggregated array column.
agg_label = generate_new_labels(
pandas_labels=["agg"], excluded=values_series.data_column_pandas_labels
)[0]
distinct_frame = (
SnowflakeQueryCompiler(values_series)
.agg("array_agg", 0, [], {})
._modin_frame
)
joined_frame = join(
frame, distinct_frame, how="inner", dummy_row_pos_mode=dummy_row_pos_mode
)[0]
assert len(joined_frame.data_column_snowflake_quoted_identifiers) == 2
return joined_frame.project_columns(
frame.data_column_pandas_labels,
column_objects=array_contains(
joined_frame.data_column_snowflake_quoted_identifiers[0],
joined_frame.data_column_snowflake_quoted_identifiers[1],
),
column_types=[SnowparkPandasType.to_pandas(BooleanType)],
)

# For each row in this dataframe
# align the index with the index of the values Series object.
# If it matches, return True, else False
Expand Down Expand Up @@ -178,18 +216,14 @@ def compute_isin_with_series(
}
).frame

# local import to avoid circular import
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
SnowflakeQueryCompiler,
)

# return internal frame but remove temporary agg column.
return SnowflakeQueryCompiler(new_frame).drop(columns=[agg_label])._modin_frame


def compute_isin_with_dataframe(
frame: InternalFrame,
values_frame: InternalFrame,
lhs_is_series: bool,
dummy_row_pos_mode: bool,
) -> InternalFrame:
"""
Expand All @@ -205,6 +239,14 @@ def compute_isin_with_dataframe(
Returns:
InternalFrame
"""
if lhs_is_series:
# a series-DF isin operation always returns False at all positions
return frame.update_snowflake_quoted_identifiers_with_expressions(
{
quoted_identifier: pandas_lit(False)
for quoted_identifier in frame.data_column_snowflake_quoted_identifiers
}
)[0]
# similar logic to series, however do not create a single column but multiple colunms
# set values via set_frame_2d_labels then

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14507,6 +14507,7 @@ def isin(
values: Union[
list[Any], np.ndarray, "SnowflakeQueryCompiler", dict[Hashable, ListLike]
],
self_is_series: bool = False,
) -> "SnowflakeQueryCompiler":
"""
Wrapper around _isin_internal to be supported in faster pandas.
Expand All @@ -14521,34 +14522,38 @@ def isin(
assert values._relaxed_query_compiler is not None
new_values = values._relaxed_query_compiler
relaxed_query_compiler = self._relaxed_query_compiler._isin_internal(
values=new_values
values=new_values,
self_is_series=self_is_series,
)

qc = self._isin_internal(values=values)
qc = self._isin_internal(values=values, self_is_series=self_is_series)
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)

def _isin_internal(
self,
values: Union[
list[Any], np.ndarray, "SnowflakeQueryCompiler", dict[Hashable, ListLike]
],
self_is_series: bool = False,
) -> "SnowflakeQueryCompiler": # noqa: PR02
"""
Check for each element of `self` whether it's contained in passed `values`.

Parameters
----------
values : list-like, np.array, SnowflakeQueryCompiler or dict of pandas labels -> listlike
Values to check elements of self in. If given as dict, match ListLike to column label given as key.
**kwargs : dict
Serves the compatibility purpose. Does not affect the result.

Returns
-------
SnowflakeQueryCompiler
Boolean mask for self of whether an element at the corresponding
position is contained in `values`.
"""
is_snowflake_query_compiler = isinstance(values, SnowflakeQueryCompiler) # type: ignore[union-attr]
is_series = is_snowflake_query_compiler and values.is_series_like() # type: ignore[union-attr]
is_rhs_series = is_snowflake_query_compiler and values.is_series_like() # type: ignore[union-attr]

# convert list-like values to [lit(...), ..., lit(...)] and determine type
# which is required to produce correct isin expression using array_contains(...) below
Expand Down Expand Up @@ -14621,13 +14626,19 @@ def _isin_internal(
# idempotent operation
return self

if is_series:
if is_rhs_series:
new_frame = compute_isin_with_series(
self._modin_frame, values._modin_frame, self._dummy_row_pos_mode
self._modin_frame,
values._modin_frame,
lhs_is_series=self_is_series,
dummy_row_pos_mode=self._dummy_row_pos_mode,
)
else:
new_frame = compute_isin_with_dataframe(
self._modin_frame, values._modin_frame, self._dummy_row_pos_mode
self._modin_frame,
values._modin_frame,
lhs_is_series=self_is_series,
dummy_row_pos_mode=self._dummy_row_pos_mode,
)

return SnowflakeQueryCompiler(new_frame)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,9 @@ def fillna(
# Snowpark pandas passes the query compiler object from a BasePandasDataset, which Modin does not do.
@register_base_override("isin")
def isin(
self, values: BasePandasDataset | ListLike | dict[Hashable, ListLike]
self,
values: BasePandasDataset | ListLike | dict[Hashable, ListLike],
self_is_series: bool = False,
) -> BasePandasDataset: # noqa: PR01, RT01, D200
"""
Whether elements in `BasePandasDataset` are contained in `values`.
Expand All @@ -1056,7 +1058,11 @@ def isin(
):
values = list(values)

return self.__constructor__(query_compiler=self._query_compiler.isin(values=values))
return self.__constructor__(
query_compiler=self._query_compiler.isin(
values=values, self_is_series=self_is_series
)
)


# Snowpark pandas uses the single `quantiles_along_axis0` query compiler method, while upstream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def isin(self, values: set | ListLike) -> Series:
if isinstance(values, set):
values = list(values)

return super(Series, self).isin(values)
return super(Series, self).isin(values, self_is_series=True)


# Snowpark pandas raises a warning before materializing data and passing to `plot`.
Expand Down
39 changes: 39 additions & 0 deletions tests/integ/modin/series/test_isin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
eval_snowpark_pandas_result,
try_cast_to_snowpark_pandas_dataframe,
try_cast_to_snowpark_pandas_series,
create_test_dfs,
create_test_series,
assert_snowpark_pandas_equal_to_pandas,
)
from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker

Expand Down Expand Up @@ -188,3 +191,39 @@ def test_isin_with_str_negative():
),
):
s.isin("test")


# Covers an edge case in SNOW-1524760 where `isin` is called with a RHS Series that has an index
# differing from that of the LHS.
# If the LHS is a DataFrame:
# - If the RHS argument is a DataFrame (even a 1-column one), the frames are joined on both row and
# column labels.
# - If the RHS is a Series, the frames are joined only on row labels.
# If the LHS is a Series:
# - If the RHS is a DataFrame, always returns False.
# - If the RHS is a Series, ignore both row and column labels, and join positionally.
# Note that since this test creates the LHS Series by indexing a DataFrame column, the resulting
# series will have a name.
def test_isin_ignores_index():
snow_rhs, native_rhs = create_test_series([4, 10], index=[99, 100])
snow_df, native_df = create_test_dfs({"A": [1, 2, 3], "B": [4, 5, 6]})
with SqlCounter(query_count=1):
# df-series operation
assert_snowpark_pandas_equal_to_pandas(
snow_df.isin(snow_rhs),
native_df.isin(native_rhs),
)
with SqlCounter(query_count=1):
# series-series operation (issue in the JIRA ticket)
assert_snowpark_pandas_equal_to_pandas(
snow_df["B"].isin(snow_rhs),
native_df["B"].isin(native_rhs),
)


@sql_count_checker(query_count=1)
def test_isin_series_length_mismatch():
rhs = native_pd.Series([1, 0])
eval_snowpark_pandas_result(
*create_test_series([0, 1, 1, 2, 1]), lambda s: s.isin(rhs)
)
2 changes: 1 addition & 1 deletion tests/integ/modin/test_faster_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def test_isin_list(session):
assert_frame_equal(snow_result, native_result, check_dtype=False)


@sql_count_checker(query_count=3)
@sql_count_checker(query_count=3, join_count=2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little surprised this changed; because you didn't need to change the join counts on other isin tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the query generated in this test got significantly more complicated because of the added ARRAY_AGG operation, but I believe the previous version was incorrect for a lot of cases. The actual query text of the other isin tests have joins as well, but I guess they're not being parsed correctly by the SQL counter.

def test_isin_series(session):
with session_parameter_override(
session, "dummy_row_pos_optimization_enabled", True
Expand Down