Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 57 additions & 13 deletions src/snowflake/snowpark/_internal/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SelectSnowflakePlan,
SelectStatement,
SelectTableFunction,
SelectableEntity,
SetStatement,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan import (
Expand All @@ -28,6 +29,7 @@
LogicalPlan,
SnowflakeCreateTable,
TableCreationSource,
WithQueryBlock,
)
from snowflake.snowpark._internal.analyzer.table_merge_expression import (
TableDelete,
Expand Down Expand Up @@ -381,15 +383,27 @@ def plot_plan_if_enabled(root: LogicalPlan, filename: str) -> None:
):
return

if int(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this Plotting threshold used for? seems it is used for restricting the complexity score? maybe call this SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD to be more clear

os.environ.get("SNOWPARK_LOGICAL_PLAN_PLOTTING_THRESHOLD", 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's simply make the default threshold -1, be clear that by default plot out all nodes.

was there a reason about why we want to add this threshold?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah. In my tests, I generally want to plot and debug "big" plans but sometime the plans get overwritten by smaller plan if they are present somewhere. That's why I added this variable. I don't think this is the best way - I'm open to suggestions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you be more specific about " sometime the plans get overwritten by smaller plan if they are present somewhere"? not quite getting this part, and what information you want to get to help your debugging process?

) > get_complexity_score(root):
return

import graphviz # pyright: ignore[reportMissingImports]

def get_stat(node: LogicalPlan):
def get_name(node: Optional[LogicalPlan]) -> str:
def get_name(node: Optional[LogicalPlan]) -> str: # pragma: no cover
if node is None:
return "EMPTY_SOURCE_PLAN" # pragma: no cover
addr = hex(id(node))
name = str(type(node)).split(".")[-1].split("'")[0]
return f"{name}({addr})"
suffix = ""
if isinstance(node, SnowflakeCreateTable):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment here about what are the different printing used here

table_name = node.table_name[-1].split(".")[-1] # pyright: ignore
suffix = f" :: {table_name}"
if isinstance(node, WithQueryBlock):
suffix = f" :: {node.name[18:]}"

return f"{name}({addr}){suffix}"

name = get_name(node)
if isinstance(node, SnowflakePlan):
Expand All @@ -411,20 +425,43 @@ def get_name(node: Optional[LogicalPlan]) -> str:
if node.offset:
properties.append("Offset") # pragma: no cover
name = f"{name} :: ({'| '.join(properties)})"
elif isinstance(node, SelectableEntity):
name = f"{name} :: ({node.entity.name.split('.')[-1]})"

def get_sql_text(node: LogicalPlan) -> str: # pragma: no cover
if isinstance(node, Selectable):
return node.sql_query
if isinstance(node, SnowflakePlan):
return node.queries[-1].sql
return ""

score = get_complexity_score(node)
num_ref_ctes = "nil"
if isinstance(node, (SnowflakePlan, Selectable)):
num_ref_ctes = len(node.referenced_ctes)
sql_text = ""
if isinstance(node, Selectable):
sql_text = node.sql_query
elif isinstance(node, SnowflakePlan):
sql_text = node.queries[-1].sql
sql_text = get_sql_text(node)
sql_size = len(sql_text)
ref_ctes = None
if isinstance(node, (SnowflakePlan, Selectable)):
ref_ctes = list(
map(
lambda node, cnt: f"{node.name[18:]}:{cnt}",
node.referenced_ctes.keys(),
node.referenced_ctes.values(),
)
)
for with_query_block in node.referenced_ctes: # pragma: no cover
sql_size += len(get_sql_text(with_query_block.children[0]))
sql_preview = sql_text[:50]

return f"{name=}\n{score=}, {num_ref_ctes=}, {sql_size=}\n{sql_preview=}"
return f"{name=}\n{score=}, {ref_ctes=}, {sql_size=}\n{sql_preview=}"

def is_with_query_block(node: Optional[LogicalPlan]) -> bool: # pragma: no cover
if isinstance(node, WithQueryBlock):
return True
if isinstance(node, SnowflakePlan):
return is_with_query_block(node.source_plan)
if isinstance(node, SelectSnowflakePlan):
return is_with_query_block(node.snowflake_plan)

return False

g = graphviz.Graph(format="png")

Expand All @@ -435,11 +472,18 @@ def get_name(node: Optional[LogicalPlan]) -> str:
for node in curr_level:
node_id = hex(id(node))
color = "lightblue" if node._is_valid_for_replacement else "red"
g.node(node_id, get_stat(node), color=color)
fillcolor = "lightgray" if is_with_query_block(node) else "white"
g.node(
node_id,
get_stat(node),
color=color,
style="filled",
fillcolor=fillcolor,
)
if isinstance(node, (Selectable, SnowflakePlan)):
children = node.children_plan_nodes
else:
children = node.children
children = node.children # pragma: no cover
for child in children:
child_id = hex(id(child))
edges.add((node_id, child_id))
Expand Down
20 changes: 17 additions & 3 deletions tests/integ/test_large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,16 +788,24 @@ def test_large_query_breakdown_enabled_parameter(session, caplog):

@pytest.mark.skipif(IS_IN_STORED_PROC, reason="requires graphviz")
@pytest.mark.parametrize("enabled", [False, True])
def test_plotter(session, large_query_df, enabled):
@pytest.mark.parametrize("plotting_score_threshold", [0, 10_000_000])
def test_plotter(large_query_df, enabled, plotting_score_threshold):
original_plotter_enabled = os.environ.get("ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING")
original_score_threshold = os.environ.get(
"SNOWPARK_LOGICAL_PLAN_PLOTTING_THRESHOLD"
)
try:
os.environ["ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING"] = str(enabled)
os.environ["SNOWPARK_LOGICAL_PLAN_PLOTTING_THRESHOLD"] = str(
plotting_score_threshold
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be better done with something like:
with mock.patch.dict(os.environ, {...}):

tmp_dir = tempfile.gettempdir()

with patch("graphviz.Graph.render") as mock_render:
large_query_df.collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we perhaps add a comment explaining that the actual complexity for large_query_df falls somewhere between 0 and 10M?

assert mock_render.called == enabled
if not enabled:
should_plot = enabled and (plotting_score_threshold == 0)
assert mock_render.called == should_plot
if not should_plot:
return

assert mock_render.call_count == 5
Expand All @@ -819,3 +827,9 @@ def test_plotter(session, large_query_df, enabled):
] = original_plotter_enabled
else:
del os.environ["ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING"]
if original_score_threshold is not None:
os.environ[
"SNOWPARK_LOGICAL_PLAN_PLOTTING_THRESHOLD"
] = original_score_threshold
else:
del os.environ["SNOWPARK_LOGICAL_PLAN_PLOTTING_THRESHOLD"]