Skip to content

Commit 637905b

Browse files
SNOW-1805851: Add scikit-learn interoperability tests. (#2796)
- Move the existing scikit-learn interoperability test into the `interoperability` test folder. - Refactor the test file so that it groups test cases into scikit-learn method categories, like classification and regression. - Add test cases so we have one test case for each category. - Remove the scikit-learn version pin for the modin dev environment. Once we implemented the interchange protocol, we no longer needed that pin. - Document the support for the methods that we've tested. - Reorganize the interoperability documentation page so that it follows the rst guidelines for sections and subsections. Signed-off-by: sfc-gh-mvashishtha <[email protected]> Co-authored-by: Hazem Elmeleegy <[email protected]>
1 parent 20cfeed commit 637905b

File tree

5 files changed

+313
-31
lines changed

5 files changed

+313
-31
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
- Updated integration testing for `session.lineage.trace` to exclude deleted objects
5151
- Added documentation for `DataFrame.map`.
5252
- Improve performance of `DataFrame.apply` by mapping numpy functions to snowpark functions if possible.
53+
- Added documentation on the extent of Snowpark pandas interoperability with scikit-learn
5354

5455
## 1.26.0 (2024-12-05)
5556

docs/source/modin/interoperability.rst

Lines changed: 100 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
===========================================
12
Interoperability with third party libraries
2-
=============================================
3+
===========================================
34

45
Many third party libraries are interoperable with pandas, for example by accepting pandas dataframes objects as function
56
inputs. Here we have a non-exhaustive list of third party library use cases with pandas and note whether each method
@@ -8,15 +9,17 @@ works in Snowpark pandas as well.
89
Snowpark pandas supports the `dataframe interchange protocol <https://data-apis.org/dataframe-protocol/latest/>`_, which
910
some libraries use to interoperate with Snowpark pandas to the same level of support as pandas.
1011

11-
The following table is structured as follows: The first column contains a method name.
12+
plotly.express
13+
==============
14+
15+
The following table is structured as follows: The first column contains the name of a method in the ``plotly.express`` module.
1216
The second column is a flag for whether or not interoperability is guaranteed with Snowpark pandas. For each of these
13-
methods, we validate that passing in a Snowpark pandas dataframe as the dataframe input parameter behaves equivalently
14-
to passing in a pandas dataframe.
17+
operations, we validate that passing in Snowpark pandas dataframes or series as the data inputs behaves equivalently
18+
to passing in pandas dataframes or series.
1519

1620
.. note::
1721
``Y`` stands for yes, i.e., interoperability is guaranteed with this method, and ``N`` stands for no.
1822

19-
Plotly.express module methods
2023

2124
.. note::
2225
Currently only plotly versions <6.0.0 are supported through the dataframe interchange protocol.
@@ -56,3 +59,95 @@ Plotly.express module methods
5659
+-------------------------+---------------------------------------------+--------------------------------------------+
5760
| ``imshow`` | Y | |
5861
+-------------------------+---------------------------------------------+--------------------------------------------+
62+
63+
64+
scikit-learn
65+
============
66+
67+
We break down scikit-learn interoperability by categories of scikit-learn
68+
operations.
69+
70+
For each category, we provide a table of interoperability with the following
71+
structure: The first column describes a scikit-learn operation that may include
72+
multiple method calls. The second column is a flag for whether or not
73+
interoperability is guaranteed with Snowpark pandas. For each of these methods,
74+
we validate that passing in Snowpark pandas objects behaves equivalently to
75+
passing in pandas objects.
76+
77+
.. note::
78+
``Y`` stands for yes, i.e., interoperability is guaranteed with this method, and ``N`` stands for no.
79+
80+
.. note::
81+
While some scikit-learn methods accept Snowpark pandas inputs, their
82+
performance with Snowpark pandas inputs is often much worse than their
83+
performance with native pandas inputs. Generally we recommend converting
84+
Snowpark pandas inputs to pandas with ``to_pandas()`` before passing them
85+
to scikit-learn.
86+
87+
88+
Classification
89+
--------------
90+
91+
+--------------------------------------------+---------------------------------------------+---------------------------------+
92+
| Operation | Interoperable with Snowpark pandas? (Y/N) | Notes for current implementation|
93+
+--------------------------------------------+---------------------------------------------+---------------------------------+
94+
| Fitting a ``LinearDiscriminantAnalysis`` | Y | |
95+
| classifier with the ``fit()`` method and | | |
96+
| classifying data with the ``predict()`` | | |
97+
| method. | | |
98+
+--------------------------------------------+---------------------------------------------+---------------------------------+
99+
100+
101+
Regression
102+
----------
103+
104+
+--------------------------------------------+---------------------------------------------+---------------------------------+
105+
| Operation | Interoperable with Snowpark pandas? (Y/N) | Notes for current implementation|
106+
+--------------------------------------------+---------------------------------------------+---------------------------------+
107+
| Fitting a ``LogisticRegression`` model | Y | |
108+
| with the ``fit()`` method and predicting | | |
109+
| results with the ``predict()`` method. | | |
110+
+--------------------------------------------+---------------------------------------------+---------------------------------+
111+
112+
Clustering
113+
----------
114+
115+
+--------------------------------------------+---------------------------------------------+---------------------------------+
116+
| Clustering method | Interoperable with Snowpark pandas? (Y/N) | Notes for current implementation|
117+
+--------------------------------------------+---------------------------------------------+---------------------------------+
118+
| ``KMeans.fit()`` | Y | |
119+
+--------------------------------------------+---------------------------------------------+---------------------------------+
120+
121+
122+
Dimensionality reduction
123+
------------------------
124+
125+
+--------------------------------------------+---------------------------------------------+---------------------------------+
126+
| Operation | Interoperable with Snowpark pandas? (Y/N) | Notes for current implementation|
127+
+--------------------------------------------+---------------------------------------------+---------------------------------+
128+
| Getting the principal components of a | Y | |
129+
| numerical dataset with ``PCA.fit()``. | | |
130+
+--------------------------------------------+---------------------------------------------+---------------------------------+
131+
132+
133+
Model selection
134+
------------------------
135+
136+
+--------------------------------------------+---------------------------------------------+-----------------------------------------------+
137+
| Operation | Interoperable with Snowpark pandas? (Y/N) | Notes for current implementation |
138+
+--------------------------------------------+---------------------------------------------+-----------------------------------------------+
139+
| Choosing parameters for a | Y | ``RandomizedSearchCV`` causes Snowpark pandas |
140+
| ``LogisticRegression`` model with | | to issue many queries. We strongly recommend |
141+
| ``RandomizedSearchCV.fit()``. | | converting Snowpark pandas inputs to pandas |
142+
| | | before using ``RandomizedSearchCV`` |
143+
+--------------------------------------------+---------------------------------------------+-----------------------------------------------+
144+
145+
Preprocessing
146+
-------------
147+
148+
+--------------------------------------------+---------------------------------------------+-----------------------------------------------+
149+
| Operation | Interoperable with Snowpark pandas? (Y/N) | Notes for current implementation |
150+
+--------------------------------------------+---------------------------------------------+-----------------------------------------------+
151+
| Scaling training data with | Y | |
152+
| ``MaxAbsScaler.fit_transform()``. | | |
153+
+--------------------------------------------+---------------------------------------------+-----------------------------------------------+

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def run(self):
200200
*DEVELOPMENT_REQUIREMENTS,
201201
"scipy", # Snowpark pandas 3rd party library testing
202202
"statsmodels", # Snowpark pandas 3rd party library testing
203-
"scikit-learn==1.5.2", # Snowpark pandas scikit-learn tests
203+
"scikit-learn", # Snowpark pandas 3rd party library testing
204204
# plotly version restricted due to foreseen change in query counts in version 6.0.0+
205205
"plotly<6.0.0", # Snowpark pandas 3rd party library testing
206206
],
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
#
2+
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
from sklearn.decomposition import PCA
6+
from sklearn.preprocessing import MaxAbsScaler
7+
8+
import snowflake.snowpark.modin.plugin # noqa: F401
9+
from tests.integ.utils.sql_counter import sql_count_checker
10+
from sklearn.linear_model import LogisticRegression
11+
from sklearn.model_selection import RandomizedSearchCV
12+
from sklearn.cluster import KMeans
13+
from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result
14+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
15+
import numpy as np
16+
import pytest
17+
18+
"""
19+
------
20+
README
21+
------
22+
23+
This test suite tests scikit-learn's interoperability with Snowpark pandas.
24+
25+
Generally, scikit-learn seems to work with Snowpark pandas inputs via a
26+
combination of the dataframe interchange protocol and converting Snowpark
27+
pandas inputs to numpy with methods like __array__() and np.asarray(). Some
28+
scikit-learn methods may cause Snowpark pandas to execute many Snowflake
29+
queries or to materialize Snowpark pandas data one or more times. We don't
30+
plan to fix the performance of scikit-learn with Snowpark pandas inputs, and
31+
we recommend that users convert their data to native pandas before passing it
32+
to scikit-learn if scikit-learn takes too long with Snowpark pandas inputs.
33+
34+
We group the tests into scenarios into the following use cases, listed under
35+
https://scikit-learn.org/stable/index.html:
36+
37+
- Classification
38+
- Regression
39+
- Clustering
40+
- Dimensionality reduction
41+
- Model selection
42+
- Preprocessing
43+
44+
Many scikit-learn methods produce non-deterministic results, and not all of
45+
them provide a way to seed the results so that they are consistent for a test.
46+
Generally, we only validate that 1) we can pass Snowpark pandas dataframe/series
47+
into methods that accept native pandas inputs and 2) the outputs have the correct
48+
type and, in case they are numpy arrays, they have the correct shape.
49+
50+
To test interoperability with a particular scikit-learn method:
51+
52+
1) Read about what the method does and how to use it
53+
2) Start writing a test case under the test class for the category that the
54+
method belongs to (e.g. under TestClassification for
55+
LinearDiscriminantAnalysis)
56+
2) Construct a use case that works with native pandas and produces a meaningful
57+
result (for example, train a model on pandas training data and fit it to test
58+
data)
59+
3) Write a test case checking that replacing the pandas input with Snowpark
60+
pandas produces results of the same type and, in the case of array-like
61+
outputs, of the same dimensions. `assert_numpy_results_valid` can validate
62+
numpy results. Avoid checking that the values in the result are the same
63+
values we would get if we use pandas, because many scikit-learn methods
64+
are non-deterministic.
65+
4) Wrap the test with an empty sql_count_checker() decorator to see how many
66+
queries and joins it requires. If it it requires a very large number of
67+
queries, see if you can simplify the test case so that it causes fewer
68+
queries, so that the test finishes quickly. If you can't reduce the number of
69+
queries to a reasonable level, you should pass the SQL count checker the
70+
`no_check=True` parameter because the number of queries is likely to vary
71+
across scikit-learn and Snowpark pandas versions, and we don't gain much
72+
insight by adjusting the query count every time it changes.
73+
5) Add a row describing interoperability with the new method in the
74+
[documentation](docs/source/modin/interoperability.rst)
75+
"""
76+
77+
78+
def assert_numpy_results_valid(snow_result, pandas_result) -> None:
79+
assert isinstance(snow_result, np.ndarray)
80+
assert isinstance(pandas_result, np.ndarray)
81+
# Generally a meaningful test case should produce a non-empty result
82+
assert pandas_result.size > 0
83+
assert snow_result.shape == pandas_result.shape
84+
85+
86+
@pytest.fixture()
87+
def test_dfs():
88+
data = {
89+
"feature1": [1, 5, 3, 4, 4, 6, 7, 2, 9, 70],
90+
"feature2": [2, 4, 1, 3, 5, 7, 6, 3, 10, 9],
91+
"target": [0, 0, 1, 0, 1, 1, 1, 0, 1, 0],
92+
}
93+
return create_test_dfs(data)
94+
95+
96+
class TestClassification:
97+
@sql_count_checker(query_count=6)
98+
def test_linear_discriminant_analysis(self, test_dfs):
99+
def get_predictions(df) -> np.ndarray:
100+
X = df[["feature1", "feature2"]]
101+
y = df["target"]
102+
train_size = 8
103+
X_train, X_test = X.iloc[:train_size], X.iloc[train_size:]
104+
y_train = y.iloc[:train_size]
105+
return LinearDiscriminantAnalysis().fit(X_train, y_train).predict(X_test)
106+
107+
eval_snowpark_pandas_result(
108+
*test_dfs, get_predictions, comparator=assert_numpy_results_valid
109+
)
110+
111+
112+
class TestRegression:
113+
@sql_count_checker(query_count=6)
114+
def test_logistic_regression(self, test_dfs):
115+
def get_predictions(df) -> np.ndarray:
116+
X = df[["feature1", "feature2"]]
117+
y = df["target"]
118+
train_size = 8
119+
X_train, X_test = X.iloc[:train_size], X.iloc[train_size:]
120+
y_train = y.iloc[:train_size]
121+
return LogisticRegression().fit(X_train, y_train).predict(X_test)
122+
123+
eval_snowpark_pandas_result(
124+
*test_dfs, get_predictions, comparator=assert_numpy_results_valid
125+
)
126+
127+
128+
class TestClustering:
129+
@sql_count_checker(query_count=3)
130+
def test_clustering(self, test_dfs):
131+
def get_cluster_centers(df) -> np.ndarray:
132+
return KMeans(n_clusters=3).fit(df).cluster_centers_
133+
134+
eval_snowpark_pandas_result(
135+
*test_dfs, get_cluster_centers, comparator=assert_numpy_results_valid
136+
)
137+
138+
139+
class TestDimensionalityReduction:
140+
@sql_count_checker(query_count=3)
141+
def test_principal_component_analysis(self, test_dfs):
142+
def get_principal_components(df) -> np.ndarray:
143+
return PCA(n_components=2).fit(df).components_
144+
145+
eval_snowpark_pandas_result(
146+
*test_dfs, get_principal_components, comparator=assert_numpy_results_valid
147+
)
148+
149+
150+
class TestModelSelection:
151+
@sql_count_checker(
152+
# Model search is a complex, iterative process. Pushing it down to
153+
# Snowflake requires many queries (approximately 31 for this case).
154+
# Since the number of queries and the number of joins are so large, they
155+
# are likely to change due to changes in both scikit-learn and Snowpark
156+
# pandas. We don't get much insight from the exact number of queries, so
157+
# we skip the query count check. The recommended solution to this query
158+
# explosion is for users to convert the Snowpark pandas object to pandas
159+
# with to_pandas() and pass the result to scikit-learn.
160+
no_check=True
161+
)
162+
def test_randomized_search_cv(self, test_dfs):
163+
def get_best_estimator(df) -> dict:
164+
# Initialize the hyperparameter search with parameters that will
165+
# reduce the search time as much as possible.
166+
return (
167+
RandomizedSearchCV(
168+
LogisticRegression(),
169+
param_distributions={
170+
"C": [0.001],
171+
},
172+
# cv=2 means 2-fold validation, which requires the fewest queries.
173+
cv=2,
174+
# Test just one combination of parameters.
175+
n_iter=1,
176+
# refit=False means that the search doesn't have to actually
177+
# train a model using the parameters that it chooses. Setting
178+
# refit=False should further reduce the number of queries.
179+
refit=False,
180+
)
181+
.fit(df[["feature1", "feature2"]], df["target"])
182+
.best_params_
183+
)
184+
185+
def validate_search_results(snow_estimator, pandas_estimator):
186+
assert isinstance(snow_estimator, dict)
187+
assert isinstance(pandas_estimator, dict)
188+
189+
eval_snowpark_pandas_result(
190+
*test_dfs, get_best_estimator, comparator=validate_search_results
191+
)
192+
193+
194+
class TestPreprocessing:
195+
@sql_count_checker(query_count=5)
196+
def test_maxabs(self, test_dfs):
197+
eval_snowpark_pandas_result(
198+
*test_dfs,
199+
MaxAbsScaler().fit_transform,
200+
comparator=assert_numpy_results_valid
201+
)
202+
203+
204+
"""
205+
------
206+
README
207+
------
208+
209+
Please see the README at the top of this file for instructions on adding test
210+
cases.
211+
"""

tests/integ/modin/test_scikit.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

0 commit comments

Comments
 (0)