Skip to content

Commit 01b8f33

Browse files
authored
SNOW-1877318 combine telemetry usages (#2855)
1 parent 7b79ae0 commit 01b8f33

File tree

6 files changed

+94
-73
lines changed

6 files changed

+94
-73
lines changed

src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import logging
66
from collections import defaultdict
7-
from typing import Dict, List, Optional, Tuple
7+
from typing import Any, Dict, List, Optional, Tuple
88

99
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
1010
drop_table_if_exists_statement,
@@ -46,7 +46,7 @@
4646
from snowflake.snowpark._internal.compiler.query_generator import QueryGenerator
4747
from snowflake.snowpark._internal.compiler.telemetry_constants import (
4848
CompilationStageTelemetryField,
49-
InvalidNodesInBreakdownCategory,
49+
NodeBreakdownCategory,
5050
SkipLargeQueryBreakdownCategory,
5151
)
5252
from snowflake.snowpark._internal.compiler.utils import (
@@ -255,35 +255,56 @@ def _try_to_breakdown_plan(self, root: TreeNode) -> List[LogicalPlan]:
255255
return [root]
256256

257257
plans = []
258-
final_partition_breakdown_summary = {}
258+
self._current_breakdown_summary: Dict[str, Any] = {
259+
CompilationStageTelemetryField.NUM_PARTITIONS_MADE.value: 0,
260+
CompilationStageTelemetryField.NUM_PIPELINE_BREAKER_USED.value: 0,
261+
CompilationStageTelemetryField.NUM_RELAXED_BREAKER_USED.value: 0,
262+
}
259263
while complexity_score > self.complexity_score_upper_bound:
260264
child, validity_statistics = self._find_node_to_breakdown(root)
265+
self._update_current_breakdown_summary(validity_statistics)
266+
261267
if child is None:
262-
final_partition_breakdown_summary = {
263-
k.value: validity_statistics.get(k, 0)
264-
for k in InvalidNodesInBreakdownCategory
265-
}
266268
_logger.debug(
267269
f"Could not find a valid node for partitioning. "
268-
f"Skipping with root {complexity_score=} {final_partition_breakdown_summary=}"
270+
f"Skipping with root {complexity_score=} {self._current_breakdown_summary=}"
269271
)
270272
break
271273

272274
partition = self._get_partitioned_plan(root, child)
273275
plans.append(partition)
274276
complexity_score = get_complexity_score(root)
275277

276-
final_partition_breakdown_summary[
277-
CompilationStageTelemetryField.NUM_PARTITIONS_MADE.value
278-
] = len(plans)
279-
self._breakdown_summary.append(final_partition_breakdown_summary)
280-
278+
self._breakdown_summary.append(self._current_breakdown_summary)
281279
plans.append(root)
282280
return plans
283281

282+
def _update_current_breakdown_summary(
283+
self, validity_statistics: Dict[NodeBreakdownCategory, int]
284+
) -> None:
285+
"""Method to update the breakdown summary based on the validity statistics of the current root."""
286+
if validity_statistics.get(NodeBreakdownCategory.VALID_NODE, 0) > 0:
287+
self._current_breakdown_summary[
288+
CompilationStageTelemetryField.NUM_PARTITIONS_MADE.value
289+
] += 1
290+
self._current_breakdown_summary[
291+
CompilationStageTelemetryField.NUM_PIPELINE_BREAKER_USED.value
292+
] += 1
293+
elif validity_statistics.get(NodeBreakdownCategory.VALID_NODE_RELAXED, 0) > 0:
294+
self._current_breakdown_summary[
295+
CompilationStageTelemetryField.NUM_PARTITIONS_MADE.value
296+
] += 1
297+
self._current_breakdown_summary[
298+
CompilationStageTelemetryField.NUM_RELAXED_BREAKER_USED.value
299+
] += 1
300+
else: # no valid nodes found
301+
self._current_breakdown_summary[
302+
CompilationStageTelemetryField.FAILED_PARTITION_SUMMARY.value
303+
] = {k.value: validity_statistics.get(k, 0) for k in NodeBreakdownCategory}
304+
284305
def _find_node_to_breakdown(
285306
self, root: TreeNode
286-
) -> Tuple[Optional[TreeNode], Dict[InvalidNodesInBreakdownCategory, int]]:
307+
) -> Tuple[Optional[TreeNode], Dict[NodeBreakdownCategory, int]]:
287308
"""This method traverses the plan tree and partitions the plan based if a valid partition node
288309
if found. The steps involved are:
289310
@@ -307,7 +328,7 @@ def _find_node_to_breakdown(
307328
validity_status, score = self._is_node_valid_to_breakdown(
308329
child, root
309330
)
310-
if validity_status == InvalidNodesInBreakdownCategory.VALID_NODE:
331+
if validity_status == NodeBreakdownCategory.VALID_NODE:
311332
# If the score for valid node is higher than the last candidate,
312333
# update the candidate node and score.
313334
if score > candidate_score:
@@ -317,10 +338,7 @@ def _find_node_to_breakdown(
317338
# don't traverse subtrees if parent is a valid candidate
318339
next_level.append(child)
319340

320-
if (
321-
validity_status
322-
== InvalidNodesInBreakdownCategory.VALID_NODE_RELAXED
323-
):
341+
if validity_status == NodeBreakdownCategory.VALID_NODE_RELAXED:
324342
# Update the relaxed candidate node and score.
325343
if score > relaxed_candidate_score:
326344
relaxed_candidate_score = score
@@ -370,7 +388,7 @@ def _get_partitioned_plan(self, root: TreeNode, child: TreeNode) -> SnowflakePla
370388

371389
def _is_node_valid_to_breakdown(
372390
self, node: TreeNode, root: TreeNode
373-
) -> Tuple[InvalidNodesInBreakdownCategory, int]:
391+
) -> Tuple[NodeBreakdownCategory, int]:
374392
"""Method to check if a node is valid to breakdown based on complexity score and node type.
375393
376394
Returns:
@@ -381,29 +399,29 @@ def _is_node_valid_to_breakdown(
381399
"""
382400
score = get_complexity_score(node)
383401
is_valid = True
384-
validity_status = InvalidNodesInBreakdownCategory.VALID_NODE
402+
validity_status = NodeBreakdownCategory.VALID_NODE
385403

386404
# check score bounds
387405
if score < self.complexity_score_lower_bound:
388406
is_valid = False
389-
validity_status = InvalidNodesInBreakdownCategory.SCORE_BELOW_LOWER_BOUND
407+
validity_status = NodeBreakdownCategory.SCORE_BELOW_LOWER_BOUND
390408

391409
if score > self.complexity_score_upper_bound:
392410
is_valid = False
393-
validity_status = InvalidNodesInBreakdownCategory.SCORE_ABOVE_UPPER_BOUND
411+
validity_status = NodeBreakdownCategory.SCORE_ABOVE_UPPER_BOUND
394412

395413
# check pipeline breaker condition
396414
if is_valid and not self._is_node_pipeline_breaker(node):
397415
if self._is_relaxed_pipeline_breaker(node):
398-
validity_status = InvalidNodesInBreakdownCategory.VALID_NODE_RELAXED
416+
validity_status = NodeBreakdownCategory.VALID_NODE_RELAXED
399417
else:
400418
is_valid = False
401-
validity_status = InvalidNodesInBreakdownCategory.NON_PIPELINE_BREAKER
419+
validity_status = NodeBreakdownCategory.NON_PIPELINE_BREAKER
402420

403421
# check external CTE ref condition
404422
if is_valid and self._contains_external_cte_ref(node, root):
405423
is_valid = False
406-
validity_status = InvalidNodesInBreakdownCategory.EXTERNAL_CTE_REF
424+
validity_status = NodeBreakdownCategory.EXTERNAL_CTE_REF
407425

408426
if is_valid:
409427
_logger.debug(

src/snowflake/snowpark/_internal/compiler/plan_compiler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
125125
plot_plan_if_enabled(plan, f"cte_optimized_plan_{i}")
126126

127127
# Large query breakdown
128-
breakdown_failure_summary, skipped_summary = {}, {}
128+
breakdown_summary, skipped_summary = {}, {}
129129
if session.large_query_breakdown_enabled:
130130
large_query_breakdown = LargeQueryBreakdown(
131131
session,
@@ -135,7 +135,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
135135
)
136136
breakdown_result = large_query_breakdown.apply()
137137
logical_plans = breakdown_result.logical_plans
138-
breakdown_failure_summary = breakdown_result.breakdown_summary
138+
breakdown_summary = breakdown_result.breakdown_summary
139139
skipped_summary = breakdown_result.skipped_summary
140140

141141
large_query_breakdown_end_time = time.time()
@@ -166,8 +166,8 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
166166
CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value: complexity_score_before_compilation,
167167
CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION.value: complexity_scores_after_cte,
168168
CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_LARGE_QUERY_BREAKDOWN.value: complexity_scores_after_large_query_breakdown,
169-
CompilationStageTelemetryField.BREAKDOWN_FAILURE_SUMMARY.value: breakdown_failure_summary,
170-
CompilationStageTelemetryField.TYPE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION_SKIPPED.value: skipped_summary,
169+
CompilationStageTelemetryField.BREAKDOWN_SUMMARY.value: breakdown_summary,
170+
CompilationStageTelemetryField.LARGE_QUERY_BREAKDOWN_OPTIMIZATION_SKIPPED.value: skipped_summary,
171171
}
172172
# add the extra optimization status
173173
summary_value.update(extra_optimization_status)

src/snowflake/snowpark/_internal/compiler/telemetry_constants.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,7 @@ class CompilationStageTelemetryField(Enum):
1919
QUERY_PLAN_COMPLEXITY = "query_plan_complexity"
2020

2121
# types
22-
TYPE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION_SKIPPED = (
23-
"snowpark_large_query_breakdown_optimization_skipped"
24-
)
2522
TYPE_COMPILATION_STAGE_STATISTICS = "snowpark_compilation_stage_statistics"
26-
TYPE_COMPILATION_STAGE_FAILED = "snowpark_compilation_stage_failed"
2723
TYPE_LARGE_QUERY_BREAKDOWN_UPDATE_COMPLEXITY_BOUNDS = (
2824
"snowpark_large_query_breakdown_update_complexity_bounds"
2925
)
@@ -37,22 +33,28 @@ class CompilationStageTelemetryField(Enum):
3733
TIME_TAKEN_FOR_DEEP_COPY_PLAN = "time_taken_for_deep_copy_plan_sec"
3834
TIME_TAKEN_FOR_CTE_OPTIMIZATION = "time_taken_for_cte_optimization_sec"
3935
TIME_TAKEN_FOR_LARGE_QUERY_BREAKDOWN = "time_taken_for_large_query_breakdown_sec"
36+
LARGE_QUERY_BREAKDOWN_OPTIMIZATION_SKIPPED = (
37+
"query_breakdown_optimization_skipped_reason"
38+
)
4039

4140
# keys for repeated subquery elimination
4241
CTE_NODE_CREATED = "cte_node_created"
4342

4443
# keys for large query breakdown
45-
BREAKDOWN_FAILURE_SUMMARY = "breakdown_failure_summary"
44+
BREAKDOWN_SUMMARY = "breakdown_summary"
4645
COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION = "complexity_score_after_cte_optimization"
4746
COMPLEXITY_SCORE_AFTER_LARGE_QUERY_BREAKDOWN = (
4847
"complexity_score_after_large_query_breakdown"
4948
)
5049
COMPLEXITY_SCORE_BEFORE_COMPILATION = "complexity_score_before_compilation"
5150
COMPLEXITY_SCORE_BOUNDS = "complexity_score_bounds"
5251
NUM_PARTITIONS_MADE = "num_partitions_made"
52+
NUM_PIPELINE_BREAKER_USED = "num_pipeline_breaker_used"
53+
NUM_RELAXED_BREAKER_USED = "num_relaxed_breaker_used"
54+
FAILED_PARTITION_SUMMARY = "failed_partition_summary"
5355

5456

55-
class InvalidNodesInBreakdownCategory(Enum):
57+
class NodeBreakdownCategory(Enum):
5658
SCORE_BELOW_LOWER_BOUND = "num_nodes_below_lower_bound"
5759
SCORE_ABOVE_UPPER_BOUND = "num_nodes_above_upper_bound"
5860
NON_PIPELINE_BREAKER = "num_non_pipeline_breaker_nodes"

src/snowflake/snowpark/_internal/telemetry.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,6 @@ class TelemetryField(Enum):
8787
NUM_TEMP_TABLES_CLEANED = "num_temp_tables_cleaned"
8888
NUM_TEMP_TABLES_CREATED = "num_temp_tables_created"
8989
TEMP_TABLE_CLEANER_ENABLED = "temp_table_cleaner_enabled"
90-
TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION = (
91-
"snowpark_temp_table_cleanup_abnormal_exception"
92-
)
9390
TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME = (
9491
"temp_table_cleanup_abnormal_exception_table_name"
9592
)
@@ -487,7 +484,7 @@ def send_query_compilation_stage_failed_telemetry(
487484
) -> None:
488485
message = {
489486
**self._create_basic_telemetry_data(
490-
CompilationStageTelemetryField.TYPE_COMPILATION_STAGE_FAILED.value
487+
CompilationStageTelemetryField.TYPE_COMPILATION_STAGE_STATISTICS.value
491488
),
492489
TelemetryField.KEY_DATA.value: {
493490
TelemetryField.SESSION_ID.value: session_id,
@@ -526,7 +523,7 @@ def send_temp_table_cleanup_abnormal_exception_telemetry(
526523
) -> None:
527524
message = {
528525
**self._create_basic_telemetry_data(
529-
TelemetryField.TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION.value
526+
TelemetryField.TYPE_TEMP_TABLE_CLEANUP.value
530527
),
531528
TelemetryField.KEY_DATA.value: {
532529
TelemetryField.SESSION_ID.value: session_id,

tests/integ/test_large_query_breakdown.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,14 @@ def check_result_with_and_without_breakdown(session, df):
9696
def check_summary_breakdown_value(patch_send, expected_summary):
9797
_, kwargs = patch_send.call_args
9898
summary_value = kwargs["compilation_stage_summary"]
99-
assert summary_value["breakdown_failure_summary"] == expected_summary
99+
assert summary_value["breakdown_summary"] == expected_summary
100+
101+
102+
def check_optimization_skipped_reason(patch_send, expected_reason):
103+
summary_value = patch_send.call_args[1]["compilation_stage_summary"]
104+
assert (
105+
summary_value["query_breakdown_optimization_skipped_reason"] == expected_reason
106+
)
100107

101108

102109
def test_no_pipeline_breaker_nodes(session):
@@ -134,6 +141,8 @@ def test_no_pipeline_breaker_nodes(session):
134141
expected_summary = [
135142
{
136143
"num_partitions_made": 1,
144+
"num_pipeline_breaker_used": 0,
145+
"num_relaxed_breaker_used": 1,
137146
}
138147
]
139148
check_summary_breakdown_value(patch_send, expected_summary)
@@ -174,13 +183,17 @@ def test_large_query_breakdown_external_cte_ref(session):
174183
patch_send.assert_called_once()
175184
expected_summary = [
176185
{
177-
"num_external_cte_ref_nodes": 6 if sql_simplifier_enabled else 2,
178-
"num_non_pipeline_breaker_nodes": 0 if sql_simplifier_enabled else 2,
179-
"num_nodes_below_lower_bound": 28,
180-
"num_nodes_above_upper_bound": 1 if sql_simplifier_enabled else 0,
181-
"num_valid_nodes": 0,
182-
"num_valid_nodes_relaxed": 0,
186+
"failed_partition_summary": {
187+
"num_external_cte_ref_nodes": 6 if sql_simplifier_enabled else 2,
188+
"num_non_pipeline_breaker_nodes": 0 if sql_simplifier_enabled else 2,
189+
"num_nodes_below_lower_bound": 28,
190+
"num_nodes_above_upper_bound": 1 if sql_simplifier_enabled else 0,
191+
"num_valid_nodes": 0,
192+
"num_valid_nodes_relaxed": 0,
193+
},
183194
"num_partitions_made": 0,
195+
"num_pipeline_breaker_used": 0,
196+
"num_relaxed_breaker_used": 0,
184197
}
185198
]
186199
check_summary_breakdown_value(patch_send, expected_summary)
@@ -213,14 +226,12 @@ def test_breakdown_at_with_query_node(session):
213226

214227
def test_large_query_breakdown_with_cte_optimization(session):
215228
"""Test large query breakdown works with cte optimized plan"""
216-
if not session.cte_optimization_enabled:
217-
pytest.skip("CTE optimization is not enabled")
229+
session._cte_optimization_enabled = True
218230

219231
if not session.sql_simplifier_enabled:
220232
# the complexity bounds are updated since nested selected calculation is not supported
221233
# when sql simplifier disabled
222234
set_bounds(session, 60, 90)
223-
session._cte_optimization_enabled = True
224235
df0 = session.sql("select 2 as b, 32 as c")
225236
df1 = session.sql("select 1 as a, 2 as b").filter(col("a") == 1)
226237
df1 = df1.join(df0, on=["b"], how="inner")
@@ -231,7 +242,7 @@ def test_large_query_breakdown_with_cte_optimization(session):
231242
df2 = df2.with_column("a", col("a") + i + col("a"))
232243
df3 = df3.with_column("b", col("b") + i + col("b"))
233244

234-
df2 = df2.group_by("a").agg(sum_distinct(col("b")).alias("b"))
245+
df2 = df2.select("b", "a")
235246
df3 = df3.group_by("b").agg(sum_distinct(col("a")).alias("a"))
236247

237248
df4 = df2.union_all(df3).filter(col("a") > 2).with_column("a", col("a") + 1)
@@ -256,14 +267,15 @@ def test_large_query_breakdown_with_cte_optimization(session):
256267
assert len(queries["post_actions"]) == 1
257268
assert queries["post_actions"][0].startswith("DROP TABLE If EXISTS")
258269

259-
patch_send.assert_called_once()
260-
_, kwargs = patch_send.call_args
261-
summary_value = kwargs["compilation_stage_summary"]
262-
assert summary_value["breakdown_failure_summary"] == [
270+
expected_summary = [
263271
{
264272
"num_partitions_made": 1,
273+
"num_pipeline_breaker_used": 1,
274+
"num_relaxed_breaker_used": 0,
265275
}
266276
]
277+
check_summary_breakdown_value(patch_send, expected_summary)
278+
patch_send.assert_called_once()
267279

268280

269281
def test_save_as_table(session, large_query_df):
@@ -547,10 +559,7 @@ def test_optimization_skipped_with_transaction(session, large_query_df, caplog):
547559
) as patch_send:
548560
large_query_df.collect()
549561

550-
summary_value = patch_send.call_args[1]["compilation_stage_summary"]
551-
assert summary_value["snowpark_large_query_breakdown_optimization_skipped"] == {
552-
"active transaction": 1,
553-
}
562+
check_optimization_skipped_reason(patch_send, {"active transaction": 1})
554563

555564
assert len(history.queries) == 2, history.queries
556565
assert history.queries[0].sql_text == "SELECT CURRENT_TRANSACTION()"
@@ -582,10 +591,9 @@ def test_optimization_skipped_with_views_and_dynamic_tables(session, caplog):
582591
"Skipping large query breakdown optimization for view/dynamic table plan"
583592
in caplog.text
584593
)
585-
summary_value = patch_send.call_args[1]["compilation_stage_summary"]
586-
assert summary_value["snowpark_large_query_breakdown_optimization_skipped"] == {
587-
"view or dynamic table command": 1,
588-
}
594+
check_optimization_skipped_reason(
595+
patch_send, {"view or dynamic table command": 1}
596+
)
589597

590598
with caplog.at_level(logging.DEBUG):
591599
with patch.object(
@@ -598,10 +606,9 @@ def test_optimization_skipped_with_views_and_dynamic_tables(session, caplog):
598606
in caplog.text
599607
)
600608
patch_send.assert_called_once()
601-
summary_value = patch_send.call_args[1]["compilation_stage_summary"]
602-
assert summary_value["snowpark_large_query_breakdown_optimization_skipped"] == {
603-
"view or dynamic table command": 1,
604-
}
609+
check_optimization_skipped_reason(
610+
patch_send, {"view or dynamic table command": 1}
611+
)
605612
finally:
606613
Utils.drop_dynamic_table(session, table_name)
607614
Utils.drop_view(session, view_name)
@@ -656,10 +663,7 @@ def test_optimization_skipped_with_no_active_db_or_schema(
656663
in caplog.text
657664
)
658665
patch_send.assert_called_once()
659-
summary_value = patch_send.call_args[1]["compilation_stage_summary"]
660-
assert summary_value["snowpark_large_query_breakdown_optimization_skipped"] == {
661-
f"no active {db_or_schema}": 1,
662-
}
666+
check_optimization_skipped_reason(patch_send, {f"no active {db_or_schema}": 1})
663667

664668

665669
def test_async_job_with_large_query_breakdown(large_query_df):

0 commit comments

Comments
 (0)