2929)
3030
3131import snowflake .snowpark
32- from snowflake .snowpark ._internal .debug_utils import DataFrameTraceNode
32+ from snowflake .snowpark ._internal .debug_utils import DataFrameLineageNode
3333import snowflake .snowpark ._internal .proto .generated .ast_pb2 as proto
3434from snowflake .connector .options import installed_pandas , pandas , pyarrow
3535
@@ -351,37 +351,52 @@ def _disambiguate(
351351
352352
353353def _get_df_lineage (dataframes_involved : List ["DataFrame" ]) -> List [str ]:
354- curr : List [DataFrameTraceNode ] = []
354+ """Helper function to get the lineage of dataframes involved in the exception.
355+ It gathers the lineage in the following way:
356+
357+ 1. For each dataframe, it checks if it has an AST ID and if so, it creates a
358+ DataFrameLineageNode for it.
359+ 2. We use BFS to traverse the lineage using dataframes from 1. as the first layer.
360+ 3. During each iteration, we check if the node's source_id has been visited. If not,
361+ we add it to the visited set and append its source format to the trace. This step
362+ is needed to avoid source_id added multiple times in lineage due to loops.
363+ 4. We then explore the next layer by adding the children of the current node to the
364+ next layer. We check if the child ID has been visited and if not, we add it to the
365+ visited set and append the DataFrameLineageNode for it to the next layer.
366+ 5. We repeat this process until there are no more nodes to explore.
367+ """
368+ curr : List [DataFrameLineageNode ] = []
355369 visited_batch_id = set ()
356- visited_format_id = set ()
370+ visited_source_id = set ()
357371
358372 for df in dataframes_involved :
359373 if (batch_id := df ._ast_id ) is not None :
360374 stmt_cache = df ._session ._ast_batch ._bind_stmt_cache
361- curr .append (DataFrameTraceNode (batch_id , stmt_cache ))
375+ curr .append (DataFrameLineageNode (batch_id , stmt_cache ))
362376 if batch_id not in visited_batch_id :
363377 visited_batch_id .add (batch_id )
364378
365- trace = []
379+ lineage = []
366380
367381 while curr :
368- next = []
382+ next : List [ DataFrameLineageNode ] = []
369383 for node in curr :
384+ # tracing updates
385+ source_id = node .get_source_id ()
386+ if source_id not in visited_source_id :
387+ visited_source_id .add (source_id )
388+ lineage .append (node .get_source_snippet ())
389+
390+ # explore next layer
370391 for child_id in node .children :
371- if child_id in visited_format_id :
392+ if child_id in visited_batch_id :
372393 continue
373- visited_format_id .add (child_id )
374- next .append (DataFrameTraceNode (child_id , node .stmt_cache ))
375-
376- # tracing updates
377- format_id = node .get_format_id ()
378- if format_id not in visited_format_id :
379- visited_format_id .add (format_id )
380- trace .append (node .get_format_src ())
394+ visited_batch_id .add (child_id )
395+ next .append (DataFrameLineageNode (child_id , node .stmt_cache ))
381396
382397 curr = next
383398
384- return trace
399+ return lineage
385400
386401
387402def dataframe_exception_handler (func ):
@@ -402,7 +417,7 @@ def wrapper(*args, **kwargs):
402417 traceback_lines = traceback .format_exception (error_type , error_value , tb )
403418 formatted_traceback = "" .join (traceback_lines )
404419
405- # compute the trace
420+ # get the dataframe lineage
406421 dataframes_involved = []
407422 for arg in args :
408423 if isinstance (arg , DataFrame ):
@@ -428,7 +443,7 @@ def wrapper(*args, **kwargs):
428443 if lineage_trace_len > show_lineage_len :
429444 traceback_with_debug_info .append (
430445 f"... and { lineage_trace_len - show_lineage_len } more.\n You can increase "
431- "the trace length by setting SNOWPARK_PYTHON_DATAFRAME_LINEAGE_LENGTH_ON_ERROR "
446+ "the lineage length by setting SNOWPARK_PYTHON_DATAFRAME_LINEAGE_LENGTH_ON_ERROR "
432447 "environment variable."
433448 )
434449
0 commit comments