diff --git a/CHANGELOG.md b/CHANGELOG.md index e334790e33..9cb9ac9d06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ ### Snowpark Python API Updates +#### New Features + +- Added debuggability improvements to eagerly validate dataframe schema metadata. Enable it using `snowflake.snowpark.context.configure_development_features()`. + #### Improvements - Added support for row validation using XSD schema using `rowValidationXSDPath` option when reading XML files with a row tag using `rowTag` option. diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index c1f3f9f8c2..1476d810fe 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -55,7 +55,6 @@ from snowflake.connector.options import MissingOptionalDependency, ModuleLikeObject from snowflake.connector.version import VERSION as connector_version from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages -from snowflake.snowpark.context import _should_use_structured_type_semantics from snowflake.snowpark.row import Row from snowflake.snowpark.version import VERSION as snowpark_version @@ -760,6 +759,8 @@ def _parse_result_meta( an expected format. For example StructType columns are returned as dict objects, but are better represented as Row objects. """ + from snowflake.snowpark.context import _should_use_structured_type_semantics + if not result_meta: return None, None col_names = [] diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index 755b5fea5c..f0f94b760f 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -11,6 +11,7 @@ import threading _logger = logging.getLogger(__name__) + _use_scoped_temp_objects = True # This is an internal-only global flag, used to determine whether to execute code in a client's local sandbox or connect to a Snowflake account. @@ -33,6 +34,7 @@ # Following are internal-only global flags, used to enable development features. _enable_dataframe_trace_on_error = False +_debug_eager_schema_validation = False # This is an internal-only global flag, used to determine whether to enable query line tracking for tracing sql compilation errors. _enable_trace_sql_errors_to_dataframe = False @@ -41,6 +43,7 @@ def configure_development_features( *, enable_dataframe_trace_on_error: bool = True, + enable_eager_schema_validation: bool = True, enable_trace_sql_errors_to_dataframe: bool = True, ) -> None: """ @@ -50,6 +53,8 @@ def configure_development_features( enable_dataframe_trace_on_error: If True, upon failure, we will add most recent dataframe operations to the error trace. This requires AST collection to be enabled in the session which can be done using `session.ast_enabled = True`. + enable_eager_schema_validation: If True, dataframe schemas are eagerly validated by querying + for column metadata after every dataframe operation. This adds additional query overhead. enable_trace_sql_errors_to_dataframe: If True, we will enable query line tracking. Note: This feature is experimental since 1.33.0. Do not use it in production. @@ -58,7 +63,9 @@ def configure_development_features( "configure_development_features() is experimental since 1.33.0. Do not use it in production.", ) global _enable_dataframe_trace_on_error, _enable_trace_sql_errors_to_dataframe + global _debug_eager_schema_validation _enable_dataframe_trace_on_error = enable_dataframe_trace_on_error + _debug_eager_schema_validation = enable_eager_schema_validation _enable_trace_sql_errors_to_dataframe = enable_trace_sql_errors_to_dataframe diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index da1ddd2573..aa71ce37ea 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -27,6 +27,7 @@ ) import snowflake.snowpark +import snowflake.snowpark.context as context import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto from snowflake.connector.options import installed_pandas, pandas, pyarrow @@ -645,6 +646,11 @@ def __init__( self._alias: Optional[str] = None + if context._debug_eager_schema_validation: + # Getting the plan attributes may run a describe query + # and populates the schema for the dataframe. + self._plan.attributes + def _set_ast_ref(self, dataframe_expr_builder: Any) -> None: """ Given a field builder expression of the AST type Expr, points the builder to reference this dataframe. diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index e3be1e53dc..3ee15559c3 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -285,7 +285,9 @@ def session( session.ast_enabled = ast_enabled if not session._generate_multiline_queries: session._enable_multiline_queries() - context.configure_development_features(enable_trace_sql_errors_to_dataframe=True) + context.configure_development_features( + enable_trace_sql_errors_to_dataframe=True, enable_eager_schema_validation=False + ) if (RUNNING_ON_GH or RUNNING_ON_JENKINS) and not local_testing_mode: set_up_external_access_integration_resources( diff --git a/tests/integ/test_df_debug_trace.py b/tests/integ/test_df_debug_trace.py index 7235376dbf..d8150ca05d 100644 --- a/tests/integ/test_df_debug_trace.py +++ b/tests/integ/test_df_debug_trace.py @@ -30,12 +30,16 @@ @pytest.fixture(autouse=True) def setup(request, session): original = session.ast_enabled - context.configure_development_features(enable_dataframe_trace_on_error=True) + context.configure_development_features( + enable_dataframe_trace_on_error=True, enable_eager_schema_validation=False + ) set_ast_state(AstFlagSource.TEST, True) if SNOWPARK_PYTHON_DATAFRAME_TRANSFORM_TRACE_LENGTH in os.environ: del os.environ[SNOWPARK_PYTHON_DATAFRAME_TRANSFORM_TRACE_LENGTH] yield - context.configure_development_features(enable_dataframe_trace_on_error=False) + context.configure_development_features( + enable_dataframe_trace_on_error=False, enable_eager_schema_validation=False + ) set_ast_state(AstFlagSource.TEST, original) if SNOWPARK_PYTHON_DATAFRAME_TRANSFORM_TRACE_LENGTH in os.environ: del os.environ[SNOWPARK_PYTHON_DATAFRAME_TRANSFORM_TRACE_LENGTH] diff --git a/tests/integ/test_eager_schema_validation.py b/tests/integ/test_eager_schema_validation.py new file mode 100644 index 0000000000..e93e49439c --- /dev/null +++ b/tests/integ/test_eager_schema_validation.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +import pytest +import snowflake.snowpark.context as context +from copy import copy +from unittest.mock import patch, Mock + + +from snowflake.snowpark.functions import col, lit, max + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="debug_mode not used in local testing mode", +) +@pytest.mark.parametrize("debug_mode", [True, False]) +@pytest.mark.parametrize( + "transform", + [ + pytest.param(lambda x: copy(x), id="copy"), + pytest.param(lambda x: x.to_df(["C", "D"]), id="to_df"), + pytest.param(lambda x: x.distinct(), id="distinct"), + pytest.param(lambda x: x.drop_duplicates(), id="drop_duplicates"), + pytest.param(lambda x: x.limit(1), id="limit"), + pytest.param(lambda x: x.union(x), id="union"), + pytest.param(lambda x: x.union_all(x), id="union_all"), + pytest.param(lambda x: x.union_by_name(x), id="union_by_name"), + pytest.param(lambda x: x.union_all_by_name(x), id="union_all_by_name"), + pytest.param(lambda x: x.intersect(x), id="intersect"), + pytest.param(lambda x: x.natural_join(x), id="natural_join"), + pytest.param(lambda x: x.cross_join(x), id="cross_join"), + pytest.param(lambda x: x.sample(n=1), id="sample"), + pytest.param( + lambda x: x.with_column_renamed(col("A"), "B"), id="with_column_renamed" + ), + # Unpivot already validates names + pytest.param(lambda x: x.unpivot("x", "y", ["A"]), id="unpivot"), + # The following functions do not error early because their schema_query do not contain + # information about the transformation being called. + pytest.param(lambda x: x.drop(col("A")), id="drop"), + pytest.param(lambda x: x.filter(col("A") == lit(1)), id="filter"), + pytest.param(lambda x: x.sort(col("A").desc()), id="sort"), + ], +) +def test_early_attributes(session, transform, debug_mode): + with patch.object(context, "_debug_eager_schema_validation", debug_mode): + df = session.create_dataframe([(1, "A"), (2, "B"), (3, "C")], ["A", "B"]) + + transformed = transform(df) + + # When debug mode is enabled the dataframe plan attributes are populated early + if debug_mode: + assert transformed._plan._metadata.attributes is not None + else: + assert transformed._plan._metadata.attributes is None + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="debug_mode not used in local testing mode", +) +@pytest.mark.parametrize("debug_mode", [True, False]) +@pytest.mark.parametrize( + "transform", + [ + pytest.param(lambda x: x.select("B"), id="select"), + pytest.param(lambda x: x.select_expr("cast(b as str)"), id="select_expr"), + pytest.param(lambda x: x.agg(max("B")), id="agg"), + pytest.param(lambda x: x.join(copy(x), on=(col("A") == col("B"))), id="join"), + pytest.param( + lambda x: x.join_table_function("flatten", col("B")), + id="join_table_function", + ), + pytest.param(lambda x: x.with_column("C", col("B")), id="with_column"), + pytest.param(lambda x: x.with_columns(["C"], [col("B")]), id="with_columns"), + ], +) +def test_early_error(session, transform, debug_mode): + with patch.object(context, "_debug_eager_schema_validation", debug_mode): + df = session.create_dataframe([1, 2, 3], ["A"]) + + show_mock = Mock() + show_mock.__qualname__ = "show" + show_mock.__name__ = "show" + + with patch("snowflake.snowpark.dataframe.DataFrame.show", show_mock): + try: + transformed = transform(df) + transformed.show() + except Exception: + pass + # When debug mode is enabled the error is thrown before reaching show. + # Without debug mode the error only shows up once show is called. + if debug_mode: + show_mock.assert_not_called() + assert df._plan._metadata.attributes is not None + else: + show_mock.assert_called() + assert df._plan._metadata.attributes is None diff --git a/tests/unit/test_selectable_queries.py b/tests/unit/test_selectable_queries.py index 8d64992503..d66843f8d7 100644 --- a/tests/unit/test_selectable_queries.py +++ b/tests/unit/test_selectable_queries.py @@ -25,9 +25,13 @@ @pytest.fixture(autouse=True) def setup(request): - context.configure_development_features(enable_trace_sql_errors_to_dataframe=True) + context.configure_development_features( + enable_trace_sql_errors_to_dataframe=True, enable_eager_schema_validation=False + ) yield - context.configure_development_features(enable_trace_sql_errors_to_dataframe=False) + context.configure_development_features( + enable_trace_sql_errors_to_dataframe=False, enable_eager_schema_validation=False + ) def test_select_statement_sql_query(mock_session, mock_analyzer):