Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

- Fixed a bug in `DataFrameReader.dbapi` (PrPr) where the `create_connection` defined as local function was incompatible with multiprocessing.
- Fixed a bug in `DataFrameReader.dbapi` (PrPr) where databricks `TIMESTAMP` type was converted to Snowflake `TIMESTAMP_NTZ` type which should be `TIMESTAMP_LTZ` type.
- Fixed a bug when create dynamic table on a table function cause error because '*' is not allowed in table function select.

#### Improvements

Expand Down
61 changes: 60 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import sys
import uuid
from collections import defaultdict
from collections import defaultdict, deque
from enum import Enum
from functools import cached_property
from typing import (
Expand All @@ -30,6 +30,7 @@
from snowflake.snowpark._internal.analyzer.table_function import (
GeneratorTableFunction,
TableFunctionRelation,
TableFunctionJoin,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -1324,6 +1325,61 @@ def create_or_replace_view(
source_plan,
)

def find_table_function_in_sql_tree(self, plan: SnowflakePlan) -> SnowflakePlan:
"""This function is meant to find any udtf function call from a create dynamic table plan and
replace '*' with explicit column identifier in the select of table function. Since we cannot
differentiate udtf call from other table functions, we apply this change to all table functions.
"""
deepcopied_plan = copy.deepcopy(plan)
queue = deque()
queue.append(deepcopied_plan)
from snowflake.snowpark._internal.analyzer.select_statement import (
SelectTableFunction,
Selectable,
)

while queue:
deepcopied_plan = queue.popleft()
for node in deepcopied_plan.children_plan_nodes:
queue.append(node)

# the bug only happen when create dynamic table on top of a table function
# this is meant to decide whether the plan is select from a table function
if isinstance(deepcopied_plan, SelectTableFunction) and isinstance(
deepcopied_plan.snowflake_plan.source_plan, TableFunctionJoin
):
# if clause used to decide that right column is '*' that we need to change and there is only 1 child
# plan to change, a table function can only be right joined, so we only care about right column here.
if (
deepcopied_plan.snowflake_plan.source_plan.right_cols == ["*"]
and len(deepcopied_plan.snowflake_plan.children_plan_nodes) == 1
):
child_plan = deepcopied_plan.snowflake_plan.children_plan_nodes[0]
if isinstance(child_plan, Selectable):
child_plan = child_plan.snowflake_plan
deepcopied_plan.snowflake_plan.source_plan.right_cols = (
deepcopied_plan.snowflake_plan.quoted_identifiers[
len(child_plan.quoted_identifiers) :
]
)
new_plan = self.session._analyzer.resolve(
deepcopied_plan.snowflake_plan.source_plan
)
deepcopied_plan._snowflake_plan = new_plan

# resolve the plan to apply change
self.session._analyzer.resolve(
deepcopied_plan.snowflake_plan.source_plan # type: ignore
if isinstance(deepcopied_plan, Selectable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is is possible for a Selebtable to have a snowflake_plan where source_plan is None. Can we make sure that is not the case here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the case shall not happen here, @sfc-gh-jdu can you help me confirm?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it's possible. But to be safe, can we exit the recursion if it's None?

else deepcopied_plan.source_plan
)
# resolve the plan to apply our change
return self.session._analyzer.resolve(
deepcopied_plan.snowflake_plan.source_plan # type: ignore
if isinstance(deepcopied_plan, Selectable)
else deepcopied_plan.source_plan
)

def create_or_replace_dynamic_table(
self,
name: str,
Expand All @@ -1341,6 +1397,9 @@ def create_or_replace_dynamic_table(
source_plan: Optional[LogicalPlan],
iceberg_config: Optional[dict] = None,
) -> SnowflakePlan:

child = self.find_table_function_in_sql_tree(child)

if len(child.queries) != 1:
raise SnowparkClientExceptionMessages.PLAN_CREATE_DYNAMIC_TABLE_FROM_DDL_DML_OPERATIONS()

Expand Down
163 changes: 162 additions & 1 deletion tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
from snowflake.snowpark import Column, Row, Window
from snowflake.snowpark._internal.analyzer.analyzer_utils import result_scan_statement
from snowflake.snowpark._internal.analyzer.expression import Attribute, Star
from snowflake.snowpark._internal.utils import TempObjectType
from snowflake.snowpark._internal.utils import (
TempObjectType,
random_name_for_temp_object,
)
from snowflake.snowpark.dataframe_na_functions import _SUBSET_CHECK_ERROR_MESSAGE
from snowflake.snowpark.exceptions import (
SnowparkColumnException,
Expand All @@ -63,6 +66,7 @@
udtf,
uniform,
when,
cast,
)
from snowflake.snowpark.types import (
FileType,
Expand Down Expand Up @@ -3352,6 +3356,163 @@ def test_append_existing_table(session, local_testing_mode):
Utils.drop_table(session, table_name)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="Dynamic table is a SQL feature",
run=False,
)
def test_dynamic_table_join_table_function(session):
class TestVolumeModels:
def process(self, s1: str, s2: float):
yield (1,)

function_name = random_name_for_temp_object(TempObjectType.TABLE_FUNCTION)
table_name = random_name_for_temp_object(TempObjectType.TABLE)
stage_name = Utils.random_stage_name()
Utils.create_stage(session, stage_name, is_temporary=True)

try:
test_udtf = session.udtf.register(
TestVolumeModels,
name=function_name,
is_permanent=True,
is_replace=True,
stage_location=stage_name,
packages=["numpy", "pandas", "snowflake-snowpark-python"],
output_schema=StructType([StructField("DUMMY", IntegerType())]),
)

(
session.create_dataframe(
[
[
100002,
100316,
9,
"2025-02-03",
3.932,
"2025-02-03 23:41:29.093 -0800",
]
],
schema=[
"SITEID",
"COMPETITOR_ID",
"GRADEID",
"OBSERVATION_DATE",
"OBSERVED_PRICE",
"LAST_UPDATED",
],
).write.save_as_table("dy_tb", mode="overwrite")
)
df_input = session.table("dy_tb")
df_t = df_input.join_table_function(
test_udtf(
cast("SITEID", StringType()), cast("OBSERVED_PRICE", FloatType())
).over(partition_by=iter(["SITEID"]))
)
finally:
session.sql(f"DROP FUNCTION IF EXISTS {function_name}(VARCHAR, FLOAT)")

df_t.create_or_replace_dynamic_table(
table_name,
warehouse=session.get_current_warehouse(),
lag="1 minute",
is_transient=True,
)
Utils.check_answer(
df_t,
[
Row(
SITEID=100002,
COMPETITOR_ID=100316,
GRADEID=9,
OBSERVATION_DATE="2025-02-03",
OBSERVED_PRICE=3.932,
LAST_UPDATED="2025-02-03 23:41:29.093 -0800",
DUMMY=1,
)
],
)


@pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="Dynamic table is a SQL feature",
run=False,
)
def test_dynamic_table_join_table_function_nested(session):
class TestVolumeModels:
def process(self, s1: str, s2: float):
yield (1,)

function_name = random_name_for_temp_object(TempObjectType.TABLE_FUNCTION)
table_name = random_name_for_temp_object(TempObjectType.TABLE)
stage_name = Utils.random_stage_name()
Utils.create_stage(session, stage_name, is_temporary=True)

try:
test_udtf = session.udtf.register(
TestVolumeModels,
name=function_name,
is_permanent=True,
is_replace=True,
stage_location=stage_name,
packages=["numpy", "pandas", "snowflake-snowpark-python"],
output_schema=StructType([StructField("DUMMY", IntegerType())]),
)

(
session.create_dataframe(
[
[
100002,
100316,
9,
"2025-02-03",
3.932,
"2025-02-03 23:41:29.093 -0800",
]
],
schema=[
"SITEID",
"COMPETITOR_ID",
"GRADEID",
"OBSERVATION_DATE",
"OBSERVED_PRICE",
"LAST_UPDATED",
],
).write.save_as_table("dy_tb", mode="overwrite")
)
df_input = session.table("dy_tb")
df_t = df_input.join_table_function(
test_udtf(
cast("SITEID", StringType()), cast("OBSERVED_PRICE", FloatType())
).over(partition_by=iter(["SITEID"]))
).select(
col("SITEID"), col("OBSERVATION_DATE"), col("LAST_UPDATED"), col("DUMMY")
)
finally:
session.sql(f"DROP FUNCTION IF EXISTS {function_name}(VARCHAR, FLOAT)")

df_t.create_or_replace_dynamic_table(
table_name,
warehouse=session.get_current_warehouse(),
lag="1 minute",
is_transient=True,
)
Utils.check_answer(
df_t,
[
Row(
SITEID=100002,
OBSERVATION_DATE="2025-02-03",
LAST_UPDATED="2025-02-03 23:41:29.093 -0800",
DUMMY=1,
)
],
)


@pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="Dynamic table is a SQL feature",
Expand Down
Loading