Skip to content

Commit b0ebd5e

Browse files
SNOW-2084165 Add dataframe operation lineage on SnowparkSQLException (#3339)
Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com>
1 parent 9f019d0 commit b0ebd5e

File tree

12 files changed

+760
-9
lines changed

12 files changed

+760
-9
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
- Added support for ignoring surrounding whitespace in the XML element using `ignoreSurroundingWhitespace` option.
3737
- Added support for parameter `return_dataframe` in `Session.call`, which can be used to set the return type of the functions to a `DataFrame` object.
3838
- Added a new argument to `Dataframe.describe` called `strings_include_math_stats` that triggers `stddev` and `mean` to be calculated for String columns.
39+
- Added debuggability improvements to show a trace of most recent dataframe transformations if an operation leads to a `SnowparkSQLException`. Enable it using `snowflake.snowpark.context.configure_development_features()`. This feature also depends on AST collection to be enabled in the session which can be done using `session.ast_enabled = True`.
3940
- Improved the error message for `Session.write_pandas()` and `Session.create_dataframe()` when the input pandas DataFrame does not have a column.
4041
- Added support for retrieving `Edge.properties` when retrieving lineage from `DGQL` in `DataFrame.lineage.trace`.
4142
- Added a parameter `table_exists` to `DataFrameWriter.save_as_table` that allows specifying if a table already exists. This allows skipping a table lookup that can be expensive.

src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#
55
import copy
66
import difflib
7+
from logging import getLogger
78
import re
89
import sys
910
import uuid
@@ -32,6 +33,9 @@
3233
TableFunctionRelation,
3334
TableFunctionJoin,
3435
)
36+
from snowflake.snowpark._internal.debug_utils import (
37+
get_df_transform_trace_message,
38+
)
3539

3640
if TYPE_CHECKING:
3741
from snowflake.snowpark._internal.analyzer.select_statement import (
@@ -134,6 +138,8 @@
134138
else:
135139
from collections.abc import Iterable
136140

141+
_logger = getLogger(__name__)
142+
137143

138144
class SnowflakePlan(LogicalPlan):
139145
class Decorator:
@@ -147,7 +153,15 @@ class Decorator:
147153

148154
@staticmethod
149155
def wrap_exception(func):
156+
"""This wrapper is used to wrap snowflake connector ProgrammingError into SnowparkSQLException.
157+
It also adds additional debug information to the raised exception when possible.
158+
"""
159+
150160
def wrap(*args, **kwargs):
161+
from snowflake.snowpark.context import (
162+
_enable_dataframe_trace_on_error,
163+
)
164+
151165
try:
152166
return func(*args, **kwargs)
153167
except snowflake.connector.errors.ProgrammingError as e:
@@ -158,9 +172,35 @@ def wrap(*args, **kwargs):
158172
query = getattr(e, "query", None)
159173
tb = sys.exc_info()[2]
160174
assert e.msg is not None
175+
176+
# extract df_ast_id, stmt_cache from args
177+
df_ast_id, stmt_cache = None, None
178+
for arg in args:
179+
if isinstance(arg, SnowflakePlan):
180+
df_ast_id = arg.df_ast_id
181+
stmt_cache = arg.session._ast_batch._bind_stmt_cache
182+
break
183+
df_transform_debug_trace = None
184+
try:
185+
if (
186+
_enable_dataframe_trace_on_error
187+
and df_ast_id is not None
188+
and stmt_cache is not None
189+
):
190+
df_transform_debug_trace = get_df_transform_trace_message(
191+
df_ast_id, stmt_cache
192+
)
193+
except Exception as trace_error:
194+
# If we encounter an error when getting the df_transform_debug_trace,
195+
# we will ignore the error and not add the debug trace to the error message.
196+
_logger.info(
197+
f"Error when getting the df_transform_debug_trace: {trace_error}"
198+
)
199+
pass
200+
161201
if "unexpected 'as'" in e.msg.lower():
162202
ne = SnowparkClientExceptionMessages.SQL_PYTHON_REPORT_UNEXPECTED_ALIAS(
163-
query
203+
query, debug_context=df_transform_debug_trace
164204
)
165205
raise ne.with_traceback(tb) from None
166206
elif e.sqlstate == "42000" and "invalid identifier" in e.msg:
@@ -171,7 +211,7 @@ def wrap(*args, **kwargs):
171211
)
172212
if not match: # pragma: no cover
173213
ne = SnowparkClientExceptionMessages.SQL_EXCEPTION_FROM_PROGRAMMING_ERROR(
174-
e
214+
e, debug_context=df_transform_debug_trace
175215
)
176216
raise ne.with_traceback(tb) from None
177217
col = match.group(1)
@@ -193,7 +233,9 @@ def wrap(*args, **kwargs):
193233
unaliased_cols[0] if unaliased_cols else "<colname>"
194234
)
195235
ne = SnowparkClientExceptionMessages.SQL_PYTHON_REPORT_INVALID_ID(
196-
orig_col_name, query
236+
orig_col_name,
237+
query,
238+
debug_context=df_transform_debug_trace,
197239
)
198240
raise ne.with_traceback(tb) from None
199241
elif (
@@ -210,7 +252,7 @@ def wrap(*args, **kwargs):
210252
> 1
211253
):
212254
ne = SnowparkClientExceptionMessages.SQL_PYTHON_REPORT_JOIN_AMBIGUOUS(
213-
col, col, query
255+
col, col, query, debug_context=df_transform_debug_trace
214256
)
215257
raise ne.with_traceback(tb) from None
216258
else:
@@ -220,7 +262,7 @@ def wrap(*args, **kwargs):
220262
)
221263
if not match: # pragma: no cover
222264
ne = SnowparkClientExceptionMessages.SQL_EXCEPTION_FROM_PROGRAMMING_ERROR(
223-
e
265+
e, debug_context=df_transform_debug_trace
224266
)
225267
raise ne.with_traceback(tb) from None
226268
col = match.group(1)
@@ -282,7 +324,7 @@ def add_single_quote(string: str) -> str:
282324

283325
e.msg = f"{e.msg}\n{msg}"
284326
ne = SnowparkClientExceptionMessages.SQL_EXCEPTION_FROM_PROGRAMMING_ERROR(
285-
e
327+
e, debug_context=df_transform_debug_trace
286328
)
287329
raise ne.with_traceback(tb) from None
288330
elif e.sqlstate == "42601" and "SELECT with no columns" in e.msg:
@@ -329,7 +371,7 @@ def search_read_file_node(
329371
raise ne.with_traceback(tb) from None
330372

331373
ne = SnowparkClientExceptionMessages.SQL_EXCEPTION_FROM_PROGRAMMING_ERROR(
332-
e
374+
e, debug_context=df_transform_debug_trace
333375
)
334376
raise ne.with_traceback(tb) from None
335377

src/snowflake/snowpark/_internal/ast/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -742,8 +742,7 @@ def with_src_position(
742742
# Once we've stepped out of the snowpark package, we should be in the code of interest.
743743
# However, the code of interest may execute in an environment that is not accessible via the filesystem.
744744
# e.g. Jupyter notebooks, REPLs, calls to exec, etc.
745-
filename = frame.f_code.co_filename if frame is not None else ""
746-
if frame is None or not Path(filename).is_file():
745+
if frame is None:
747746
src.file = __intern_string("")
748747
return expr_ast
749748

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
#
2+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
from functools import cached_property
6+
import os
7+
import sys
8+
from typing import Dict, List, Optional
9+
import itertools
10+
11+
from snowflake.snowpark._internal.ast.batch import get_dependent_bind_ids
12+
from snowflake.snowpark._internal.ast.utils import __STRING_INTERNING_MAP__
13+
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
14+
15+
UNKNOWN_FILE = "__UNKNOWN_FILE__"
16+
SNOWPARK_PYTHON_DATAFRAME_TRANSFORM_TRACE_LENGTH = (
17+
"SNOWPARK_PYTHON_DATAFRAME_TRANSFORM_TRACE_LENGTH"
18+
)
19+
20+
21+
class DataFrameTraceNode:
22+
"""A node representing a dataframe operation in the DAG that represents the lineage of a DataFrame."""
23+
24+
def __init__(self, batch_id: int, stmt_cache: Dict[int, proto.Stmt]) -> None:
25+
self.batch_id = batch_id
26+
self.stmt_cache = stmt_cache
27+
28+
@cached_property
29+
def children(self) -> set[int]:
30+
"""Returns the batch_ids of the children of this node."""
31+
return get_dependent_bind_ids(self.stmt_cache[self.batch_id])
32+
33+
def get_src(self) -> Optional[proto.SrcPosition]:
34+
"""The source Stmt of the DataFrame described by the batch_id."""
35+
stmt = self.stmt_cache[self.batch_id]
36+
api_call = stmt.bind.expr.WhichOneof("variant")
37+
return (
38+
getattr(stmt.bind.expr, api_call).src
39+
if api_call and getattr(stmt.bind.expr, api_call).HasField("src")
40+
else None
41+
)
42+
43+
def _read_file(
44+
self, filename, start_line, end_line, start_column, end_column
45+
) -> str:
46+
"""Read the relevant code snippets of where the DataFrame was created. The filename given here
47+
must have read permissions for the executing user."""
48+
with open(filename) as f:
49+
code_lines = []
50+
if sys.version_info >= (3, 11):
51+
# Skip to start_line and read only the required lines
52+
lines = itertools.islice(f, start_line - 1, end_line)
53+
code_lines = list(lines)
54+
if start_line == end_line:
55+
code_lines[0] = code_lines[0][start_column:end_column]
56+
else:
57+
code_lines[0] = code_lines[0][start_column:]
58+
code_lines[-1] = code_lines[-1][:end_column]
59+
else:
60+
# For python 3.9/3.10, we do not extract the end line from the source code
61+
# so we just read the start line and return.
62+
for line in itertools.islice(f, start_line - 1, start_line):
63+
code_lines.append(line)
64+
65+
code_lines = [line.rstrip() for line in code_lines]
66+
return "\n".join(code_lines)
67+
68+
@cached_property
69+
def source_id(self) -> str:
70+
"""Unique identifier of the location of the DataFrame creation in the source code."""
71+
src = self.get_src()
72+
if src is None: # pragma: no cover
73+
return ""
74+
75+
fileno = src.file
76+
start_line = src.start_line
77+
start_column = src.start_column
78+
end_line = src.end_line
79+
end_column = src.end_column
80+
return f"{fileno}:{start_line}:{start_column}-{end_line}:{end_column}"
81+
82+
def get_source_snippet(self) -> str:
83+
"""Read the source file and extract the snippet where the dataframe is created."""
84+
src = self.get_src()
85+
if src is None: # pragma: no cover
86+
return "No source"
87+
88+
# get the latest mapping of fileno to filename
89+
_fileno_to_filename_map = {v: k for k, v in __STRING_INTERNING_MAP__.items()}
90+
fileno = src.file
91+
filename = _fileno_to_filename_map.get(fileno, UNKNOWN_FILE)
92+
93+
start_line = src.start_line
94+
end_line = src.end_line
95+
start_column = src.start_column
96+
end_column = src.end_column
97+
98+
# Build the code identifier to find the operations where the DataFrame was created
99+
if sys.version_info >= (3, 11):
100+
code_identifier = (
101+
f"{filename}|{start_line}:{start_column}-{end_line}:{end_column}"
102+
)
103+
else:
104+
code_identifier = f"{filename}|{start_line}"
105+
106+
if filename != UNKNOWN_FILE and os.access(filename, os.R_OK):
107+
# If the file is readable, read the code snippet
108+
code = self._read_file(
109+
filename, start_line, end_line, start_column, end_column
110+
)
111+
return f"{code_identifier}: {code}"
112+
return code_identifier # pragma: no cover
113+
114+
115+
def _get_df_transform_trace(
116+
batch_id: int,
117+
stmt_cache: Dict[int, proto.Stmt],
118+
) -> List[DataFrameTraceNode]:
119+
"""Helper function to get the transform trace of the dataframe involved in the exception.
120+
It gathers the lineage in the following way:
121+
122+
1. Start by creating a DataFrameTraceNode for the given batch_id.
123+
2. We use BFS to traverse the lineage using the node created in 1. as the first layer.
124+
3. During each iteration, we check if the node's source_id has been visited. If not,
125+
we add it to the visited set and append its source format to the trace. This step
126+
is needed to avoid source_id added multiple times in lineage due to loops.
127+
4. We then explore the next layer by adding the children of the current node to the
128+
next layer. We check if the child ID has been visited and if not, we add it to the
129+
visited set and append the DataFrameTraceNode for it to the next layer.
130+
5. We repeat this process until there are no more nodes to explore.
131+
132+
Args:
133+
batch_id: The batch ID of the dataframe involved in the exception.
134+
stmt_cache: The statement cache of the session.
135+
136+
Returns:
137+
A list of DataFrameTraceNode objects representing the transform trace of the dataframe.
138+
"""
139+
visited_batch_id = set()
140+
visited_source_id = set()
141+
142+
visited_batch_id.add(batch_id)
143+
curr = [DataFrameTraceNode(batch_id, stmt_cache)]
144+
lineage = []
145+
146+
while curr:
147+
next: List[DataFrameTraceNode] = []
148+
for node in curr:
149+
# tracing updates
150+
source_id = node.source_id
151+
if source_id not in visited_source_id:
152+
visited_source_id.add(source_id)
153+
lineage.append(node)
154+
155+
# explore next layer
156+
for child_id in node.children:
157+
if child_id in visited_batch_id:
158+
continue
159+
visited_batch_id.add(child_id)
160+
next.append(DataFrameTraceNode(child_id, stmt_cache))
161+
162+
curr = next
163+
164+
return lineage
165+
166+
167+
def get_df_transform_trace_message(
168+
df_ast_id: int, stmt_cache: Dict[int, proto.Stmt]
169+
) -> Optional[str]:
170+
"""Get the transform trace message for the dataframe involved in the exception.
171+
172+
Args:
173+
df_ast_id: The AST ID of the dataframe involved in the exception.
174+
stmt_cache: The statement cache of the session.
175+
176+
Returns:
177+
A string representing the transform trace message.
178+
"""
179+
df_transform_trace_nodes = _get_df_transform_trace(df_ast_id, stmt_cache)
180+
if len(df_transform_trace_nodes) == 0: # pragma: no cover
181+
return None
182+
183+
df_transform_trace_length = len(df_transform_trace_nodes)
184+
show_trace_length = int(
185+
os.environ.get(SNOWPARK_PYTHON_DATAFRAME_TRANSFORM_TRACE_LENGTH, 5)
186+
)
187+
188+
debug_info_lines = [
189+
"\n\n--- Additional Debug Information ---\n",
190+
f"Trace of the most recent dataframe operations associated with the error (total {df_transform_trace_length}):\n",
191+
]
192+
for node in df_transform_trace_nodes[:show_trace_length]:
193+
debug_info_lines.append(node.get_source_snippet())
194+
if df_transform_trace_length > show_trace_length:
195+
debug_info_lines.append(
196+
f"... and {df_transform_trace_length - show_trace_length} more.\nYou can increase "
197+
f"the lineage length by setting {SNOWPARK_PYTHON_DATAFRAME_TRANSFORM_TRACE_LENGTH} "
198+
"environment variable."
199+
)
200+
return "\n".join(debug_info_lines)

0 commit comments

Comments
 (0)