Skip to content

Commit 1585def

Browse files
authored
SNOW-2022995: Enable automatic backend switching for unimplemented pandas methods (#3512)
When AutoSwitchBackend (hybrid execution PrPr) is enabled, the client will automatically switch to the pandas backend when a method registered via a register_*_not_implemented annotation is called.
1 parent a80973d commit 1585def

File tree

8 files changed

+111
-8
lines changed

8 files changed

+111
-8
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ local ingestion. By default, local ingestion uses multithreading. Multiprocessin
4747
- Support `pd.read_snowflake` when the global modin backend is `Pandas`.
4848

4949
#### Improvements
50-
- Add a data type guard to the cost functions for hybrid execution which checks for data type compatibility
50+
- Add a data type guard to the cost functions for hybrid execution mode (PrPr) which checks for data type compatibility.
51+
- Added automatic switching to the pandas backend in hybrid execution mode (PrPr) for many methods that are not directly implemented in Snowpark pandas.
5152

5253
#### Dependency Updates
5354
- Added tqdm and ipywidgets as dependencies so that progress bars appear when switching between modin backends.

docs/source/modin/hybrid_execution.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ Example Pre-Operation Switchpoints:
1919
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2020
apply, iterrows, itertuples, items, plot, quantile, __init__, plot, quantile, T, read_csv, read_json, concat, merge
2121

22+
Many methods that are not yet implemented in Snowpark pandas are also registered as
23+
pre-operation switch points, and will automatically move data to local pandas for execution when
24+
called. This includes most methods that are ordinarily completely unsupported by Snowpark pandas,
25+
and have `N` in their implemented status in the :doc:`DataFrame <supported/dataframe_supported>` and
26+
:doc:`Series <supported/series_supported>` supported API lists.
27+
2228
Post-Operation Switchpoints:
2329
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2430
read_snowflake, value_counts, tail, var, std, sum, sem, max, min, mean, agg, aggregate, count, nunique, cummax, cummin, cumprod, cumsum

src/snowflake/snowpark/modin/plugin/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@
185185
AutoSwitchBackend().disable()
186186

187187
# Hybrid Mode Registration
188+
# In hybrid execution mode, the client will automatically switch backends when a
189+
# wholly-unimplemented method is called. Those switch points are registered separately in
190+
# extensions files via the register_*_not_implemented family of methods.
188191
pre_op_switch_points: list[dict[str, Union[str, None]]] = [
189192
{"class_name": "DataFrame", "method": "__init__"},
190193
{"class_name": "Series", "method": "__init__"},

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,18 @@
1818
from datetime import timedelta, tzinfo
1919
from functools import reduce
2020
from types import MappingProxyType
21-
from typing import Any, Callable, List, Literal, Optional, TypeVar, Union, get_args
21+
from typing import (
22+
Any,
23+
Callable,
24+
List,
25+
Literal,
26+
Optional,
27+
TypeVar,
28+
Union,
29+
get_args,
30+
Set,
31+
Tuple,
32+
)
2233

2334
import modin.pandas as pd
2435
from modin.pandas import Series, DataFrame
@@ -481,6 +492,9 @@
481492
HYBRID_ALL_EXPENSIVE_METHODS = (
482493
HYBRID_HIGH_OVERHEAD_METHODS + HYBRID_ITERATIVE_STYLE_METHODS
483494
)
495+
# Set of (class name, method name) tuples for methods that are wholly unimplemented by
496+
# Snowpark pandas. This list is populated by the register_*_not_implemented decorators.
497+
HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS: Set[Tuple[str, str]] = set()
484498

485499
T = TypeVar("T", bound=Callable[..., Any])
486500

@@ -858,7 +872,10 @@ def stay_cost(
858872
operation: str,
859873
arguments: MappingProxyType[str, Any],
860874
) -> Optional[int]:
861-
if self._is_in_memory_init(api_cls_name, operation, arguments):
875+
if (
876+
self._is_in_memory_init(api_cls_name, operation, arguments)
877+
or (api_cls_name, operation) in HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS
878+
):
862879
return QCCoercionCost.COST_IMPOSSIBLE
863880
# Strongly discourage the use of these methods in snowflake
864881
if operation in HYBRID_ALL_EXPENSIVE_METHODS:
@@ -891,9 +908,11 @@ def move_to_me_cost(
891908
None if the cost cannot be determined.
892909
"""
893910
# in-memory intialization should not move to Snowflake
894-
if cls._is_in_memory_init(api_cls_name, operation, arguments):
895-
return QCCoercionCost.COST_IMPOSSIBLE
896-
if not cls._are_dtypes_compatible_with_snowflake(other_qc):
911+
if (
912+
cls._is_in_memory_init(api_cls_name, operation, arguments)
913+
or not cls._are_dtypes_compatible_with_snowflake(other_qc)
914+
or (api_cls_name, operation) in HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS
915+
):
897916
return QCCoercionCost.COST_IMPOSSIBLE
898917
# Strongly discourage the use of these methods in snowflake
899918
if operation in HYBRID_ALL_EXPENSIVE_METHODS:

src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@
6868
MODIN_IS_AT_LEAST_0_34_0,
6969
)
7070
from snowflake.snowpark.modin.plugin._typing import ListLike
71+
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
72+
HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS,
73+
)
7174
from snowflake.snowpark.modin.plugin.extensions.utils import (
7275
ensure_index,
7376
extract_validate_and_try_convert_named_aggs_from_kwargs,
@@ -91,14 +94,23 @@
9194

9295
if MODIN_IS_AT_LEAST_0_33_0:
9396
from modin.pandas.api.extensions import register_base_accessor
97+
from modin.core.storage_formats.pandas.query_compiler_caster import (
98+
register_function_for_pre_op_switch,
99+
)
94100

95101
register_base_override = functools.partial(
96102
register_base_accessor, backend="Snowflake"
97103
)
98104

99105
def register_base_not_implemented():
100106
def decorator(base_method: Any):
101-
return register_base_override(name=base_method.__name__)(
107+
name = base_method.__name__
108+
HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(("BasePandasDataset", name))
109+
if MODIN_IS_AT_LEAST_0_33_0:
110+
register_function_for_pre_op_switch(
111+
class_name="BasePandasDataset", backend="Snowflake", method=name
112+
)
113+
return register_base_override(name=name)(
102114
base_not_implemented()(base_method)
103115
)
104116

src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
from snowflake.snowpark.modin.plugin._typing import ListLike
8686
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
8787
SnowflakeQueryCompiler,
88+
HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS,
8889
)
8990
from snowflake.snowpark.modin.plugin.extensions.index import Index
9091
from snowflake.snowpark.modin.plugin.extensions.snow_partition_iterator import (
@@ -123,6 +124,9 @@
123124
from modin.pandas.api.extensions import (
124125
register_dataframe_accessor as _register_dataframe_accessor,
125126
)
127+
from modin.core.storage_formats.pandas.query_compiler_caster import (
128+
register_function_for_pre_op_switch,
129+
)
126130

127131
register_dataframe_accessor = functools.partial(
128132
_register_dataframe_accessor, backend="Snowflake"
@@ -138,7 +142,13 @@
138142
def register_dataframe_not_implemented():
139143
def decorator(base_method: Any):
140144
func = dataframe_not_implemented()(base_method)
141-
register_dataframe_accessor(base_method.__name__)(func)
145+
name = base_method.__name__
146+
HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(("DataFrame", name))
147+
if MODIN_IS_AT_LEAST_0_33_0:
148+
register_function_for_pre_op_switch(
149+
class_name="DataFrame", backend="Snowflake", method=name
150+
)
151+
register_dataframe_accessor(name)(func)
142152
return func
143153

144154
return decorator

src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@
5050
error_checking_for_init,
5151
MODIN_IS_AT_LEAST_0_33_0,
5252
)
53+
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
54+
HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS,
55+
)
5356
from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike
5457
from snowflake.snowpark.modin.plugin.extensions.snow_partition_iterator import (
5558
SnowparkPandasRowPartitionIterator,
@@ -80,6 +83,9 @@
8083
from modin.pandas.api.extensions import (
8184
register_series_accessor as _register_series_accessor,
8285
)
86+
from modin.core.storage_formats.pandas.query_compiler_caster import (
87+
register_function_for_pre_op_switch,
88+
)
8389

8490
register_series_accessor = functools.partial(
8591
_register_series_accessor, backend="Snowflake"
@@ -96,6 +102,11 @@ def decorator(base_method: Any):
96102
if isinstance(base_method, property)
97103
else base_method.__name__
98104
)
105+
HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(("Series", name))
106+
if MODIN_IS_AT_LEAST_0_33_0:
107+
register_function_for_pre_op_switch(
108+
class_name="Series", backend="Snowflake", method=name
109+
)
99110
register_series_accessor(name)(func)
100111
return func
101112

tests/integ/modin/hybrid/test_switch_operations.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
from unittest.mock import patch
99
import tqdm.auto
1010

11+
import pandas as native_pd
1112
import numpy as np
1213
from numpy.testing import assert_array_equal
14+
from modin.config import context as config_context
1315
import modin.pandas as pd
1416
import snowflake.snowpark.modin.plugin # noqa: F401
1517
from snowflake.snowpark.modin.plugin._internal.utils import (
@@ -265,3 +267,42 @@ def test_tqdm_usage_during_snowflake_to_pandas_switch():
265267
df.set_backend("Pandas")
266268

267269
mock_trange.assert_called_once()
270+
271+
272+
@pytest.mark.parametrize(
273+
"class_name, method_name, f_args",
274+
[
275+
("DataFrame", "to_json", ()), # declared in base_overrides
276+
("Series", "to_json", ()), # declared in base_overrides
277+
("DataFrame", "dot", ([6],)), # declared in dataframe_overrides
278+
("Series", "transform", (lambda x: x * 2,)), # declared in series_overrides
279+
],
280+
)
281+
@sql_count_checker(query_count=1)
282+
def test_unimplemented_autoswitches(class_name, method_name, f_args):
283+
# Unimplemented methods declared via register_*_not_implemented should automatically
284+
# default to local pandas execution.
285+
# This test needs to be modified if any of the APIs in question are ever natively implemented
286+
# for Snowpark pandas.
287+
data = [1, 2, 3]
288+
method = getattr(getattr(pd, class_name)(data).move_to("Snowflake"), method_name)
289+
# Attempting to call the method without switching should raise.
290+
with config_context(AutoSwitchBackend=False):
291+
with pytest.raises(
292+
NotImplementedError, match="Snowpark pandas does not yet support the method"
293+
):
294+
method(*f_args)
295+
# Attempting to call the method while switching is enabled should work fine.
296+
snow_result = method(*f_args)
297+
pandas_result = getattr(getattr(native_pd, class_name)(data), method_name)(*f_args)
298+
if isinstance(snow_result, (pd.DataFrame, pd.Series)):
299+
assert snow_result.get_backend() == "Pandas"
300+
assert_array_equal(snow_result.to_numpy(), pandas_result.to_numpy())
301+
else:
302+
# Series.to_json will output an extraneous level for the __reduced__ column, but that's OK
303+
# since we don't officially support the method.
304+
# See modin bug: https://github.com/modin-project/modin/issues/7624
305+
if class_name == "Series" and method_name == "to_json":
306+
assert snow_result == '{"__reduced__":{"0":1,"1":2,"2":3}}'
307+
else:
308+
assert snow_result == pandas_result

0 commit comments

Comments
 (0)