1515 SelectSnowflakePlan ,
1616 SelectStatement ,
1717 SelectTableFunction ,
18+ SelectableEntity ,
1819 SetStatement ,
1920)
2021from snowflake .snowpark ._internal .analyzer .snowflake_plan import (
2829 LogicalPlan ,
2930 SnowflakeCreateTable ,
3031 TableCreationSource ,
32+ WithQueryBlock ,
3133)
3234from snowflake .snowpark ._internal .analyzer .table_merge_expression import (
3335 TableDelete ,
@@ -381,15 +383,29 @@ def plot_plan_if_enabled(root: LogicalPlan, filename: str) -> None:
381383 ):
382384 return
383385
386+ if int (
387+ os .environ .get ("SNOWPARK_LOGICAL_PLAN_PLOTTING_COMPLEXITY_THRESHOLD" , 0 )
388+ ) > get_complexity_score (root ):
389+ return
390+
384391 import graphviz # pyright: ignore[reportMissingImports]
385392
386393 def get_stat (node : LogicalPlan ):
387- def get_name (node : Optional [LogicalPlan ]) -> str :
394+ def get_name (node : Optional [LogicalPlan ]) -> str : # pragma: no cover
388395 if node is None :
389396 return "EMPTY_SOURCE_PLAN" # pragma: no cover
390397 addr = hex (id (node ))
391398 name = str (type (node )).split ("." )[- 1 ].split ("'" )[0 ]
392- return f"{ name } ({ addr } )"
399+ suffix = ""
400+ if isinstance (node , SnowflakeCreateTable ):
401+ # get the table name from the full qualified name
402+ table_name = node .table_name [- 1 ].split ("." )[- 1 ] # pyright: ignore
403+ suffix = f" :: { table_name } "
404+ if isinstance (node , WithQueryBlock ):
405+ # get the CTE identifier excluding SNOWPARK_TEMP_CTE_
406+ suffix = f" :: { node .name [18 :]} "
407+
408+ return f"{ name } ({ addr } ){ suffix } "
393409
394410 name = get_name (node )
395411 if isinstance (node , SnowflakePlan ):
@@ -411,20 +427,44 @@ def get_name(node: Optional[LogicalPlan]) -> str:
411427 if node .offset :
412428 properties .append ("Offset" ) # pragma: no cover
413429 name = f"{ name } :: ({ '| ' .join (properties )} )"
430+ elif isinstance (node , SelectableEntity ):
431+ # get the table name from the full qualified name
432+ name = f"{ name } :: ({ node .entity .name .split ('.' )[- 1 ]} )"
433+
434+ def get_sql_text (node : LogicalPlan ) -> str : # pragma: no cover
435+ if isinstance (node , Selectable ):
436+ return node .sql_query
437+ if isinstance (node , SnowflakePlan ):
438+ return node .queries [- 1 ].sql
439+ return ""
414440
415441 score = get_complexity_score (node )
416- num_ref_ctes = "nil"
417- if isinstance (node , (SnowflakePlan , Selectable )):
418- num_ref_ctes = len (node .referenced_ctes )
419- sql_text = ""
420- if isinstance (node , Selectable ):
421- sql_text = node .sql_query
422- elif isinstance (node , SnowflakePlan ):
423- sql_text = node .queries [- 1 ].sql
442+ sql_text = get_sql_text (node )
424443 sql_size = len (sql_text )
444+ ref_ctes = None
445+ if isinstance (node , (SnowflakePlan , Selectable )):
446+ ref_ctes = list (
447+ map (
448+ lambda node , cnt : f"{ node .name [18 :]} :{ cnt } " ,
449+ node .referenced_ctes .keys (),
450+ node .referenced_ctes .values (),
451+ )
452+ )
453+ for with_query_block in node .referenced_ctes : # pragma: no cover
454+ sql_size += len (get_sql_text (with_query_block .children [0 ]))
425455 sql_preview = sql_text [:50 ]
426456
427- return f"{ name = } \n { score = } , { num_ref_ctes = } , { sql_size = } \n { sql_preview = } "
457+ return f"{ name = } \n { score = } , { ref_ctes = } , { sql_size = } \n { sql_preview = } "
458+
459+ def is_with_query_block (node : Optional [LogicalPlan ]) -> bool : # pragma: no cover
460+ if isinstance (node , WithQueryBlock ):
461+ return True
462+ if isinstance (node , SnowflakePlan ):
463+ return is_with_query_block (node .source_plan )
464+ if isinstance (node , SelectSnowflakePlan ):
465+ return is_with_query_block (node .snowflake_plan )
466+
467+ return False
428468
429469 g = graphviz .Graph (format = "png" )
430470
@@ -435,11 +475,18 @@ def get_name(node: Optional[LogicalPlan]) -> str:
435475 for node in curr_level :
436476 node_id = hex (id (node ))
437477 color = "lightblue" if node ._is_valid_for_replacement else "red"
438- g .node (node_id , get_stat (node ), color = color )
478+ fillcolor = "lightgray" if is_with_query_block (node ) else "white"
479+ g .node (
480+ node_id ,
481+ get_stat (node ),
482+ color = color ,
483+ style = "filled" ,
484+ fillcolor = fillcolor ,
485+ )
439486 if isinstance (node , (Selectable , SnowflakePlan )):
440487 children = node .children_plan_nodes
441488 else :
442- children = node .children
489+ children = node .children # pragma: no cover
443490 for child in children :
444491 child_id = hex (id (child ))
445492 edges .add ((node_id , child_id ))
0 commit comments