Skip to content

Commit 678a748

Browse files
authored
Fix CTE join fix with df.alias (#3999)
1 parent 68fd097 commit 678a748

File tree

5 files changed

+78
-20
lines changed

5 files changed

+78
-20
lines changed

CHANGELOG.md

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@
3535
- `try_hex_decode_string`
3636
- `unicode`
3737
- `uuid_string`
38-
38+
3939
- Conditional expressions:
4040
- `booland_agg`
4141
- `boolxor_agg`
4242
- `regr_valy`
4343
- `zeroifnull`
44-
44+
4545
- Numeric expressions:
4646
- `cot`
4747
- `mod`
@@ -60,6 +60,7 @@
6060

6161
#### Bug Fixes
6262

63+
- Fixed with a bug when sql generation when joining two `DataFrame`s created using `DataFrame.alias` and CTE optimization is enabled.
6364
- Fixed a bug in `XMLReader` where finding the start position of a row tag could return an incorrect file position.
6465

6566
### Snowpark pandas API Updates
@@ -122,13 +123,13 @@
122123
- `str.pad`
123124
- `str.len`
124125
- `str.ljust`
125-
- `str.rjust`
126-
- `str.split`
127-
- `str.replace`
128-
- `str.strip`
129-
- `str.lstrip`
130-
- `str.rstrip`
131-
- `str.translate`
126+
- `str.rjust`
127+
- `str.split`
128+
- `str.replace`
129+
- `str.strip`
130+
- `str.lstrip`
131+
- `str.rstrip`
132+
- `str.translate`
132133
- `dt.tz_localize`
133134
- `dt.tz_convert`
134135
- `dt.ceil`
@@ -137,11 +138,11 @@
137138
- `dt.normalize`
138139
- `dt.month_name`
139140
- `dt.day_name`
140-
- `dt.strftime`
141-
- `dt.dayofweek`
142-
- `dt.weekday`
143-
- `dt.dayofyear`
144-
- `dt.isocalendar`
141+
- `dt.strftime`
142+
- `dt.dayofweek`
143+
- `dt.weekday`
144+
- `dt.dayofyear`
145+
- `dt.isocalendar`
145146
- `rolling.min`
146147
- `rolling.max`
147148
- `rolling.count`

src/snowflake/snowpark/_internal/analyzer/analyzer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,8 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan:
972972

973973
for c in logical_plan.children: # post-order traversal of the tree
974974
resolved = self.resolve(c)
975-
df_aliased_col_name_to_real_col_name.update(resolved.df_aliased_col_name_to_real_col_name) # type: ignore
975+
for alias, dict_ in resolved.df_aliased_col_name_to_real_col_name.items():
976+
df_aliased_col_name_to_real_col_name[alias].update(dict_)
976977
resolved_children[c] = resolved
977978

978979
if isinstance(logical_plan, Selectable):
@@ -1004,9 +1005,8 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan:
10041005
res = self.do_resolve_with_resolved_children(
10051006
logical_plan, resolved_children, df_aliased_col_name_to_real_col_name
10061007
)
1007-
res.df_aliased_col_name_to_real_col_name.update(
1008-
df_aliased_col_name_to_real_col_name
1009-
)
1008+
for alias, dict_ in df_aliased_col_name_to_real_col_name.items():
1009+
res.df_aliased_col_name_to_real_col_name[alias].update(dict_)
10101010
return res
10111011

10121012
def do_resolve_with_resolved_children(

src/snowflake/snowpark/_internal/analyzer/select_statement.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ def __init__(
880880
self._projection_in_str = None
881881
self._query_params = None
882882
self.expr_to_alias.update(self.from_.expr_to_alias)
883-
self.df_aliased_col_name_to_real_col_name.update(
883+
self.df_aliased_col_name_to_real_col_name = deepcopy(
884884
self.from_.df_aliased_col_name_to_real_col_name
885885
)
886886
self.api_calls = (

src/snowflake/snowpark/_internal/compiler/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,10 @@ def update_resolvable_node(
252252
# df_aliased_col_name_to_real_col_name is updated at the frontend api
253253
# layer when alias is called, not produced during code generation. Should
254254
# always retain the original value of the map.
255-
node.df_aliased_col_name_to_real_col_name.update(
255+
node.df_aliased_col_name_to_real_col_name = copy.deepcopy(
256256
node.from_.df_aliased_col_name_to_real_col_name
257257
)
258+
258259
# projection_in_str for SelectStatement runs a analyzer.analyze() which
259260
# needs the correct expr_to_alias map setup. This map is setup during
260261
# snowflake plan generation and cached for later use. Calling snowflake_plan

tests/integ/test_cte.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
when_matched,
2626
to_timestamp,
2727
)
28+
from snowflake.snowpark.types import (
29+
StructType,
30+
StructField,
31+
IntegerType,
32+
StringType,
33+
TimestampType,
34+
)
2835
from tests.integ.scala.test_dataframe_reader_suite import get_reader
2936
from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker
3037
from tests.utils import IS_IN_STORED_PROC_LOCALFS, TestFiles, Utils
@@ -272,6 +279,55 @@ def test_join_with_alias_dataframe(session):
272279
assert last_query.count(WITH) == 1
273280

274281

282+
def test_join_with_alias_dataframe_2(session):
283+
# Reproduced from issue SNOW-2257191
284+
schema1 = StructType(
285+
[
286+
StructField("DST_Year", IntegerType(), True),
287+
StructField("DST_Start", TimestampType(), True),
288+
StructField("DST_End", TimestampType(), True),
289+
]
290+
)
291+
292+
schema2 = StructType(
293+
[
294+
StructField("MATTRANSID", StringType(), True),
295+
StructField("LOADSTARTTIME", TimestampType(), True),
296+
StructField("LOADENDTIME", TimestampType(), True),
297+
StructField("DUMPENDTIME", TimestampType(), True),
298+
StructField("__CURRENT", StringType(), True),
299+
StructField("__DELETED", StringType(), True),
300+
]
301+
)
302+
303+
schema3 = StructType(
304+
[
305+
StructField("MATTRANSID", StringType(), True),
306+
StructField("DUMPENDTIME", TimestampType(), True),
307+
StructField("LOADENDTIME", TimestampType(), True),
308+
StructField("__CURRENT", StringType(), True),
309+
StructField("__DELETED", StringType(), True),
310+
]
311+
)
312+
313+
df1 = session.create_dataframe([], schema=schema1).cache_result()
314+
df2 = session.create_dataframe([], schema=schema2).cache_result()
315+
df3 = session.create_dataframe([], schema=schema3).cache_result()
316+
317+
df4 = df2.alias("d2").join(
318+
df1, col("d2", "LoadStartTime").between(df1.DST_Start, df1.DST_End), "left"
319+
)
320+
321+
df5 = df3.alias("d3").join(
322+
df1, col("d3", "LoadEndTime").between(df1.DST_Start, df1.DST_End), "left"
323+
)
324+
325+
df6 = df5.join(df4, (df5.MatTransId == df4.MatTransId), "left")
326+
327+
# Assert that the generated sql compiles
328+
df6.collect()
329+
330+
275331
def test_join_with_set_operation(session):
276332
df1 = session.create_dataframe([[1, 2, 3], [4, 5, 6]], "a: int, b: int, c: int")
277333
df2 = session.create_dataframe([[1, 1], [4, 5]], "a: int, b: int")

0 commit comments

Comments
 (0)