-
Notifications
You must be signed in to change notification settings - Fork 144
Expand file tree
/
Copy pathplan_compiler.py
More file actions
238 lines (211 loc) · 11.5 KB
/
plan_compiler.py
File metadata and controls
238 lines (211 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
import copy
import logging
import time
from typing import Any, Dict, List
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
get_complexity_score,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan import (
PlanQueryType,
Query,
SnowflakePlan,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan
from snowflake.snowpark._internal.compiler.large_query_breakdown import (
LargeQueryBreakdown,
)
from snowflake.snowpark._internal.compiler.repeated_subquery_elimination import (
RepeatedSubqueryElimination,
)
from snowflake.snowpark._internal.compiler.telemetry_constants import (
CompilationStageTelemetryField,
)
from snowflake.snowpark._internal.compiler.utils import (
create_query_generator,
plot_plan_if_enabled,
)
from snowflake.snowpark._internal.telemetry import TelemetryField
from snowflake.snowpark._internal.utils import random_name_for_temp_object
from snowflake.snowpark.mock._connection import MockServerConnection
_logger = logging.getLogger(__name__)
class PlanCompiler:
"""
This class is responsible for compiling a SnowflakePlan to list of queries and post actions that
will be sent over to the server for execution.
The entry point function is compile(), which applies the following steps:
1) Run pre-check for the Snowflake plan, which mainly checks if optimizations can be applied.
2) Run pre-process step if optimization can be applied, which extracts and copies the set of
logical plans associated with the original plan to apply optimizations on.
3) Applies steps of optimizations. Each optimization takes a set of logical plan and produces a
new set of logical plans. Note that the optimizations will not maintain schema/attributes,
so none of the optimizations should rely on the schema/attributes.
4) Generate queries.
"""
def __init__(self, plan: SnowflakePlan) -> None:
self._plan = plan
def should_start_query_compilation(self) -> bool:
"""
Whether optimization should be applied to the plan or not.
Optimization can be applied if
1) there is source logical plan attached to the current snowflake plan
2) the query compilation stage is enabled
3) optimizations are enabled in the current session, such as cte_optimization_enabled
Returns
-------
True if optimization should be applied. Otherwise, return False.
"""
current_session = self._plan.session
return (
not isinstance(current_session._conn, MockServerConnection)
and (self._plan.source_plan is not None)
and current_session._query_compilation_stage_enabled
and (
current_session.cte_optimization_enabled
or current_session.large_query_breakdown_enabled
)
)
def compile(self) -> Dict[PlanQueryType, List[Query]]:
# initialize the queries with the original queries without optimization
final_plan = self._plan
queries = {
PlanQueryType.QUERIES: final_plan.queries,
PlanQueryType.POST_ACTIONS: final_plan.post_actions,
}
if self.should_start_query_compilation():
session = self._plan.session
try:
# preparation for compilation
# 1. make a copy of the original plan
start_time = time.time()
complexity_score_before_compilation = get_complexity_score(self._plan)
logical_plans: List[LogicalPlan] = [copy.deepcopy(self._plan)]
plot_plan_if_enabled(self._plan, "original_plan")
plot_plan_if_enabled(logical_plans[0], "deep_copied_plan")
deep_copy_end_time = time.time()
# 2. create a code generator with the original plan
query_generator = create_query_generator(self._plan)
extra_optimization_status: Dict[str, Any] = {}
# 3. apply each optimizations if needed
# CTE optimization
cte_start_time = time.time()
if session.cte_optimization_enabled:
repeated_subquery_eliminator = RepeatedSubqueryElimination(
logical_plans, query_generator
)
elimination_result = repeated_subquery_eliminator.apply()
logical_plans = elimination_result.logical_plans
# add the extra repeated subquery elimination status
extra_optimization_status[
CompilationStageTelemetryField.CTE_NODE_CREATED.value
] = elimination_result.total_num_of_ctes
cte_end_time = time.time()
complexity_scores_after_cte = [
get_complexity_score(logical_plan) for logical_plan in logical_plans
]
for i, plan in enumerate(logical_plans):
plot_plan_if_enabled(plan, f"cte_optimized_plan_{i}")
# Large query breakdown
breakdown_failure_summary, skipped_summary = {}, {}
if session.large_query_breakdown_enabled:
large_query_breakdown = LargeQueryBreakdown(
session,
query_generator,
logical_plans,
session.large_query_breakdown_complexity_bounds,
)
breakdown_result = large_query_breakdown.apply()
logical_plans = breakdown_result.logical_plans
breakdown_failure_summary = breakdown_result.breakdown_summary
skipped_summary = breakdown_result.skipped_summary
large_query_breakdown_end_time = time.time()
complexity_scores_after_large_query_breakdown = [
get_complexity_score(logical_plan) for logical_plan in logical_plans
]
for i, plan in enumerate(logical_plans):
plot_plan_if_enabled(plan, f"large_query_breakdown_plan_{i}")
# 4. do a final pass of code generation
queries = query_generator.generate_queries(logical_plans)
# log telemetry data
deep_copy_time = deep_copy_end_time - start_time
cte_time = cte_end_time - cte_start_time
large_query_breakdown_time = (
large_query_breakdown_end_time - cte_end_time
)
total_time = time.time() - start_time
summary_value = {
TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled,
TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled,
CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds,
CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time,
CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time,
CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time,
CompilationStageTelemetryField.TIME_TAKEN_FOR_LARGE_QUERY_BREAKDOWN.value: large_query_breakdown_time,
CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value: complexity_score_before_compilation,
CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION.value: complexity_scores_after_cte,
CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_LARGE_QUERY_BREAKDOWN.value: complexity_scores_after_large_query_breakdown,
CompilationStageTelemetryField.BREAKDOWN_FAILURE_SUMMARY.value: breakdown_failure_summary,
CompilationStageTelemetryField.TYPE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION_SKIPPED.value: skipped_summary,
}
# add the extra optimization status
summary_value.update(extra_optimization_status)
session._conn._telemetry_client.send_query_compilation_summary_telemetry(
session_id=session.session_id,
plan_uuid=self._plan.uuid,
compilation_stage_summary=summary_value,
)
except Exception as e:
# if any error occurs during the compilation, we should fall back to the original plan
_logger.debug(f"Skipping optimization due to error: {e}")
session._conn._telemetry_client.send_query_compilation_stage_failed_telemetry(
session_id=session.session_id,
plan_uuid=self._plan.uuid,
error_type=type(e).__name__,
error_message=str(e),
)
return self.replace_temp_obj_placeholders(queries)
def replace_temp_obj_placeholders(
self, queries: Dict[PlanQueryType, List[Query]]
) -> Dict[PlanQueryType, List[Query]]:
"""
When thread-safe session is enabled, we use temporary object name placeholders instead of a temporary name
when generating snowflake plan. We replace the temporary object name placeholders with actual temporary object
names here. This is done to prevent the following scenario:
1. A dataframe is created and resolved in main thread.
2. The resolve plan contains queries that create and drop temp objects.
3. If the plan with same temp object names is executed my multiple threads, the temp object names will conflict.
One thread can drop the object before another thread finished using it.
To prevent this, we generate queries with temp object name placeholders and replace them with actual temp object
here.
"""
session = self._plan.session
if session._conn._thread_safe_session_enabled:
# This dictionary will store the mapping between placeholder name and actual temp object name.
placeholders = {}
# Final execution queries
execution_queries = {}
for query_type, query_list in queries.items():
execution_queries[query_type] = []
for query in query_list:
# If the query contains a temp object name placeholder, we generate a random
# name for the temp object and add it to the placeholders dictionary.
if query.temp_obj_name_placeholder:
(
placeholder_name,
temp_obj_type,
) = query.temp_obj_name_placeholder
placeholders[placeholder_name] = random_name_for_temp_object(
temp_obj_type
)
copied_query = copy.copy(query)
for placeholder_name, target_temp_name in placeholders.items():
# Copy the original query and replace all the placeholder names with the
# actual temp object names.
copied_query.sql = copied_query.sql.replace(
placeholder_name, target_temp_name
)
execution_queries[query_type].append(copied_query)
return execution_queries
return queries