Skip to content

Commit 5ae5275

Browse files
SNOW-2454948: Revert 2 modin commits to fix CI. (#3945)
This commit reverts 2 commits that broke tests, b505e92 and 226b9d9. b505e92 broke several tests. 226b9d9 broke tests/integ/modin/hybrid/test_switch_operations.py::test_query_count_no_switch[True] and tests/integ/modin/hybrid/test_switch_operations.py::test_query_count_no_switch[False]. GitHub allowed us to merge the commits because I think I never saved the branch protection rule requiring modin-ubuntu-AWS tests to pass before merging to main. I'm struggling with the branch protection rules so I am still trying to fix them.
1 parent 30fa5f3 commit 5ae5275

File tree

4 files changed

+41
-419
lines changed

4 files changed

+41
-419
lines changed

CHANGELOG.md

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,16 +95,6 @@
9595
- `skew()` with `axis=1` or `numeric_only=False` parameters
9696
- `round()` with `decimals` parameter as a Series
9797
- `corr()` with `method!=pearson` parameter
98-
- `shift()` with `suffix` or non-integer `periods` parameters
99-
- `sort_index()` with `axis=1` or `key` parameters
100-
- `sort_values()` with `axis=1`
101-
- `melt()` with `col_level` parameter
102-
- `apply()` with `result_type` parameter for DataFrame
103-
- `pivot_table()` with `sort=True`, non-string `index` list, non-string `columns` list, non-string `values` list, or `aggfunc` dict with non-string values
104-
- `fillna()` with `downcast` parameter or using `limit` together with `value`
105-
- `dropna()` with `axis=1`
106-
107-
10898
- Set `cte_optimization_enabled` to True for all Snowpark pandas sessions.
10999
- Add support for the following in faster pandas:
110100
- `isin`
@@ -166,7 +156,6 @@
166156
- `groupby.var`
167157
- `groupby.nunique`
168158
- `groupby.size`
169-
- `groupby.apply`
170159
- `drop_duplicates`
171160
- Reuse row count from the relaxed query compiler in `get_axis_len`.
172161

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 17 additions & 219 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def get_unsupported_args_reason(
626626

627627

628628
def register_query_compiler_method_not_implemented(
629-
api_cls_names: Union[list[Optional[str]], Optional[str]],
629+
api_cls_name: Optional[str],
630630
method_name: str,
631631
unsupported_args: Optional["UnsupportedArgsRule"] = None,
632632
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
@@ -644,30 +644,22 @@ def register_query_compiler_method_not_implemented(
644644
without meaningful benefit.
645645

646646
Args:
647-
api_cls_names: Frontend class names (e.g. "BasePandasDataset", "Series", "DataFrame", or None). It can be a list if multiple api_cls_names are needed.
647+
api_cls_name: Frontend class name (e.g., "BasePandasDataset", "Series", "DataFrame", "None").
648648
method_name: Method name to register.
649649
unsupported_args: UnsupportedArgsRule for args-based auto-switching.
650650
If None, method is treated as completely unimplemented.
651651
"""
652+
reg_key = MethodKey(api_cls_name, method_name)
652653

653-
if isinstance(api_cls_names, str) or api_cls_names is None:
654-
api_cls_names = [api_cls_names]
655-
assert (
656-
api_cls_names
657-
), "api_cls_names must be a string (e.g., 'DataFrame', 'Series') or a list of strings (e.g., ['DataFrame', 'Series']) or None for top-level functions"
658-
659-
for api_cls_name in api_cls_names:
660-
reg_key = MethodKey(api_cls_name, method_name)
661-
662-
# register the method in the hybrid switch for unsupported args
663-
if unsupported_args is None:
664-
HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(reg_key)
665-
else:
666-
HYBRID_SWITCH_FOR_UNSUPPORTED_ARGS[reg_key] = unsupported_args
654+
# register the method in the hybrid switch for unsupported args
655+
if unsupported_args is None:
656+
HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(reg_key)
657+
else:
658+
HYBRID_SWITCH_FOR_UNSUPPORTED_ARGS[reg_key] = unsupported_args
667659

668-
register_function_for_pre_op_switch(
669-
class_name=api_cls_name, backend="Snowflake", method=method_name
670-
)
660+
register_function_for_pre_op_switch(
661+
class_name=api_cls_name, backend="Snowflake", method=method_name
662+
)
671663

672664
def decorator(query_compiler_method: Callable[..., Any]) -> Callable[..., Any]:
673665
@functools.wraps(query_compiler_method)
@@ -2613,22 +2605,6 @@ def _shift_index(self, periods: int, freq: Any) -> "SnowflakeQueryCompiler": #
26132605
# TODO: SNOW-1023324, implement shifting index only.
26142606
ErrorMessage.not_implemented("shifting index values not yet supported.")
26152607

2616-
@register_query_compiler_method_not_implemented(
2617-
"BasePandasDataset",
2618-
"shift",
2619-
UnsupportedArgsRule(
2620-
unsupported_conditions=[
2621-
(
2622-
lambda args: args.get("suffix") is not None,
2623-
"the 'suffix' parameter is not yet supported",
2624-
),
2625-
(
2626-
lambda args: not isinstance(args.get("periods"), int),
2627-
"only int 'periods' is currently supported",
2628-
),
2629-
]
2630-
),
2631-
)
26322608
def shift(
26332609
self,
26342610
periods: Union[int, Sequence[int]] = 1,
@@ -4216,19 +4192,6 @@ def first_last_valid_index(
42164192
)
42174193
return None
42184194

4219-
@register_query_compiler_method_not_implemented(
4220-
"BasePandasDataset",
4221-
"sort_index",
4222-
UnsupportedArgsRule(
4223-
unsupported_conditions=[
4224-
("axis", 1),
4225-
(
4226-
lambda args: args.get("key") is not None,
4227-
"the 'key' parameter is not yet supported",
4228-
),
4229-
]
4230-
),
4231-
)
42324195
def sort_index(
42334196
self,
42344197
*,
@@ -4291,7 +4254,7 @@ def sort_index(
42914254
1.0 c
42924255
dtype: object
42934256
"""
4294-
if axis == 1:
4257+
if axis in (1, "index"):
42954258
ErrorMessage.not_implemented(
42964259
"sort_index is not supported yet on axis=1 in Snowpark pandas."
42974260
)
@@ -4315,17 +4278,8 @@ def sort_index(
43154278
include_indexer=include_indexer,
43164279
)
43174280

4318-
@register_query_compiler_method_not_implemented(
4319-
"BasePandasDataset",
4320-
"sort_values",
4321-
UnsupportedArgsRule(
4322-
unsupported_conditions=[
4323-
("axis", 1),
4324-
]
4325-
),
4326-
)
43274281
def sort_columns_by_row_values(
4328-
self, rows: IndexLabel, ascending: bool = True, axis: int = 1, **kwargs: Any
4282+
self, rows: IndexLabel, ascending: bool = True, **kwargs: Any
43294283
) -> None:
43304284
"""
43314285
Reorder the columns based on the lexicographic order of the given rows.
@@ -4335,9 +4289,6 @@ def sort_columns_by_row_values(
43354289
The row or rows to sort by.
43364290
ascending : bool, default: True
43374291
Sort in ascending order (True) or descending order (False).
4338-
axis: Always set to 1. Required because the decorator compares frontend
4339-
method arguments during stay_cost computation (returning COST_IMPOSSIBLE)
4340-
but examines QC method arguments when calling the wrapped method.
43414292
**kwargs : dict
43424293
Serves the compatibility purpose. Does not affect the result.
43434294

@@ -4935,52 +4886,6 @@ def convert_func_to_agg_func_info(
49354886
return query_compiler if as_index else query_compiler.reset_index(drop=drop)
49364887

49374888
def groupby_apply(
4938-
self,
4939-
by: Any,
4940-
agg_func: AggFuncType,
4941-
axis: int,
4942-
groupby_kwargs: dict[str, Any],
4943-
agg_args: Any,
4944-
agg_kwargs: dict[str, Any],
4945-
series_groupby: bool,
4946-
include_groups: bool,
4947-
force_single_group: bool = False,
4948-
force_list_like_to_series: bool = False,
4949-
) -> "SnowflakeQueryCompiler":
4950-
"""
4951-
Wrapper around _groupby_apply_internal to be supported in faster pandas.
4952-
"""
4953-
relaxed_query_compiler = None
4954-
if self._relaxed_query_compiler is not None:
4955-
relaxed_query_compiler = (
4956-
self._relaxed_query_compiler._groupby_apply_internal(
4957-
by=by,
4958-
agg_func=agg_func,
4959-
axis=axis,
4960-
groupby_kwargs=groupby_kwargs,
4961-
agg_args=agg_args,
4962-
agg_kwargs=agg_kwargs,
4963-
series_groupby=series_groupby,
4964-
include_groups=include_groups,
4965-
force_single_group=force_single_group,
4966-
force_list_like_to_series=force_list_like_to_series,
4967-
)
4968-
)
4969-
qc = self._groupby_apply_internal(
4970-
by=by,
4971-
agg_func=agg_func,
4972-
axis=axis,
4973-
groupby_kwargs=groupby_kwargs,
4974-
agg_args=agg_args,
4975-
agg_kwargs=agg_kwargs,
4976-
series_groupby=series_groupby,
4977-
include_groups=include_groups,
4978-
force_single_group=force_single_group,
4979-
force_list_like_to_series=force_list_like_to_series,
4980-
)
4981-
return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
4982-
4983-
def _groupby_apply_internal(
49844889
self,
49854890
by: Any,
49864891
agg_func: Callable,
@@ -9056,18 +8961,6 @@ def cummax(
90568961
).frame
90578962
)
90588963

9059-
@register_query_compiler_method_not_implemented(
9060-
None,
9061-
"melt",
9062-
UnsupportedArgsRule(
9063-
unsupported_conditions=[
9064-
(
9065-
lambda args: args.get("col_level") is not None,
9066-
"col_level argument is not yet supported",
9067-
),
9068-
]
9069-
),
9070-
)
90718964
def melt(
90728965
self,
90738966
id_vars: list[str],
@@ -10115,18 +10008,6 @@ def align(
1011510008

1011610009
return left_qc, right_qc
1011710010

10118-
@register_query_compiler_method_not_implemented(
10119-
"DataFrame",
10120-
"apply",
10121-
UnsupportedArgsRule(
10122-
unsupported_conditions=[
10123-
(
10124-
lambda args: args.get("result_type") is not None,
10125-
"the 'result_type' parameter is not yet supported",
10126-
),
10127-
]
10128-
),
10129-
)
1013010011
def apply(
1013110012
self,
1013210013
func: Union[AggFuncType, UserDefinedFunction],
@@ -10799,63 +10680,6 @@ def pivot(
1079910680
sort=True,
1080010681
)
1080110682

10802-
@register_query_compiler_method_not_implemented(
10803-
None,
10804-
"pivot_table",
10805-
UnsupportedArgsRule(
10806-
unsupported_conditions=[
10807-
("sort", False),
10808-
(
10809-
lambda args: (
10810-
args.get("index") is not None
10811-
and (
10812-
not isinstance(args.get("index"), str)
10813-
and not all([isinstance(v, str) for v in args.get("index")])
10814-
and None not in args.get("index")
10815-
)
10816-
),
10817-
"non-string of list of string index is not yet supported for pivot_table",
10818-
),
10819-
(
10820-
lambda args: (
10821-
args.get("columns") is not None
10822-
and (
10823-
not isinstance(args.get("columns"), str)
10824-
and not all(
10825-
[isinstance(v, str) for v in args.get("columns")]
10826-
)
10827-
and None not in args.get("columns")
10828-
)
10829-
),
10830-
"non-string of list of string columns is not yet supported for pivot_table",
10831-
),
10832-
(
10833-
lambda args: (
10834-
args.get("values") is not None
10835-
and (
10836-
not isinstance(args.get("values"), str)
10837-
and not all(
10838-
[isinstance(v, str) for v in args.get("values")]
10839-
)
10840-
and None not in args.get("values")
10841-
)
10842-
),
10843-
"non-string of list of string values is not yet supported for pivot_table",
10844-
),
10845-
(
10846-
lambda args: (
10847-
isinstance(args.get("aggfunc"), dict)
10848-
and any(
10849-
not isinstance(af, str)
10850-
for af in args.get("aggfunc").values()
10851-
)
10852-
and args.get("index") is None
10853-
),
10854-
"dictionary aggfunc with non-string aggregation functions is not yet supported for pivot_table with margins or when index is None",
10855-
),
10856-
]
10857-
),
10858-
)
1085910683
def pivot_table(
1086010684
self,
1086110685
index: Any,
@@ -10931,7 +10755,7 @@ def pivot_table(
1093110755
index = [index]
1093210756

1093310757
# TODO: SNOW-857485 Support for non-str and list of non-str for index/columns/values
10934-
if index is not None and (
10758+
if index and (
1093510759
not isinstance(index, str)
1093610760
and not all([isinstance(v, str) for v in index])
1093710761
and None not in index
@@ -10940,7 +10764,7 @@ def pivot_table(
1094010764
f"Not implemented non-string of list of string {index}."
1094110765
)
1094210766

10943-
if values is not None and (
10767+
if values and (
1094410768
not isinstance(values, str)
1094510769
and not all([isinstance(v, str) for v in values])
1094610770
and None not in values
@@ -10949,7 +10773,7 @@ def pivot_table(
1094910773
f"Not implemented non-string of list of string {values}."
1095010774
)
1095110775

10952-
if columns is not None and (
10776+
if columns and (
1095310777
not isinstance(columns, str)
1095410778
and not all([isinstance(v, str) for v in columns])
1095510779
and None not in columns
@@ -13050,23 +12874,6 @@ def _make_fill_expression_for_column_wise_fillna(
1305012874
*columns_to_include,
1305112875
)
1305212876

13053-
@register_query_compiler_method_not_implemented(
13054-
["DataFrame", "Series"],
13055-
"fillna",
13056-
UnsupportedArgsRule(
13057-
unsupported_conditions=[
13058-
(
13059-
lambda kwargs: kwargs.get("value") is not None
13060-
and kwargs.get("limit") is not None,
13061-
"the 'limit' parameter with 'value' parameter is not yet supported",
13062-
),
13063-
(
13064-
lambda kwargs: kwargs.get("downcast") is not None,
13065-
"the 'downcast' parameter is not yet supported",
13066-
),
13067-
]
13068-
),
13069-
)
1307012877
def fillna(
1307112878
self,
1307212879
value: Optional[Union[Hashable, Mapping, "pd.DataFrame", "pd.Series"]] = None,
@@ -13312,15 +13119,6 @@ def fillna_expr(snowflake_quoted_id: str) -> SnowparkColumn:
1331213119
).frame
1331313120
)
1331413121

13315-
@register_query_compiler_method_not_implemented(
13316-
"DataFrame",
13317-
"dropna",
13318-
UnsupportedArgsRule(
13319-
unsupported_conditions=[
13320-
("axis", 1),
13321-
]
13322-
),
13323-
)
1332413122
def dropna(
1332513123
self,
1332613124
axis: int,
@@ -22018,7 +21816,7 @@ def _stack_helper(
2201821816
return qc
2201921817

2202021818
@register_query_compiler_method_not_implemented(
22021-
api_cls_names="DataFrame",
21819+
api_cls_name="DataFrame",
2202221820
method_name="corr",
2202321821
unsupported_args=UnsupportedArgsRule(
2202421822
unsupported_conditions=[

0 commit comments

Comments
 (0)