Skip to content

Commit 1d449ab

Browse files
authored
SNOW-2317965: Refactor context_configure_development_features (#4041)
1 parent 1713ffe commit 1d449ab

File tree

5 files changed

+67
-34
lines changed

5 files changed

+67
-34
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#### Bug Fixes
2424

2525
#### Improvements
26+
- `snowflake.snowpark.context.configure_development_features` is effective for multiple sessions including newly created sessions after the configuration. No duplicate experimental warning any more.
2627

2728
### Snowpark pandas API Updates
2829

src/snowflake/snowpark/context.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,21 +77,19 @@ def configure_development_features(
7777
_debug_eager_schema_validation = enable_eager_schema_validation
7878

7979
if enable_dataframe_trace_on_error or enable_trace_sql_errors_to_dataframe:
80-
try:
81-
session = get_active_session()
82-
if session is None:
80+
_enable_dataframe_trace_on_error = enable_dataframe_trace_on_error
81+
_enable_trace_sql_errors_to_dataframe = enable_trace_sql_errors_to_dataframe
82+
with snowflake.snowpark.session._session_management_lock:
83+
sessions = snowflake.snowpark.session._get_active_sessions(
84+
require_at_least_one=False
85+
)
86+
try:
87+
for active_session in sessions:
88+
active_session._set_ast_enabled_internal(True)
89+
except Exception as e: # pragma: no cover
8390
_logger.warning(
84-
"No active session found. Please create a session first and call "
85-
"`configure_development_features()` after creating the session.",
91+
f"Cannot enable AST collection in the session due to {str(e)}. Some development features may not work as expected.",
8692
)
87-
return
88-
_enable_dataframe_trace_on_error = enable_dataframe_trace_on_error
89-
_enable_trace_sql_errors_to_dataframe = enable_trace_sql_errors_to_dataframe
90-
session.ast_enabled = True
91-
except Exception as e:
92-
_logger.warning(
93-
f"Cannot enable AST collection in the session due to {str(e)}. Some development features may not work as expected.",
94-
)
9593
else:
9694
_enable_dataframe_trace_on_error = False
9795
_enable_trace_sql_errors_to_dataframe = False

src/snowflake/snowpark/session.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -332,14 +332,16 @@ def _get_active_session() -> "Session":
332332
raise SnowparkClientExceptionMessages.SERVER_NO_DEFAULT_SESSION()
333333

334334

335-
def _get_active_sessions() -> Set["Session"]:
335+
def _get_active_sessions(require_at_least_one: bool = True) -> Set["Session"]:
336336
with _session_management_lock:
337337
if len(_active_sessions) >= 1:
338338
# TODO: This function is allowing unsafe access to a mutex protected data
339339
# structure, we should ONLY use it in tests
340340
return _active_sessions
341341
else:
342-
raise SnowparkClientExceptionMessages.SERVER_NO_DEFAULT_SESSION()
342+
if require_at_least_one:
343+
raise SnowparkClientExceptionMessages.SERVER_NO_DEFAULT_SESSION()
344+
return set()
343345

344346

345347
def _add_session(session: "Session") -> None:
@@ -736,6 +738,16 @@ def __init__(
736738
ast_enabled = False
737739

738740
set_ast_state(AstFlagSource.SERVER, ast_enabled)
741+
742+
# development features require AST to be enabled
743+
from snowflake.snowpark.context import (
744+
_enable_trace_sql_errors_to_dataframe,
745+
_enable_dataframe_trace_on_error,
746+
)
747+
748+
if _enable_trace_sql_errors_to_dataframe or _enable_dataframe_trace_on_error:
749+
self._set_ast_enabled_internal(True)
750+
739751
# The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT
740752
# in Snowflake. This is the limit where we start seeing compilation errors.
741753
self._large_query_breakdown_complexity_bounds: Tuple[int, int] = (
@@ -973,9 +985,7 @@ def ast_enabled(self) -> bool:
973985
"""
974986
return is_ast_enabled()
975987

976-
@ast_enabled.setter
977-
@experimental_parameter(version="1.33.0")
978-
def ast_enabled(self, value: bool) -> None:
988+
def _set_ast_enabled_internal(self, value: bool) -> None:
979989
# TODO: we could send here explicit telemetry if a user changes the behavior.
980990
# In addition, we could introduce a server-side parameter to enable AST capture or not.
981991
# self._conn._telemetry_client.send_ast_enabled_telemetry(
@@ -998,6 +1008,11 @@ def ast_enabled(self, value: bool) -> None:
9981008
self._auto_clean_up_temp_table_enabled = False
9991009
set_ast_state(AstFlagSource.USER, value)
10001010

1011+
@ast_enabled.setter
1012+
@experimental_parameter(version="1.33.0")
1013+
def ast_enabled(self, value: bool) -> None:
1014+
self._set_ast_enabled_internal(value)
1015+
10011016
@property
10021017
def cte_optimization_enabled(self) -> bool:
10031018
"""Set to ``True`` to enable the CTE optimization (defaults to ``False``).

tests/integ/test_context.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from snowflake.snowpark.context import get_active_session
77
import snowflake.snowpark.context as context
8+
import snowflake.snowpark.session as session
89
from unittest import mock
910

1011

@@ -14,35 +15,32 @@ def test_get_active_session(session):
1415

1516
def test_context_configure_development_features():
1617
try:
17-
# Test when get_active_session() returns None
18-
with mock.patch.object(context, "get_active_session", return_value=None):
18+
# Test when _get_active_sessions() returns None
19+
with mock.patch.object(session, "_get_active_sessions", return_value=set()):
1920
context.configure_development_features(
2021
enable_trace_sql_errors_to_dataframe=True
2122
)
22-
assert context._enable_trace_sql_errors_to_dataframe is False
23+
assert context._enable_trace_sql_errors_to_dataframe is True
2324
assert context._enable_dataframe_trace_on_error is False
25+
assert context._debug_eager_schema_validation is False
2426

25-
# Test when get_active_session() throws an exception
26-
with mock.patch.object(
27-
context, "get_active_session", side_effect=RuntimeError("test")
28-
):
29-
context.configure_development_features(
30-
enable_trace_sql_errors_to_dataframe=True
31-
)
32-
assert context._enable_trace_sql_errors_to_dataframe is False
33-
assert context._enable_dataframe_trace_on_error is False
27+
# Test when _get_active_sessions() returns a valid session
28+
mock_session1 = mock.MagicMock()
29+
mock_session1._set_ast_enabled_internal = mock.MagicMock()
30+
mock_session2 = mock.MagicMock()
31+
mock_session2._set_ast_enabled_internal = mock.MagicMock()
3432

35-
# Test when get_active_session() returns a valid session
36-
mock_session = mock.MagicMock()
3733
with mock.patch.object(
38-
context, "get_active_session", return_value=mock_session
34+
session, "_get_active_sessions", return_value=[mock_session1, mock_session2]
3935
):
4036
context.configure_development_features(
4137
enable_trace_sql_errors_to_dataframe=True
4238
)
4339
assert context._enable_trace_sql_errors_to_dataframe is True
4440
assert context._enable_dataframe_trace_on_error is False
45-
assert mock_session.ast_enabled is True
41+
mock_session1._set_ast_enabled_internal.assert_called_once_with(True)
42+
mock_session2._set_ast_enabled_internal.assert_called_once_with(True)
43+
4644
finally:
4745
context.configure_development_features(
4846
enable_trace_sql_errors_to_dataframe=False

tests/integ/test_session.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,3 +1069,24 @@ def test_transaction(session):
10691069
session.sql(f"DROP TABLE IF EXISTS {temp_table_name}").collect()
10701070
except Exception:
10711071
pass
1072+
1073+
1074+
def test_session_eanble_development_features(db_parameters):
1075+
from snowflake.snowpark import context
1076+
1077+
with patch.object(
1078+
context, "_enable_trace_sql_errors_to_dataframe", return_value=True
1079+
):
1080+
with Session.builder.configs(db_parameters).create() as new_session:
1081+
assert new_session.ast_enabled is True
1082+
1083+
with patch.object(context, "_enable_dataframe_trace_on_error", return_value=True):
1084+
with Session.builder.configs(db_parameters).create() as new_session:
1085+
assert new_session.ast_enabled is True
1086+
1087+
1088+
def test_get_active_sessions_empty():
1089+
from snowflake.snowpark import session as session_module
1090+
1091+
with patch.object(session_module, "_active_sessions", return_value=set()):
1092+
assert session_module._get_active_sessions(require_at_least_one=False) == set()

0 commit comments

Comments
 (0)