diff --git a/CHANGELOG.md b/CHANGELOG.md index d0898b6603..eedfd7d94f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ #### Bug Fixes #### Improvements +- `snowflake.snowpark.context.configure_development_features` is effective for multiple sessions including newly created sessions after the configuration. No duplicate experimental warning any more. ### Snowpark pandas API Updates diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index 42d3452b28..a9e585d74b 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -77,21 +77,19 @@ def configure_development_features( _debug_eager_schema_validation = enable_eager_schema_validation if enable_dataframe_trace_on_error or enable_trace_sql_errors_to_dataframe: - try: - session = get_active_session() - if session is None: + _enable_dataframe_trace_on_error = enable_dataframe_trace_on_error + _enable_trace_sql_errors_to_dataframe = enable_trace_sql_errors_to_dataframe + with snowflake.snowpark.session._session_management_lock: + sessions = snowflake.snowpark.session._get_active_sessions( + require_at_least_one=False + ) + try: + for active_session in sessions: + active_session._set_ast_enabled_internal(True) + except Exception as e: # pragma: no cover _logger.warning( - "No active session found. Please create a session first and call " - "`configure_development_features()` after creating the session.", + f"Cannot enable AST collection in the session due to {str(e)}. Some development features may not work as expected.", ) - return - _enable_dataframe_trace_on_error = enable_dataframe_trace_on_error - _enable_trace_sql_errors_to_dataframe = enable_trace_sql_errors_to_dataframe - session.ast_enabled = True - except Exception as e: - _logger.warning( - f"Cannot enable AST collection in the session due to {str(e)}. Some development features may not work as expected.", - ) else: _enable_dataframe_trace_on_error = False _enable_trace_sql_errors_to_dataframe = False diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index b8906aae61..844b378de9 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -332,14 +332,16 @@ def _get_active_session() -> "Session": raise SnowparkClientExceptionMessages.SERVER_NO_DEFAULT_SESSION() -def _get_active_sessions() -> Set["Session"]: +def _get_active_sessions(require_at_least_one: bool = True) -> Set["Session"]: with _session_management_lock: if len(_active_sessions) >= 1: # TODO: This function is allowing unsafe access to a mutex protected data # structure, we should ONLY use it in tests return _active_sessions else: - raise SnowparkClientExceptionMessages.SERVER_NO_DEFAULT_SESSION() + if require_at_least_one: + raise SnowparkClientExceptionMessages.SERVER_NO_DEFAULT_SESSION() + return set() def _add_session(session: "Session") -> None: @@ -736,6 +738,16 @@ def __init__( ast_enabled = False set_ast_state(AstFlagSource.SERVER, ast_enabled) + + # development features require AST to be enabled + from snowflake.snowpark.context import ( + _enable_trace_sql_errors_to_dataframe, + _enable_dataframe_trace_on_error, + ) + + if _enable_trace_sql_errors_to_dataframe or _enable_dataframe_trace_on_error: + self._set_ast_enabled_internal(True) + # The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT # in Snowflake. This is the limit where we start seeing compilation errors. self._large_query_breakdown_complexity_bounds: Tuple[int, int] = ( @@ -973,9 +985,7 @@ def ast_enabled(self) -> bool: """ return is_ast_enabled() - @ast_enabled.setter - @experimental_parameter(version="1.33.0") - def ast_enabled(self, value: bool) -> None: + def _set_ast_enabled_internal(self, value: bool) -> None: # TODO: we could send here explicit telemetry if a user changes the behavior. # In addition, we could introduce a server-side parameter to enable AST capture or not. # self._conn._telemetry_client.send_ast_enabled_telemetry( @@ -998,6 +1008,11 @@ def ast_enabled(self, value: bool) -> None: self._auto_clean_up_temp_table_enabled = False set_ast_state(AstFlagSource.USER, value) + @ast_enabled.setter + @experimental_parameter(version="1.33.0") + def ast_enabled(self, value: bool) -> None: + self._set_ast_enabled_internal(value) + @property def cte_optimization_enabled(self) -> bool: """Set to ``True`` to enable the CTE optimization (defaults to ``False``). diff --git a/tests/integ/test_context.py b/tests/integ/test_context.py index 914923dec9..f335b5c6a2 100644 --- a/tests/integ/test_context.py +++ b/tests/integ/test_context.py @@ -5,6 +5,7 @@ from snowflake.snowpark.context import get_active_session import snowflake.snowpark.context as context +import snowflake.snowpark.session as session from unittest import mock @@ -14,35 +15,32 @@ def test_get_active_session(session): def test_context_configure_development_features(): try: - # Test when get_active_session() returns None - with mock.patch.object(context, "get_active_session", return_value=None): + # Test when _get_active_sessions() returns None + with mock.patch.object(session, "_get_active_sessions", return_value=set()): context.configure_development_features( enable_trace_sql_errors_to_dataframe=True ) - assert context._enable_trace_sql_errors_to_dataframe is False + assert context._enable_trace_sql_errors_to_dataframe is True assert context._enable_dataframe_trace_on_error is False + assert context._debug_eager_schema_validation is False - # Test when get_active_session() throws an exception - with mock.patch.object( - context, "get_active_session", side_effect=RuntimeError("test") - ): - context.configure_development_features( - enable_trace_sql_errors_to_dataframe=True - ) - assert context._enable_trace_sql_errors_to_dataframe is False - assert context._enable_dataframe_trace_on_error is False + # Test when _get_active_sessions() returns a valid session + mock_session1 = mock.MagicMock() + mock_session1._set_ast_enabled_internal = mock.MagicMock() + mock_session2 = mock.MagicMock() + mock_session2._set_ast_enabled_internal = mock.MagicMock() - # Test when get_active_session() returns a valid session - mock_session = mock.MagicMock() with mock.patch.object( - context, "get_active_session", return_value=mock_session + session, "_get_active_sessions", return_value=[mock_session1, mock_session2] ): context.configure_development_features( enable_trace_sql_errors_to_dataframe=True ) assert context._enable_trace_sql_errors_to_dataframe is True assert context._enable_dataframe_trace_on_error is False - assert mock_session.ast_enabled is True + mock_session1._set_ast_enabled_internal.assert_called_once_with(True) + mock_session2._set_ast_enabled_internal.assert_called_once_with(True) + finally: context.configure_development_features( enable_trace_sql_errors_to_dataframe=False diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index 522fb342f9..12c600ef47 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -1069,3 +1069,24 @@ def test_transaction(session): session.sql(f"DROP TABLE IF EXISTS {temp_table_name}").collect() except Exception: pass + + +def test_session_eanble_development_features(db_parameters): + from snowflake.snowpark import context + + with patch.object( + context, "_enable_trace_sql_errors_to_dataframe", return_value=True + ): + with Session.builder.configs(db_parameters).create() as new_session: + assert new_session.ast_enabled is True + + with patch.object(context, "_enable_dataframe_trace_on_error", return_value=True): + with Session.builder.configs(db_parameters).create() as new_session: + assert new_session.ast_enabled is True + + +def test_get_active_sessions_empty(): + from snowflake.snowpark import session as session_module + + with patch.object(session_module, "_active_sessions", return_value=set()): + assert session_module._get_active_sessions(require_at_least_one=False) == set()