@@ -626,7 +626,7 @@ def get_unsupported_args_reason(
626626
627627
628628def register_query_compiler_method_not_implemented(
629- api_cls_names: Union[list[ Optional[str]], Optional[str] ],
629+ api_cls_name: Optional[str],
630630 method_name: str,
631631 unsupported_args: Optional["UnsupportedArgsRule"] = None,
632632) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
@@ -644,30 +644,22 @@ def register_query_compiler_method_not_implemented(
644644 without meaningful benefit.
645645
646646 Args:
647- api_cls_names : Frontend class names (e.g. "BasePandasDataset", "Series", "DataFrame", or None). It can be a list if multiple api_cls_names are needed .
647+ api_cls_name : Frontend class name (e.g., "BasePandasDataset", "Series", "DataFrame", " None") .
648648 method_name: Method name to register.
649649 unsupported_args: UnsupportedArgsRule for args-based auto-switching.
650650 If None, method is treated as completely unimplemented.
651651 """
652+ reg_key = MethodKey(api_cls_name, method_name)
652653
653- if isinstance(api_cls_names, str) or api_cls_names is None:
654- api_cls_names = [api_cls_names]
655- assert (
656- api_cls_names
657- ), "api_cls_names must be a string (e.g., 'DataFrame', 'Series') or a list of strings (e.g., ['DataFrame', 'Series']) or None for top-level functions"
658-
659- for api_cls_name in api_cls_names:
660- reg_key = MethodKey(api_cls_name, method_name)
661-
662- # register the method in the hybrid switch for unsupported args
663- if unsupported_args is None:
664- HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(reg_key)
665- else:
666- HYBRID_SWITCH_FOR_UNSUPPORTED_ARGS[reg_key] = unsupported_args
654+ # register the method in the hybrid switch for unsupported args
655+ if unsupported_args is None:
656+ HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(reg_key)
657+ else:
658+ HYBRID_SWITCH_FOR_UNSUPPORTED_ARGS[reg_key] = unsupported_args
667659
668- register_function_for_pre_op_switch(
669- class_name=api_cls_name, backend="Snowflake", method=method_name
670- )
660+ register_function_for_pre_op_switch(
661+ class_name=api_cls_name, backend="Snowflake", method=method_name
662+ )
671663
672664 def decorator(query_compiler_method: Callable[..., Any]) -> Callable[..., Any]:
673665 @functools.wraps(query_compiler_method)
@@ -2613,22 +2605,6 @@ def _shift_index(self, periods: int, freq: Any) -> "SnowflakeQueryCompiler": #
26132605 # TODO: SNOW-1023324, implement shifting index only.
26142606 ErrorMessage.not_implemented("shifting index values not yet supported.")
26152607
2616- @register_query_compiler_method_not_implemented(
2617- "BasePandasDataset",
2618- "shift",
2619- UnsupportedArgsRule(
2620- unsupported_conditions=[
2621- (
2622- lambda args: args.get("suffix") is not None,
2623- "the 'suffix' parameter is not yet supported",
2624- ),
2625- (
2626- lambda args: not isinstance(args.get("periods"), int),
2627- "only int 'periods' is currently supported",
2628- ),
2629- ]
2630- ),
2631- )
26322608 def shift(
26332609 self,
26342610 periods: Union[int, Sequence[int]] = 1,
@@ -4216,19 +4192,6 @@ def first_last_valid_index(
42164192 )
42174193 return None
42184194
4219- @register_query_compiler_method_not_implemented(
4220- "BasePandasDataset",
4221- "sort_index",
4222- UnsupportedArgsRule(
4223- unsupported_conditions=[
4224- ("axis", 1),
4225- (
4226- lambda args: args.get("key") is not None,
4227- "the 'key' parameter is not yet supported",
4228- ),
4229- ]
4230- ),
4231- )
42324195 def sort_index(
42334196 self,
42344197 *,
@@ -4291,7 +4254,7 @@ def sort_index(
42914254 1.0 c
42924255 dtype: object
42934256 """
4294- if axis == 1 :
4257+ if axis in (1, "index") :
42954258 ErrorMessage.not_implemented(
42964259 "sort_index is not supported yet on axis=1 in Snowpark pandas."
42974260 )
@@ -4315,17 +4278,8 @@ def sort_index(
43154278 include_indexer=include_indexer,
43164279 )
43174280
4318- @register_query_compiler_method_not_implemented(
4319- "BasePandasDataset",
4320- "sort_values",
4321- UnsupportedArgsRule(
4322- unsupported_conditions=[
4323- ("axis", 1),
4324- ]
4325- ),
4326- )
43274281 def sort_columns_by_row_values(
4328- self, rows: IndexLabel, ascending: bool = True, axis: int = 1, **kwargs: Any
4282+ self, rows: IndexLabel, ascending: bool = True, **kwargs: Any
43294283 ) -> None:
43304284 """
43314285 Reorder the columns based on the lexicographic order of the given rows.
@@ -4335,9 +4289,6 @@ def sort_columns_by_row_values(
43354289 The row or rows to sort by.
43364290 ascending : bool, default: True
43374291 Sort in ascending order (True) or descending order (False).
4338- axis: Always set to 1. Required because the decorator compares frontend
4339- method arguments during stay_cost computation (returning COST_IMPOSSIBLE)
4340- but examines QC method arguments when calling the wrapped method.
43414292 **kwargs : dict
43424293 Serves the compatibility purpose. Does not affect the result.
43434294
@@ -4935,52 +4886,6 @@ def convert_func_to_agg_func_info(
49354886 return query_compiler if as_index else query_compiler.reset_index(drop=drop)
49364887
49374888 def groupby_apply(
4938- self,
4939- by: Any,
4940- agg_func: AggFuncType,
4941- axis: int,
4942- groupby_kwargs: dict[str, Any],
4943- agg_args: Any,
4944- agg_kwargs: dict[str, Any],
4945- series_groupby: bool,
4946- include_groups: bool,
4947- force_single_group: bool = False,
4948- force_list_like_to_series: bool = False,
4949- ) -> "SnowflakeQueryCompiler":
4950- """
4951- Wrapper around _groupby_apply_internal to be supported in faster pandas.
4952- """
4953- relaxed_query_compiler = None
4954- if self._relaxed_query_compiler is not None:
4955- relaxed_query_compiler = (
4956- self._relaxed_query_compiler._groupby_apply_internal(
4957- by=by,
4958- agg_func=agg_func,
4959- axis=axis,
4960- groupby_kwargs=groupby_kwargs,
4961- agg_args=agg_args,
4962- agg_kwargs=agg_kwargs,
4963- series_groupby=series_groupby,
4964- include_groups=include_groups,
4965- force_single_group=force_single_group,
4966- force_list_like_to_series=force_list_like_to_series,
4967- )
4968- )
4969- qc = self._groupby_apply_internal(
4970- by=by,
4971- agg_func=agg_func,
4972- axis=axis,
4973- groupby_kwargs=groupby_kwargs,
4974- agg_args=agg_args,
4975- agg_kwargs=agg_kwargs,
4976- series_groupby=series_groupby,
4977- include_groups=include_groups,
4978- force_single_group=force_single_group,
4979- force_list_like_to_series=force_list_like_to_series,
4980- )
4981- return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler)
4982-
4983- def _groupby_apply_internal(
49844889 self,
49854890 by: Any,
49864891 agg_func: Callable,
@@ -9056,18 +8961,6 @@ def cummax(
90568961 ).frame
90578962 )
90588963
9059- @register_query_compiler_method_not_implemented(
9060- None,
9061- "melt",
9062- UnsupportedArgsRule(
9063- unsupported_conditions=[
9064- (
9065- lambda args: args.get("col_level") is not None,
9066- "col_level argument is not yet supported",
9067- ),
9068- ]
9069- ),
9070- )
90718964 def melt(
90728965 self,
90738966 id_vars: list[str],
@@ -10115,18 +10008,6 @@ def align(
1011510008
1011610009 return left_qc, right_qc
1011710010
10118- @register_query_compiler_method_not_implemented(
10119- "DataFrame",
10120- "apply",
10121- UnsupportedArgsRule(
10122- unsupported_conditions=[
10123- (
10124- lambda args: args.get("result_type") is not None,
10125- "the 'result_type' parameter is not yet supported",
10126- ),
10127- ]
10128- ),
10129- )
1013010011 def apply(
1013110012 self,
1013210013 func: Union[AggFuncType, UserDefinedFunction],
@@ -10799,63 +10680,6 @@ def pivot(
1079910680 sort=True,
1080010681 )
1080110682
10802- @register_query_compiler_method_not_implemented(
10803- None,
10804- "pivot_table",
10805- UnsupportedArgsRule(
10806- unsupported_conditions=[
10807- ("sort", False),
10808- (
10809- lambda args: (
10810- args.get("index") is not None
10811- and (
10812- not isinstance(args.get("index"), str)
10813- and not all([isinstance(v, str) for v in args.get("index")])
10814- and None not in args.get("index")
10815- )
10816- ),
10817- "non-string of list of string index is not yet supported for pivot_table",
10818- ),
10819- (
10820- lambda args: (
10821- args.get("columns") is not None
10822- and (
10823- not isinstance(args.get("columns"), str)
10824- and not all(
10825- [isinstance(v, str) for v in args.get("columns")]
10826- )
10827- and None not in args.get("columns")
10828- )
10829- ),
10830- "non-string of list of string columns is not yet supported for pivot_table",
10831- ),
10832- (
10833- lambda args: (
10834- args.get("values") is not None
10835- and (
10836- not isinstance(args.get("values"), str)
10837- and not all(
10838- [isinstance(v, str) for v in args.get("values")]
10839- )
10840- and None not in args.get("values")
10841- )
10842- ),
10843- "non-string of list of string values is not yet supported for pivot_table",
10844- ),
10845- (
10846- lambda args: (
10847- isinstance(args.get("aggfunc"), dict)
10848- and any(
10849- not isinstance(af, str)
10850- for af in args.get("aggfunc").values()
10851- )
10852- and args.get("index") is None
10853- ),
10854- "dictionary aggfunc with non-string aggregation functions is not yet supported for pivot_table with margins or when index is None",
10855- ),
10856- ]
10857- ),
10858- )
1085910683 def pivot_table(
1086010684 self,
1086110685 index: Any,
@@ -10931,7 +10755,7 @@ def pivot_table(
1093110755 index = [index]
1093210756
1093310757 # TODO: SNOW-857485 Support for non-str and list of non-str for index/columns/values
10934- if index is not None and (
10758+ if index and (
1093510759 not isinstance(index, str)
1093610760 and not all([isinstance(v, str) for v in index])
1093710761 and None not in index
@@ -10940,7 +10764,7 @@ def pivot_table(
1094010764 f"Not implemented non-string of list of string {index}."
1094110765 )
1094210766
10943- if values is not None and (
10767+ if values and (
1094410768 not isinstance(values, str)
1094510769 and not all([isinstance(v, str) for v in values])
1094610770 and None not in values
@@ -10949,7 +10773,7 @@ def pivot_table(
1094910773 f"Not implemented non-string of list of string {values}."
1095010774 )
1095110775
10952- if columns is not None and (
10776+ if columns and (
1095310777 not isinstance(columns, str)
1095410778 and not all([isinstance(v, str) for v in columns])
1095510779 and None not in columns
@@ -13050,23 +12874,6 @@ def _make_fill_expression_for_column_wise_fillna(
1305012874 *columns_to_include,
1305112875 )
1305212876
13053- @register_query_compiler_method_not_implemented(
13054- ["DataFrame", "Series"],
13055- "fillna",
13056- UnsupportedArgsRule(
13057- unsupported_conditions=[
13058- (
13059- lambda kwargs: kwargs.get("value") is not None
13060- and kwargs.get("limit") is not None,
13061- "the 'limit' parameter with 'value' parameter is not yet supported",
13062- ),
13063- (
13064- lambda kwargs: kwargs.get("downcast") is not None,
13065- "the 'downcast' parameter is not yet supported",
13066- ),
13067- ]
13068- ),
13069- )
1307012877 def fillna(
1307112878 self,
1307212879 value: Optional[Union[Hashable, Mapping, "pd.DataFrame", "pd.Series"]] = None,
@@ -13312,15 +13119,6 @@ def fillna_expr(snowflake_quoted_id: str) -> SnowparkColumn:
1331213119 ).frame
1331313120 )
1331413121
13315- @register_query_compiler_method_not_implemented(
13316- "DataFrame",
13317- "dropna",
13318- UnsupportedArgsRule(
13319- unsupported_conditions=[
13320- ("axis", 1),
13321- ]
13322- ),
13323- )
1332413122 def dropna(
1332513123 self,
1332613124 axis: int,
@@ -22018,7 +21816,7 @@ def _stack_helper(
2201821816 return qc
2201921817
2202021818 @register_query_compiler_method_not_implemented(
22021- api_cls_names ="DataFrame",
21819+ api_cls_name ="DataFrame",
2202221820 method_name="corr",
2202321821 unsupported_args=UnsupportedArgsRule(
2202421822 unsupported_conditions=[
0 commit comments