Skip to content

Commit b505e92

Browse files
SNOW-2359402 Enabled autoswitching on some unsupported args (#3906)
1 parent 569aef4 commit b505e92

File tree

3 files changed

+341
-41
lines changed

3 files changed

+341
-41
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,16 @@
9595
- `skew()` with `axis=1` or `numeric_only=False` parameters
9696
- `round()` with `decimals` parameter as a Series
9797
- `corr()` with `method!=pearson` parameter
98+
- `shift()` with `suffix` or non-integer `periods` parameters
99+
- `sort_index()` with `axis=1` or `key` parameters
100+
- `sort_values()` with `axis=1`
101+
- `melt()` with `col_level` parameter
102+
- `apply()` with `result_type` parameter for DataFrame
103+
- `pivot_table()` with `sort=True`, non-string `index` list, non-string `columns` list, non-string `values` list, or `aggfunc` dict with non-string values
104+
- `fillna()` with `downcast` parameter or using `limit` together with `value`
105+
- `dropna()` with `axis=1`
106+
107+
98108
- Set `cte_optimization_enabled` to True for all Snowpark pandas sessions.
99109
- Add support for the following in faster pandas:
100110
- `isin`

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 173 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def get_unsupported_args_reason(
626626

627627

628628
def register_query_compiler_method_not_implemented(
629-
api_cls_name: Optional[str],
629+
api_cls_names: Union[list[Optional[str]], Optional[str]],
630630
method_name: str,
631631
unsupported_args: Optional["UnsupportedArgsRule"] = None,
632632
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
@@ -644,22 +644,30 @@ def register_query_compiler_method_not_implemented(
644644
without meaningful benefit.
645645

646646
Args:
647-
api_cls_name: Frontend class name (e.g., "BasePandasDataset", "Series", "DataFrame", "None").
647+
api_cls_names: Frontend class names (e.g. "BasePandasDataset", "Series", "DataFrame", or None). It can be a list if multiple api_cls_names are needed.
648648
method_name: Method name to register.
649649
unsupported_args: UnsupportedArgsRule for args-based auto-switching.
650650
If None, method is treated as completely unimplemented.
651651
"""
652-
reg_key = MethodKey(api_cls_name, method_name)
653652

654-
# register the method in the hybrid switch for unsupported args
655-
if unsupported_args is None:
656-
HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(reg_key)
657-
else:
658-
HYBRID_SWITCH_FOR_UNSUPPORTED_ARGS[reg_key] = unsupported_args
653+
if isinstance(api_cls_names, str) or api_cls_names is None:
654+
api_cls_names = [api_cls_names]
655+
assert (
656+
api_cls_names
657+
), "api_cls_names must be a string (e.g., 'DataFrame', 'Series') or a list of strings (e.g., ['DataFrame', 'Series']) or None for top-level functions"
659658

660-
register_function_for_pre_op_switch(
661-
class_name=api_cls_name, backend="Snowflake", method=method_name
662-
)
659+
for api_cls_name in api_cls_names:
660+
reg_key = MethodKey(api_cls_name, method_name)
661+
662+
# register the method in the hybrid switch for unsupported args
663+
if unsupported_args is None:
664+
HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(reg_key)
665+
else:
666+
HYBRID_SWITCH_FOR_UNSUPPORTED_ARGS[reg_key] = unsupported_args
667+
668+
register_function_for_pre_op_switch(
669+
class_name=api_cls_name, backend="Snowflake", method=method_name
670+
)
663671

664672
def decorator(query_compiler_method: Callable[..., Any]) -> Callable[..., Any]:
665673
@functools.wraps(query_compiler_method)
@@ -2605,6 +2613,22 @@ def _shift_index(self, periods: int, freq: Any) -> "SnowflakeQueryCompiler": #
26052613
# TODO: SNOW-1023324, implement shifting index only.
26062614
ErrorMessage.not_implemented("shifting index values not yet supported.")
26072615

2616+
@register_query_compiler_method_not_implemented(
2617+
"BasePandasDataset",
2618+
"shift",
2619+
UnsupportedArgsRule(
2620+
unsupported_conditions=[
2621+
(
2622+
lambda args: args.get("suffix") is not None,
2623+
"the 'suffix' parameter is not yet supported",
2624+
),
2625+
(
2626+
lambda args: not isinstance(args.get("periods"), int),
2627+
"only int 'periods' is currently supported",
2628+
),
2629+
]
2630+
),
2631+
)
26082632
def shift(
26092633
self,
26102634
periods: Union[int, Sequence[int]] = 1,
@@ -4192,6 +4216,19 @@ def first_last_valid_index(
41924216
)
41934217
return None
41944218

4219+
@register_query_compiler_method_not_implemented(
4220+
"BasePandasDataset",
4221+
"sort_index",
4222+
UnsupportedArgsRule(
4223+
unsupported_conditions=[
4224+
("axis", 1),
4225+
(
4226+
lambda args: args.get("key") is not None,
4227+
"the 'key' parameter is not yet supported",
4228+
),
4229+
]
4230+
),
4231+
)
41954232
def sort_index(
41964233
self,
41974234
*,
@@ -4254,7 +4291,7 @@ def sort_index(
42544291
1.0 c
42554292
dtype: object
42564293
"""
4257-
if axis in (1, "index"):
4294+
if axis == 1:
42584295
ErrorMessage.not_implemented(
42594296
"sort_index is not supported yet on axis=1 in Snowpark pandas."
42604297
)
@@ -4278,8 +4315,17 @@ def sort_index(
42784315
include_indexer=include_indexer,
42794316
)
42804317

4318+
@register_query_compiler_method_not_implemented(
4319+
"BasePandasDataset",
4320+
"sort_values",
4321+
UnsupportedArgsRule(
4322+
unsupported_conditions=[
4323+
("axis", 1),
4324+
]
4325+
),
4326+
)
42814327
def sort_columns_by_row_values(
4282-
self, rows: IndexLabel, ascending: bool = True, **kwargs: Any
4328+
self, rows: IndexLabel, ascending: bool = True, axis: int = 1, **kwargs: Any
42834329
) -> None:
42844330
"""
42854331
Reorder the columns based on the lexicographic order of the given rows.
@@ -4289,6 +4335,9 @@ def sort_columns_by_row_values(
42894335
The row or rows to sort by.
42904336
ascending : bool, default: True
42914337
Sort in ascending order (True) or descending order (False).
4338+
axis: Always set to 1. Required because the decorator compares frontend
4339+
method arguments during stay_cost computation (returning COST_IMPOSSIBLE)
4340+
but examines QC method arguments when calling the wrapped method.
42924341
**kwargs : dict
42934342
Serves the compatibility purpose. Does not affect the result.
42944343

@@ -9007,6 +9056,18 @@ def cummax(
90079056
).frame
90089057
)
90099058

9059+
@register_query_compiler_method_not_implemented(
9060+
None,
9061+
"melt",
9062+
UnsupportedArgsRule(
9063+
unsupported_conditions=[
9064+
(
9065+
lambda args: args.get("col_level") is not None,
9066+
"col_level argument is not yet supported",
9067+
),
9068+
]
9069+
),
9070+
)
90109071
def melt(
90119072
self,
90129073
id_vars: list[str],
@@ -10054,6 +10115,18 @@ def align(
1005410115

1005510116
return left_qc, right_qc
1005610117

10118+
@register_query_compiler_method_not_implemented(
10119+
"DataFrame",
10120+
"apply",
10121+
UnsupportedArgsRule(
10122+
unsupported_conditions=[
10123+
(
10124+
lambda args: args.get("result_type") is not None,
10125+
"the 'result_type' parameter is not yet supported",
10126+
),
10127+
]
10128+
),
10129+
)
1005710130
def apply(
1005810131
self,
1005910132
func: Union[AggFuncType, UserDefinedFunction],
@@ -10726,6 +10799,63 @@ def pivot(
1072610799
sort=True,
1072710800
)
1072810801

10802+
@register_query_compiler_method_not_implemented(
10803+
None,
10804+
"pivot_table",
10805+
UnsupportedArgsRule(
10806+
unsupported_conditions=[
10807+
("sort", False),
10808+
(
10809+
lambda args: (
10810+
args.get("index") is not None
10811+
and (
10812+
not isinstance(args.get("index"), str)
10813+
and not all([isinstance(v, str) for v in args.get("index")])
10814+
and None not in args.get("index")
10815+
)
10816+
),
10817+
"non-string of list of string index is not yet supported for pivot_table",
10818+
),
10819+
(
10820+
lambda args: (
10821+
args.get("columns") is not None
10822+
and (
10823+
not isinstance(args.get("columns"), str)
10824+
and not all(
10825+
[isinstance(v, str) for v in args.get("columns")]
10826+
)
10827+
and None not in args.get("columns")
10828+
)
10829+
),
10830+
"non-string of list of string columns is not yet supported for pivot_table",
10831+
),
10832+
(
10833+
lambda args: (
10834+
args.get("values") is not None
10835+
and (
10836+
not isinstance(args.get("values"), str)
10837+
and not all(
10838+
[isinstance(v, str) for v in args.get("values")]
10839+
)
10840+
and None not in args.get("values")
10841+
)
10842+
),
10843+
"non-string of list of string values is not yet supported for pivot_table",
10844+
),
10845+
(
10846+
lambda args: (
10847+
isinstance(args.get("aggfunc"), dict)
10848+
and any(
10849+
not isinstance(af, str)
10850+
for af in args.get("aggfunc").values()
10851+
)
10852+
and args.get("index") is None
10853+
),
10854+
"dictionary aggfunc with non-string aggregation functions is not yet supported for pivot_table with margins or when index is None",
10855+
),
10856+
]
10857+
),
10858+
)
1072910859
def pivot_table(
1073010860
self,
1073110861
index: Any,
@@ -10801,7 +10931,7 @@ def pivot_table(
1080110931
index = [index]
1080210932

1080310933
# TODO: SNOW-857485 Support for non-str and list of non-str for index/columns/values
10804-
if index and (
10934+
if index is not None and (
1080510935
not isinstance(index, str)
1080610936
and not all([isinstance(v, str) for v in index])
1080710937
and None not in index
@@ -10810,7 +10940,7 @@ def pivot_table(
1081010940
f"Not implemented non-string of list of string {index}."
1081110941
)
1081210942

10813-
if values and (
10943+
if values is not None and (
1081410944
not isinstance(values, str)
1081510945
and not all([isinstance(v, str) for v in values])
1081610946
and None not in values
@@ -10819,7 +10949,7 @@ def pivot_table(
1081910949
f"Not implemented non-string of list of string {values}."
1082010950
)
1082110951

10822-
if columns and (
10952+
if columns is not None and (
1082310953
not isinstance(columns, str)
1082410954
and not all([isinstance(v, str) for v in columns])
1082510955
and None not in columns
@@ -12920,6 +13050,23 @@ def _make_fill_expression_for_column_wise_fillna(
1292013050
*columns_to_include,
1292113051
)
1292213052

13053+
@register_query_compiler_method_not_implemented(
13054+
["DataFrame", "Series"],
13055+
"fillna",
13056+
UnsupportedArgsRule(
13057+
unsupported_conditions=[
13058+
(
13059+
lambda kwargs: kwargs.get("value") is not None
13060+
and kwargs.get("limit") is not None,
13061+
"the 'limit' parameter with 'value' parameter is not yet supported",
13062+
),
13063+
(
13064+
lambda kwargs: kwargs.get("downcast") is not None,
13065+
"the 'downcast' parameter is not yet supported",
13066+
),
13067+
]
13068+
),
13069+
)
1292313070
def fillna(
1292413071
self,
1292513072
value: Optional[Union[Hashable, Mapping, "pd.DataFrame", "pd.Series"]] = None,
@@ -13165,6 +13312,15 @@ def fillna_expr(snowflake_quoted_id: str) -> SnowparkColumn:
1316513312
).frame
1316613313
)
1316713314

13315+
@register_query_compiler_method_not_implemented(
13316+
"DataFrame",
13317+
"dropna",
13318+
UnsupportedArgsRule(
13319+
unsupported_conditions=[
13320+
("axis", 1),
13321+
]
13322+
),
13323+
)
1316813324
def dropna(
1316913325
self,
1317013326
axis: int,
@@ -21862,7 +22018,7 @@ def _stack_helper(
2186222018
return qc
2186322019

2186422020
@register_query_compiler_method_not_implemented(
21865-
api_cls_name="DataFrame",
22021+
api_cls_names="DataFrame",
2186622022
method_name="corr",
2186722023
unsupported_args=UnsupportedArgsRule(
2186822024
unsupported_conditions=[

0 commit comments

Comments
 (0)