Skip to content

Commit 041e624

Browse files
Implementing dataframe profiler (#3504)
1 parent 88c85ec commit 041e624

18 files changed

+1263
-31
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
- Added support for the following functions in `functions.py`:
1818
- `ai_embed`
1919
- `try_parse_json`
20+
- Added a dataframe profiler. To use, you can call get_execution_profile() on your desired dataframe. This profiler reports the queries executed to evaluate a dataframe, and statistics about each of the query operators.
2021

2122
#### Bug Fixes
2223

@@ -64,8 +65,7 @@
6465
- Added debuggability improvements to eagerly validate dataframe schema metadata. Enable it using `snowflake.snowpark.context.configure_development_features()`.
6566
- Added a new function `snowflake.snowpark.dataframe.map_in_pandas` that allows users map a function across a dataframe. The mapping function takes an iterator of pandas dataframes as input and provides one as output.
6667
- Added a ttl cache to describe queries. Repeated queries in a 15 second interval will use the cached value rather than requery Snowflake.
67-
- Added a parameter `fetch_with_process` to `DataFrameReader.dbapi` (PrPr) to enable multiprocessing for parallel data fetching in
68-
local ingestion. By default, local ingestion uses multithreading. Multiprocessing may improve performance for CPU-bound tasks like Parquet file generation.
68+
- Added a parameter `fetch_with_process` to `DataFrameReader.dbapi` (PrPr) to enable multiprocessing for parallel data fetching in local ingestion. By default, local ingestion uses multithreading. Multiprocessing may improve performance for CPU-bound tasks like Parquet file generation.
6969
- Added a new function `snowflake.snowpark.functions.model` that allows users to call methods of a model.
7070

7171
#### Improvements

src/snowflake/snowpark/_internal/debug_utils.py

Lines changed: 229 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
from functools import cached_property
66
import os
77
import sys
8-
from typing import Dict, List, Optional
8+
from typing import Dict, List, Optional, Set, Tuple
99
import itertools
1010
import re
1111
from typing import TYPE_CHECKING
12+
import snowflake.snowpark
1213
from snowflake.snowpark._internal.ast.batch import get_dependent_bind_ids
1314
from snowflake.snowpark._internal.ast.utils import __STRING_INTERNING_MAP__
1415
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
16+
from ast import literal_eval
1517
from snowflake.snowpark._internal.ast.utils import extract_src_from_expr
1618

1719
if TYPE_CHECKING:
@@ -220,6 +222,32 @@ def _format_source_location(src: Optional[proto.SrcPosition]) -> str:
220222
return lines_info
221223

222224

225+
def _extract_source_locations_from_plan(plan: "SnowflakePlan") -> List[str]:
226+
"""
227+
Extract source locations from a SnowflakePlan's AST IDs.
228+
229+
Args:
230+
plan: The SnowflakePlan object to extract source locations from
231+
232+
Returns:
233+
List of unique source location strings (e.g., "file.py: line 42")
234+
"""
235+
source_locations = []
236+
found_locations = set()
237+
238+
if plan.df_ast_ids is not None:
239+
for ast_id in plan.df_ast_ids:
240+
bind_stmt = plan.session._ast_batch._bind_stmt_cache.get(ast_id)
241+
if bind_stmt is not None:
242+
src = extract_src_from_expr(bind_stmt.bind.expr)
243+
location = _format_source_location(src)
244+
if location and location not in found_locations:
245+
found_locations.add(location)
246+
source_locations.append(location)
247+
248+
return source_locations
249+
250+
223251
def get_python_source_from_sql_error(top_plan: "SnowflakePlan", error_msg: str) -> str:
224252
"""
225253
Extract SQL error line number and map it back to Python source code. We use the
@@ -249,17 +277,8 @@ def get_python_source_from_sql_error(top_plan: "SnowflakePlan", error_msg: str)
249277
)
250278

251279
plan = get_plan_from_line_numbers(top_plan, sql_line_number)
252-
source_locations = []
253-
found_locations = set()
254-
if plan.df_ast_ids is not None:
255-
for ast_id in plan.df_ast_ids:
256-
bind_stmt = plan.session._ast_batch._bind_stmt_cache.get(ast_id)
257-
if bind_stmt is not None:
258-
src = extract_src_from_expr(bind_stmt.bind.expr)
259-
location = _format_source_location(src)
260-
if location != "" and location not in found_locations:
261-
found_locations.add(location)
262-
source_locations.append(location)
280+
source_locations = _extract_source_locations_from_plan(plan)
281+
263282
if source_locations:
264283
if len(source_locations) == 1:
265284
return f"\nSQL compilation error corresponds to Python source at {source_locations[0]}.\n"
@@ -434,3 +453,201 @@ def sql_contains_object_creation(sql_query: str, target_object: str) -> bool:
434453
return f"\nObject '{object_name}' was first referenced at {location}.\n"
435454

436455
return ""
456+
457+
458+
class QueryProfiler:
459+
"""
460+
A class for profiling Snowflake queries and analyzing operator statistics.
461+
It can generate tree visualizations and output tables of operator statistics.
462+
"""
463+
464+
def __init__(
465+
self, session: "snowflake.snowpark.Session", output_file: Optional[str] = None
466+
) -> None:
467+
self.session = session
468+
if output_file:
469+
self.file_handle = open(output_file, "a", encoding="utf-8")
470+
else:
471+
self.file_handle = None
472+
473+
def _get_node_info(self, row: Dict) -> Dict:
474+
parent_operators = row.get("PARENT_OPERATORS")
475+
parent_operators = (
476+
str(parent_operators) if parent_operators is not None else None
477+
)
478+
node_info = {
479+
"id": row.get("OPERATOR_ID") or 0,
480+
"parent_operators": parent_operators,
481+
"type": row.get("OPERATOR_TYPE") or "N/A",
482+
"input_rows": row.get("INPUT_ROWS") or 0,
483+
"output_rows": row.get("OUTPUT_ROWS") or 0,
484+
"row_multiple": row.get("ROW_MULTIPLE") or 0,
485+
"exec_time": row.get("OVERALL_PERCENTAGE") or 0,
486+
"attributes": row.get("OPERATOR_ATTRIBUTES") or "N/A",
487+
}
488+
return node_info
489+
490+
def build_operator_tree(self, operators_data: List[Dict]) -> Tuple[Dict, Dict, Set]:
491+
"""
492+
Build a tree structure from raw operator data for query profiling.
493+
494+
Args:
495+
operators_data (List[Dict]): A list of dictionaries containing operator statistics.
496+
The keys include operator id, operator type, parent operators, input rows, output rows,
497+
row multiple, overall percentage, and operator attributes.
498+
499+
Returns:
500+
Tuple[Dict, Dict, Set]: A tuple containing:
501+
- nodes (Dict[int, Dict]): Dictionary mapping operator IDs to node information
502+
- children (Dict[int, List[int]]): Dictionary mapping operator IDs to lists of child operator IDs
503+
- root_nodes (Set[int]): Set of operator IDs that are root nodes (have no parents)
504+
505+
"""
506+
507+
nodes = {}
508+
children = {}
509+
root_nodes = set()
510+
for row in operators_data:
511+
node_info = self._get_node_info(row)
512+
513+
nodes[node_info["id"]] = node_info
514+
children[node_info["id"]] = []
515+
516+
if node_info["parent_operators"] is None:
517+
root_nodes.add(node_info["id"])
518+
else:
519+
# parse parent_operators, which is a string like "[1, 2, 3]" to a list
520+
x = literal_eval(node_info["parent_operators"])
521+
for parent_id in x:
522+
if parent_id not in children:
523+
children[parent_id] = []
524+
children[parent_id].append(node_info["id"])
525+
526+
return nodes, children, root_nodes
527+
528+
def _write_output(self, message: str) -> None:
529+
"""Helper function to write output to either console or file."""
530+
if self.file_handle:
531+
self.file_handle.write(message + "\n")
532+
else:
533+
sys.stdout.write(message + "\n")
534+
535+
def close(self) -> None:
536+
"""Close the file handle if it exists."""
537+
if self.file_handle:
538+
self.file_handle.close()
539+
540+
def print_operator_tree(
541+
self,
542+
nodes: Dict[int, Dict],
543+
children: Dict[int, List[int]],
544+
node_id: int,
545+
prefix: str = "",
546+
is_last: bool = True,
547+
) -> None:
548+
"""
549+
Print a visual tree representation of query operators with their statistics.
550+
551+
Args:
552+
nodes (Dict[int, Dict]): Dictionary mapping operator IDs to node information.
553+
children (Dict[int, List[int]]): Dictionary mapping operator IDs to lists of child operator IDs.
554+
node_id (int): The ID of the current operator node to print.
555+
prefix (str, optional): String prefix for tree formatting (used for indentation).
556+
Defaults to "".
557+
is_last (bool, optional): Whether this node is the last child of its parent.
558+
Used for proper tree connector formatting. Defaults to True.
559+
560+
Returns:
561+
None: This function writes output to a file or prints and doesn't return a value.
562+
563+
"""
564+
node = nodes[node_id]
565+
566+
connector = "└── " if is_last else "├── "
567+
568+
node_info = (
569+
f"[{node['id']}] {node['type']} "
570+
f"(In: {node['input_rows']:,}, Out: {node['output_rows']:,}, "
571+
f"Mult: {node['row_multiple']:.2f}, Time: {node['exec_time']:.2f}%)"
572+
)
573+
574+
self._write_output(f"{prefix}{connector}{node_info}")
575+
576+
extension = " " if is_last else "│ "
577+
new_prefix = prefix + extension
578+
579+
child_list = children.get(node_id, [])
580+
for i, child_id in enumerate(child_list):
581+
is_last_child = i == len(child_list) - 1
582+
self.print_operator_tree(
583+
nodes, children, child_id, new_prefix, is_last_child
584+
)
585+
586+
def profile_query(
587+
self,
588+
query_id: str,
589+
) -> None:
590+
"""
591+
Profile a query and save the results to a file.
592+
593+
Args:
594+
query_id: The query ID to profile
595+
596+
Returns:
597+
None - output either to the console or to the file specified by output_file
598+
"""
599+
600+
stats_query = f"""
601+
SELECT
602+
operator_id,
603+
operator_type,
604+
operator_attributes,
605+
operator_statistics:input_rows::number as input_rows,
606+
operator_statistics:output_rows::number as output_rows,
607+
CASE
608+
WHEN operator_statistics:input_rows::number > 0
609+
THEN operator_statistics:output_rows::number / operator_statistics:input_rows::number
610+
ELSE NULL
611+
END as row_multiple,
612+
execution_time_breakdown:overall_percentage::number as overall_percentage
613+
FROM TABLE(get_query_operator_stats('{query_id}'))
614+
ORDER BY step_id, operator_id
615+
"""
616+
stats_connection = self.session._conn._conn.cursor()
617+
stats_connection.execute(stats_query)
618+
raw_results = stats_connection.fetchall()
619+
620+
column_names = [desc[0] for desc in stats_connection.description]
621+
stats_result = [dict(zip(column_names, row)) for row in raw_results]
622+
623+
nodes, children, root_nodes = self.build_operator_tree(stats_result)
624+
625+
self._write_output(f"\n=== Analyzing Query {query_id} ===")
626+
self._write_output(f"\n{'='*80}")
627+
self._write_output("QUERY OPERATOR TREE")
628+
self._write_output(f"{'='*80}")
629+
630+
root_list = sorted(list(root_nodes))
631+
for i, root_id in enumerate(root_list):
632+
is_last_root = i == len(root_list) - 1
633+
self.print_operator_tree(nodes, children, root_id, "", is_last_root)
634+
635+
self._write_output(f"\n{'='*160}")
636+
self._write_output("DETAILED OPERATOR STATISTICS")
637+
self._write_output(f"{'='*160}")
638+
self._write_output(
639+
f"{'Operator':<15} {'Type':<15} {'Input Rows':<12} {'Output Rows':<12} {'Row Multiple':<12} {'Overall %':<12} {'Attributes':<50}",
640+
)
641+
self._write_output(f"{'='*160}")
642+
643+
for row in stats_result:
644+
node_info = self._get_node_info(row)
645+
operator_attrs = (
646+
node_info["attributes"].replace("\n", " ").replace(" ", " ")
647+
)
648+
649+
self._write_output(
650+
f"{node_info['id']:<15} {node_info['type']:<15} {node_info['input_rows']:<12} {node_info['output_rows']:<12} {node_info['row_multiple']:<12.2f} {node_info['exec_time']:<12} {operator_attrs:<50}",
651+
)
652+
653+
self._write_output(f"{'='*160}")

src/snowflake/snowpark/_internal/server_connection.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,10 @@ def execute_and_notify_query_listener(
434434
notify_kwargs = {}
435435
if DATAFRAME_AST_PARAMETER in kwargs and is_ast_enabled():
436436
notify_kwargs["dataframeAst"] = kwargs[DATAFRAME_AST_PARAMETER]
437-
437+
if "_statement_params" in kwargs and kwargs["_statement_params"]:
438+
statement_params = kwargs["_statement_params"]
439+
if "_PLAN_UUID" in statement_params:
440+
notify_kwargs["dataframe_uuid"] = statement_params["_PLAN_UUID"]
438441
try:
439442
results_cursor = self._cursor.execute(query, **kwargs)
440443
except Exception as ex:
@@ -456,14 +459,23 @@ def execute_and_notify_query_listener(
456459
def execute_async_and_notify_query_listener(
457460
self, query: str, **kwargs: Any
458461
) -> Dict[str, Any]:
462+
notify_kwargs = {}
463+
464+
if "_statement_params" in kwargs and kwargs["_statement_params"]:
465+
statement_params = kwargs["_statement_params"]
466+
if "_PLAN_UUID" in statement_params:
467+
notify_kwargs["dataframe_uuid"] = statement_params["_PLAN_UUID"]
468+
459469
try:
460470
results_cursor = self._cursor.execute_async(query, **kwargs)
461471
except Error as err:
462472
self.notify_query_listeners(
463-
QueryRecord(err.sfqid, err.query), is_error=True
473+
QueryRecord(err.sfqid, err.query), is_error=True, **notify_kwargs
464474
)
465475
raise err
466-
self.notify_query_listeners(QueryRecord(results_cursor["queryId"], query))
476+
self.notify_query_listeners(
477+
QueryRecord(results_cursor["queryId"], query), **notify_kwargs
478+
)
467479
return results_cursor
468480

469481
def execute_and_get_sfqid(

src/snowflake/snowpark/_internal/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,16 +2029,19 @@ def find_interval_containing_line(intervals, line_number):
20292029
return -1
20302030

20312031
# traverse the plan tree to find the plan that contains the line number
2032-
stack = [(plan_node, line_number)]
2032+
stack = [(plan_node, line_number, None)]
20332033
while stack:
2034-
node, line_number = stack.pop()
2034+
node, line_number, df_ast_ids = stack.pop()
20352035
if isinstance(node, Selectable):
20362036
node = node.get_snowflake_plan(skip_schema_query=False)
2037+
if node.df_ast_ids is not None:
2038+
df_ast_ids = node.df_ast_ids
20372039
query_line_intervals = node.queries[-1].query_line_intervals
20382040
idx = find_interval_containing_line(query_line_intervals, line_number)
20392041
if idx >= 0:
20402042
uuid = query_line_intervals[idx].uuid
20412043
if node.uuid == uuid:
2044+
node.df_ast_ids = df_ast_ids
20422045
return node
20432046
else:
20442047
for child in node.children_plan_nodes:
@@ -2047,6 +2050,7 @@ def find_interval_containing_line(intervals, line_number):
20472050
(
20482051
child,
20492052
line_number - query_line_intervals[idx].start,
2053+
df_ast_ids,
20502054
)
20512055
)
20522056
break

0 commit comments

Comments
 (0)