Skip to content
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#### Bug Fixes

#### Improvements
- `snowflake.snowpark.context.configure_development_features` is effective for multiple sessions including newly created sessions after the configuration. No duplicate experimental warning any more.

### Snowpark pandas API Updates

Expand Down
24 changes: 11 additions & 13 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,21 +77,19 @@ def configure_development_features(
_debug_eager_schema_validation = enable_eager_schema_validation

if enable_dataframe_trace_on_error or enable_trace_sql_errors_to_dataframe:
try:
session = get_active_session()
if session is None:
_enable_dataframe_trace_on_error = enable_dataframe_trace_on_error
_enable_trace_sql_errors_to_dataframe = enable_trace_sql_errors_to_dataframe
with snowflake.snowpark.session._session_management_lock:
sessions = snowflake.snowpark.session._get_active_sessions(
require_at_least_one=False
)
try:
for active_session in sessions:
active_session._set_ast_enabled_internal(True)
except Exception as e: # pragma: no cover
_logger.warning(
"No active session found. Please create a session first and call "
"`configure_development_features()` after creating the session.",
f"Cannot enable AST collection in the session due to {str(e)}. Some development features may not work as expected.",
)
return
_enable_dataframe_trace_on_error = enable_dataframe_trace_on_error
_enable_trace_sql_errors_to_dataframe = enable_trace_sql_errors_to_dataframe
session.ast_enabled = True
except Exception as e:
_logger.warning(
f"Cannot enable AST collection in the session due to {str(e)}. Some development features may not work as expected.",
)
else:
_enable_dataframe_trace_on_error = False
_enable_trace_sql_errors_to_dataframe = False
Expand Down
25 changes: 20 additions & 5 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,14 +332,16 @@ def _get_active_session() -> "Session":
raise SnowparkClientExceptionMessages.SERVER_NO_DEFAULT_SESSION()


def _get_active_sessions() -> Set["Session"]:
def _get_active_sessions(require_at_least_one: bool = True) -> Set["Session"]:
with _session_management_lock:
if len(_active_sessions) >= 1:
# TODO: This function is allowing unsafe access to a mutex protected data
# structure, we should ONLY use it in tests
return _active_sessions
else:
raise SnowparkClientExceptionMessages.SERVER_NO_DEFAULT_SESSION()
if require_at_least_one:
raise SnowparkClientExceptionMessages.SERVER_NO_DEFAULT_SESSION()
return set()


def _add_session(session: "Session") -> None:
Expand Down Expand Up @@ -736,6 +738,16 @@ def __init__(
ast_enabled = False

set_ast_state(AstFlagSource.SERVER, ast_enabled)

# development features require AST to be enabled
from snowflake.snowpark.context import (
_enable_trace_sql_errors_to_dataframe,
_enable_dataframe_trace_on_error,
)

if _enable_trace_sql_errors_to_dataframe or _enable_dataframe_trace_on_error:
self._set_ast_enabled_internal(True)

# The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT
# in Snowflake. This is the limit where we start seeing compilation errors.
self._large_query_breakdown_complexity_bounds: Tuple[int, int] = (
Expand Down Expand Up @@ -973,9 +985,7 @@ def ast_enabled(self) -> bool:
"""
return is_ast_enabled()

@ast_enabled.setter
@experimental_parameter(version="1.33.0")
def ast_enabled(self, value: bool) -> None:
def _set_ast_enabled_internal(self, value: bool) -> None:
# TODO: we could send here explicit telemetry if a user changes the behavior.
# In addition, we could introduce a server-side parameter to enable AST capture or not.
# self._conn._telemetry_client.send_ast_enabled_telemetry(
Expand All @@ -998,6 +1008,11 @@ def ast_enabled(self, value: bool) -> None:
self._auto_clean_up_temp_table_enabled = False
set_ast_state(AstFlagSource.USER, value)

@ast_enabled.setter
@experimental_parameter(version="1.33.0")
def ast_enabled(self, value: bool) -> None:
self._set_ast_enabled_internal(value)

@property
def cte_optimization_enabled(self) -> bool:
"""Set to ``True`` to enable the CTE optimization (defaults to ``False``).
Expand Down
30 changes: 14 additions & 16 deletions tests/integ/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from snowflake.snowpark.context import get_active_session
import snowflake.snowpark.context as context
import snowflake.snowpark.session as session
from unittest import mock


Expand All @@ -14,35 +15,32 @@ def test_get_active_session(session):

def test_context_configure_development_features():
try:
# Test when get_active_session() returns None
with mock.patch.object(context, "get_active_session", return_value=None):
# Test when _get_active_sessions() returns None
with mock.patch.object(session, "_get_active_sessions", return_value=set()):
context.configure_development_features(
enable_trace_sql_errors_to_dataframe=True
)
assert context._enable_trace_sql_errors_to_dataframe is False
assert context._enable_trace_sql_errors_to_dataframe is True
assert context._enable_dataframe_trace_on_error is False
assert context._debug_eager_schema_validation is False

# Test when get_active_session() throws an exception
with mock.patch.object(
context, "get_active_session", side_effect=RuntimeError("test")
):
context.configure_development_features(
enable_trace_sql_errors_to_dataframe=True
)
assert context._enable_trace_sql_errors_to_dataframe is False
assert context._enable_dataframe_trace_on_error is False
# Test when _get_active_sessions() returns a valid session
mock_session1 = mock.MagicMock()
mock_session1._set_ast_enabled_internal = mock.MagicMock()
mock_session2 = mock.MagicMock()
mock_session2._set_ast_enabled_internal = mock.MagicMock()

# Test when get_active_session() returns a valid session
mock_session = mock.MagicMock()
with mock.patch.object(
context, "get_active_session", return_value=mock_session
session, "_get_active_sessions", return_value=[mock_session1, mock_session2]
):
context.configure_development_features(
enable_trace_sql_errors_to_dataframe=True
)
assert context._enable_trace_sql_errors_to_dataframe is True
assert context._enable_dataframe_trace_on_error is False
assert mock_session.ast_enabled is True
mock_session1._set_ast_enabled_internal.assert_called_once_with(True)
mock_session2._set_ast_enabled_internal.assert_called_once_with(True)

finally:
context.configure_development_features(
enable_trace_sql_errors_to_dataframe=False
Expand Down
21 changes: 21 additions & 0 deletions tests/integ/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,3 +1069,24 @@ def test_transaction(session):
session.sql(f"DROP TABLE IF EXISTS {temp_table_name}").collect()
except Exception:
pass


def test_session_eanble_development_features(db_parameters):
from snowflake.snowpark import context

with patch.object(
context, "_enable_trace_sql_errors_to_dataframe", return_value=True
):
with Session.builder.configs(db_parameters).create() as new_session:
assert new_session.ast_enabled is True

with patch.object(context, "_enable_dataframe_trace_on_error", return_value=True):
with Session.builder.configs(db_parameters).create() as new_session:
assert new_session.ast_enabled is True


def test_get_active_sessions_empty():
from snowflake.snowpark import session as session_module

with patch.object(session_module, "_active_sessions", return_value=set()):
assert session_module._get_active_sessions(require_at_least_one=False) == set()
Loading