diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ad6937409..87ef503a5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ - Fixed a bug in `DataFrameReader.dbapi` (PrPr) where the `create_connection` defined as local function was incompatible with multiprocessing. - Fixed a bug in `DataFrameReader.dbapi` (PrPr) where databricks `TIMESTAMP` type was converted to Snowflake `TIMESTAMP_NTZ` type which should be `TIMESTAMP_LTZ` type. +- Fixed a bug that `DataFrame.create_or_replace_dynamic_table` raises error when the dataframe contains a UDTF and `SELECT *` in UDTF not being parsed correctly. #### Improvements diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 19fc854e8c..fba22456e0 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -7,7 +7,7 @@ import re import sys import uuid -from collections import defaultdict +from collections import defaultdict, deque from enum import Enum from functools import cached_property from typing import ( @@ -30,6 +30,7 @@ from snowflake.snowpark._internal.analyzer.table_function import ( GeneratorTableFunction, TableFunctionRelation, + TableFunctionJoin, ) if TYPE_CHECKING: @@ -1324,6 +1325,72 @@ def create_or_replace_view( source_plan, ) + def find_and_update_table_function_plan( + self, plan: SnowflakePlan + ) -> Optional[SnowflakePlan]: + """This function is meant to find any udtf function call from a create dynamic table plan and + replace '*' with explicit column identifier in the select of table function. Since we cannot + differentiate udtf call from other table functions, we apply this change to all table functions. + """ + from snowflake.snowpark._internal.analyzer.select_statement import ( + SelectTableFunction, + Selectable, + ) + from snowflake.snowpark._internal.compiler.utils import ( + create_query_generator, + update_resolvable_node, + replace_child_and_update_ancestors, + ) + + visited = set() + node_parents_map = defaultdict(set) + deepcopied_plan = copy.deepcopy(plan) + query_generator = create_query_generator(plan) + queue = deque() + + queue.append(deepcopied_plan) + visited.add(deepcopied_plan) + + while queue: + node = queue.popleft() + visited.add(node) + for child_node in reversed(node.children_plan_nodes): + node_parents_map[child_node].add(node) + if child_node not in visited: + queue.append(child_node) + + # the bug only happen when create dynamic table on top of a table function + # this is meant to decide whether the plan is select from a table function + if isinstance(node, SelectTableFunction) and isinstance( + node.snowflake_plan.source_plan, TableFunctionJoin + ): + table_function_join_node = node.snowflake_plan.source_plan + # if the plan has only 1 child and the source_plan.right_cols == '*', then we need to update the + # plan with the output column identifiers. + if len( + node.snowflake_plan.children_plan_nodes + ) == 1 and table_function_join_node.right_cols == ["*"]: + child_plan: Union[ + SnowflakePlan, Selectable + ] = node.snowflake_plan.children_plan_nodes[0] + if isinstance(child_plan, Selectable): + child_plan = child_plan.snowflake_plan + assert isinstance(child_plan, SnowflakePlan) + + new_plan = copy.deepcopy(node) + new_plan.snowflake_plan.source_plan.right_cols = ( # type: ignore + node.snowflake_plan.quoted_identifiers[ + len(child_plan.quoted_identifiers) : + ] + ) + + update_resolvable_node(new_plan, query_generator) + replace_child_and_update_ancestors( + node, new_plan, node_parents_map, query_generator + ) + + return deepcopied_plan + def create_or_replace_dynamic_table( self, name: str, @@ -1341,6 +1408,9 @@ def create_or_replace_dynamic_table( source_plan: Optional[LogicalPlan], iceberg_config: Optional[dict] = None, ) -> SnowflakePlan: + + child = self.find_and_update_table_function_plan(child) # type: ignore + if len(child.queries) != 1: raise SnowparkClientExceptionMessages.PLAN_CREATE_DYNAMIC_TABLE_FROM_DDL_DML_OPERATIONS() diff --git a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py index 139b3abed5..3ef64da4b1 100644 --- a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py +++ b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py @@ -55,8 +55,7 @@ extract_child_from_with_query_block, is_active_transaction, is_with_query_block, - replace_child, - update_resolvable_node, + replace_child_and_update_ancestors, ) from snowflake.snowpark._internal.utils import ( TempObjectType, @@ -590,15 +589,6 @@ def _replace_child_and_update_ancestors( ) temp_table_selectable.post_actions = [drop_table_query] - parents = self._parent_map[child] - for parent in parents: - replace_child(parent, child, temp_table_selectable, self._query_generator) - - nodes_to_reset = list(parents) - while nodes_to_reset: - node = nodes_to_reset.pop() - - update_resolvable_node(node, self._query_generator) - - parents = self._parent_map[node] - nodes_to_reset.extend(parents) + replace_child_and_update_ancestors( + child, temp_table_selectable, self._parent_map, self._query_generator + ) diff --git a/src/snowflake/snowpark/_internal/compiler/utils.py b/src/snowflake/snowpark/_internal/compiler/utils.py index 8a0b4de892..005ff1cc1b 100644 --- a/src/snowflake/snowpark/_internal/compiler/utils.py +++ b/src/snowflake/snowpark/_internal/compiler/utils.py @@ -4,7 +4,7 @@ # import copy import tempfile -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Set, Union from snowflake.snowpark._internal.analyzer.binary_plan_node import BinaryNode from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( @@ -391,6 +391,31 @@ def is_with_query_block(node: LogicalPlan) -> bool: return False +def replace_child_and_update_ancestors( + child: LogicalPlan, + new_child: LogicalPlan, + parent_map: Dict[LogicalPlan, Set[TreeNode]], + query_generator: QueryGenerator, +): + """ + For the given child, this helper function updates all its parents with the new + child provided and updates all the ancestor nodes. + """ + parents = parent_map[child] + + for parent in parents: + replace_child(parent, child, new_child, query_generator) + + nodes_to_reset = list(parents) + while nodes_to_reset: + node = nodes_to_reset.pop() + + update_resolvable_node(node, query_generator) + + parents = parent_map[node] + nodes_to_reset.extend(parents) + + def plot_plan_if_enabled(root: LogicalPlan, filename: str) -> None: """A helper function to plot the query plan tree using graphviz useful for debugging. It plots the plan if the environment variable ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index 28eca83633..b106ba9a56 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -36,7 +36,10 @@ from snowflake.snowpark import Column, Row, Window from snowflake.snowpark._internal.analyzer.analyzer_utils import result_scan_statement from snowflake.snowpark._internal.analyzer.expression import Attribute, Star -from snowflake.snowpark._internal.utils import TempObjectType +from snowflake.snowpark._internal.utils import ( + TempObjectType, + random_name_for_temp_object, +) from snowflake.snowpark.dataframe_na_functions import _SUBSET_CHECK_ERROR_MESSAGE from snowflake.snowpark.exceptions import ( SnowparkColumnException, @@ -63,6 +66,7 @@ udtf, uniform, when, + cast, ) from snowflake.snowpark.types import ( FileType, @@ -3352,6 +3356,241 @@ def test_append_existing_table(session, local_testing_mode): Utils.drop_table(session, table_name) +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="Dynamic table is a SQL feature", +) +def test_dynamic_table_join_table_function(session): + class TestVolumeModels: + def process(self, s1: str, s2: float): + yield (1,) + + function_name = random_name_for_temp_object(TempObjectType.TABLE_FUNCTION) + table_name = random_name_for_temp_object(TempObjectType.TABLE) + stage_name = Utils.random_stage_name() + Utils.create_stage(session, stage_name, is_temporary=True) + + try: + test_udtf = session.udtf.register( + TestVolumeModels, + name=function_name, + is_permanent=True, + is_replace=True, + stage_location=stage_name, + packages=["numpy", "pandas", "snowflake-snowpark-python"], + output_schema=StructType([StructField("DUMMY", IntegerType())]), + ) + + ( + session.create_dataframe( + [ + [ + 100002, + 100316, + 9, + "2025-02-03", + 3.932, + "2025-02-03 23:41:29.093 -0800", + ] + ], + schema=[ + "SITEID", + "COMPETITOR_ID", + "GRADEID", + "OBSERVATION_DATE", + "OBSERVED_PRICE", + "LAST_UPDATED", + ], + ).write.save_as_table("dy_tb", mode="overwrite") + ) + df_input = session.table("dy_tb") + df_t = df_input.join_table_function( + test_udtf( + cast("SITEID", StringType()), cast("OBSERVED_PRICE", FloatType()) + ).over(partition_by=iter(["SITEID"])) + ) + + df_t.create_or_replace_dynamic_table( + table_name, + warehouse=session.get_current_warehouse(), + lag="1 minute", + is_transient=True, + ) + Utils.check_answer( + df_t, + [ + Row( + SITEID=100002, + COMPETITOR_ID=100316, + GRADEID=9, + OBSERVATION_DATE="2025-02-03", + OBSERVED_PRICE=3.932, + LAST_UPDATED="2025-02-03 23:41:29.093 -0800", + DUMMY=1, + ) + ], + ) + finally: + session.sql(f"DROP FUNCTION IF EXISTS {function_name}(VARCHAR, FLOAT)") + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="Dynamic table is a SQL feature", +) +def test_dynamic_table_join_table_function_with_more_layers(session): + class TestVolumeModels: + def process(self, s1: str, s2: float): + yield (1,) + + function_name = random_name_for_temp_object(TempObjectType.TABLE_FUNCTION) + table_name = random_name_for_temp_object(TempObjectType.TABLE) + stage_name = Utils.random_stage_name() + Utils.create_stage(session, stage_name, is_temporary=True) + + try: + test_udtf = session.udtf.register( + TestVolumeModels, + name=function_name, + is_permanent=True, + is_replace=True, + stage_location=stage_name, + packages=["numpy", "pandas", "snowflake-snowpark-python"], + output_schema=StructType([StructField("DUMMY", IntegerType())]), + ) + + ( + session.create_dataframe( + [ + [ + 100002, + 100316, + 9, + "2025-02-03", + 3.932, + "2025-02-03 23:41:29.093 -0800", + ] + ], + schema=[ + "SITEID", + "COMPETITOR_ID", + "GRADEID", + "OBSERVATION_DATE", + "OBSERVED_PRICE", + "LAST_UPDATED", + ], + ).write.save_as_table("dy_tb", mode="overwrite") + ) + df_input = session.table("dy_tb") + df_t = df_input.join_table_function( + test_udtf( + cast("SITEID", StringType()), cast("OBSERVED_PRICE", FloatType()) + ).over(partition_by=iter(["SITEID"])) + ) + + df_t = df_t.with_column("COL1", lit(1)).distinct() + df_t.create_or_replace_dynamic_table( + table_name, + warehouse=session.get_current_warehouse(), + lag="1 minute", + is_transient=True, + ) + Utils.check_answer( + df_t, + [ + Row( + SITEID=100002, + COMPETITOR_ID=100316, + GRADEID=9, + OBSERVATION_DATE="2025-02-03", + OBSERVED_PRICE=3.932, + LAST_UPDATED="2025-02-03 23:41:29.093 -0800", + DUMMY=1, + COL1=1, + ) + ], + ) + finally: + session.sql(f"DROP FUNCTION IF EXISTS {function_name}(VARCHAR, FLOAT)") + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="Dynamic table is a SQL feature", +) +def test_dynamic_table_join_table_function_nested(session): + class TestVolumeModels: + def process(self, s1: str, s2: float): + yield (1,) + + function_name = random_name_for_temp_object(TempObjectType.TABLE_FUNCTION) + table_name = random_name_for_temp_object(TempObjectType.TABLE) + stage_name = Utils.random_stage_name() + Utils.create_stage(session, stage_name, is_temporary=True) + + try: + test_udtf = session.udtf.register( + TestVolumeModels, + name=function_name, + is_permanent=True, + is_replace=True, + stage_location=stage_name, + packages=["numpy", "pandas", "snowflake-snowpark-python"], + output_schema=StructType([StructField("DUMMY", IntegerType())]), + ) + + ( + session.create_dataframe( + [ + [ + 100002, + 100316, + 9, + "2025-02-03", + 3.932, + "2025-02-03 23:41:29.093 -0800", + ] + ], + schema=[ + "SITEID", + "COMPETITOR_ID", + "GRADEID", + "OBSERVATION_DATE", + "OBSERVED_PRICE", + "LAST_UPDATED", + ], + ).write.save_as_table("dy_tb", mode="overwrite") + ) + df_input = session.table("dy_tb") + df_t = df_input.join_table_function( + test_udtf( + cast("SITEID", StringType()), cast("OBSERVED_PRICE", FloatType()) + ).over(partition_by=iter(["SITEID"])) + ).select( + col("SITEID"), col("OBSERVATION_DATE"), col("LAST_UPDATED"), col("DUMMY") + ) + finally: + session.sql(f"DROP FUNCTION IF EXISTS {function_name}(VARCHAR, FLOAT)") + + df_t.create_or_replace_dynamic_table( + table_name, + warehouse=session.get_current_warehouse(), + lag="1 minute", + is_transient=True, + ) + Utils.check_answer( + df_t, + [ + Row( + SITEID=100002, + OBSERVATION_DATE="2025-02-03", + LAST_UPDATED="2025-02-03 23:41:29.093 -0800", + DUMMY=1, + ) + ], + ) + + @pytest.mark.xfail( "config.getoption('local_testing_mode', default=False)", reason="Dynamic table is a SQL feature",