Skip to content

Commit 0ee1c34

Browse files
SNOW-1857250: More fixes to AST generation in functions.py (#2793)
1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1857250 2. Fill out the following pre-review checklist: - [ ] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [x] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development) 3. Please describe how your code solves the related issue. Fixed the AST generation where columns where recorded as literals instead of column parameters and parameters were being rearranged in some cases. This change gets rid of a lot of unnecessary `col()` in the unparsed code; this helps match the unparser code with the original code.
1 parent 33a176d commit 0ee1c34

File tree

3 files changed

+105
-250
lines changed

3 files changed

+105
-250
lines changed

src/snowflake/snowpark/functions.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -753,17 +753,31 @@ def convert_timezone(
753753
)
754754
target_tz = _to_col_if_str(target_timezone, "convert_timezone")
755755
source_time_to_convert = _to_col_if_str(source_time, "convert_timezone")
756-
756+
# Build AST here to prevent rearrangement of args in the encoded AST.
757+
ast = (
758+
build_function_expr(
759+
"convert_timezone",
760+
[target_timezone, source_time, source_timezone],
761+
ignore_null_args=True,
762+
)
763+
if _emit_ast
764+
else None
765+
)
757766
if source_timezone is None:
758767
return call_builtin(
759-
"convert_timezone", target_tz, source_time_to_convert, _emit_ast=_emit_ast
768+
"convert_timezone",
769+
target_tz,
770+
source_time_to_convert,
771+
_ast=ast,
772+
_emit_ast=False,
760773
)
761774
return call_builtin(
762775
"convert_timezone",
763776
source_tz,
764777
target_tz,
765778
source_time_to_convert,
766-
_emit_ast=_emit_ast,
779+
_ast=ast,
780+
_emit_ast=False,
767781
)
768782

769783

@@ -889,7 +903,7 @@ def count_distinct(*cols: ColumnOrName, _emit_ast: bool = True) -> Column:
889903
return Column(
890904
FunctionExpression("count", [c._expression for c in cs], is_distinct=True),
891905
_ast=ast,
892-
_emit_ast=_emit_ast,
906+
_emit_ast=False,
893907
)
894908

895909

@@ -4331,7 +4345,12 @@ def next_day(
43314345
[Row(NEXT_DAY("A", 'FR')=datetime.date(2020, 8, 7)), Row(NEXT_DAY("A", 'FR')=datetime.date(2020, 12, 4))]
43324346
"""
43334347
c = _to_col_if_str(date, "next_day")
4334-
return builtin("next_day", _emit_ast=_emit_ast)(c, Column._to_expr(day_of_week))
4348+
# Build AST here to prevent `date` from being recorded as a Column instead of a literal and
4349+
# `day_of_week` from being recorded as a literal instead of Column.
4350+
ast = build_function_expr("next_day", [date, day_of_week]) if _emit_ast else None
4351+
return builtin("next_day", _ast=ast, _emit_ast=False)(
4352+
c, Column._to_expr(day_of_week)
4353+
)
43354354

43364355

43374356
@publicapi
@@ -4354,7 +4373,14 @@ def previous_day(
43544373
[Row(PREVIOUS_DAY("A", 'FR')=datetime.date(2020, 7, 31)), Row(PREVIOUS_DAY("A", 'FR')=datetime.date(2020, 11, 27))]
43554374
"""
43564375
c = _to_col_if_str(date, "previous_day")
4357-
return builtin("previous_day", _emit_ast=_emit_ast)(c, Column._to_expr(day_of_week))
4376+
# Build AST here to prevent `date` from being recorded as a Column instead of a literal and
4377+
# `day_of_week` from being recorded as a literal instead of Column.
4378+
ast = (
4379+
build_function_expr("previous_day", [date, day_of_week]) if _emit_ast else None
4380+
)
4381+
return builtin("previous_day", _ast=ast, _emit_ast=False)(
4382+
c, Column._to_expr(day_of_week)
4383+
)
43584384

43594385

43604386
@publicapi

0 commit comments

Comments
 (0)