Skip to content

Commit 015fffc

Browse files
authored
SNOW-1878372: Fix analyzer access across threads (#2912)
1 parent 8db5109 commit 015fffc

File tree

13 files changed

+148
-61
lines changed

13 files changed

+148
-61
lines changed

src/snowflake/snowpark/_internal/analyzer/analyzer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#
55
import uuid
66
from collections import Counter, defaultdict
7-
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union
7+
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Union
88

99
from snowflake.connector import IntegrityError
1010

@@ -168,7 +168,7 @@ def __init__(self, session: "snowflake.snowpark.session.Session") -> None:
168168
self.plan_builder = SnowflakePlanBuilder(self.session)
169169
self.generated_alias_maps = {}
170170
self.subquery_plans = []
171-
self.alias_maps_to_use: Optional[Dict[uuid.UUID, str]] = None
171+
self.alias_maps_to_use: Dict[uuid.UUID, str] = {}
172172

173173
def analyze(
174174
self,
@@ -368,7 +368,6 @@ def analyze(
368368
return expr.sql
369369

370370
if isinstance(expr, Attribute):
371-
assert self.alias_maps_to_use is not None
372371
name = self.alias_maps_to_use.get(expr.expr_id, expr.name)
373372
return quote_name(name)
374373

src/snowflake/snowpark/_internal/analyzer/metadata_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def cache_metadata_if_select_statement(
197197

198198
if (
199199
isinstance(source_plan, SelectStatement)
200-
and source_plan.analyzer.session.reduce_describe_query_enabled
200+
and source_plan._session.reduce_describe_query_enabled
201201
):
202202
source_plan._attributes = metadata.attributes
203203
# When source_plan doesn't have a projection, it's a simple `SELECT * from ...`,

src/snowflake/snowpark/_internal/analyzer/select_statement.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,14 @@ def __init__(
229229
] = None, # Use Any because it's recursive.
230230
) -> None:
231231
super().__init__()
232-
self.analyzer = analyzer
232+
# With multi-threading support, each thread has its own analyzer which can be
233+
# accessed through session object. Therefore, we need to store the session in
234+
# the Selectable object and use the session to access the appropriate analyzer
235+
# for current thread.
236+
self._session = analyzer.session
237+
# We create this internal object to be used for setting query generator during
238+
# the optimization stage
239+
self._analyzer = None
233240
self.pre_actions: Optional[List["Query"]] = None
234241
self.post_actions: Optional[List["Query"]] = None
235242
self.flatten_disabled: bool = False
@@ -243,6 +250,23 @@ def __init__(
243250
self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None
244251
self._encoded_node_id_with_query: Optional[str] = None
245252

253+
@property
254+
def analyzer(self) -> "Analyzer":
255+
"""Get the analyzer for used for the current thread"""
256+
return self._analyzer or self._session._analyzer
257+
258+
@analyzer.setter
259+
def analyzer(self, value: "Analyzer") -> None:
260+
"""For query optimization stage, we need to replace the analyzer with a query generator which
261+
is aware of schema for the final plan and can compile WithQueryBlocks. Therefore we update the
262+
setter to allow the analyzer to be set externally."""
263+
if not self._is_valid_for_replacement:
264+
raise ValueError(
265+
"Cannot set analyzer for a Selectable that is not valid for replacement"
266+
)
267+
268+
self._analyzer = value
269+
246270
@property
247271
@abstractmethod
248272
def sql_query(self) -> str:
@@ -258,7 +282,7 @@ def encoded_node_id_with_query(self) -> str:
258282
two selectable node with same queries. This is currently used by repeated subquery
259283
elimination to detect two nodes with same query, please use it with careful.
260284
"""
261-
with self.analyzer.session._plan_lock:
285+
with self._session._plan_lock:
262286
if self._encoded_node_id_with_query is None:
263287
self._encoded_node_id_with_query = encode_node_id_with_query(self)
264288
return self._encoded_node_id_with_query
@@ -310,7 +334,7 @@ def get_snowflake_plan(self, skip_schema_query) -> SnowflakePlan:
310334
queries,
311335
schema_query,
312336
post_actions=self.post_actions,
313-
session=self.analyzer.session,
337+
session=self._session,
314338
expr_to_alias=self.expr_to_alias,
315339
df_aliased_col_name_to_real_col_name=self.df_aliased_col_name_to_real_col_name,
316340
source_plan=self,
@@ -328,7 +352,7 @@ def plan_state(self) -> Dict[PlanState, Any]:
328352

329353
@property
330354
def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]:
331-
with self.analyzer.session._plan_lock:
355+
with self._session._plan_lock:
332356
if self._cumulative_node_complexity is None:
333357
self._cumulative_node_complexity = sum_node_complexities(
334358
self.individual_node_complexity,
@@ -361,7 +385,7 @@ def column_states(self) -> ColumnStateDict:
361385
Refer to class ColumnStateDict.
362386
"""
363387
if self._column_states is None:
364-
if self.analyzer.session.reduce_describe_query_enabled:
388+
if self._session.reduce_describe_query_enabled:
365389
# data types are not needed in SQL simplifier, so we
366390
# just create dummy data types here.
367391
column_attrs = [
@@ -512,7 +536,7 @@ def __init__(
512536
self.pre_actions[0].query_id_place_holder
513537
)
514538
self._schema_query = analyzer_utils.schema_value_statement(
515-
analyze_attributes(sql, self.analyzer.session)
539+
analyze_attributes(sql, self._session)
516540
) # Change to subqueryable schema query so downstream query plan can describe the SQL
517541
self._query_param = None
518542
else:
@@ -1165,7 +1189,7 @@ def filter(self, col: Expression) -> "SelectStatement":
11651189
new = SelectStatement(
11661190
from_=self.to_subqueryable(), where=col, analyzer=self.analyzer
11671191
)
1168-
if self.analyzer.session.reduce_describe_query_enabled:
1192+
if self._session.reduce_describe_query_enabled:
11691193
new._attributes = self._attributes
11701194

11711195
return new
@@ -1200,7 +1224,7 @@ def sort(self, cols: List[Expression]) -> "SelectStatement":
12001224
order_by=cols,
12011225
analyzer=self.analyzer,
12021226
)
1203-
if self.analyzer.session.reduce_describe_query_enabled:
1227+
if self._session.reduce_describe_query_enabled:
12041228
new._attributes = self._attributes
12051229

12061230
return new
@@ -1284,7 +1308,7 @@ def limit(self, n: int, *, offset: int = 0) -> "SelectStatement":
12841308
new.pre_actions = new.from_.pre_actions
12851309
new.post_actions = new.from_.post_actions
12861310
new._merge_projection_complexity_with_subquery = False
1287-
if self.analyzer.session.reduce_describe_query_enabled:
1311+
if self._session.reduce_describe_query_enabled:
12881312
new._attributes = self._attributes
12891313

12901314
return new
@@ -1604,7 +1628,7 @@ def can_select_projection_complexity_be_merged(
16041628
on top of subquery.
16051629
subquery: the subquery where the current select is performed on top of
16061630
"""
1607-
if not subquery.analyzer.session._large_query_breakdown_enabled:
1631+
if not subquery._session._large_query_breakdown_enabled:
16081632
return False
16091633

16101634
# only merge of nested select statement is supported, and subquery must be

src/snowflake/snowpark/mock/_nop_analyzer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def attributes(self):
102102
class NopSelectableEntity(MockSelectableEntity):
103103
@property
104104
def attributes(self):
105-
return resolve_attributes(self.entity_plan, session=self.analyzer.session)
105+
return resolve_attributes(self.entity_plan, session=self._session)
106106

107107

108108
class NopAnalyzer(MockAnalyzer):

src/snowflake/snowpark/mock/_plan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -986,15 +986,15 @@ def execute_mock_plan(
986986
res_df = execute_mock_plan(
987987
MockExecutionPlan(
988988
first_operand.selectable,
989-
source_plan.analyzer.session,
989+
source_plan._session,
990990
),
991991
expr_to_alias,
992992
)
993993
for i in range(1, len(source_plan.set_operands)):
994994
operand = source_plan.set_operands[i]
995995
operator = operand.operator
996996
cur_df = execute_mock_plan(
997-
MockExecutionPlan(operand.selectable, source_plan.analyzer.session),
997+
MockExecutionPlan(operand.selectable, source_plan._session),
998998
expr_to_alias,
999999
)
10001000
if len(res_df.columns) != len(cur_df.columns):

src/snowflake/snowpark/mock/_select_statement.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
analyzer: "Analyzer",
6565
) -> None:
6666
super().__init__()
67-
self.analyzer = analyzer
67+
self._session = analyzer.session
6868
self.pre_actions = None
6969
self.post_actions = None
7070
self.flatten_disabled: bool = False
@@ -76,6 +76,10 @@ def __init__(
7676
str, Dict[str, str]
7777
] = defaultdict(dict)
7878

79+
@property
80+
def analyzer(self) -> "Analyzer":
81+
return self._session._analyzer
82+
7983
@property
8084
def sql_query(self) -> str:
8185
"""Returns the sql query of this Selectable logical plan."""
@@ -97,7 +101,7 @@ def execution_plan(self):
97101
from snowflake.snowpark.mock._plan import MockExecutionPlan
98102

99103
if self._execution_plan is None:
100-
self._execution_plan = MockExecutionPlan(self, self.analyzer.session)
104+
self._execution_plan = MockExecutionPlan(self, self._session)
101105
return self._execution_plan
102106

103107
@property

tests/integ/compiler/test_query_generator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def reset_node(node: LogicalPlan, query_generator: QueryGenerator) -> None:
7070
def reset_selectable(selectable_node: Selectable) -> None:
7171
# reset the analyzer to use the current query generator instance to
7272
# ensure the new query generator is used during the resolve process
73+
selectable_node._is_valid_for_replacement = True
7374
selectable_node.analyzer = query_generator
7475
if not isinstance(selectable_node, (SelectSnowflakePlan, SelectSQL)):
7576
selectable_node._snowflake_plan = None

tests/integ/test_multithreading.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -660,11 +660,12 @@ def change_config_value(session_):
660660
session_.conf.set(config, value)
661661

662662
caplog.clear()
663-
change_config_value(threadsafe_session)
664-
assert (
665-
f"You might have more than one threads sharing the Session object trying to update {config}"
666-
not in caplog.text
667-
)
663+
if threading.active_count() == 1:
664+
change_config_value(threadsafe_session)
665+
assert (
666+
f"You might have more than one threads sharing the Session object trying to update {config}"
667+
not in caplog.text
668+
)
668669

669670
with caplog.at_level(logging.WARNING):
670671
with ThreadPoolExecutor(max_workers=5) as executor:
@@ -857,9 +858,13 @@ def process_data(df_, thread_id):
857858
).csv(f"{stage_with_prefix}/{filename}")
858859

859860
with threadsafe_session.query_history() as history:
861+
futures = []
860862
with ThreadPoolExecutor(max_workers=5) as executor:
861863
for i in range(10):
862-
executor.submit(process_data, df, i)
864+
futures.append(executor.submit(process_data, df, i))
865+
866+
for future in as_completed(futures):
867+
future.result()
863868

864869
queries_sent = [query.sql_text for query in history.queries]
865870

@@ -953,3 +958,58 @@ def call_critical_lazy_methods(df_):
953958
# called only once and the cached result should be used for the rest of
954959
# the calls.
955960
mock_find_duplicate_subtrees.assert_called_once()
961+
962+
963+
def create_and_join(_session):
964+
df1 = _session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
965+
df2 = _session.create_dataframe([[1, 7], [3, 8]], schema=["a", "b"])
966+
df3 = df1.join(df2)
967+
expected = [Row(1, 2, 1, 7), Row(1, 2, 3, 8), Row(3, 4, 1, 7), Row(3, 4, 3, 8)]
968+
Utils.check_answer(df3, expected)
969+
return [df1, df2, df3]
970+
971+
972+
def join_again(df1, df2, df3):
973+
df3 = df1.join(df2).select(df1.a)
974+
expected = [Row(1, 2, 1, 7), Row(1, 2, 3, 8), Row(3, 4, 1, 7), Row(3, 4, 3, 8)]
975+
Utils.check_answer(df3, expected)
976+
977+
978+
def create_aliased_df(_session):
979+
df1 = _session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
980+
df2 = df1.join(df1.filter(col("a") == 1)).select(df1.a.alias("a1"))
981+
Utils.check_answer(df2, [Row(A1=1), Row(A1=3)])
982+
return [df2]
983+
984+
985+
def select_aliased_col(df2):
986+
df2 = df2.select(df2.a1)
987+
Utils.check_answer(df2, [Row(A1=1), Row(A1=3)])
988+
989+
990+
@pytest.mark.xfail(
991+
"config.getoption('local_testing_mode', default=False)",
992+
reason="SNOW-1373887: Support basic diamond shaped joins in Local Testing",
993+
run=False,
994+
)
995+
@pytest.mark.parametrize(
996+
"f1,f2", [(create_and_join, join_again), (create_aliased_df, select_aliased_col)]
997+
)
998+
def test_SNOW_1878372(threadsafe_session, f1, f2):
999+
class ReturnableThread(threading.Thread):
1000+
def __init__(self, target, *args, **kwargs) -> None:
1001+
super().__init__(*args, **kwargs)
1002+
self._target = target
1003+
self.result = None
1004+
1005+
def run(self):
1006+
if self._target is not None:
1007+
self.result = self._target(*self._args, **self._kwargs)
1008+
1009+
t1 = ReturnableThread(target=f1, args=(threadsafe_session,))
1010+
t1.start()
1011+
t1.join()
1012+
1013+
t2 = ReturnableThread(target=f2, args=tuple(t1.result))
1014+
t2.start()
1015+
t2.join()

tests/unit/compiler/test_large_query_breakdown.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,14 @@
3939
from snowflake.snowpark._internal.compiler.large_query_breakdown import (
4040
LargeQueryBreakdown,
4141
)
42+
from snowflake.snowpark.session import Session
4243

44+
dummy_session = mock.create_autospec(Session)
45+
dummy_analyzer = mock.create_autospec(Analyzer)
46+
dummy_analyzer.session = dummy_session
4347
empty_logical_plan = LogicalPlan()
4448
empty_expression = Expression()
45-
empty_selectable = SelectSQL("dummy_query", analyzer=mock.create_autospec(Analyzer))
49+
empty_selectable = SelectSQL("dummy_query", analyzer=dummy_analyzer)
4650

4751

4852
@pytest.mark.parametrize(

tests/unit/compiler/test_replace_child_and_update_node.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,8 @@ def test_select_statement(
455455
new_replaced_plan = plan.children_plan_nodes[0]
456456
assert isinstance(new_replaced_plan, SelectSnowflakePlan)
457457
assert new_replaced_plan._snowflake_plan.source_plan == new_plan
458-
assert new_replaced_plan.analyzer == mock_query_generator
458+
# new_replaced_plan is created with QueryGenerator.to_selectable
459+
assert new_replaced_plan.analyzer == mock_analyzer
459460

460461
post_actions = [Query("drop table if exists table_name")]
461462
new_replaced_plan.post_actions = post_actions

0 commit comments

Comments
 (0)