Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

### Snowpark Python API Updates

#### New Features

- Added debuggability improvements to eagerly validate dataframe schema metadata. Enable it using `snowflake.snowpark.context.configure_development_features()`.

#### Improvements

- Added support for row validation using XSD schema using `rowValidationXSDPath` option when reading XML files with a row tag using `rowTag` option.
Expand Down
3 changes: 2 additions & 1 deletion src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
from snowflake.connector.options import MissingOptionalDependency, ModuleLikeObject
from snowflake.connector.version import VERSION as connector_version
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark.context import _should_use_structured_type_semantics
from snowflake.snowpark.row import Row
from snowflake.snowpark.version import VERSION as snowpark_version

Expand Down Expand Up @@ -756,6 +755,8 @@ def _parse_result_meta(
an expected format. For example StructType columns are returned as dict objects, but are better
represented as Row objects.
"""
from snowflake.snowpark.context import _should_use_structured_type_semantics

if not result_meta:
return None, None
col_names = []
Expand Down
7 changes: 7 additions & 0 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import threading

_logger = logging.getLogger(__name__)

_use_scoped_temp_objects = True

# This is an internal-only global flag, used to determine whether to execute code in a client's local sandbox or connect to a Snowflake account.
Expand All @@ -33,11 +34,13 @@

# Following are internal-only global flags, used to enable development features.
_enable_dataframe_trace_on_error = False
_debug_eager_schema_validation = False


def configure_development_features(
*,
enable_dataframe_trace_on_error: bool = True,
enable_eager_schema_validation: bool = True,
) -> None:
"""
Configure development features for the session.
Expand All @@ -46,6 +49,8 @@ def configure_development_features(
enable_dataframe_trace_on_error: If True, upon failure, we will add most recent dataframe
operations to the error trace. This requires AST collection to be enabled in the
session which can be done using `session.ast_enabled = True`.
enable_eager_schema_validation: If True, dataframe schemas are eagerly validated by querying
for column metadata after every dataframe operation. This adds additional query overhead.

Note:
This feature is experimental since 1.33.0. Do not use it in production.
Expand All @@ -55,6 +60,8 @@ def configure_development_features(
)
global _enable_dataframe_trace_on_error
_enable_dataframe_trace_on_error = enable_dataframe_trace_on_error
global _debug_eager_schema_validation
_debug_eager_schema_validation = enable_eager_schema_validation


def _should_use_structured_type_semantics():
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

import snowflake.snowpark
import snowflake.snowpark.context as context
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
from snowflake.connector.options import installed_pandas, pandas, pyarrow

Expand Down Expand Up @@ -645,6 +646,11 @@ def __init__(

self._alias: Optional[str] = None

if context._debug_eager_schema_validation:
# Getting the plan attributes may run a describe query
# and popilates the schema for the dataframe.
self._plan.attributes

def _set_ast_ref(self, dataframe_expr_builder: Any) -> None:
"""
Given a field builder expression of the AST type Expr, points the builder to reference this dataframe.
Expand Down
102 changes: 102 additions & 0 deletions tests/integ/test_debug_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#!/usr/bin/env python3
#
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#

import pytest
import snowflake.snowpark.context as context
from copy import copy
from unittest.mock import patch, Mock


from snowflake.snowpark.functions import col, lit, max


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="debug_mode not used in local testing mode",
)
@pytest.mark.parametrize("debug_mode", [True, False])
@pytest.mark.parametrize(
"transform",
[
pytest.param(lambda x: copy(x), id="copy"),
pytest.param(lambda x: x.to_df(["C", "D"]), id="to_df"),
pytest.param(lambda x: x.distinct(), id="distinct"),
pytest.param(lambda x: x.drop_duplicates(), id="drop_duplicates"),
pytest.param(lambda x: x.limit(1), id="limit"),
pytest.param(lambda x: x.union(x), id="union"),
pytest.param(lambda x: x.union_all(x), id="union_all"),
pytest.param(lambda x: x.union_by_name(x), id="union_by_name"),
pytest.param(lambda x: x.union_all_by_name(x), id="union_all_by_name"),
pytest.param(lambda x: x.intersect(x), id="intersect"),
pytest.param(lambda x: x.natural_join(x), id="natural_join"),
pytest.param(lambda x: x.cross_join(x), id="cross_join"),
pytest.param(lambda x: x.sample(n=1), id="sample"),
pytest.param(
lambda x: x.with_column_renamed(col("A"), "B"), id="with_column_renamed"
),
# Unpivot already validates names
pytest.param(lambda x: x.unpivot("x", "y", ["A"]), id="unpivot"),
# The following functions do not error early because their schema_query do not contain
# information about the transformation being called.
pytest.param(lambda x: x.drop(col("A")), id="drop"),
pytest.param(lambda x: x.filter(col("A") == lit(1)), id="filter"),
pytest.param(lambda x: x.sort(col("A").desc()), id="sort"),
],
)
def test_early_attributes(session, transform, debug_mode):
with patch.object(context, "_debug_eager_schema_validation", debug_mode):
df = session.create_dataframe([(1, "A"), (2, "B"), (3, "C")], ["A", "B"])

transformed = transform(df)

# When debug mode is enabled the dataframe plan attributes are populated early
if debug_mode:
assert transformed._plan._metadata.attributes is not None
else:
assert transformed._plan._metadata.attributes is None


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="debug_mode not used in local testing mode",
)
@pytest.mark.parametrize("debug_mode", [True, False])
@pytest.mark.parametrize(
"transform",
[
pytest.param(lambda x: x.select("B"), id="select"),
pytest.param(lambda x: x.select_expr("cast(b as str)"), id="select_expr"),
pytest.param(lambda x: x.agg(max("B")), id="agg"),
pytest.param(lambda x: x.join(copy(x), on=(col("A") == col("B"))), id="join"),
pytest.param(
lambda x: x.join_table_function("flatten", col("B")),
id="join_table_function",
),
pytest.param(lambda x: x.with_column("C", col("B")), id="with_column"),
pytest.param(lambda x: x.with_columns(["C"], [col("B")]), id="with_columns"),
],
)
def test_early_error(session, transform, debug_mode):
with patch.object(context, "_debug_eager_schema_validation", debug_mode):
df = session.create_dataframe([1, 2, 3], ["A"])

show_mock = Mock()
show_mock.__qualname__ = "show"
show_mock.__name__ = "show"

with patch("snowflake.snowpark.dataframe.DataFrame.show", show_mock):
try:
transformed = transform(df)
transformed.show()
except Exception:
pass
# When debug mode is enabled the error is thrown before reaching show.
# Without debug mode the error only shows up once show is called.
if debug_mode:
show_mock.assert_not_called()
assert df._plan._metadata.attributes is not None
else:
show_mock.assert_called()
assert df._plan._metadata.attributes is None
Loading