Skip to content

Commit 7f59deb

Browse files
SNOW-2396077: Support dtype parameter in get_dummies. (#3879)
The "dtype" parameter controls the values of the indicator variables, e.g. dtype=int means that we use 1 and 0 instead of True and False, respectively. Signed-off-by: sfc-gh-mvashishtha <mahesh.vashishtha@snowflake.com> Co-authored-by: Hazem Elmeleegy <hazem.elmeleegy@snowflake.com>
1 parent b524000 commit 7f59deb

File tree

8 files changed

+301
-101
lines changed

8 files changed

+301
-101
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@
6464

6565
### Snowpark pandas API Updates
6666

67+
#### New Features
68+
- Added support for the `dtypes` parameter of `pd.get_dummies`
69+
6770
#### Improvements
6871

6972
- Improved performance of `Series.to_snowflake` and `pd.to_snowflake(series)` for large data by uploading data via a parquet file. You can control the dataset size at which Snowpark pandas switches to parquet with the variable `modin.config.PandasToSnowflakeParquetThresholdBytes`.

docs/source/modin/supported/general_supported.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ Data manipulations
3232
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
3333
| ``from_dummies`` | N | | |
3434
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
35-
| ``get_dummies`` | P | ``sparse`` is ignored | ``Y`` if params ``dummy_na``, ``drop_first`` |
36-
| | | | and ``dtype`` are default, otherwise ``N`` |
35+
| ``get_dummies`` | P | ``sparse`` is ignored | ``Y`` if params ``dummy_na`` and ``drop_first`` |
36+
| | | | are default, otherwise ``N`` |
3737
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
3838
| ``json_normalize`` | Y | | |
3939
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+

src/snowflake/snowpark/modin/plugin/_internal/get_dummies_utils.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,18 @@
33
#
44

55
from collections.abc import Hashable
6+
from typing import Any
7+
8+
from pandas.api.types import (
9+
is_bool_dtype,
10+
is_datetime64_any_dtype,
11+
is_float_dtype,
12+
is_integer_dtype,
13+
is_object_dtype,
14+
is_timedelta64_dtype,
15+
is_string_dtype,
16+
)
17+
import pandas as native_pd
618

719
from snowflake.snowpark.functions import (
820
col,
@@ -37,6 +49,7 @@ def single_get_dummies_pivot(
3749
pivot_column_snowflake_quoted_identifier: str,
3850
columns_to_keep_snowflake_quoted_identifiers: list[str],
3951
columns_to_keep_pandas_labels: list[Hashable],
52+
dummy_false: Any,
4053
) -> InternalFrame:
4154
"""
4255
Helper function for get dummies to perform a single pivot on the encoded column.
@@ -51,6 +64,7 @@ def single_get_dummies_pivot(
5164
internal_frame to keep as the data column of final result internal frame.
5265
columns_to_keep_pandas_labels: The pandas label in the internal_frame to keep as the
5366
data_column of final result internal frame.
67+
dummy_false: The scalar value representing that a particular column value is not present.
5468
5569
Note: columns_to_keep_snowflake_quoted_identifiers must be the same length as columns_to_keep_pandas_labels
5670
Returns:
@@ -93,7 +107,7 @@ def single_get_dummies_pivot(
93107
columns_snowflake_quoted_identifier
94108
)
95109
# Perform pivot on the pivot column with dummy lit true column as value column.
96-
# With the above example, the result of pivot will be:
110+
# With the above example, the result of pivot will be (assuming dtype is bool):
97111
#
98112
# C a b
99113
# 0 1 True False
@@ -102,7 +116,7 @@ def single_get_dummies_pivot(
102116
pivoted_ordered_dataframe = ordered_dataframe.pivot(
103117
col(str(pivot_column_snowflake_quoted_identifier)),
104118
None,
105-
0,
119+
pandas_lit(dummy_false),
106120
min_(lit_true_column_snowflake_quoted_identifier),
107121
)
108122
pivoted_ordered_dataframe = pivoted_ordered_dataframe.sort(
@@ -179,11 +193,41 @@ def single_get_dummies_pivot(
179193
)
180194

181195

196+
def _get_dummies_true_and_false_values(dtype: Any) -> tuple[Any, Any]:
197+
"""
198+
Get the indicator values repsresenting whether a column is equal to a particular value.
199+
200+
Args:
201+
dtype: The dtype of the indicator column.
202+
203+
Returns:
204+
A tuple of the indicator values. The first value reprsents that the
205+
value is present, and the second value represents that the value is not
206+
present.
207+
"""
208+
if is_object_dtype(dtype):
209+
raise ValueError("dtype=object is not a valid dtype for get_dummies")
210+
if is_string_dtype(dtype):
211+
return ("1", "")
212+
if is_bool_dtype(dtype) or dtype is None:
213+
return (True, False)
214+
if is_integer_dtype(dtype):
215+
return (1, 0)
216+
if is_float_dtype(dtype):
217+
return (1.0, 0.0)
218+
if is_datetime64_any_dtype(dtype):
219+
return (native_pd.Timestamp(1), native_pd.Timestamp(0))
220+
if is_timedelta64_dtype(dtype):
221+
ErrorMessage.not_implemented_for_timedelta(method="get_dummies")
222+
raise TypeError(f"data type '{dtype}' not understood")
223+
224+
182225
def get_dummies_helper(
183226
internal_frame: InternalFrame,
184227
columns: list[Hashable],
185228
prefixes: list[Hashable],
186229
prefix_sep: str,
230+
dtype: Any,
187231
dummy_row_pos_mode: bool = False,
188232
) -> InternalFrame:
189233
"""
@@ -222,11 +266,12 @@ def get_dummies_helper(
222266
f"get_dummies with duplicated columns {pandas_label}"
223267
)
224268

225-
# append a lit true column as value column for pivot
269+
dummy_true, dummy_false = _get_dummies_true_and_false_values(dtype)
270+
271+
# the dummy column is appended as the last data column of the new_internal_frame
226272
new_internal_frame = internal_frame.ensure_row_position_column(
227273
dummy_row_pos_mode
228-
).append_column(LIT_TRUE_COLUMN_PANDAS_LABEL, pandas_lit(True))
229-
# the dummy column is appended as the last data column of the new_internal_frame
274+
).append_column(LIT_TRUE_COLUMN_PANDAS_LABEL, pandas_lit(dummy_true))
230275
row_position_column_snowflake_quoted_identifier = (
231276
new_internal_frame.row_position_snowflake_quoted_identifier
232277
)
@@ -266,7 +311,7 @@ def get_dummies_helper(
266311

267312
# Do the first pivot with the first column and keep all remaining columns.
268313
# With the example given above, the first pivot is performed on column A, and we will
269-
# get the following result:
314+
# get the following result (assuming dtype is int):
270315
# C A_a A_b
271316
# 0 1 1 0
272317
# 1 2 0 1
@@ -278,6 +323,7 @@ def get_dummies_helper(
278323
pivot_column_snowflake_quoted_identifier=grouped_quoted_identifiers[0][0],
279324
columns_to_keep_snowflake_quoted_identifiers=remaining_data_column_snowflake_quoted_identifiers,
280325
columns_to_keep_pandas_labels=remaining_data_column_pandas_labels,
326+
dummy_false=dummy_false,
281327
)
282328

283329
# Perform pivot on rest columns and join on the row position column to form the final result.
@@ -294,6 +340,7 @@ def get_dummies_helper(
294340
pivot_column_snowflake_quoted_identifier=grouped_quoted_identifiers[i][0],
295341
columns_to_keep_snowflake_quoted_identifiers=[],
296342
columns_to_keep_pandas_labels=[],
343+
dummy_false=dummy_false,
297344
)
298345
result_internal_frame = join_utils.join(
299346
result_internal_frame,

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7003,10 +7003,6 @@ def groupby_pct_change(
70037003
unsupported_conditions=[
70047004
("dummy_na", True),
70057005
("drop_first", True),
7006-
(
7007-
lambda args: args.get("dtype") is not None,
7008-
"get_dummies with non-default dtype parameter is not supported yet in Snowpark pandas.",
7009-
),
70107006
]
70117007
),
70127008
)
@@ -7049,9 +7045,9 @@ def get_dummies(
70497045
"""
70507046
self._raise_not_implemented_error_for_timedelta()
70517047

7052-
if dummy_na is True or drop_first is True or dtype is not None:
7048+
if dummy_na is True or drop_first is True:
70537049
ErrorMessage.not_implemented(
7054-
"get_dummies with non-default dummy_na, drop_first, and dtype parameters"
7050+
"get_dummies with non-default dummy_na or drop_first parameters"
70557051
+ " is not supported yet in Snowpark pandas."
70567052
)
70577053
if columns is None:
@@ -7095,6 +7091,7 @@ def get_dummies(
70957091
columns=columns,
70967092
prefixes=prefix,
70977093
prefix_sep=prefix_sep,
7094+
dtype=dtype,
70987095
)
70997096
query_compiler = SnowflakeQueryCompiler(result_internal_frame)
71007097

tests/integ/modin/strings/test_get_dummies_dataframe.py renamed to tests/integ/modin/frame/test_get_dummies.py

Lines changed: 91 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,15 @@
99
import pytest
1010

1111
import snowflake.snowpark.modin.plugin # noqa: F401
12+
from pytest import param
1213
from snowflake.snowpark._internal.utils import (
1314
TempObjectType,
1415
random_name_for_temp_object,
1516
)
16-
from tests.integ.modin.utils import assert_snowpark_pandas_equal_to_pandas
17+
from tests.integ.modin.utils import (
18+
assert_snowpark_pandas_equal_to_pandas,
19+
eval_snowpark_pandas_result,
20+
)
1721
from tests.integ.utils.sql_counter import sql_count_checker
1822

1923

@@ -246,25 +250,98 @@ def test_get_dummies_pandas_after_read_snowflake(session):
246250
assert_snowpark_pandas_equal_to_pandas(snow_get_dummies, pandas_get_dummies)
247251

248252

249-
@sql_count_checker(query_count=0)
250-
def test_get_dummies_pandas_negative():
251-
252-
pandas_df = native_pd.DataFrame(
253-
{"A": ["a", "b", "a"], "B": ["b", "a", "c"], "C": [1, 2, 3]}
253+
class TestDtypeParameter:
254+
@pytest.mark.parametrize(
255+
"dtype",
256+
[
257+
np.int64,
258+
int,
259+
"int",
260+
float,
261+
np.float64,
262+
"float64",
263+
str,
264+
"str",
265+
np.str_,
266+
bool,
267+
"bool",
268+
np.bool_,
269+
"datetime64[ns]",
270+
param(
271+
"timedelta64[ns]",
272+
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
273+
),
274+
None,
275+
],
254276
)
277+
@sql_count_checker(query_count=1)
278+
def test_valid_dtype(self, dtype):
279+
pandas_df = native_pd.DataFrame({"A": ["a", "b", "a"]})
280+
snow_df = pd.DataFrame(pandas_df)
281+
# note that we're using the default check_dtype=True to check that we
282+
# are producing the correct dtypes.
283+
assert_snowpark_pandas_equal_to_pandas(
284+
pd.get_dummies(snow_df, dtype=dtype),
285+
native_pd.get_dummies(pandas_df, dtype=dtype),
286+
)
255287

256-
snow_df = pd.DataFrame(pandas_df)
288+
@sql_count_checker(query_count=1)
289+
def test_valid_dtype_argument_int32(self):
290+
"""Test int32 separately because Snowpark pandas always produces int64 for integers."""
291+
pandas_df = native_pd.DataFrame({"A": ["a", "b", "a"]})
292+
snow_df = pd.DataFrame(pandas_df)
293+
snow_result = pd.get_dummies(snow_df, dtype=np.int32)
294+
pandas_result = native_pd.get_dummies(pandas_df, dtype=np.int32)
295+
# note that we're using the default check_dtype=True to check that we
296+
# are producing the correct dtypes.
297+
assert_snowpark_pandas_equal_to_pandas(
298+
snow_result, pandas_result.astype(np.int64)
299+
)
257300

258-
with pytest.raises(NotImplementedError):
259-
pd.get_dummies(
260-
snow_df,
261-
prefix=["col1", "col2"],
262-
dummy_na=True,
263-
drop_first=True,
264-
dtype=np.int32,
301+
@sql_count_checker(query_count=0)
302+
def test_invalid_dtype_argument(self):
303+
eval_snowpark_pandas_result(
304+
pd,
305+
native_pd,
306+
lambda module: module.get_dummies(
307+
module.DataFrame({"A": ["a", "b", "a"]}), dtype="invalid_dtype"
308+
),
309+
expect_exception=True,
310+
expect_exception_type=TypeError,
311+
expect_exception_match=re.escape(
312+
"data type 'invalid_dtype' not understood"
313+
),
314+
)
315+
316+
@sql_count_checker(query_count=0)
317+
@pytest.mark.parametrize("dtype", ["object", np.dtype("object")])
318+
def test_invalid_dtype_argument_object(self, dtype):
319+
eval_snowpark_pandas_result(
320+
pd,
321+
native_pd,
322+
lambda module: module.get_dummies(
323+
module.DataFrame({"A": ["a", "b", "a"]}), dtype=dtype
324+
),
325+
expect_exception=True,
326+
expect_exception_type=ValueError,
327+
expect_exception_match=re.escape(
328+
"dtype=object is not a valid dtype for get_dummies"
329+
),
265330
)
266331

267332

333+
@sql_count_checker(query_count=0)
334+
def test_dummy_na_negative():
335+
with pytest.raises(NotImplementedError):
336+
pd.get_dummies(pd.DataFrame(["a", None]), dummy_na=True)
337+
338+
339+
@sql_count_checker(query_count=0)
340+
def test_drop_first_negative():
341+
with pytest.raises(NotImplementedError):
342+
pd.get_dummies(pd.DataFrame(["a", "b"]), drop_first=True)
343+
344+
268345
@sql_count_checker(query_count=0)
269346
def test_get_dummies_pandas_negative_duplicated_columns():
270347
pandas_df = native_pd.DataFrame(

tests/integ/modin/hybrid/test_switch_operations.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,6 @@ def test_auto_switch_supported_post_op_switch_point_series(method, kwargs):
884884
"get_dummies",
885885
{"drop_first": True},
886886
),
887-
("get_dummies", {"dtype": int}),
888887
],
889888
)
890889
def test_auto_switch_unsupported_top_level_functions(method, kwargs):
@@ -1039,11 +1038,6 @@ def test_auto_switch_unsupported_series(method, kwargs):
10391038
{"drop_first": True},
10401039
"drop_first = True is not supported",
10411040
),
1042-
(
1043-
"get_dummies",
1044-
{"dtype": int},
1045-
"get_dummies with non-default dtype parameter is not supported yet in Snowpark pandas.",
1046-
),
10471041
],
10481042
)
10491043
@sql_count_checker(query_count=0)

0 commit comments

Comments
 (0)