Skip to content

Commit 367fc26

Browse files
authored
SNOW-1524760: Fix Series.isin behavior (#3973)
This PR fixes the behavior of `Series.isin(other_series)`, which ignores indices instead of joining on row/column labels. It also adds a fast path for `Series.isin(dataframe)`, which should always return false at every index. Per @sfc-gh-mvashishtha's investigation in the linked ticket: > It seems that pandas behavior is: > - ignore row and column labels for Series.isin(series) > - Series.isin(dataframe) always returns False, e.g. s = pandas.Series([1]); s.isin(s.to_frame()) > - DataFrame.isin(dataframe) joins on both row and column labels > - DataFrame.isin(series) ignores column labels but not row labels, e.g. pandas.DataFrame({'A': [1, 2]}).isin(pandas.Series([1, 2], name='B', index=[0,1])) gives True values because even though the column name is different, the index matches, but pandas.DataFrame({'A': [1, 2]}).isin(pandas.Series([1, 2], name='B', index=[-1, -2])) gives False values.
1 parent 8a555ef commit 367fc26

File tree

7 files changed

+116
-17
lines changed

7 files changed

+116
-17
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
#### Improvements
3030

31-
- Enhanced autoswitching functionality from Snowflake to native Pandas for methods with unsupported argument combinations:
31+
- Enhanced autoswitching functionality from Snowflake to native pandas for methods with unsupported argument combinations:
3232
- `shift()` with `suffix` or non-integer `periods` parameters
3333
- `sort_index()` with `axis=1` or `key` parameters
3434
- `sort_values()` with `axis=1`
@@ -42,6 +42,7 @@
4242

4343
- Fixed a bug in `DataFrameGroupBy.agg` where func is a list of tuples used to set the names of the output columns.
4444
- Fixed a bug where converting a modin datetime index with a timezone to a numpy array with `np.asarray` would cause a `TypeError`.
45+
- Fixed a bug where `Series.isin` with a Series argument matched index labels instead of the row position.
4546

4647
#### Improvements
4748

src/snowflake/snowpark/modin/plugin/_internal/isin_utils.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,21 @@
1919
SnowparkPandasType,
2020
)
2121
from snowflake.snowpark.modin.plugin._internal.type_utils import infer_series_type
22+
from snowflake.snowpark.modin.plugin._internal.join_utils import join
2223
from snowflake.snowpark.modin.plugin._internal.utils import (
2324
append_columns,
2425
generate_new_labels,
2526
is_duplicate_free,
2627
pandas_lit,
2728
)
2829
from snowflake.snowpark.modin.plugin._typing import ListLike
29-
from snowflake.snowpark.types import DataType, DoubleType, VariantType, _IntegralType
30+
from snowflake.snowpark.types import (
31+
DataType,
32+
DoubleType,
33+
VariantType,
34+
_IntegralType,
35+
BooleanType,
36+
)
3037

3138

3239
def convert_values_to_list_of_literals_and_return_type(
@@ -120,6 +127,7 @@ def scalar_isin_expression(
120127
def compute_isin_with_series(
121128
frame: InternalFrame,
122129
values_series: InternalFrame,
130+
lhs_is_series: bool,
123131
dummy_row_pos_mode: bool,
124132
) -> InternalFrame:
125133
"""
@@ -135,6 +143,36 @@ def compute_isin_with_series(
135143
Returns:
136144
InternalFrame
137145
"""
146+
# local import to avoid circular import
147+
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
148+
SnowflakeQueryCompiler,
149+
)
150+
151+
if lhs_is_series:
152+
# If the LHS is a Series, directly compute distinct elements of the RHS, which will be used as
153+
# the argument to ARRAY_CONTAINS for every element in the LHS. The only necessary join is
154+
# between the original data column and the 1-element aggregated array column.
155+
agg_label = generate_new_labels(
156+
pandas_labels=["agg"], excluded=values_series.data_column_pandas_labels
157+
)[0]
158+
distinct_frame = (
159+
SnowflakeQueryCompiler(values_series)
160+
.agg("array_agg", 0, [], {})
161+
._modin_frame
162+
)
163+
joined_frame = join(
164+
frame, distinct_frame, how="inner", dummy_row_pos_mode=dummy_row_pos_mode
165+
)[0]
166+
assert len(joined_frame.data_column_snowflake_quoted_identifiers) == 2
167+
return joined_frame.project_columns(
168+
frame.data_column_pandas_labels,
169+
column_objects=array_contains(
170+
joined_frame.data_column_snowflake_quoted_identifiers[0],
171+
joined_frame.data_column_snowflake_quoted_identifiers[1],
172+
),
173+
column_types=[SnowparkPandasType.to_pandas(BooleanType)],
174+
)
175+
138176
# For each row in this dataframe
139177
# align the index with the index of the values Series object.
140178
# If it matches, return True, else False
@@ -178,18 +216,14 @@ def compute_isin_with_series(
178216
}
179217
).frame
180218

181-
# local import to avoid circular import
182-
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
183-
SnowflakeQueryCompiler,
184-
)
185-
186219
# return internal frame but remove temporary agg column.
187220
return SnowflakeQueryCompiler(new_frame).drop(columns=[agg_label])._modin_frame
188221

189222

190223
def compute_isin_with_dataframe(
191224
frame: InternalFrame,
192225
values_frame: InternalFrame,
226+
lhs_is_series: bool,
193227
dummy_row_pos_mode: bool,
194228
) -> InternalFrame:
195229
"""
@@ -205,6 +239,14 @@ def compute_isin_with_dataframe(
205239
Returns:
206240
InternalFrame
207241
"""
242+
if lhs_is_series:
243+
# a series-DF isin operation always returns False at all positions
244+
return frame.update_snowflake_quoted_identifiers_with_expressions(
245+
{
246+
quoted_identifier: pandas_lit(False)
247+
for quoted_identifier in frame.data_column_snowflake_quoted_identifiers
248+
}
249+
)[0]
208250
# similar logic to series, however do not create a single column but multiple colunms
209251
# set values via set_frame_2d_labels then
210252

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14507,6 +14507,7 @@ def isin(
1450714507
values: Union[
1450814508
list[Any], np.ndarray, "SnowflakeQueryCompiler", dict[Hashable, ListLike]
1450914509
],
14510+
self_is_series: bool = False,
1451014511
) -> "SnowflakeQueryCompiler":
1451114512
"""
1451214513
Wrapper around _isin_internal to be supported in faster pandas.
@@ -14521,34 +14522,38 @@ def isin(
1452114522
assert values._relaxed_query_compiler is not None
1452214523
new_values = values._relaxed_query_compiler
1452314524
relaxed_query_compiler = self._relaxed_query_compiler._isin_internal(
14524-
values=new_values
14525+
values=new_values,
14526+
self_is_series=self_is_series,
1452514527
)
1452614528

14527-
qc = self._isin_internal(values=values)
14529+
qc = self._isin_internal(values=values, self_is_series=self_is_series)
1452814530
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
1452914531

1453014532
def _isin_internal(
1453114533
self,
1453214534
values: Union[
1453314535
list[Any], np.ndarray, "SnowflakeQueryCompiler", dict[Hashable, ListLike]
1453414536
],
14537+
self_is_series: bool = False,
1453514538
) -> "SnowflakeQueryCompiler": # noqa: PR02
1453614539
"""
1453714540
Check for each element of `self` whether it's contained in passed `values`.
14541+
1453814542
Parameters
1453914543
----------
1454014544
values : list-like, np.array, SnowflakeQueryCompiler or dict of pandas labels -> listlike
1454114545
Values to check elements of self in. If given as dict, match ListLike to column label given as key.
1454214546
**kwargs : dict
1454314547
Serves the compatibility purpose. Does not affect the result.
14548+
1454414549
Returns
1454514550
-------
1454614551
SnowflakeQueryCompiler
1454714552
Boolean mask for self of whether an element at the corresponding
1454814553
position is contained in `values`.
1454914554
"""
1455014555
is_snowflake_query_compiler = isinstance(values, SnowflakeQueryCompiler) # type: ignore[union-attr]
14551-
is_series = is_snowflake_query_compiler and values.is_series_like() # type: ignore[union-attr]
14556+
is_rhs_series = is_snowflake_query_compiler and values.is_series_like() # type: ignore[union-attr]
1455214557

1455314558
# convert list-like values to [lit(...), ..., lit(...)] and determine type
1455414559
# which is required to produce correct isin expression using array_contains(...) below
@@ -14621,13 +14626,19 @@ def _isin_internal(
1462114626
# idempotent operation
1462214627
return self
1462314628

14624-
if is_series:
14629+
if is_rhs_series:
1462514630
new_frame = compute_isin_with_series(
14626-
self._modin_frame, values._modin_frame, self._dummy_row_pos_mode
14631+
self._modin_frame,
14632+
values._modin_frame,
14633+
lhs_is_series=self_is_series,
14634+
dummy_row_pos_mode=self._dummy_row_pos_mode,
1462714635
)
1462814636
else:
1462914637
new_frame = compute_isin_with_dataframe(
14630-
self._modin_frame, values._modin_frame, self._dummy_row_pos_mode
14638+
self._modin_frame,
14639+
values._modin_frame,
14640+
lhs_is_series=self_is_series,
14641+
dummy_row_pos_mode=self._dummy_row_pos_mode,
1463114642
)
1463214643

1463314644
return SnowflakeQueryCompiler(new_frame)

src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,9 @@ def fillna(
10371037
# Snowpark pandas passes the query compiler object from a BasePandasDataset, which Modin does not do.
10381038
@register_base_override("isin")
10391039
def isin(
1040-
self, values: BasePandasDataset | ListLike | dict[Hashable, ListLike]
1040+
self,
1041+
values: BasePandasDataset | ListLike | dict[Hashable, ListLike],
1042+
self_is_series: bool = False,
10411043
) -> BasePandasDataset: # noqa: PR01, RT01, D200
10421044
"""
10431045
Whether elements in `BasePandasDataset` are contained in `values`.
@@ -1056,7 +1058,11 @@ def isin(
10561058
):
10571059
values = list(values)
10581060

1059-
return self.__constructor__(query_compiler=self._query_compiler.isin(values=values))
1061+
return self.__constructor__(
1062+
query_compiler=self._query_compiler.isin(
1063+
values=values, self_is_series=self_is_series
1064+
)
1065+
)
10601066

10611067

10621068
# Snowpark pandas uses the single `quantiles_along_axis0` query compiler method, while upstream

src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,7 @@ def isin(self, values: set | ListLike) -> Series:
794794
if isinstance(values, set):
795795
values = list(values)
796796

797-
return super(Series, self).isin(values)
797+
return super(Series, self).isin(values, self_is_series=True)
798798

799799

800800
# Snowpark pandas raises a warning before materializing data and passing to `plot`.

tests/integ/modin/series/test_isin.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
eval_snowpark_pandas_result,
1717
try_cast_to_snowpark_pandas_dataframe,
1818
try_cast_to_snowpark_pandas_series,
19+
create_test_dfs,
20+
create_test_series,
21+
assert_snowpark_pandas_equal_to_pandas,
1922
)
2023
from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker
2124

@@ -188,3 +191,39 @@ def test_isin_with_str_negative():
188191
),
189192
):
190193
s.isin("test")
194+
195+
196+
# Covers an edge case in SNOW-1524760 where `isin` is called with a RHS Series that has an index
197+
# differing from that of the LHS.
198+
# If the LHS is a DataFrame:
199+
# - If the RHS argument is a DataFrame (even a 1-column one), the frames are joined on both row and
200+
# column labels.
201+
# - If the RHS is a Series, the frames are joined only on row labels.
202+
# If the LHS is a Series:
203+
# - If the RHS is a DataFrame, always returns False.
204+
# - If the RHS is a Series, ignore both row and column labels, and join positionally.
205+
# Note that since this test creates the LHS Series by indexing a DataFrame column, the resulting
206+
# series will have a name.
207+
def test_isin_ignores_index():
208+
snow_rhs, native_rhs = create_test_series([4, 10], index=[99, 100])
209+
snow_df, native_df = create_test_dfs({"A": [1, 2, 3], "B": [4, 5, 6]})
210+
with SqlCounter(query_count=1):
211+
# df-series operation
212+
assert_snowpark_pandas_equal_to_pandas(
213+
snow_df.isin(snow_rhs),
214+
native_df.isin(native_rhs),
215+
)
216+
with SqlCounter(query_count=1):
217+
# series-series operation (issue in the JIRA ticket)
218+
assert_snowpark_pandas_equal_to_pandas(
219+
snow_df["B"].isin(snow_rhs),
220+
native_df["B"].isin(native_rhs),
221+
)
222+
223+
224+
@sql_count_checker(query_count=1)
225+
def test_isin_series_length_mismatch():
226+
rhs = native_pd.Series([1, 0])
227+
eval_snowpark_pandas_result(
228+
*create_test_series([0, 1, 1, 2, 1]), lambda s: s.isin(rhs)
229+
)

tests/integ/modin/test_faster_pandas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ def test_isin_list(session):
688688
assert_frame_equal(snow_result, native_result, check_dtype=False)
689689

690690

691-
@sql_count_checker(query_count=3)
691+
@sql_count_checker(query_count=3, join_count=2)
692692
def test_isin_series(session):
693693
with session_parameter_override(
694694
session, "dummy_row_pos_optimization_enabled", True

0 commit comments

Comments
 (0)