Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

#### Improvements

- Enhanced autoswitching functionality from Snowflake to native Pandas for methods with unsupported argument combinations:
- Enhanced autoswitching functionality from Snowflake to native pandas for methods with unsupported argument combinations:
- `shift()` with `suffix` or non-integer `periods` parameters
- `sort_index()` with `axis=1` or `key` parameters
- `sort_values()` with `axis=1`
Expand All @@ -42,6 +42,7 @@

- Fixed a bug in `DataFrameGroupBy.agg` where func is a list of tuples used to set the names of the output columns.
- Fixed a bug where converting a modin datetime index with a timezone to a numpy array with `np.asarray` would cause a `TypeError`.
- Fixed a bug where `Series.isin` with a Series argument matched index labels instead of the row position.

#### Improvements

Expand Down
54 changes: 48 additions & 6 deletions src/snowflake/snowpark/modin/plugin/_internal/isin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,21 @@
SnowparkPandasType,
)
from snowflake.snowpark.modin.plugin._internal.type_utils import infer_series_type
from snowflake.snowpark.modin.plugin._internal.join_utils import join
from snowflake.snowpark.modin.plugin._internal.utils import (
append_columns,
generate_new_labels,
is_duplicate_free,
pandas_lit,
)
from snowflake.snowpark.modin.plugin._typing import ListLike
from snowflake.snowpark.types import DataType, DoubleType, VariantType, _IntegralType
from snowflake.snowpark.types import (
DataType,
DoubleType,
VariantType,
_IntegralType,
BooleanType,
)


def convert_values_to_list_of_literals_and_return_type(
Expand Down Expand Up @@ -120,6 +127,7 @@ def scalar_isin_expression(
def compute_isin_with_series(
frame: InternalFrame,
values_series: InternalFrame,
lhs_is_series: bool,
dummy_row_pos_mode: bool,
) -> InternalFrame:
"""
Expand All @@ -135,6 +143,36 @@ def compute_isin_with_series(
Returns:
InternalFrame
"""
# local import to avoid circular import
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
SnowflakeQueryCompiler,
)

if lhs_is_series:
# If the LHS is a Series, directly compute distinct elements of the RHS, which will be used as
# the argument to ARRAY_CONTAINS for every element in the LHS. The only necessary join is
# between the original data column and the 1-element aggregated array column.
agg_label = generate_new_labels(
pandas_labels=["agg"], excluded=values_series.data_column_pandas_labels
)[0]
distinct_frame = (
SnowflakeQueryCompiler(values_series)
.agg("array_agg", 0, [], {})
._modin_frame
)
joined_frame = join(
frame, distinct_frame, how="inner", dummy_row_pos_mode=dummy_row_pos_mode
)[0]
assert len(joined_frame.data_column_snowflake_quoted_identifiers) == 2
return joined_frame.project_columns(
frame.data_column_pandas_labels,
column_objects=array_contains(
joined_frame.data_column_snowflake_quoted_identifiers[0],
joined_frame.data_column_snowflake_quoted_identifiers[1],
),
column_types=[SnowparkPandasType.to_pandas(BooleanType)],
)

# For each row in this dataframe
# align the index with the index of the values Series object.
# If it matches, return True, else False
Expand Down Expand Up @@ -178,18 +216,14 @@ def compute_isin_with_series(
}
).frame

# local import to avoid circular import
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
SnowflakeQueryCompiler,
)

# return internal frame but remove temporary agg column.
return SnowflakeQueryCompiler(new_frame).drop(columns=[agg_label])._modin_frame


def compute_isin_with_dataframe(
frame: InternalFrame,
values_frame: InternalFrame,
lhs_is_series: bool,
dummy_row_pos_mode: bool,
) -> InternalFrame:
"""
Expand All @@ -205,6 +239,14 @@ def compute_isin_with_dataframe(
Returns:
InternalFrame
"""
if lhs_is_series:
# a series-DF isin operation always returns False at all positions
return frame.update_snowflake_quoted_identifiers_with_expressions(
{
quoted_identifier: pandas_lit(False)
for quoted_identifier in frame.data_column_snowflake_quoted_identifiers
}
)[0]
# similar logic to series, however do not create a single column but multiple colunms
# set values via set_frame_2d_labels then

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14507,6 +14507,7 @@ def isin(
values: Union[
list[Any], np.ndarray, "SnowflakeQueryCompiler", dict[Hashable, ListLike]
],
self_is_series: bool = False,
) -> "SnowflakeQueryCompiler":
"""
Wrapper around _isin_internal to be supported in faster pandas.
Expand All @@ -14521,34 +14522,38 @@ def isin(
assert values._relaxed_query_compiler is not None
new_values = values._relaxed_query_compiler
relaxed_query_compiler = self._relaxed_query_compiler._isin_internal(
values=new_values
values=new_values,
self_is_series=self_is_series,
)

qc = self._isin_internal(values=values)
qc = self._isin_internal(values=values, self_is_series=self_is_series)
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)

def _isin_internal(
self,
values: Union[
list[Any], np.ndarray, "SnowflakeQueryCompiler", dict[Hashable, ListLike]
],
self_is_series: bool = False,
) -> "SnowflakeQueryCompiler": # noqa: PR02
"""
Check for each element of `self` whether it's contained in passed `values`.

Parameters
----------
values : list-like, np.array, SnowflakeQueryCompiler or dict of pandas labels -> listlike
Values to check elements of self in. If given as dict, match ListLike to column label given as key.
**kwargs : dict
Serves the compatibility purpose. Does not affect the result.

Returns
-------
SnowflakeQueryCompiler
Boolean mask for self of whether an element at the corresponding
position is contained in `values`.
"""
is_snowflake_query_compiler = isinstance(values, SnowflakeQueryCompiler) # type: ignore[union-attr]
is_series = is_snowflake_query_compiler and values.is_series_like() # type: ignore[union-attr]
is_rhs_series = is_snowflake_query_compiler and values.is_series_like() # type: ignore[union-attr]

# convert list-like values to [lit(...), ..., lit(...)] and determine type
# which is required to produce correct isin expression using array_contains(...) below
Expand Down Expand Up @@ -14621,13 +14626,19 @@ def _isin_internal(
# idempotent operation
return self

if is_series:
if is_rhs_series:
new_frame = compute_isin_with_series(
self._modin_frame, values._modin_frame, self._dummy_row_pos_mode
self._modin_frame,
values._modin_frame,
lhs_is_series=self_is_series,
dummy_row_pos_mode=self._dummy_row_pos_mode,
)
else:
new_frame = compute_isin_with_dataframe(
self._modin_frame, values._modin_frame, self._dummy_row_pos_mode
self._modin_frame,
values._modin_frame,
lhs_is_series=self_is_series,
dummy_row_pos_mode=self._dummy_row_pos_mode,
)

return SnowflakeQueryCompiler(new_frame)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,9 @@ def fillna(
# Snowpark pandas passes the query compiler object from a BasePandasDataset, which Modin does not do.
@register_base_override("isin")
def isin(
self, values: BasePandasDataset | ListLike | dict[Hashable, ListLike]
self,
values: BasePandasDataset | ListLike | dict[Hashable, ListLike],
self_is_series: bool = False,
) -> BasePandasDataset: # noqa: PR01, RT01, D200
"""
Whether elements in `BasePandasDataset` are contained in `values`.
Expand All @@ -1056,7 +1058,11 @@ def isin(
):
values = list(values)

return self.__constructor__(query_compiler=self._query_compiler.isin(values=values))
return self.__constructor__(
query_compiler=self._query_compiler.isin(
values=values, self_is_series=self_is_series
)
)


# Snowpark pandas uses the single `quantiles_along_axis0` query compiler method, while upstream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def isin(self, values: set | ListLike) -> Series:
if isinstance(values, set):
values = list(values)

return super(Series, self).isin(values)
return super(Series, self).isin(values, self_is_series=True)


# Snowpark pandas raises a warning before materializing data and passing to `plot`.
Expand Down
39 changes: 39 additions & 0 deletions tests/integ/modin/series/test_isin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
eval_snowpark_pandas_result,
try_cast_to_snowpark_pandas_dataframe,
try_cast_to_snowpark_pandas_series,
create_test_dfs,
create_test_series,
assert_snowpark_pandas_equal_to_pandas,
)
from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker

Expand Down Expand Up @@ -188,3 +191,39 @@ def test_isin_with_str_negative():
),
):
s.isin("test")


# Covers an edge case in SNOW-1524760 where `isin` is called with a RHS Series that has an index
# differing from that of the LHS.
# If the LHS is a DataFrame:
# - If the RHS argument is a DataFrame (even a 1-column one), the frames are joined on both row and
# column labels.
# - If the RHS is a Series, the frames are joined only on row labels.
# If the LHS is a Series:
# - If the RHS is a DataFrame, always returns False.
# - If the RHS is a Series, ignore both row and column labels, and join positionally.
# Note that since this test creates the LHS Series by indexing a DataFrame column, the resulting
# series will have a name.
def test_isin_ignores_index():
snow_rhs, native_rhs = create_test_series([4, 10], index=[99, 100])
snow_df, native_df = create_test_dfs({"A": [1, 2, 3], "B": [4, 5, 6]})
with SqlCounter(query_count=1):
# df-series operation
assert_snowpark_pandas_equal_to_pandas(
snow_df.isin(snow_rhs),
native_df.isin(native_rhs),
)
with SqlCounter(query_count=1):
# series-series operation (issue in the JIRA ticket)
assert_snowpark_pandas_equal_to_pandas(
snow_df["B"].isin(snow_rhs),
native_df["B"].isin(native_rhs),
)


@sql_count_checker(query_count=1)
def test_isin_series_length_mismatch():
rhs = native_pd.Series([1, 0])
eval_snowpark_pandas_result(
*create_test_series([0, 1, 1, 2, 1]), lambda s: s.isin(rhs)
)
Loading