Skip to content

Commit a26070d

Browse files
SNOW-2359402 Enabled autoswitching on some unsupported args (#3953)
1 parent ecd7b5c commit a26070d

File tree

10 files changed

+388
-50
lines changed

10 files changed

+388
-50
lines changed

CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@
2626
- Added support for mapping `np.percentile` with DataFrame and Series inputs to `Series.quantile`.
2727
- Added support for setting the `random_state` parameter to an integer when calling `DataFrame.sample` or `Series.sample`.
2828

29+
#### Improvements
30+
31+
- Enhanced autoswitching functionality from Snowflake to native Pandas for methods with unsupported argument combinations:
32+
- `shift()` with `suffix` or non-integer `periods` parameters
33+
- `sort_index()` with `axis=1` or `key` parameters
34+
- `sort_values()` with `axis=1`
35+
- `melt()` with `col_level` parameter
36+
- `apply()` with `result_type` parameter for DataFrame
37+
- `pivot_table()` with `sort=True`, non-string `index` list, non-string `columns` list, non-string `values` list, or `aggfunc` dict with non-string values
38+
- `fillna()` with `downcast` parameter or using `limit` together with `value`
39+
- `dropna()` with `axis=1`
40+
2941
#### Bug Fixes
3042

3143
- Fixed a bug in `DataFrameGroupBy.agg` where func is a list of tuples used to set the names of the output columns.

src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,3 +1866,58 @@ def generate_column_prefix_groupings(
18661866
)
18671867

18681868
return list(zip(margin_data_column_prefixes, margin_data_column_groupings))
1869+
1870+
1871+
def check_pivot_table_unsupported_args(args: dict) -> Optional[str]:
1872+
"""
1873+
Validate pivot_table arguments for unsupported conditions.
1874+
1875+
This helper function checks various argument combinations that are not yet
1876+
supported by Snowpark pandas pivot_table implementation.
1877+
1878+
Args:
1879+
args : dictionary of arguments passed to pivot_table
1880+
1881+
Returns:
1882+
Error message if an unsupported condition is found, None otherwise
1883+
"""
1884+
# Check if index argument is a string or list of strings
1885+
index = args.get("index")
1886+
if (
1887+
index is not None
1888+
and not isinstance(index, str)
1889+
and not all(isinstance(v, str) for v in index)
1890+
and None not in index
1891+
):
1892+
return "index argument should be a string or a list of strings"
1893+
1894+
# Check if columns argument is a string or list of strings
1895+
columns = args.get("columns")
1896+
if (
1897+
columns is not None
1898+
and not isinstance(columns, str)
1899+
and not all(isinstance(v, str) for v in columns)
1900+
and None not in columns
1901+
):
1902+
return "columns argument should be a string or a list of strings"
1903+
1904+
# Check if values argument is a string or list of strings
1905+
values = args.get("values")
1906+
if (
1907+
values is not None
1908+
and not isinstance(values, str)
1909+
and not all(isinstance(v, str) for v in values)
1910+
and None not in values
1911+
):
1912+
return "values argument should be a string or a list of strings"
1913+
1914+
# Check for dictionary aggfunc with non-string functions when index is None
1915+
aggfunc = args.get("aggfunc")
1916+
if (
1917+
isinstance(aggfunc, dict)
1918+
and any(not isinstance(af, str) for af in aggfunc.values())
1919+
and args.get("index") is None
1920+
):
1921+
return "dictionary aggfunc with non-string aggregation functions is not yet supported for pivot_table when index is None"
1922+
1923+
return None

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 135 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@
298298
OrderingColumn,
299299
)
300300
from snowflake.snowpark.modin.plugin._internal.pivot_utils import (
301+
check_pivot_table_unsupported_args,
301302
expand_pivot_result_with_pivot_table_margins,
302303
expand_pivot_result_with_pivot_table_margins_no_groupby_columns,
303304
generate_pivot_aggregation_value_label_snowflake_quoted_identifier_mappings,
@@ -626,7 +627,7 @@ def get_unsupported_args_reason(
626627

627628

628629
def register_query_compiler_method_not_implemented(
629-
api_cls_name: Optional[str],
630+
api_cls_names: Union[list[Optional[str]], Optional[str]],
630631
method_name: str,
631632
unsupported_args: Optional["UnsupportedArgsRule"] = None,
632633
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
@@ -644,22 +645,30 @@ def register_query_compiler_method_not_implemented(
644645
without meaningful benefit.
645646

646647
Args:
647-
api_cls_name: Frontend class name (e.g., "BasePandasDataset", "Series", "DataFrame", "None").
648+
api_cls_names: Frontend class names (e.g. "BasePandasDataset", "Series", "DataFrame", or None). It can be a list if multiple api_cls_names are needed.
648649
method_name: Method name to register.
649650
unsupported_args: UnsupportedArgsRule for args-based auto-switching.
650651
If None, method is treated as completely unimplemented.
651652
"""
652-
reg_key = MethodKey(api_cls_name, method_name)
653653

654-
# register the method in the hybrid switch for unsupported args
655-
if unsupported_args is None:
656-
HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(reg_key)
657-
else:
658-
HYBRID_SWITCH_FOR_UNSUPPORTED_ARGS[reg_key] = unsupported_args
654+
if isinstance(api_cls_names, str) or api_cls_names is None:
655+
api_cls_names = [api_cls_names]
656+
assert (
657+
api_cls_names
658+
), "api_cls_names must be a string (e.g., 'DataFrame', 'Series') or a list of strings (e.g., ['DataFrame', 'Series']) or None for top-level functions"
659659

660-
register_function_for_pre_op_switch(
661-
class_name=api_cls_name, backend="Snowflake", method=method_name
662-
)
660+
for api_cls_name in api_cls_names:
661+
reg_key = MethodKey(api_cls_name, method_name)
662+
663+
# register the method in the hybrid switch for unsupported args
664+
if unsupported_args is None:
665+
HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(reg_key)
666+
else:
667+
HYBRID_SWITCH_FOR_UNSUPPORTED_ARGS[reg_key] = unsupported_args
668+
669+
register_function_for_pre_op_switch(
670+
class_name=api_cls_name, backend="Snowflake", method=method_name
671+
)
663672

664673
def decorator(query_compiler_method: Callable[..., Any]) -> Callable[..., Any]:
665674
@functools.wraps(query_compiler_method)
@@ -1154,11 +1163,13 @@ def stay_cost(
11541163

11551164
return QCCoercionCost.COST_IMPOSSIBLE
11561165

1157-
return QCCoercionCost.COST_ZERO
1158-
11591166
# Strongly discourage the use of these methods in snowflake
11601167
if operation in HYBRID_ALL_EXPENSIVE_METHODS:
11611168
return QCCoercionCost.COST_HIGH
1169+
1170+
if method_key in HYBRID_SWITCH_FOR_UNSUPPORTED_ARGS:
1171+
return QCCoercionCost.COST_ZERO
1172+
11621173
return super().stay_cost(api_cls_name, operation, arguments)
11631174

11641175
@classmethod
@@ -2615,6 +2626,22 @@ def _shift_index(self, periods: int, freq: Any) -> "SnowflakeQueryCompiler": #
26152626
# TODO: SNOW-1023324, implement shifting index only.
26162627
ErrorMessage.not_implemented("shifting index values not yet supported.")
26172628

2629+
@register_query_compiler_method_not_implemented(
2630+
"BasePandasDataset",
2631+
"shift",
2632+
UnsupportedArgsRule(
2633+
unsupported_conditions=[
2634+
(
2635+
lambda args: args.get("suffix") is not None,
2636+
"the 'suffix' parameter is not yet supported",
2637+
),
2638+
(
2639+
lambda args: not isinstance(args.get("periods"), int),
2640+
"only int 'periods' is currently supported",
2641+
),
2642+
]
2643+
),
2644+
)
26182645
def shift(
26192646
self,
26202647
periods: Union[int, Sequence[int]] = 1,
@@ -4204,6 +4231,19 @@ def first_last_valid_index(
42044231
)
42054232
return None
42064233

4234+
@register_query_compiler_method_not_implemented(
4235+
"BasePandasDataset",
4236+
"sort_index",
4237+
UnsupportedArgsRule(
4238+
unsupported_conditions=[
4239+
("axis", 1),
4240+
(
4241+
lambda args: args.get("key") is not None,
4242+
"the 'key' parameter is not yet supported",
4243+
),
4244+
]
4245+
),
4246+
)
42074247
def sort_index(
42084248
self,
42094249
*,
@@ -4266,7 +4306,7 @@ def sort_index(
42664306
1.0 c
42674307
dtype: object
42684308
"""
4269-
if axis in (1, "index"):
4309+
if axis == 1:
42704310
ErrorMessage.not_implemented(
42714311
"sort_index is not supported yet on axis=1 in Snowpark pandas."
42724312
)
@@ -4290,8 +4330,17 @@ def sort_index(
42904330
include_indexer=include_indexer,
42914331
)
42924332

4333+
@register_query_compiler_method_not_implemented(
4334+
"BasePandasDataset",
4335+
"sort_values",
4336+
UnsupportedArgsRule(
4337+
unsupported_conditions=[
4338+
("axis", 1),
4339+
]
4340+
),
4341+
)
42934342
def sort_columns_by_row_values(
4294-
self, rows: IndexLabel, ascending: bool = True, **kwargs: Any
4343+
self, rows: IndexLabel, ascending: bool = True, axis: int = 1, **kwargs: Any
42954344
) -> None:
42964345
"""
42974346
Reorder the columns based on the lexicographic order of the given rows.
@@ -4301,6 +4350,9 @@ def sort_columns_by_row_values(
43014350
The row or rows to sort by.
43024351
ascending : bool, default: True
43034352
Sort in ascending order (True) or descending order (False).
4353+
axis: Always set to 1. Required because the decorator compares frontend
4354+
method arguments during stay_cost computation (returning COST_IMPOSSIBLE)
4355+
but examines QC method arguments when calling the wrapped method.
43044356
**kwargs : dict
43054357
Serves the compatibility purpose. Does not affect the result.
43064358

@@ -9085,6 +9137,18 @@ def cummax(
90859137
).frame
90869138
)
90879139

9140+
@register_query_compiler_method_not_implemented(
9141+
None,
9142+
"melt",
9143+
UnsupportedArgsRule(
9144+
unsupported_conditions=[
9145+
(
9146+
lambda args: args.get("col_level") is not None,
9147+
"col_level argument is not yet supported",
9148+
),
9149+
]
9150+
),
9151+
)
90889152
def melt(
90899153
self,
90909154
id_vars: list[str],
@@ -10142,6 +10206,18 @@ def align(
1014210206

1014310207
return left_qc, right_qc
1014410208

10209+
@register_query_compiler_method_not_implemented(
10210+
"DataFrame",
10211+
"apply",
10212+
UnsupportedArgsRule(
10213+
unsupported_conditions=[
10214+
(
10215+
lambda args: args.get("result_type") is not None,
10216+
"the 'result_type' parameter is not yet supported",
10217+
),
10218+
]
10219+
),
10220+
)
1014510221
def apply(
1014610222
self,
1014710223
func: Union[AggFuncType, UserDefinedFunction],
@@ -10814,6 +10890,20 @@ def pivot(
1081410890
sort=True,
1081510891
)
1081610892

10893+
@register_query_compiler_method_not_implemented(
10894+
None,
10895+
"pivot_table",
10896+
UnsupportedArgsRule(
10897+
unsupported_conditions=[
10898+
("observed", True),
10899+
("sort", False),
10900+
(
10901+
lambda args: check_pivot_table_unsupported_args(args) is not None,
10902+
check_pivot_table_unsupported_args,
10903+
),
10904+
]
10905+
),
10906+
)
1081710907
def pivot_table(
1081810908
self,
1081910909
index: Any,
@@ -10889,7 +10979,7 @@ def pivot_table(
1088910979
index = [index]
1089010980

1089110981
# TODO: SNOW-857485 Support for non-str and list of non-str for index/columns/values
10892-
if index and (
10982+
if index is not None and (
1089310983
not isinstance(index, str)
1089410984
and not all([isinstance(v, str) for v in index])
1089510985
and None not in index
@@ -10898,7 +10988,7 @@ def pivot_table(
1089810988
f"Not implemented non-string of list of string {index}."
1089910989
)
1090010990

10901-
if values and (
10991+
if values is not None and (
1090210992
not isinstance(values, str)
1090310993
and not all([isinstance(v, str) for v in values])
1090410994
and None not in values
@@ -10907,7 +10997,7 @@ def pivot_table(
1090710997
f"Not implemented non-string of list of string {values}."
1090810998
)
1090910999

10910-
if columns and (
11000+
if columns is not None and (
1091111001
not isinstance(columns, str)
1091211002
and not all([isinstance(v, str) for v in columns])
1091311003
and None not in columns
@@ -13021,6 +13111,23 @@ def _make_fill_expression_for_column_wise_fillna(
1302113111
*columns_to_include,
1302213112
)
1302313113

13114+
@register_query_compiler_method_not_implemented(
13115+
["DataFrame", "Series"],
13116+
"fillna",
13117+
UnsupportedArgsRule(
13118+
unsupported_conditions=[
13119+
(
13120+
lambda kwargs: kwargs.get("value") is not None
13121+
and kwargs.get("limit") is not None,
13122+
"the 'limit' parameter with 'value' parameter is not yet supported",
13123+
),
13124+
(
13125+
lambda kwargs: kwargs.get("downcast") is not None,
13126+
"the 'downcast' parameter is not yet supported",
13127+
),
13128+
]
13129+
),
13130+
)
1302413131
def fillna(
1302513132
self,
1302613133
value: Optional[Union[Hashable, Mapping, "pd.DataFrame", "pd.Series"]] = None,
@@ -13266,6 +13373,15 @@ def fillna_expr(snowflake_quoted_id: str) -> SnowparkColumn:
1326613373
).frame
1326713374
)
1326813375

13376+
@register_query_compiler_method_not_implemented(
13377+
"DataFrame",
13378+
"dropna",
13379+
UnsupportedArgsRule(
13380+
unsupported_conditions=[
13381+
("axis", 1),
13382+
]
13383+
),
13384+
)
1326913385
def dropna(
1327013386
self,
1327113387
axis: int,
@@ -22121,7 +22237,7 @@ def _stack_helper(
2212122237
return qc
2212222238

2212322239
@register_query_compiler_method_not_implemented(
22124-
api_cls_name="DataFrame",
22240+
api_cls_names="DataFrame",
2212522241
method_name="corr",
2212622242
unsupported_args=UnsupportedArgsRule(
2212722243
unsupported_conditions=[

tests/integ/modin/frame/test_apply.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ class TestNotImplemented:
377377
@sql_count_checker(query_count=0)
378378
def test_result_type(self, result_type):
379379
snow_df = pd.DataFrame([[1, 2], [3, 4]])
380-
msg = "Snowpark pandas apply API doesn't yet support 'result_type' parameter"
380+
msg = "Snowpark pandas apply does not yet support the parameter combination because the 'result_type' parameter is not yet supported."
381381
with pytest.raises(NotImplementedError, match=msg):
382382
snow_df.apply(lambda x: [1, 2], axis=1, result_type=result_type)
383383

tests/integ/modin/frame/test_apply_axis_0.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ class TestNotImplemented:
242242
@sql_count_checker(query_count=0)
243243
def test_result_type(self, result_type):
244244
snow_df = pd.DataFrame([[1, 2], [3, 4]])
245-
msg = "Snowpark pandas apply API doesn't yet support 'result_type' parameter"
245+
msg = "Snowpark pandas apply does not yet support the parameter combination because the 'result_type' parameter is not yet supported."
246246
with pytest.raises(NotImplementedError, match=msg):
247247
snow_df.apply(lambda x: [1, 2], axis=0, result_type=result_type)
248248

0 commit comments

Comments
 (0)