Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

- Fixed a bug in `DataFrameReader.dbapi` (PrPr) where the `create_connection` defined as local function was incompatible with multiprocessing.
- Fixed a bug in `DataFrameReader.dbapi` (PrPr) where databricks `TIMESTAMP` type was converted to Snowflake `TIMESTAMP_NTZ` type which should be `TIMESTAMP_LTZ` type.
- Fixed a bug when create dynamic table on a table function cause error because '*' is not allowed in table function select.

#### Improvements

Expand Down
72 changes: 71 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import sys
import uuid
from collections import defaultdict
from collections import defaultdict, deque
from enum import Enum
from functools import cached_property
from typing import (
Expand All @@ -30,6 +30,7 @@
from snowflake.snowpark._internal.analyzer.table_function import (
GeneratorTableFunction,
TableFunctionRelation,
TableFunctionJoin,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -1324,6 +1325,72 @@ def create_or_replace_view(
source_plan,
)

def find_and_update_table_function_plan(
self, plan: SnowflakePlan
) -> Optional[SnowflakePlan]:
"""This function is meant to find any udtf function call from a create dynamic table plan and
replace '*' with explicit column identifier in the select of table function. Since we cannot
differentiate udtf call from other table functions, we apply this change to all table functions.
"""
from snowflake.snowpark._internal.analyzer.select_statement import (
SelectTableFunction,
Selectable,
)
from snowflake.snowpark._internal.compiler.utils import (
create_query_generator,
update_resolvable_node,
replace_child_and_update_ancestors,
)

visited = set()
node_parents_map = defaultdict(set)
deepcopied_plan = copy.deepcopy(plan)
query_generator = create_query_generator(plan)
queue = deque()

queue.append(deepcopied_plan)
visited.add(deepcopied_plan)

while queue:
node = queue.popleft()
visited.add(node)
for child_node in reversed(node.children_plan_nodes):
node_parents_map[child_node].add(node)
if child_node not in visited:
queue.append(child_node)

# the bug only happen when create dynamic table on top of a table function
# this is meant to decide whether the plan is select from a table function
if isinstance(node, SelectTableFunction) and isinstance(
node.snowflake_plan.source_plan, TableFunctionJoin
):
table_function_join_node = node.snowflake_plan.source_plan
# if the plan has only 1 child and the source_plan.right_cols == '*', then we need to update the
# plan with the output column identifiers.
if len(
node.snowflake_plan.children_plan_nodes
) == 1 and table_function_join_node.right_cols == ["*"]:
child_plan: Union[
SnowflakePlan, Selectable
] = node.snowflake_plan.children_plan_nodes[0]
if isinstance(child_plan, Selectable):
child_plan = child_plan.snowflake_plan
assert isinstance(child_plan, SnowflakePlan)

new_plan = copy.deepcopy(node)
new_plan.snowflake_plan.source_plan.right_cols = (
node.snowflake_plan.quoted_identifiers[
len(child_plan.quoted_identifiers) :
]
)

update_resolvable_node(new_plan, query_generator)
replace_child_and_update_ancestors(
node, new_plan, node_parents_map, query_generator
)

return deepcopied_plan

def create_or_replace_dynamic_table(
self,
name: str,
Expand All @@ -1341,6 +1408,9 @@ def create_or_replace_dynamic_table(
source_plan: Optional[LogicalPlan],
iceberg_config: Optional[dict] = None,
) -> SnowflakePlan:

child = self.find_and_update_table_function_plan(child)

if len(child.queries) != 1:
raise SnowparkClientExceptionMessages.PLAN_CREATE_DYNAMIC_TABLE_FROM_DDL_DML_OPERATIONS()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@
extract_child_from_with_query_block,
is_active_transaction,
is_with_query_block,
replace_child,
update_resolvable_node,
replace_child_and_update_ancestors,
)
from snowflake.snowpark._internal.utils import (
TempObjectType,
Expand Down Expand Up @@ -590,15 +589,6 @@ def _replace_child_and_update_ancestors(
)
temp_table_selectable.post_actions = [drop_table_query]

parents = self._parent_map[child]
for parent in parents:
replace_child(parent, child, temp_table_selectable, self._query_generator)

nodes_to_reset = list(parents)
while nodes_to_reset:
node = nodes_to_reset.pop()

update_resolvable_node(node, self._query_generator)

parents = self._parent_map[node]
nodes_to_reset.extend(parents)
replace_child_and_update_ancestors(
child, temp_table_selectable, self._parent_map, self._query_generator
)
27 changes: 26 additions & 1 deletion src/snowflake/snowpark/_internal/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
import copy
import tempfile
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Set, Union

from snowflake.snowpark._internal.analyzer.binary_plan_node import BinaryNode
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
Expand Down Expand Up @@ -391,6 +391,31 @@ def is_with_query_block(node: LogicalPlan) -> bool:
return False


def replace_child_and_update_ancestors(
child: LogicalPlan,
new_child: LogicalPlan,
parent_map: Dict[LogicalPlan, Set[TreeNode]],
query_generator: QueryGenerator,
):
"""
For the given child, this helper function updates all its parents with the new
child provided and updates all the ancestor nodes.
"""
parents = parent_map[child]

for parent in parents:
replace_child(parent, child, new_child, query_generator)

nodes_to_reset = list(parents)
while nodes_to_reset:
node = nodes_to_reset.pop()

update_resolvable_node(node, query_generator)

parents = parent_map[node]
nodes_to_reset.extend(parents)


def plot_plan_if_enabled(root: LogicalPlan, filename: str) -> None:
"""A helper function to plot the query plan tree using graphviz useful for debugging.
It plots the plan if the environment variable ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING
Expand Down
Loading
Loading