Skip to content

Commit c3f0374

Browse files
committed
refactor + add more comments
1 parent ee12e72 commit c3f0374

File tree

2 files changed

+41
-25
lines changed

2 files changed

+41
-25
lines changed

src/snowflake/snowpark/_internal/debug_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,20 @@
1111
UNKNOWN_FILE = "__UNKNOWN_FILE__"
1212

1313

14-
class DataFrameTraceNode:
15-
"""A node in the trace of a tree that represents the lineage of a DataFrame."""
14+
class DataFrameLineageNode:
15+
"""A node representing a dataframe operation in the DAG that represents the lineage of a DataFrame."""
1616

1717
def __init__(self, batch_id: int, stmt_cache) -> None:
1818
self.batch_id = batch_id
1919
self.stmt_cache = stmt_cache
2020

2121
@cached_property
2222
def children(self) -> set[int]:
23+
"""Returns the batch_ids of the children of this node."""
2324
return get_dependent_bind_ids(self.stmt_cache[self.batch_id])
2425

2526
def get_src(self):
26-
"""The source Stmt of the DataFrame descried by the batch_id."""
27+
"""The source Stmt of the DataFrame described by the batch_id."""
2728
stmt = self.stmt_cache[self.batch_id]
2829
api_call = stmt.bind.expr.WhichOneof("variant")
2930
return (
@@ -49,7 +50,7 @@ def _read_file(
4950
code_lines = [line.rstrip() for line in code_lines]
5051
return "\n".join(code_lines)
5152

52-
def get_format_id(self) -> str:
53+
def get_source_id(self) -> str:
5354
"""Unique identifier of the location of the DataFrame in the source code."""
5455
src = self.get_src()
5556
if src is None:
@@ -62,8 +63,8 @@ def get_format_id(self) -> str:
6263
end_column = src.end_column
6364
return f"{fileno}:{start_line}:{start_column}-{end_line}:{end_column}"
6465

65-
def get_format_src(self) -> str:
66-
"""The snippet of the source code where the DataFrame was created."""
66+
def get_source_snippet(self) -> str:
67+
"""Read the source file and extract the snippet where the dataframe is created."""
6768
src = self.get_src()
6869
if src is None:
6970
return "No source"
@@ -83,7 +84,7 @@ def get_format_src(self) -> str:
8384
code_identifier = f"{filename}:{start_line}"
8485
else:
8586
code_identifier = (
86-
f"{filename}:{start_line}:{start_column}-{end_line}:{end_column}"
87+
f"{filename}|{start_line}:{start_column}-{end_line}:{end_column}"
8788
)
8889

8990
if filename != UNKNOWN_FILE and os.access(filename, os.R_OK):

src/snowflake/snowpark/dataframe.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
)
3030

3131
import snowflake.snowpark
32-
from snowflake.snowpark._internal.debug_utils import DataFrameTraceNode
32+
from snowflake.snowpark._internal.debug_utils import DataFrameLineageNode
3333
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
3434
from snowflake.connector.options import installed_pandas, pandas, pyarrow
3535

@@ -351,37 +351,52 @@ def _disambiguate(
351351

352352

353353
def _get_df_lineage(dataframes_involved: List["DataFrame"]) -> List[str]:
354-
curr: List[DataFrameTraceNode] = []
354+
"""Helper function to get the lineage of dataframes involved in the exception.
355+
It gathers the lineage in the following way:
356+
357+
1. For each dataframe, it checks if it has an AST ID and if so, it creates a
358+
DataFrameLineageNode for it.
359+
2. We use BFS to traverse the lineage using dataframes from 1. as the first layer.
360+
3. During each iteration, we check if the node's source_id has been visited. If not,
361+
we add it to the visited set and append its source format to the trace. This step
362+
is needed to avoid source_id added multiple times in lineage due to loops.
363+
4. We then explore the next layer by adding the children of the current node to the
364+
next layer. We check if the child ID has been visited and if not, we add it to the
365+
visited set and append the DataFrameLineageNode for it to the next layer.
366+
5. We repeat this process until there are no more nodes to explore.
367+
"""
368+
curr: List[DataFrameLineageNode] = []
355369
visited_batch_id = set()
356-
visited_format_id = set()
370+
visited_source_id = set()
357371

358372
for df in dataframes_involved:
359373
if (batch_id := df._ast_id) is not None:
360374
stmt_cache = df._session._ast_batch._bind_stmt_cache
361-
curr.append(DataFrameTraceNode(batch_id, stmt_cache))
375+
curr.append(DataFrameLineageNode(batch_id, stmt_cache))
362376
if batch_id not in visited_batch_id:
363377
visited_batch_id.add(batch_id)
364378

365-
trace = []
379+
lineage = []
366380

367381
while curr:
368-
next = []
382+
next: List[DataFrameLineageNode] = []
369383
for node in curr:
384+
# tracing updates
385+
source_id = node.get_source_id()
386+
if source_id not in visited_source_id:
387+
visited_source_id.add(source_id)
388+
lineage.append(node.get_source_snippet())
389+
390+
# explore next layer
370391
for child_id in node.children:
371-
if child_id in visited_format_id:
392+
if child_id in visited_batch_id:
372393
continue
373-
visited_format_id.add(child_id)
374-
next.append(DataFrameTraceNode(child_id, node.stmt_cache))
375-
376-
# tracing updates
377-
format_id = node.get_format_id()
378-
if format_id not in visited_format_id:
379-
visited_format_id.add(format_id)
380-
trace.append(node.get_format_src())
394+
visited_batch_id.add(child_id)
395+
next.append(DataFrameLineageNode(child_id, node.stmt_cache))
381396

382397
curr = next
383398

384-
return trace
399+
return lineage
385400

386401

387402
def dataframe_exception_handler(func):
@@ -402,7 +417,7 @@ def wrapper(*args, **kwargs):
402417
traceback_lines = traceback.format_exception(error_type, error_value, tb)
403418
formatted_traceback = "".join(traceback_lines)
404419

405-
# compute the trace
420+
# get the dataframe lineage
406421
dataframes_involved = []
407422
for arg in args:
408423
if isinstance(arg, DataFrame):
@@ -428,7 +443,7 @@ def wrapper(*args, **kwargs):
428443
if lineage_trace_len > show_lineage_len:
429444
traceback_with_debug_info.append(
430445
f"... and {lineage_trace_len - show_lineage_len} more.\nYou can increase "
431-
"the trace length by setting SNOWPARK_PYTHON_DATAFRAME_LINEAGE_LENGTH_ON_ERROR "
446+
"the lineage length by setting SNOWPARK_PYTHON_DATAFRAME_LINEAGE_LENGTH_ON_ERROR "
432447
"environment variable."
433448
)
434449

0 commit comments

Comments
 (0)