Skip to content

Commit 05abdbe

Browse files
Merge branch 'main' into helmeleegy-SNOW-1842841
2 parents 19317e2 + 66dc14d commit 05abdbe

File tree

14 files changed

+444
-151
lines changed

14 files changed

+444
-151
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
#### Improvements
2020

2121
- Updated README.md to include instructions on how to verify package signatures using `cosign`.
22+
- Added an option `keep_column_order` for keeping original column order in `DataFrame.with_column` and `DataFrame.with_columns`.
2223

2324
#### Bug Fixes
2425

2526
- Fixed a bug in local testing mode that caused a column to contain None when it should contain 0
26-
- Fixed a bug in StructField.from_json that prevented TimestampTypes with tzinfo from being parsed correctly.
27+
- Fixed a bug in `StructField.from_json` that prevented TimestampTypes with tzinfo from being parsed correctly.
28+
- Fixed a bug in function `date_format` that caused an error when the input column was date type or timestamp type.
2729

2830
### Snowpark pandas API Updates
2931

@@ -49,6 +51,7 @@
4951
- %X: Locale’s appropriate time representation.
5052
- %%: A literal '%' character.
5153
- Added support for `Series.between`.
54+
- Added support for `include_groups=False` in `DataFrameGroupBy.apply`.
5255
- Added support for `DataFrame.pop` and `Series.pop`.
5356

5457
#### Bug Fixes

docs/source/modin/supported/groupby_supported.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ Function application
3939
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
4040
| ``apply`` | P | ``axis`` other than 0 is not | ``Y`` if the following are true, otherwise ``N``: |
4141
| | | implemented. | - ``func`` is a callable that always returns |
42-
| | | ``include_groups = False`` is | either a pandas DataFrame, a pandas Series, or |
43-
| | | not implemented. | objects that are neither DataFrame nor Series. |
42+
| | | | either a pandas DataFrame, a pandas Series, or |
43+
| | | | objects that are neither DataFrame nor Series. |
4444
| | | | - grouping on axis=0 |
4545
| | | | - Not applying transform to a dataframe with a |
4646
| | | | non-unique index |

pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[build-system]
2+
requires = [
3+
"setuptools",
4+
"protoc-wheel-0==21.1", # Protocol buffer compiler for Snowpark IR
5+
"mypy-protobuf", # used in generating typed Python code from protobuf for Snowpark IR
6+
]
7+
build-backend = "setuptools.build_meta"

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,8 @@
5858
"graphviz", # used in plot tests
5959
"pytest-assume", # sql counter check
6060
"decorator", # sql counter check
61-
"protoc-wheel-0==21.1", # Protocol buffer compiler, for Snowpark IR
62-
"mypy-protobuf", # used in generating typed Python code from protobuf for Snowpark IR
6361
"lxml", # used in read_xml tests
62+
"tox", # used for setting up testing environments
6463
]
6564

6665
# read the version

src/snowflake/snowpark/dataframe.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3688,6 +3688,8 @@ def with_column(
36883688
self,
36893689
col_name: str,
36903690
col: Union[Column, TableFunctionCall],
3691+
*,
3692+
keep_column_order: bool = False,
36913693
ast_stmt: proto.Expr = None,
36923694
_emit_ast: bool = True,
36933695
) -> "DataFrame":
@@ -3730,6 +3732,7 @@ def with_column(
37303732
Args:
37313733
col_name: The name of the column to add or replace.
37323734
col: The :class:`Column` or :class:`table_function.TableFunctionCall` with single column output to add or replace.
3735+
keep_column_order: If ``True``, the original order of the columns in the DataFrame is preserved when reaplacing a column.
37333736
"""
37343737
if ast_stmt is None and _emit_ast:
37353738
ast_stmt = self._session._ast_batch.assign()
@@ -3738,7 +3741,13 @@ def with_column(
37383741
build_expr_from_snowpark_column_or_table_fn(expr.col, col)
37393742
self._set_ast_ref(expr.df)
37403743

3741-
df = self.with_columns([col_name], [col], _ast_stmt=ast_stmt, _emit_ast=False)
3744+
df = self.with_columns(
3745+
[col_name],
3746+
[col],
3747+
keep_column_order=keep_column_order,
3748+
_ast_stmt=ast_stmt,
3749+
_emit_ast=False,
3750+
)
37423751

37433752
if _emit_ast:
37443753
df._ast_id = ast_stmt.var_id.bitfield1
@@ -3751,6 +3760,8 @@ def with_columns(
37513760
self,
37523761
col_names: List[str],
37533762
values: List[Union[Column, TableFunctionCall]],
3763+
*,
3764+
keep_column_order: bool = False,
37543765
_ast_stmt: proto.Expr = None,
37553766
_emit_ast: bool = True,
37563767
) -> "DataFrame":
@@ -3797,6 +3808,7 @@ def with_columns(
37973808
col_names: A list of the names of the columns to add or replace.
37983809
values: A list of the :class:`Column` objects or :class:`table_function.TableFunctionCall` object
37993810
to add or replace.
3811+
keep_column_order: If ``True``, the original order of the columns in the DataFrame is preserved when reaplacing a column.
38003812
"""
38013813
# Get a list of the new columns and their dedupped values
38023814
qualified_names = [quote_name(n) for n in col_names]
@@ -3837,14 +3849,7 @@ def with_columns(
38373849
names = col_names[i : i + offset + 1]
38383850
new_cols.append(col.as_(*names))
38393851

3840-
# Get a list of existing column names that are not being replaced
3841-
old_cols = [
3842-
Column(field)
3843-
for field in self._output
3844-
if field.name not in new_column_names
3845-
]
3846-
3847-
# AST.
3852+
# AST
38483853
if _ast_stmt is None and _emit_ast:
38493854
_ast_stmt = self._session._ast_batch.assign()
38503855
expr = with_src_position(
@@ -3856,8 +3861,41 @@ def with_columns(
38563861
build_expr_from_snowpark_column_or_table_fn(expr.values.add(), value)
38573862
self._set_ast_ref(expr.df)
38583863

3859-
# Put it all together
3860-
df = self.select([*old_cols, *new_cols], _ast_stmt=_ast_stmt, _emit_ast=False)
3864+
# If there's a table function call or keep_column_order=False,
3865+
# we do the original "remove old columns and append new ones" logic.
3866+
if num_table_func_calls > 0 or not keep_column_order:
3867+
old_cols = [
3868+
Column(field)
3869+
for field in self._output
3870+
if field.name not in new_column_names
3871+
]
3872+
final_cols = [*old_cols, *new_cols]
3873+
else:
3874+
# keep_column_order=True and no table function calls
3875+
# Re-insert replaced columns in their original positions if they exist
3876+
replaced_map = {
3877+
name: new_col for name, new_col in zip(qualified_names, new_cols)
3878+
}
3879+
final_cols = []
3880+
used = set() # track which new cols we've inserted
3881+
3882+
for field in self._output:
3883+
field_quoted = quote_name(field.name)
3884+
# If this old column name is being replaced, insert the new col at the same position
3885+
if field_quoted in replaced_map:
3886+
final_cols.append(replaced_map[field_quoted])
3887+
used.add(field_quoted)
3888+
else:
3889+
# keep the original col
3890+
final_cols.append(Column(field))
3891+
3892+
# For any new columns that didn't exist in the old schema, append them at the end
3893+
for name, c in replaced_map.items():
3894+
if name not in used:
3895+
final_cols.append(c)
3896+
3897+
# Construct the final DataFrame
3898+
df = self.select(final_cols, _ast_stmt=_ast_stmt, _emit_ast=False)
38613899

38623900
if _emit_ast:
38633901
df._ast_id = _ast_stmt.var_id.bitfield1

src/snowflake/snowpark/functions.py

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@
224224
StoredProcedureRegistration,
225225
)
226226
from snowflake.snowpark.types import (
227+
ArrayType,
227228
DataType,
228229
FloatType,
229230
PandasDataFrameType,
@@ -3561,20 +3562,67 @@ def _concat_ws_ignore_nulls(sep: str, *cols: ColumnOrName) -> Column:
35613562
|Hello |
35623563
-----------------------------------------------------
35633564
<BLANKLINE>
3565+
3566+
>>> df = session.create_dataframe([
3567+
... (['Hello', 'World', None], None, '!'),
3568+
... (['Hi', 'World', "."], "I'm Dad", '.'),
3569+
... ], schema=['a', 'b', 'c'])
3570+
>>> df.select(_concat_ws_ignore_nulls(", ", "a", "b", "c")).show()
3571+
-----------------------------------------------------
3572+
|"CONCAT_WS_IGNORE_NULLS(', ', ""A"",""B"",""C"")" |
3573+
-----------------------------------------------------
3574+
|Hello, World, ! |
3575+
|Hi, World, ., I'm Dad, . |
3576+
-----------------------------------------------------
3577+
<BLANKLINE>
35643578
"""
35653579
# TODO: SNOW-1831917 create ast
35663580
columns = [_to_col_if_str(c, "_concat_ws_ignore_nulls") for c in cols]
35673581
names = ",".join([c.get_name() for c in columns])
35683582

3569-
input_column_array = array_construct_compact(*columns, _emit_ast=False)
3570-
reduced_result = builtin("reduce", _emit_ast=False)(
3571-
input_column_array,
3572-
lit("", _emit_ast=False),
3573-
sql_expr(f"(l, r) -> l || '{sep}' || r"),
3574-
)
3575-
return substring(reduced_result, len(sep) + 1, _emit_ast=False).alias(
3576-
f"CONCAT_WS_IGNORE_NULLS('{sep}', {names})", _emit_ast=False
3577-
)
3583+
# The implementation of this function is as follows with example input of
3584+
# sep = "," and row = [a, NULL], b, NULL, c:
3585+
# 1. Cast all columns to array.
3586+
# [a, NULL], [b], NULL, [c]
3587+
# 2. Combine all arrays into a array of arrays after removing nulls (array_construct_compact).
3588+
# [[a, NULL], [b], [c]]
3589+
# 3. Flatten the array of arrays into a single array (array_flatten).
3590+
# [a, NULL, b, c]
3591+
# 4. Filter out nulls (array_remove_nulls).
3592+
# [a, b, c]
3593+
# 5. Concatenate the non-null values into a single string (concat_strings_with_sep).
3594+
# "a,b,c"
3595+
3596+
def array_remove_nulls(col: Column) -> Column:
3597+
"""Expects an array and returns an array with nulls removed."""
3598+
return builtin("filter", _emit_ast=False)(
3599+
col, sql_expr("x -> NOT IS_NULL_VALUE(x)", _emit_ast=False)
3600+
)
3601+
3602+
def concat_strings_with_sep(col: Column) -> Column:
3603+
"""
3604+
Expects an array of strings and returns a single string
3605+
with the values concatenated with the separator.
3606+
"""
3607+
return substring(
3608+
builtin("reduce", _emit_ast=False)(
3609+
col, lit(""), sql_expr(f"(l, r) -> l || '{sep}' || r", _emit_ast=False)
3610+
),
3611+
len(sep) + 1,
3612+
_emit_ast=False,
3613+
)
3614+
3615+
return concat_strings_with_sep(
3616+
array_remove_nulls(
3617+
array_flatten(
3618+
array_construct_compact(
3619+
*[c.cast(ArrayType(), _emit_ast=False) for c in columns],
3620+
_emit_ast=False,
3621+
),
3622+
_emit_ast=False,
3623+
)
3624+
)
3625+
).alias(f"CONCAT_WS_IGNORE_NULLS('{sep}', {names})", _emit_ast=False)
35783626

35793627

35803628
@publicapi
@@ -3828,6 +3876,19 @@ def date_format(
38283876
|2022/05/15 10:45:00 |
38293877
-----------------------
38303878
<BLANKLINE>
3879+
3880+
Example::
3881+
>>> df = session.sql("select '2023-10-10'::DATE as date_col, '2023-10-10 15:30:00'::TIMESTAMP as timestamp_col")
3882+
>>> df.select(
3883+
... date_format('date_col', 'YYYY/MM/DD').as_('formatted_dt'),
3884+
... date_format('timestamp_col', 'YYYY/MM/DD HH:mi:ss').as_('formatted_ts')
3885+
... ).show()
3886+
----------------------------------------
3887+
|"FORMATTED_DT" |"FORMATTED_TS" |
3888+
----------------------------------------
3889+
|2023/10/10 |2023/10/10 15:30:00 |
3890+
----------------------------------------
3891+
<BLANKLINE>
38313892
"""
38323893

38333894
# AST.
@@ -3836,7 +3897,11 @@ def date_format(
38363897
ast = proto.Expr()
38373898
build_builtin_fn_apply(ast, "date_format", c, fmt)
38383899

3839-
ans = to_char(try_cast(c, TimestampType(), _emit_ast=False), fmt, _emit_ast=False)
3900+
ans = to_char(
3901+
try_cast(to_char(c, _emit_ast=False), TimestampType(), _emit_ast=False),
3902+
fmt,
3903+
_emit_ast=False,
3904+
)
38403905
ans._ast = ast
38413906
return ans
38423907

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3979,6 +3979,7 @@ def groupby_apply(
39793979
agg_args: Any,
39803980
agg_kwargs: dict[str, Any],
39813981
series_groupby: bool,
3982+
include_groups: bool,
39823983
force_single_group: bool = False,
39833984
force_list_like_to_series: bool = False,
39843985
) -> "SnowflakeQueryCompiler":
@@ -4001,6 +4002,9 @@ def groupby_apply(
40014002
Keyword arguments to pass to agg_func when applying it to each group.
40024003
series_groupby:
40034004
Whether we are performing a SeriesGroupBy.apply() instead of a DataFrameGroupBy.apply()
4005+
include_groups:
4006+
When True, will include grouping keys when calling func in the case that
4007+
they are columns of the DataFrame.
40044008
force_single_group:
40054009
Force single group (empty set of group by labels) useful for DataFrame.apply() with axis=0
40064010
force_list_like_to_series:
@@ -4019,14 +4023,6 @@ def groupby_apply(
40194023
+ f"level={level}, and axis={axis}"
40204024
)
40214025

4022-
if "include_groups" in agg_kwargs:
4023-
# exclude "include_groups" from the apply function kwargs
4024-
include_groups = agg_kwargs.pop("include_groups")
4025-
if not include_groups:
4026-
ErrorMessage.not_implemented(
4027-
f"No support for groupby.apply with include_groups = {include_groups}"
4028-
)
4029-
40304026
sort = groupby_kwargs.get("sort", True)
40314027
as_index = groupby_kwargs.get("as_index", True)
40324028
dropna = groupby_kwargs.get("dropna", True)
@@ -4051,17 +4047,36 @@ def groupby_apply(
40514047
)
40524048

40534049
snowflake_type_map = self._modin_frame.quoted_identifier_to_snowflake_type()
4054-
4055-
# For DataFrameGroupBy, `func` operates on this frame in its entirety.
4056-
# For SeriesGroupBy, this frame may also include some grouping columns
4057-
# that `func` should not take as input. In that case, the only column
4058-
# that `func` takes as input is the last data column, so grab just that
4059-
# column with a slice starting at index -1 and ending at None.
4060-
input_data_column_identifiers = (
4061-
self._modin_frame.data_column_snowflake_quoted_identifiers[
4062-
slice(-1, None) if series_groupby else slice(None)
4063-
]
4064-
)
4050+
input_data_column_positions = [
4051+
i
4052+
for i, identifier in enumerate(
4053+
self._modin_frame.data_column_snowflake_quoted_identifiers
4054+
)
4055+
if (
4056+
(
4057+
# For SeriesGroupBy, this frame may also include some
4058+
# grouping columns that `func` should not take as input. In
4059+
# that case, the only column that `func` takes as input is
4060+
# the last data column, so take just that column.
4061+
# include_groups has no effect.
4062+
i
4063+
== len(self._modin_frame.data_column_snowflake_quoted_identifiers)
4064+
- 1
4065+
)
4066+
if series_groupby
4067+
else (
4068+
# For DataFrameGroupBy, if include_groups, we apply the
4069+
# function to all data columns. Otherwise, we exclude
4070+
# data columns that we are grouping by.
4071+
include_groups
4072+
or identifier not in by_snowflake_quoted_identifiers_list
4073+
)
4074+
)
4075+
]
4076+
input_data_column_identifiers = [
4077+
self._modin_frame.data_column_snowflake_quoted_identifiers[i]
4078+
for i in input_data_column_positions
4079+
]
40654080

40664081
# TODO(SNOW-1210489): When type hints show that `agg_func` returns a
40674082
# scalar, we can use a vUDF instead of a vUDTF and we can skip the
@@ -4070,7 +4085,9 @@ def groupby_apply(
40704085
agg_func,
40714086
agg_args,
40724087
agg_kwargs,
4073-
data_column_index=self._modin_frame.data_columns_index,
4088+
data_column_index=self._modin_frame.data_columns_index[
4089+
input_data_column_positions
4090+
],
40744091
index_column_names=self._modin_frame.index_column_pandas_labels,
40754092
input_data_column_types=[
40764093
snowflake_type_map[quoted_identifier]
@@ -8511,6 +8528,7 @@ def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no
85118528
series_groupby=True,
85128529
force_single_group=True,
85138530
force_list_like_to_series=True,
8531+
include_groups=True,
85148532
)
85158533

85168534
data_col_result_frame = data_col_qc._modin_frame

src/snowflake/snowpark/modin/plugin/docstrings/groupby.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,9 @@ def apply():
10781078
A callable that takes a dataframe or series as its first argument, and
10791079
returns a dataframe, a series or a scalar. In addition the
10801080
callable may take positional and keyword arguments.
1081+
include_groups : bool, default True
1082+
When True, will apply ``func`` to the groups in the case that they
1083+
are columns of the DataFrame.
10811084
args, kwargs : tuple and dict
10821085
Optional positional and keyword arguments to pass to ``func``.
10831086

0 commit comments

Comments
 (0)