Skip to content

Commit e75b506

Browse files
authored
SNOW-1869362: Plan plotter improvements (#2813)
1 parent ba31301 commit e75b506

File tree

2 files changed

+77
-16
lines changed

2 files changed

+77
-16
lines changed

src/snowflake/snowpark/_internal/compiler/utils.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
SelectSnowflakePlan,
1616
SelectStatement,
1717
SelectTableFunction,
18+
SelectableEntity,
1819
SetStatement,
1920
)
2021
from snowflake.snowpark._internal.analyzer.snowflake_plan import (
@@ -28,6 +29,7 @@
2829
LogicalPlan,
2930
SnowflakeCreateTable,
3031
TableCreationSource,
32+
WithQueryBlock,
3133
)
3234
from snowflake.snowpark._internal.analyzer.table_merge_expression import (
3335
TableDelete,
@@ -381,15 +383,29 @@ def plot_plan_if_enabled(root: LogicalPlan, filename: str) -> None:
381383
):
382384
return
383385

386+
if int(
387+
os.environ.get("SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD", 0)
388+
) > get_complexity_score(root):
389+
return
390+
384391
import graphviz # pyright: ignore[reportMissingImports]
385392

386393
def get_stat(node: LogicalPlan):
387-
def get_name(node: Optional[LogicalPlan]) -> str:
394+
def get_name(node: Optional[LogicalPlan]) -> str: # pragma: no cover
388395
if node is None:
389396
return "EMPTY_SOURCE_PLAN" # pragma: no cover
390397
addr = hex(id(node))
391398
name = str(type(node)).split(".")[-1].split("'")[0]
392-
return f"{name}({addr})"
399+
suffix = ""
400+
if isinstance(node, SnowflakeCreateTable):
401+
# get the table name from the full qualified name
402+
table_name = node.table_name[-1].split(".")[-1] # pyright: ignore
403+
suffix = f" :: {table_name}"
404+
if isinstance(node, WithQueryBlock):
405+
# get the CTE identifier excluding SNOWPARK_TEMP_CTE_
406+
suffix = f" :: {node.name[18:]}"
407+
408+
return f"{name}({addr}){suffix}"
393409

394410
name = get_name(node)
395411
if isinstance(node, SnowflakePlan):
@@ -411,20 +427,44 @@ def get_name(node: Optional[LogicalPlan]) -> str:
411427
if node.offset:
412428
properties.append("Offset") # pragma: no cover
413429
name = f"{name} :: ({'| '.join(properties)})"
430+
elif isinstance(node, SelectableEntity):
431+
# get the table name from the full qualified name
432+
name = f"{name} :: ({node.entity.name.split('.')[-1]})"
433+
434+
def get_sql_text(node: LogicalPlan) -> str: # pragma: no cover
435+
if isinstance(node, Selectable):
436+
return node.sql_query
437+
if isinstance(node, SnowflakePlan):
438+
return node.queries[-1].sql
439+
return ""
414440

415441
score = get_complexity_score(node)
416-
num_ref_ctes = "nil"
417-
if isinstance(node, (SnowflakePlan, Selectable)):
418-
num_ref_ctes = len(node.referenced_ctes)
419-
sql_text = ""
420-
if isinstance(node, Selectable):
421-
sql_text = node.sql_query
422-
elif isinstance(node, SnowflakePlan):
423-
sql_text = node.queries[-1].sql
442+
sql_text = get_sql_text(node)
424443
sql_size = len(sql_text)
444+
ref_ctes = None
445+
if isinstance(node, (SnowflakePlan, Selectable)):
446+
ref_ctes = list(
447+
map(
448+
lambda node, cnt: f"{node.name[18:]}:{cnt}",
449+
node.referenced_ctes.keys(),
450+
node.referenced_ctes.values(),
451+
)
452+
)
453+
for with_query_block in node.referenced_ctes: # pragma: no cover
454+
sql_size += len(get_sql_text(with_query_block.children[0]))
425455
sql_preview = sql_text[:50]
426456

427-
return f"{name=}\n{score=}, {num_ref_ctes=}, {sql_size=}\n{sql_preview=}"
457+
return f"{name=}\n{score=}, {ref_ctes=}, {sql_size=}\n{sql_preview=}"
458+
459+
def is_with_query_block(node: Optional[LogicalPlan]) -> bool: # pragma: no cover
460+
if isinstance(node, WithQueryBlock):
461+
return True
462+
if isinstance(node, SnowflakePlan):
463+
return is_with_query_block(node.source_plan)
464+
if isinstance(node, SelectSnowflakePlan):
465+
return is_with_query_block(node.snowflake_plan)
466+
467+
return False
428468

429469
g = graphviz.Graph(format="png")
430470

@@ -435,11 +475,18 @@ def get_name(node: Optional[LogicalPlan]) -> str:
435475
for node in curr_level:
436476
node_id = hex(id(node))
437477
color = "lightblue" if node._is_valid_for_replacement else "red"
438-
g.node(node_id, get_stat(node), color=color)
478+
fillcolor = "lightgray" if is_with_query_block(node) else "white"
479+
g.node(
480+
node_id,
481+
get_stat(node),
482+
color=color,
483+
style="filled",
484+
fillcolor=fillcolor,
485+
)
439486
if isinstance(node, (Selectable, SnowflakePlan)):
440487
children = node.children_plan_nodes
441488
else:
442-
children = node.children
489+
children = node.children # pragma: no cover
443490
for child in children:
444491
child_id = hex(id(child))
445492
edges.add((node_id, child_id))

tests/integ/test_large_query_breakdown.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -788,16 +788,24 @@ def test_large_query_breakdown_enabled_parameter(session, caplog):
788788

789789
@pytest.mark.skipif(IS_IN_STORED_PROC, reason="requires graphviz")
790790
@pytest.mark.parametrize("enabled", [False, True])
791-
def test_plotter(session, large_query_df, enabled):
791+
@pytest.mark.parametrize("plotting_score_threshold", [0, 10_000_000])
792+
def test_plotter(large_query_df, enabled, plotting_score_threshold):
792793
original_plotter_enabled = os.environ.get("ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING")
794+
original_score_threshold = os.environ.get(
795+
"SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD"
796+
)
793797
try:
794798
os.environ["ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING"] = str(enabled)
799+
os.environ["SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD"] = str(
800+
plotting_score_threshold
801+
)
795802
tmp_dir = tempfile.gettempdir()
796803

797804
with patch("graphviz.Graph.render") as mock_render:
798805
large_query_df.collect()
799-
assert mock_render.called == enabled
800-
if not enabled:
806+
should_plot = enabled and (plotting_score_threshold == 0)
807+
assert mock_render.called == should_plot
808+
if not should_plot:
801809
return
802810

803811
assert mock_render.call_count == 5
@@ -819,3 +827,9 @@ def test_plotter(session, large_query_df, enabled):
819827
] = original_plotter_enabled
820828
else:
821829
del os.environ["ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING"]
830+
if original_score_threshold is not None:
831+
os.environ[
832+
"SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD"
833+
] = original_score_threshold
834+
else:
835+
del os.environ["SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD"]

0 commit comments

Comments
 (0)