|
6 | 6 | import functools |
7 | 7 | import inspect |
8 | 8 | import re |
| 9 | +import time |
9 | 10 | from contextlib import nullcontext |
10 | 11 | from enum import Enum, unique |
11 | 12 | from typing import Any, Callable, Optional, TypeVar, Union, cast |
|
15 | 16 | from modin.config import MetricsMode |
16 | 17 | from typing_extensions import ParamSpec |
17 | 18 |
|
| 19 | +from snowflake.snowpark.modin.config.envvars import ( |
| 20 | + SnowflakeModinTelemetryEnabled, |
| 21 | + SnowflakeModinTelemetryFlushInterval, |
| 22 | +) |
18 | 23 | import snowflake.snowpark.session |
19 | 24 | from snowflake.connector.telemetry import TelemetryField as PCTelemetryField |
20 | 25 | from snowflake.snowpark._internal.telemetry import TelemetryField, safe_telemetry |
@@ -62,6 +67,19 @@ class PropertyMethodType(Enum): |
62 | 67 | FDEL = "delete" |
63 | 68 |
|
64 | 69 |
|
| 70 | +class ModinTelemetrySender: |
| 71 | + """ |
| 72 | + Class designed to allow for easier testing of telemetry |
| 73 | + """ |
| 74 | + |
| 75 | + @classmethod |
| 76 | + def _send_telemetry(cls, session: Session, message: dict) -> None: |
| 77 | + """ |
| 78 | + Internal method to allow for easier testing |
| 79 | + """ |
| 80 | + return session._conn._telemetry_client.send(message) |
| 81 | + |
| 82 | + |
65 | 83 | @safe_telemetry |
66 | 84 | def _send_modin_api_telemetry( |
67 | 85 | session: Session, event: str, value: Union[int, float], aggregatable: bool |
@@ -94,7 +112,7 @@ def _send_modin_api_telemetry( |
94 | 112 | TelemetryField.KEY_DATA.value: data, |
95 | 113 | PCTelemetryField.KEY_SOURCE.value: "modin", |
96 | 114 | } |
97 | | - session._conn._telemetry_client.send(message) |
| 115 | + ModinTelemetrySender()._send_telemetry(session, message) |
98 | 116 |
|
99 | 117 |
|
100 | 118 | @safe_telemetry |
@@ -146,7 +164,7 @@ def _send_snowpark_pandas_telemetry_helper( |
146 | 164 | TelemetryField.KEY_DATA.value: data, |
147 | 165 | PCTelemetryField.KEY_SOURCE.value: "SnowparkPandas", |
148 | 166 | } |
149 | | - session._conn._telemetry_client.send(message) |
| 167 | + ModinTelemetrySender()._send_telemetry(session, message) |
150 | 168 |
|
151 | 169 |
|
152 | 170 | def _not_equal_to_default(arg_val: Any, default_val: Any) -> bool: |
@@ -644,28 +662,101 @@ def __new__( |
644 | 662 | return type.__new__(cls, name, bases, attrs) |
645 | 663 |
|
646 | 664 |
|
| 665 | +_modin_event_log: list = [[]] |
| 666 | +_last_modin_metric_flush: float = 0 |
| 667 | +_modin_metric_flush_interval = 0 |
| 668 | + |
| 669 | +MODIN_SWITCH_DECISION_METRIC_PREFIXES = ( |
| 670 | + "modin.hybrid.merge.decision", |
| 671 | + "modin.hybrid.auto.decision", |
| 672 | +) |
| 673 | +MODIN_PERFORMANCE_METRIC_PREFIXES = ("modin.query-compiler",) |
| 674 | + |
| 675 | + |
| 676 | +def _check_and_reset_metric_flush_time() -> bool: |
| 677 | + """ |
| 678 | + Return False if we still need to aggregate more metrics |
| 679 | + Return True if we should flush the metrics, and reset the clock |
| 680 | +
|
| 681 | + """ |
| 682 | + global _last_modin_metric_flush |
| 683 | + global _modin_metric_flush_interval |
| 684 | + |
| 685 | + # Support a changing flush interval |
| 686 | + current_flush_interval = SnowflakeModinTelemetryFlushInterval.get() |
| 687 | + current_time = time.time() |
| 688 | + if current_time > _last_modin_metric_flush + current_flush_interval: |
| 689 | + _last_modin_metric_flush = current_time |
| 690 | + return True |
| 691 | + |
| 692 | + return False |
| 693 | + |
| 694 | + |
| 695 | +def _flush_modin_metrics() -> None: |
| 696 | + """ |
| 697 | + Flush the collected modin metrics through the normal telemetry channel. |
| 698 | + Aggregate all metrics with the same name into simple statistics. Set |
| 699 | + the aggregatable field to True only for the count statistic. |
| 700 | +
|
| 701 | + This will output metrics of the form: |
| 702 | + modin.query-compiler.snowflakequerycompiler.value_counts.stat.mean |
| 703 | + modin.query-compiler.snowflakequerycompiler.value_counts.stat.median |
| 704 | + modin.query-compiler.snowflakequerycompiler.value_counts.stat.count |
| 705 | + modin.hybrid.auto.decision.Pandas.count |
| 706 | + modin.hybrid.auto.decision.Snowflake.mean |
| 707 | + ... |
| 708 | + """ |
| 709 | + global _modin_event_log |
| 710 | + try: |
| 711 | + summary_stat_names = ["count", "median", "mean"] |
| 712 | + processing_df = native_pd.DataFrame( |
| 713 | + _modin_event_log, columns=["metric", "value"] |
| 714 | + ) |
| 715 | + summary_stats = processing_df.groupby("metric").agg(summary_stat_names) |
| 716 | + session = snowflake.snowpark.session._get_active_session() |
| 717 | + for row in summary_stats.iterrows(): |
| 718 | + for stat in summary_stats: |
| 719 | + stat_specific_metric = f"{row[0]}.stat.{stat[1]}" |
| 720 | + |
| 721 | + _send_modin_api_telemetry( |
| 722 | + session=session, |
| 723 | + event=stat_specific_metric, |
| 724 | + value=row[1][stat], |
| 725 | + aggregatable=stat == ("value", "count"), |
| 726 | + ) |
| 727 | + except Exception: |
| 728 | + pass |
| 729 | + _modin_event_log = [] |
| 730 | + |
| 731 | + |
647 | 732 | def modin_telemetry_watcher(metric_name: str, metric_value: Union[int, float]) -> None: |
648 | 733 | """ |
649 | 734 | Telemetry hook that collects modin telemetry events of interest for |
650 | 735 | transmission to Snowflake. |
651 | 736 | """ |
652 | | - useful_metrics = ( |
653 | | - "modin.hybrid.merge.decision", |
654 | | - "modin.pandas-api", |
655 | | - "modin.query-compiler", |
656 | | - "modin.hybrid.auto.decision", |
657 | | - ) |
658 | | - if metric_name.startswith(useful_metrics): |
659 | | - try: |
660 | | - session = snowflake.snowpark.session._get_active_session() |
661 | | - _send_modin_api_telemetry( |
662 | | - session=session, |
663 | | - event=metric_name, |
664 | | - value=metric_value, |
665 | | - aggregatable=False, |
666 | | - ) |
667 | | - except Exception: |
668 | | - pass |
| 737 | + simplified_metric = metric_name |
| 738 | + |
| 739 | + metric_valid = False |
| 740 | + # ignore telemetry from dunder and internal metrics |
| 741 | + if metric_name.startswith(MODIN_PERFORMANCE_METRIC_PREFIXES): |
| 742 | + parts = metric_name.split(".") |
| 743 | + if parts[3].startswith("_"): |
| 744 | + return |
| 745 | + metric_valid = True |
| 746 | + |
| 747 | + if metric_name.startswith(MODIN_SWITCH_DECISION_METRIC_PREFIXES): |
| 748 | + # strip off the groups |
| 749 | + simplified_metric = ".".join(metric_name.split(".")[0:5]) |
| 750 | + metric_valid = True |
| 751 | + |
| 752 | + if not metric_valid: |
| 753 | + return |
| 754 | + |
| 755 | + _modin_event_log.append([simplified_metric, metric_value]) |
| 756 | + # We will lose telemetry at the tail end of the process, but |
| 757 | + # that's OK - this telemetry is meant to be lossy |
| 758 | + if _check_and_reset_metric_flush_time(): |
| 759 | + _flush_modin_metrics() |
669 | 760 |
|
670 | 761 |
|
671 | 762 | hybrid_switch_log = native_pd.DataFrame({}) |
@@ -736,5 +827,6 @@ def hybrid_describe_telemetry_watcher( |
736 | 827 |
|
737 | 828 | def connect_modin_telemetry() -> None: |
738 | 829 | MetricsMode.enable() |
739 | | - add_metric_handler(modin_telemetry_watcher) |
| 830 | + if SnowflakeModinTelemetryEnabled.get(): |
| 831 | + add_metric_handler(modin_telemetry_watcher) |
740 | 832 | add_metric_handler(hybrid_describe_telemetry_watcher) |
0 commit comments