Skip to content

Commit b4d25e3

Browse files
authored
Merge branch 'main' into aalam-SNOW-1882151-revert-old-lit-behavior
2 parents 5c05366 + 1ed17cc commit b4d25e3

File tree

21 files changed

+432
-293
lines changed

21 files changed

+432
-293
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@
8080
- %%: A literal '%' character.
8181
- Added support for `Series.between`.
8282
- Added support for `include_groups=False` in `DataFrameGroupBy.apply`.
83+
- Added support for `expand=True` in `Series.str.split`.
84+
- Added support for `DataFrame.pop` and `Series.pop`.
8385

8486
#### Bug Fixes
8587

docs/source/modin/supported/dataframe_supported.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ Methods
311311
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
312312
| ``plot`` | D | | Performed locally on the client |
313313
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
314-
| ``pop`` | N | | |
314+
| ``pop`` | Y | | |
315315
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
316316
| ``pow`` | P | ``level`` | |
317317
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+

docs/source/modin/supported/series_str_supported.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ the method in the left column.
119119
| ``slice_replace`` | N | |
120120
+-----------------------------+---------------------------------+----------------------------------------------------+
121121
| ``split`` | P | ``N`` if `pat` is non-string, `n` is non-numeric, |
122-
| | | `expand` is set, or `regex` is set. |
122+
| | | or `regex` is set. |
123123
+-----------------------------+---------------------------------+----------------------------------------------------+
124124
| ``startswith`` | P | ``N`` if the `na` parameter is set to a non-bool |
125125
| | | value. |

docs/source/modin/supported/series_supported.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ Methods
306306
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
307307
| ``plot`` | D | | Performed locally on the client |
308308
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
309-
| ``pop`` | N | | |
309+
| ``pop`` | Y | | |
310310
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
311311
| ``pow`` | P | ``level`` | |
312312
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+

src/snowflake/snowpark/_internal/analyzer/analyzer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from collections import Counter, defaultdict
77
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union
88

9+
from snowflake.connector import IntegrityError
10+
911
import snowflake.snowpark
1012
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
1113
alias_expression,
@@ -975,6 +977,8 @@ def do_resolve_with_resolved_children(
975977

976978
if logical_plan.data:
977979
if not logical_plan.is_large_local_data:
980+
if logical_plan.is_contain_illegal_null_value:
981+
raise IntegrityError("NULL result in a non-nullable column")
978982
return self.plan_builder.query(
979983
values_statement(logical_plan.output, logical_plan.data),
980984
logical_plan,

src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,20 @@ def is_large_local_data(self) -> bool:
158158

159159
return len(self.data) * len(self.output) >= ARRAY_BIND_THRESHOLD
160160

161+
@property
162+
def is_contain_illegal_null_value(self) -> bool:
163+
from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD
164+
165+
rows_to_compare = min(
166+
ARRAY_BIND_THRESHOLD // len(self.output) + 1, len(self.data)
167+
)
168+
for j in range(len(self.output)):
169+
if not self.output[j].nullable:
170+
for i in range(rows_to_compare):
171+
if self.data[i][j] is None:
172+
return True
173+
return False
174+
161175
@property
162176
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
163177
if self.is_large_local_data:

src/snowflake/snowpark/functions.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -758,17 +758,31 @@ def convert_timezone(
758758
)
759759
target_tz = _to_col_if_str(target_timezone, "convert_timezone")
760760
source_time_to_convert = _to_col_if_str(source_time, "convert_timezone")
761-
761+
# Build AST here to prevent rearrangement of args in the encoded AST.
762+
ast = (
763+
build_function_expr(
764+
"convert_timezone",
765+
[target_timezone, source_time, source_timezone],
766+
ignore_null_args=True,
767+
)
768+
if _emit_ast
769+
else None
770+
)
762771
if source_timezone is None:
763772
return call_builtin(
764-
"convert_timezone", target_tz, source_time_to_convert, _emit_ast=_emit_ast
773+
"convert_timezone",
774+
target_tz,
775+
source_time_to_convert,
776+
_ast=ast,
777+
_emit_ast=False,
765778
)
766779
return call_builtin(
767780
"convert_timezone",
768781
source_tz,
769782
target_tz,
770783
source_time_to_convert,
771-
_emit_ast=_emit_ast,
784+
_ast=ast,
785+
_emit_ast=False,
772786
)
773787

774788

@@ -894,7 +908,7 @@ def count_distinct(*cols: ColumnOrName, _emit_ast: bool = True) -> Column:
894908
return Column(
895909
FunctionExpression("count", [c._expression for c in cs], is_distinct=True),
896910
_ast=ast,
897-
_emit_ast=_emit_ast,
911+
_emit_ast=False,
898912
)
899913

900914

@@ -3435,7 +3449,7 @@ def charindex(
34353449
s = _to_col_if_str(source_expr, "charindex")
34363450
# Build AST here to prevent `position` from being recorded as a literal instead of int/None.
34373451
ast = (
3438-
build_function_expr("char_index", [t, s, position], ignore_null_args=True)
3452+
build_function_expr("charindex", [t, s, position], ignore_null_args=True)
34393453
if _emit_ast
34403454
else None
34413455
)
@@ -4336,7 +4350,12 @@ def next_day(
43364350
[Row(NEXT_DAY("A", 'FR')=datetime.date(2020, 8, 7)), Row(NEXT_DAY("A", 'FR')=datetime.date(2020, 12, 4))]
43374351
"""
43384352
c = _to_col_if_str(date, "next_day")
4339-
return builtin("next_day", _emit_ast=_emit_ast)(c, Column._to_expr(day_of_week))
4353+
# Build AST here to prevent `date` from being recorded as a Column instead of a literal and
4354+
# `day_of_week` from being recorded as a literal instead of Column.
4355+
ast = build_function_expr("next_day", [date, day_of_week]) if _emit_ast else None
4356+
return builtin("next_day", _ast=ast, _emit_ast=False)(
4357+
c, Column._to_expr(day_of_week)
4358+
)
43404359

43414360

43424361
@publicapi
@@ -4359,7 +4378,14 @@ def previous_day(
43594378
[Row(PREVIOUS_DAY("A", 'FR')=datetime.date(2020, 7, 31)), Row(PREVIOUS_DAY("A", 'FR')=datetime.date(2020, 11, 27))]
43604379
"""
43614380
c = _to_col_if_str(date, "previous_day")
4362-
return builtin("previous_day", _emit_ast=_emit_ast)(c, Column._to_expr(day_of_week))
4381+
# Build AST here to prevent `date` from being recorded as a Column instead of a literal and
4382+
# `day_of_week` from being recorded as a literal instead of Column.
4383+
ast = (
4384+
build_function_expr("previous_day", [date, day_of_week]) if _emit_ast else None
4385+
)
4386+
return builtin("previous_day", _ast=ast, _emit_ast=False)(
4387+
c, Column._to_expr(day_of_week)
4388+
)
43634389

43644390

43654391
@publicapi

src/snowflake/snowpark/mock/_udtf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _do_register_udtf(
7878
ast = with_src_position(stmt.expr.udtf, stmt)
7979
ast_id = stmt.var_id.bitfield1
8080

81-
object_name = kwargs["_registrated_object_name"]
81+
object_name = kwargs["_registered_object_name"]
8282
udtf = MockUserDefinedTableFunction(
8383
handler,
8484
output_schema,

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
dense_rank,
108108
first_value,
109109
floor,
110+
get,
110111
greatest,
111112
hour,
112113
iff,
@@ -16813,10 +16814,6 @@ def str_split(
1681316814
ErrorMessage.not_implemented(
1681416815
"Snowpark pandas doesn't support non-str 'pat' argument"
1681516816
)
16816-
if expand:
16817-
ErrorMessage.not_implemented(
16818-
"Snowpark pandas doesn't support 'expand' argument"
16819-
)
1682016817
if regex:
1682116818
ErrorMessage.not_implemented(
1682216819
"Snowpark pandas doesn't support 'regex' argument"
@@ -16864,6 +16861,12 @@ def output_col(
1686416861
if np.isnan(n):
1686516862
# Follow pandas behavior
1686616863
return pandas_lit(np.nan)
16864+
elif n < -1 and not pandas.isnull(pat) and len(str(pat)) > 1:
16865+
# Follow pandas behavior, which based on our experiments, leaves the input column as is
16866+
# whenever the above condition is satisfied.
16867+
new_col = iff(
16868+
column.is_null(), pandas_lit(None), array_construct(column)
16869+
)
1686716870
elif n <= 0:
1686816871
# If all possible splits are requested, we just use SQL's split function.
1686916872
new_col = builtin("split")(new_col, pandas_lit(new_pat))
@@ -16907,9 +16910,93 @@ def output_col(
1690716910
)
1690816911
return self._replace_non_str(column, new_col)
1690916912

16910-
new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns(
16911-
lambda col_name: output_col(col_name, pat, n)
16912-
)
16913+
def output_cols(
16914+
column: SnowparkColumn, pat: Optional[str], n: int, max_splits: int
16915+
) -> list[SnowparkColumn]:
16916+
"""
16917+
Returns the list of columns that the input column will be split into.
16918+
This is only used when expand=True.
16919+
Args:
16920+
column : SnowparkColumn
16921+
Input column
16922+
pat : str
16923+
String to split on
16924+
n : int
16925+
Limit on the number of output splits
16926+
max_splits : int
16927+
Maximum number of achievable splits across all values in the input column.
16928+
This is needed to be able to pad rows with fewer splits than desired with nulls.
16929+
"""
16930+
col = output_col(column, pat, n)
16931+
final_splits = 0
16932+
16933+
if np.isnan(n):
16934+
# Follow pandas behavior
16935+
final_splits = 1
16936+
elif n <= 0:
16937+
final_splits = max_splits
16938+
else:
16939+
final_splits = min(n + 1, max_splits)
16940+
16941+
if n < -1 and not pandas.isnull(pat) and len(str(pat)) > 1:
16942+
# Follow pandas behavior, which based on our experiments, leaves the input column as is
16943+
# whenever the above condition is satisfied.
16944+
final_splits = 1
16945+
16946+
return [
16947+
iff(
16948+
array_size(col) > pandas_lit(i),
16949+
get(col, pandas_lit(i)),
16950+
pandas_lit(None),
16951+
)
16952+
for i in range(final_splits)
16953+
]
16954+
16955+
def get_max_splits() -> int:
16956+
"""
16957+
Returns the maximum number of splits achievable
16958+
across all values stored in the input column.
16959+
"""
16960+
splits_as_list_frame = self.str_split(
16961+
pat=pat,
16962+
n=-1,
16963+
expand=False,
16964+
regex=regex,
16965+
)._modin_frame
16966+
16967+
split_counts_frame = splits_as_list_frame.append_column(
16968+
"split_counts",
16969+
array_size(
16970+
col(
16971+
splits_as_list_frame.data_column_snowflake_quoted_identifiers[0]
16972+
)
16973+
),
16974+
)
16975+
16976+
max_count_rows = split_counts_frame.ordered_dataframe.agg(
16977+
max_(
16978+
col(split_counts_frame.data_column_snowflake_quoted_identifiers[-1])
16979+
).as_("max_count")
16980+
).collect()
16981+
16982+
return max_count_rows[0][0]
16983+
16984+
if expand:
16985+
cols = output_cols(
16986+
col(self._modin_frame.data_column_snowflake_quoted_identifiers[0]),
16987+
pat,
16988+
n,
16989+
get_max_splits(),
16990+
)
16991+
new_internal_frame = self._modin_frame.project_columns(
16992+
list(range(len(cols))),
16993+
cols,
16994+
)
16995+
else:
16996+
new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns(
16997+
lambda col_name: output_col(col_name, pat, n)
16998+
)
16999+
1691317000
return SnowflakeQueryCompiler(new_internal_frame)
1691417001

1691517002
def str_rsplit(

src/snowflake/snowpark/modin/plugin/docstrings/base.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2346,6 +2346,43 @@ def pipe():
23462346
def pop():
23472347
"""
23482348
Return item and drop from frame. Raise KeyError if not found.
2349+
2350+
Parameters
2351+
----------
2352+
item : label
2353+
Label of column to be popped.
2354+
2355+
Returns
2356+
-------
2357+
Series
2358+
2359+
Examples
2360+
--------
2361+
>>> df = pd.DataFrame([('falcon', 'bird', 389.0),
2362+
... ('parrot', 'bird', 24.0),
2363+
... ('lion', 'mammal', 80.5),
2364+
... ('monkey', 'mammal', np.nan)],
2365+
... columns=('name', 'class', 'max_speed'))
2366+
>>> df
2367+
name class max_speed
2368+
0 falcon bird 389.0
2369+
1 parrot bird 24.0
2370+
2 lion mammal 80.5
2371+
3 monkey mammal NaN
2372+
2373+
>>> df.pop('class')
2374+
0 bird
2375+
1 bird
2376+
2 mammal
2377+
3 mammal
2378+
Name: class, dtype: object
2379+
2380+
>>> df
2381+
name max_speed
2382+
0 falcon 389.0
2383+
1 parrot 24.0
2384+
2 lion 80.5
2385+
3 monkey NaN
23492386
"""
23502387

23512388
def pow():

0 commit comments

Comments
 (0)