@@ -626,7 +626,7 @@ def get_unsupported_args_reason(
626626
627627
628628def register_query_compiler_method_not_implemented(
629- api_cls_name: Optional[str],
629+ api_cls_names: Union[list[ Optional[str]], Optional[str] ],
630630 method_name: str,
631631 unsupported_args: Optional["UnsupportedArgsRule"] = None,
632632) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
@@ -644,22 +644,30 @@ def register_query_compiler_method_not_implemented(
644644 without meaningful benefit.
645645
646646 Args:
647- api_cls_name : Frontend class name (e.g., "BasePandasDataset", "Series", "DataFrame", " None") .
647+ api_cls_names : Frontend class names (e.g. "BasePandasDataset", "Series", "DataFrame", or None). It can be a list if multiple api_cls_names are needed .
648648 method_name: Method name to register.
649649 unsupported_args: UnsupportedArgsRule for args-based auto-switching.
650650 If None, method is treated as completely unimplemented.
651651 """
652- reg_key = MethodKey(api_cls_name, method_name)
653652
654- # register the method in the hybrid switch for unsupported args
655- if unsupported_args is None:
656- HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(reg_key)
657- else:
658- HYBRID_SWITCH_FOR_UNSUPPORTED_ARGS[reg_key] = unsupported_args
653+ if isinstance(api_cls_names, str) or api_cls_names is None:
654+ api_cls_names = [api_cls_names]
655+ assert (
656+ api_cls_names
657+ ), "api_cls_names must be a string (e.g., 'DataFrame', 'Series') or a list of strings (e.g., ['DataFrame', 'Series']) or None for top-level functions"
659658
660- register_function_for_pre_op_switch(
661- class_name=api_cls_name, backend="Snowflake", method=method_name
662- )
659+ for api_cls_name in api_cls_names:
660+ reg_key = MethodKey(api_cls_name, method_name)
661+
662+ # register the method in the hybrid switch for unsupported args
663+ if unsupported_args is None:
664+ HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(reg_key)
665+ else:
666+ HYBRID_SWITCH_FOR_UNSUPPORTED_ARGS[reg_key] = unsupported_args
667+
668+ register_function_for_pre_op_switch(
669+ class_name=api_cls_name, backend="Snowflake", method=method_name
670+ )
663671
664672 def decorator(query_compiler_method: Callable[..., Any]) -> Callable[..., Any]:
665673 @functools.wraps(query_compiler_method)
@@ -2605,6 +2613,22 @@ def _shift_index(self, periods: int, freq: Any) -> "SnowflakeQueryCompiler": #
26052613 # TODO: SNOW-1023324, implement shifting index only.
26062614 ErrorMessage.not_implemented("shifting index values not yet supported.")
26072615
2616+ @register_query_compiler_method_not_implemented(
2617+ "BasePandasDataset",
2618+ "shift",
2619+ UnsupportedArgsRule(
2620+ unsupported_conditions=[
2621+ (
2622+ lambda args: args.get("suffix") is not None,
2623+ "the 'suffix' parameter is not yet supported",
2624+ ),
2625+ (
2626+ lambda args: not isinstance(args.get("periods"), int),
2627+ "only int 'periods' is currently supported",
2628+ ),
2629+ ]
2630+ ),
2631+ )
26082632 def shift(
26092633 self,
26102634 periods: Union[int, Sequence[int]] = 1,
@@ -4192,6 +4216,19 @@ def first_last_valid_index(
41924216 )
41934217 return None
41944218
4219+ @register_query_compiler_method_not_implemented(
4220+ "BasePandasDataset",
4221+ "sort_index",
4222+ UnsupportedArgsRule(
4223+ unsupported_conditions=[
4224+ ("axis", 1),
4225+ (
4226+ lambda args: args.get("key") is not None,
4227+ "the 'key' parameter is not yet supported",
4228+ ),
4229+ ]
4230+ ),
4231+ )
41954232 def sort_index(
41964233 self,
41974234 *,
@@ -4254,7 +4291,7 @@ def sort_index(
42544291 1.0 c
42554292 dtype: object
42564293 """
4257- if axis in (1, "index") :
4294+ if axis == 1 :
42584295 ErrorMessage.not_implemented(
42594296 "sort_index is not supported yet on axis=1 in Snowpark pandas."
42604297 )
@@ -4278,8 +4315,17 @@ def sort_index(
42784315 include_indexer=include_indexer,
42794316 )
42804317
4318+ @register_query_compiler_method_not_implemented(
4319+ "BasePandasDataset",
4320+ "sort_values",
4321+ UnsupportedArgsRule(
4322+ unsupported_conditions=[
4323+ ("axis", 1),
4324+ ]
4325+ ),
4326+ )
42814327 def sort_columns_by_row_values(
4282- self, rows: IndexLabel, ascending: bool = True, **kwargs: Any
4328+ self, rows: IndexLabel, ascending: bool = True, axis: int = 1, **kwargs: Any
42834329 ) -> None:
42844330 """
42854331 Reorder the columns based on the lexicographic order of the given rows.
@@ -4289,6 +4335,9 @@ def sort_columns_by_row_values(
42894335 The row or rows to sort by.
42904336 ascending : bool, default: True
42914337 Sort in ascending order (True) or descending order (False).
4338+ axis: Always set to 1. Required because the decorator compares frontend
4339+ method arguments during stay_cost computation (returning COST_IMPOSSIBLE)
4340+ but examines QC method arguments when calling the wrapped method.
42924341 **kwargs : dict
42934342 Serves the compatibility purpose. Does not affect the result.
42944343
@@ -9007,6 +9056,18 @@ def cummax(
90079056 ).frame
90089057 )
90099058
9059+ @register_query_compiler_method_not_implemented(
9060+ None,
9061+ "melt",
9062+ UnsupportedArgsRule(
9063+ unsupported_conditions=[
9064+ (
9065+ lambda args: args.get("col_level") is not None,
9066+ "col_level argument is not yet supported",
9067+ ),
9068+ ]
9069+ ),
9070+ )
90109071 def melt(
90119072 self,
90129073 id_vars: list[str],
@@ -10054,6 +10115,18 @@ def align(
1005410115
1005510116 return left_qc, right_qc
1005610117
10118+ @register_query_compiler_method_not_implemented(
10119+ "DataFrame",
10120+ "apply",
10121+ UnsupportedArgsRule(
10122+ unsupported_conditions=[
10123+ (
10124+ lambda args: args.get("result_type") is not None,
10125+ "the 'result_type' parameter is not yet supported",
10126+ ),
10127+ ]
10128+ ),
10129+ )
1005710130 def apply(
1005810131 self,
1005910132 func: Union[AggFuncType, UserDefinedFunction],
@@ -10726,6 +10799,63 @@ def pivot(
1072610799 sort=True,
1072710800 )
1072810801
10802+ @register_query_compiler_method_not_implemented(
10803+ None,
10804+ "pivot_table",
10805+ UnsupportedArgsRule(
10806+ unsupported_conditions=[
10807+ ("sort", False),
10808+ (
10809+ lambda args: (
10810+ args.get("index") is not None
10811+ and (
10812+ not isinstance(args.get("index"), str)
10813+ and not all([isinstance(v, str) for v in args.get("index")])
10814+ and None not in args.get("index")
10815+ )
10816+ ),
10817+ "non-string of list of string index is not yet supported for pivot_table",
10818+ ),
10819+ (
10820+ lambda args: (
10821+ args.get("columns") is not None
10822+ and (
10823+ not isinstance(args.get("columns"), str)
10824+ and not all(
10825+ [isinstance(v, str) for v in args.get("columns")]
10826+ )
10827+ and None not in args.get("columns")
10828+ )
10829+ ),
10830+ "non-string of list of string columns is not yet supported for pivot_table",
10831+ ),
10832+ (
10833+ lambda args: (
10834+ args.get("values") is not None
10835+ and (
10836+ not isinstance(args.get("values"), str)
10837+ and not all(
10838+ [isinstance(v, str) for v in args.get("values")]
10839+ )
10840+ and None not in args.get("values")
10841+ )
10842+ ),
10843+ "non-string of list of string values is not yet supported for pivot_table",
10844+ ),
10845+ (
10846+ lambda args: (
10847+ isinstance(args.get("aggfunc"), dict)
10848+ and any(
10849+ not isinstance(af, str)
10850+ for af in args.get("aggfunc").values()
10851+ )
10852+ and args.get("index") is None
10853+ ),
10854+ "dictionary aggfunc with non-string aggregation functions is not yet supported for pivot_table with margins or when index is None",
10855+ ),
10856+ ]
10857+ ),
10858+ )
1072910859 def pivot_table(
1073010860 self,
1073110861 index: Any,
@@ -10801,7 +10931,7 @@ def pivot_table(
1080110931 index = [index]
1080210932
1080310933 # TODO: SNOW-857485 Support for non-str and list of non-str for index/columns/values
10804- if index and (
10934+ if index is not None and (
1080510935 not isinstance(index, str)
1080610936 and not all([isinstance(v, str) for v in index])
1080710937 and None not in index
@@ -10810,7 +10940,7 @@ def pivot_table(
1081010940 f"Not implemented non-string of list of string {index}."
1081110941 )
1081210942
10813- if values and (
10943+ if values is not None and (
1081410944 not isinstance(values, str)
1081510945 and not all([isinstance(v, str) for v in values])
1081610946 and None not in values
@@ -10819,7 +10949,7 @@ def pivot_table(
1081910949 f"Not implemented non-string of list of string {values}."
1082010950 )
1082110951
10822- if columns and (
10952+ if columns is not None and (
1082310953 not isinstance(columns, str)
1082410954 and not all([isinstance(v, str) for v in columns])
1082510955 and None not in columns
@@ -12920,6 +13050,23 @@ def _make_fill_expression_for_column_wise_fillna(
1292013050 *columns_to_include,
1292113051 )
1292213052
13053+ @register_query_compiler_method_not_implemented(
13054+ ["DataFrame", "Series"],
13055+ "fillna",
13056+ UnsupportedArgsRule(
13057+ unsupported_conditions=[
13058+ (
13059+ lambda kwargs: kwargs.get("value") is not None
13060+ and kwargs.get("limit") is not None,
13061+ "the 'limit' parameter with 'value' parameter is not yet supported",
13062+ ),
13063+ (
13064+ lambda kwargs: kwargs.get("downcast") is not None,
13065+ "the 'downcast' parameter is not yet supported",
13066+ ),
13067+ ]
13068+ ),
13069+ )
1292313070 def fillna(
1292413071 self,
1292513072 value: Optional[Union[Hashable, Mapping, "pd.DataFrame", "pd.Series"]] = None,
@@ -13165,6 +13312,15 @@ def fillna_expr(snowflake_quoted_id: str) -> SnowparkColumn:
1316513312 ).frame
1316613313 )
1316713314
13315+ @register_query_compiler_method_not_implemented(
13316+ "DataFrame",
13317+ "dropna",
13318+ UnsupportedArgsRule(
13319+ unsupported_conditions=[
13320+ ("axis", 1),
13321+ ]
13322+ ),
13323+ )
1316813324 def dropna(
1316913325 self,
1317013326 axis: int,
@@ -21862,7 +22018,7 @@ def _stack_helper(
2186222018 return qc
2186322019
2186422020 @register_query_compiler_method_not_implemented(
21865- api_cls_name ="DataFrame",
22021+ api_cls_names ="DataFrame",
2186622022 method_name="corr",
2186722023 unsupported_args=UnsupportedArgsRule(
2186822024 unsupported_conditions=[
0 commit comments