298298 OrderingColumn,
299299)
300300from snowflake.snowpark.modin.plugin._internal.pivot_utils import (
301+ check_pivot_table_unsupported_args,
301302 expand_pivot_result_with_pivot_table_margins,
302303 expand_pivot_result_with_pivot_table_margins_no_groupby_columns,
303304 generate_pivot_aggregation_value_label_snowflake_quoted_identifier_mappings,
@@ -626,7 +627,7 @@ def get_unsupported_args_reason(
626627
627628
628629def register_query_compiler_method_not_implemented(
629- api_cls_name: Optional[str],
630+ api_cls_names: Union[list[ Optional[str]], Optional[str] ],
630631 method_name: str,
631632 unsupported_args: Optional["UnsupportedArgsRule"] = None,
632633) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
@@ -644,22 +645,30 @@ def register_query_compiler_method_not_implemented(
644645 without meaningful benefit.
645646
646647 Args:
647- api_cls_name : Frontend class name (e.g., "BasePandasDataset", "Series", "DataFrame", " None") .
648+ api_cls_names : Frontend class names (e.g. "BasePandasDataset", "Series", "DataFrame", or None). It can be a list if multiple api_cls_names are needed .
648649 method_name: Method name to register.
649650 unsupported_args: UnsupportedArgsRule for args-based auto-switching.
650651 If None, method is treated as completely unimplemented.
651652 """
652- reg_key = MethodKey(api_cls_name, method_name)
653653
654- # register the method in the hybrid switch for unsupported args
655- if unsupported_args is None:
656- HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(reg_key)
657- else:
658- HYBRID_SWITCH_FOR_UNSUPPORTED_ARGS[reg_key] = unsupported_args
654+ if isinstance(api_cls_names, str) or api_cls_names is None:
655+ api_cls_names = [api_cls_names]
656+ assert (
657+ api_cls_names
658+ ), "api_cls_names must be a string (e.g., 'DataFrame', 'Series') or a list of strings (e.g., ['DataFrame', 'Series']) or None for top-level functions"
659659
660- register_function_for_pre_op_switch(
661- class_name=api_cls_name, backend="Snowflake", method=method_name
662- )
660+ for api_cls_name in api_cls_names:
661+ reg_key = MethodKey(api_cls_name, method_name)
662+
663+ # register the method in the hybrid switch for unsupported args
664+ if unsupported_args is None:
665+ HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(reg_key)
666+ else:
667+ HYBRID_SWITCH_FOR_UNSUPPORTED_ARGS[reg_key] = unsupported_args
668+
669+ register_function_for_pre_op_switch(
670+ class_name=api_cls_name, backend="Snowflake", method=method_name
671+ )
663672
664673 def decorator(query_compiler_method: Callable[..., Any]) -> Callable[..., Any]:
665674 @functools.wraps(query_compiler_method)
@@ -1154,11 +1163,13 @@ def stay_cost(
11541163
11551164 return QCCoercionCost.COST_IMPOSSIBLE
11561165
1157- return QCCoercionCost.COST_ZERO
1158-
11591166 # Strongly discourage the use of these methods in snowflake
11601167 if operation in HYBRID_ALL_EXPENSIVE_METHODS:
11611168 return QCCoercionCost.COST_HIGH
1169+
1170+ if method_key in HYBRID_SWITCH_FOR_UNSUPPORTED_ARGS:
1171+ return QCCoercionCost.COST_ZERO
1172+
11621173 return super().stay_cost(api_cls_name, operation, arguments)
11631174
11641175 @classmethod
@@ -2615,6 +2626,22 @@ def _shift_index(self, periods: int, freq: Any) -> "SnowflakeQueryCompiler": #
26152626 # TODO: SNOW-1023324, implement shifting index only.
26162627 ErrorMessage.not_implemented("shifting index values not yet supported.")
26172628
2629+ @register_query_compiler_method_not_implemented(
2630+ "BasePandasDataset",
2631+ "shift",
2632+ UnsupportedArgsRule(
2633+ unsupported_conditions=[
2634+ (
2635+ lambda args: args.get("suffix") is not None,
2636+ "the 'suffix' parameter is not yet supported",
2637+ ),
2638+ (
2639+ lambda args: not isinstance(args.get("periods"), int),
2640+ "only int 'periods' is currently supported",
2641+ ),
2642+ ]
2643+ ),
2644+ )
26182645 def shift(
26192646 self,
26202647 periods: Union[int, Sequence[int]] = 1,
@@ -4204,6 +4231,19 @@ def first_last_valid_index(
42044231 )
42054232 return None
42064233
4234+ @register_query_compiler_method_not_implemented(
4235+ "BasePandasDataset",
4236+ "sort_index",
4237+ UnsupportedArgsRule(
4238+ unsupported_conditions=[
4239+ ("axis", 1),
4240+ (
4241+ lambda args: args.get("key") is not None,
4242+ "the 'key' parameter is not yet supported",
4243+ ),
4244+ ]
4245+ ),
4246+ )
42074247 def sort_index(
42084248 self,
42094249 *,
@@ -4266,7 +4306,7 @@ def sort_index(
42664306 1.0 c
42674307 dtype: object
42684308 """
4269- if axis in (1, "index") :
4309+ if axis == 1 :
42704310 ErrorMessage.not_implemented(
42714311 "sort_index is not supported yet on axis=1 in Snowpark pandas."
42724312 )
@@ -4290,8 +4330,17 @@ def sort_index(
42904330 include_indexer=include_indexer,
42914331 )
42924332
4333+ @register_query_compiler_method_not_implemented(
4334+ "BasePandasDataset",
4335+ "sort_values",
4336+ UnsupportedArgsRule(
4337+ unsupported_conditions=[
4338+ ("axis", 1),
4339+ ]
4340+ ),
4341+ )
42934342 def sort_columns_by_row_values(
4294- self, rows: IndexLabel, ascending: bool = True, **kwargs: Any
4343+ self, rows: IndexLabel, ascending: bool = True, axis: int = 1, **kwargs: Any
42954344 ) -> None:
42964345 """
42974346 Reorder the columns based on the lexicographic order of the given rows.
@@ -4301,6 +4350,9 @@ def sort_columns_by_row_values(
43014350 The row or rows to sort by.
43024351 ascending : bool, default: True
43034352 Sort in ascending order (True) or descending order (False).
4353+ axis: Always set to 1. Required because the decorator compares frontend
4354+ method arguments during stay_cost computation (returning COST_IMPOSSIBLE)
4355+ but examines QC method arguments when calling the wrapped method.
43044356 **kwargs : dict
43054357 Serves the compatibility purpose. Does not affect the result.
43064358
@@ -9085,6 +9137,18 @@ def cummax(
90859137 ).frame
90869138 )
90879139
9140+ @register_query_compiler_method_not_implemented(
9141+ None,
9142+ "melt",
9143+ UnsupportedArgsRule(
9144+ unsupported_conditions=[
9145+ (
9146+ lambda args: args.get("col_level") is not None,
9147+ "col_level argument is not yet supported",
9148+ ),
9149+ ]
9150+ ),
9151+ )
90889152 def melt(
90899153 self,
90909154 id_vars: list[str],
@@ -10142,6 +10206,18 @@ def align(
1014210206
1014310207 return left_qc, right_qc
1014410208
10209+ @register_query_compiler_method_not_implemented(
10210+ "DataFrame",
10211+ "apply",
10212+ UnsupportedArgsRule(
10213+ unsupported_conditions=[
10214+ (
10215+ lambda args: args.get("result_type") is not None,
10216+ "the 'result_type' parameter is not yet supported",
10217+ ),
10218+ ]
10219+ ),
10220+ )
1014510221 def apply(
1014610222 self,
1014710223 func: Union[AggFuncType, UserDefinedFunction],
@@ -10814,6 +10890,20 @@ def pivot(
1081410890 sort=True,
1081510891 )
1081610892
10893+ @register_query_compiler_method_not_implemented(
10894+ None,
10895+ "pivot_table",
10896+ UnsupportedArgsRule(
10897+ unsupported_conditions=[
10898+ ("observed", True),
10899+ ("sort", False),
10900+ (
10901+ lambda args: check_pivot_table_unsupported_args(args) is not None,
10902+ check_pivot_table_unsupported_args,
10903+ ),
10904+ ]
10905+ ),
10906+ )
1081710907 def pivot_table(
1081810908 self,
1081910909 index: Any,
@@ -10889,7 +10979,7 @@ def pivot_table(
1088910979 index = [index]
1089010980
1089110981 # TODO: SNOW-857485 Support for non-str and list of non-str for index/columns/values
10892- if index and (
10982+ if index is not None and (
1089310983 not isinstance(index, str)
1089410984 and not all([isinstance(v, str) for v in index])
1089510985 and None not in index
@@ -10898,7 +10988,7 @@ def pivot_table(
1089810988 f"Not implemented non-string of list of string {index}."
1089910989 )
1090010990
10901- if values and (
10991+ if values is not None and (
1090210992 not isinstance(values, str)
1090310993 and not all([isinstance(v, str) for v in values])
1090410994 and None not in values
@@ -10907,7 +10997,7 @@ def pivot_table(
1090710997 f"Not implemented non-string of list of string {values}."
1090810998 )
1090910999
10910- if columns and (
11000+ if columns is not None and (
1091111001 not isinstance(columns, str)
1091211002 and not all([isinstance(v, str) for v in columns])
1091311003 and None not in columns
@@ -13021,6 +13111,23 @@ def _make_fill_expression_for_column_wise_fillna(
1302113111 *columns_to_include,
1302213112 )
1302313113
13114+ @register_query_compiler_method_not_implemented(
13115+ ["DataFrame", "Series"],
13116+ "fillna",
13117+ UnsupportedArgsRule(
13118+ unsupported_conditions=[
13119+ (
13120+ lambda kwargs: kwargs.get("value") is not None
13121+ and kwargs.get("limit") is not None,
13122+ "the 'limit' parameter with 'value' parameter is not yet supported",
13123+ ),
13124+ (
13125+ lambda kwargs: kwargs.get("downcast") is not None,
13126+ "the 'downcast' parameter is not yet supported",
13127+ ),
13128+ ]
13129+ ),
13130+ )
1302413131 def fillna(
1302513132 self,
1302613133 value: Optional[Union[Hashable, Mapping, "pd.DataFrame", "pd.Series"]] = None,
@@ -13266,6 +13373,15 @@ def fillna_expr(snowflake_quoted_id: str) -> SnowparkColumn:
1326613373 ).frame
1326713374 )
1326813375
13376+ @register_query_compiler_method_not_implemented(
13377+ "DataFrame",
13378+ "dropna",
13379+ UnsupportedArgsRule(
13380+ unsupported_conditions=[
13381+ ("axis", 1),
13382+ ]
13383+ ),
13384+ )
1326913385 def dropna(
1327013386 self,
1327113387 axis: int,
@@ -22121,7 +22237,7 @@ def _stack_helper(
2212122237 return qc
2212222238
2212322239 @register_query_compiler_method_not_implemented(
22124- api_cls_name ="DataFrame",
22240+ api_cls_names ="DataFrame",
2212522241 method_name="corr",
2212622242 unsupported_args=UnsupportedArgsRule(
2212722243 unsupported_conditions=[
0 commit comments