Skip to content

Commit fd5333c

Browse files
SNOW-2230705: add support for UDF profiler (#3909)
1 parent 5ae5275 commit fd5333c

File tree

8 files changed

+371
-86
lines changed

8 files changed

+371
-86
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# Release History
22

3+
## 1.42.0 (YYYY-MM-DD)
4+
5+
### Snowpark Python API Updates
6+
7+
#### New Features
8+
9+
- Added support for `Session.udf_profiler`.
10+
311
## 1.41.0 (YYYY-MM-DD)
412

513
### Snowpark Python API Updates

src/snowflake/snowpark/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"QueryListener",
4141
"AsyncJob",
4242
"StoredProcedureProfiler",
43+
"UDFProfiler",
4344
]
4445

4546

@@ -54,6 +55,7 @@
5455
from snowflake.snowpark.async_job import AsyncJob
5556
from snowflake.snowpark.column import CaseExpr, Column
5657
from snowflake.snowpark.stored_procedure_profiler import StoredProcedureProfiler
58+
from snowflake.snowpark.udf_profiler import UDFProfiler
5759
from snowflake.snowpark.dataframe import DataFrame
5860
from snowflake.snowpark.dataframe_ai_functions import DataFrameAIFunctions
5961
from snowflake.snowpark.dataframe_analytics_functions import DataFrameAnalyticsFunctions
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#
2+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
import logging
6+
import threading
7+
from typing import List, Literal, Optional
8+
9+
import snowflake.snowpark
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class SnowparkProfiler:
15+
"""
16+
Base class for stored procedure profiler and UDF profiler
17+
"""
18+
19+
def __init__(
20+
self,
21+
session: "snowflake.snowpark.Session",
22+
) -> None:
23+
self._session = session
24+
self._query_history = None
25+
self._lock = threading.RLock()
26+
self._active_profiler_number = 0
27+
self._has_target_stage = False
28+
self._is_enabled = False
29+
30+
self._active_profiler_name = "ACTIVE_PYTHON_PROFILER"
31+
self._output_sql = ""
32+
self._profiler_module_name = ""
33+
34+
def register_modules(self, modules: Optional[List[str]] = None) -> None:
35+
"""
36+
Register modules to generate profiles for them.
37+
38+
Args:
39+
modules: List of names of stored procedures. Registered modules will be overwritten by this input.
40+
Input None or an empty list will remove registered modules.
41+
"""
42+
module_string = ",".join(modules) if modules is not None else ""
43+
sql_statement = (
44+
f"alter session set {self._profiler_module_name}='{module_string}'"
45+
)
46+
self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry()
47+
48+
def set_active_profiler(
49+
self, active_profiler_type: Literal["LINE", "MEMORY"] = "LINE"
50+
) -> None:
51+
"""
52+
Set active profiler.
53+
54+
Args:
55+
active_profiler_type: String that represent active_profiler, must be either 'LINE' or 'MEMORY'
56+
(case-insensitive). Active profiler is 'LINE' by default.
57+
58+
"""
59+
if active_profiler_type.upper() not in ["LINE", "MEMORY"]:
60+
raise ValueError(
61+
f"active_profiler expect 'LINE', 'MEMORY', got {active_profiler_type} instead"
62+
)
63+
sql_statement = f"alter session set {self._active_profiler_name} = '{active_profiler_type.upper()}'"
64+
try:
65+
self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry()
66+
except Exception as e:
67+
logger.warning(
68+
f"Set active profiler failed because of {e}. Active profiler is previously set value or default 'LINE' now."
69+
)
70+
with self._lock:
71+
self._active_profiler_number += 1
72+
if self._query_history is None:
73+
self._query_history = self._session.query_history(
74+
include_thread_id=True, include_error=True
75+
)
76+
self._is_enabled = True
77+
78+
def disable(self) -> None:
79+
"""
80+
Disable profiler.
81+
"""
82+
with self._lock:
83+
self._active_profiler_number -= 1
84+
if self._active_profiler_number == 0:
85+
self._session._conn.remove_query_listener(self._query_history) # type: ignore
86+
self._query_history = None
87+
self._is_enabled = False
88+
sql_statement = f"alter session set {self._active_profiler_name} = ''"
89+
self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry()
90+
91+
@staticmethod
92+
def _is_procedure_or_function_call(query: str) -> bool:
93+
pass
94+
95+
def _get_last_query_id(self) -> Optional[str]:
96+
current_thread = threading.get_ident()
97+
for query in self._query_history.queries[::-1]: # type: ignore
98+
query_thread = getattr(query, "thread_id", None)
99+
if query_thread == current_thread and self._is_procedure_or_function_call(
100+
query.sql_text
101+
):
102+
return query.query_id
103+
return None
104+
105+
def get_output(self) -> str:
106+
"""
107+
Return the profiles of last executed stored procedure or UDF in current thread. If there is no previous
108+
stored procedure or UDF call, an error will be raised.
109+
110+
Note:
111+
Please call this function right after the stored procedure or UDF you want to profile to avoid any error.
112+
113+
"""
114+
# return empty string when profiler is not enabled to not interrupt user's code
115+
if not self._is_enabled:
116+
logger.warning(
117+
"You are seeing this warning because you try to get profiler output while profiler is disabled. Please use profiler.set_active_profiler() to enable profiler."
118+
)
119+
return ""
120+
query_id = self._get_last_query_id()
121+
if query_id is None:
122+
logger.warning(
123+
"You are seeing this warning because last executed stored procedure or UDF does not exist. Please run the store procedure or UDF before get profiler output."
124+
)
125+
return ""
126+
sql = self._output_sql.format(query_id=query_id)
127+
return self._session.sql(sql)._internal_collect_with_tag_no_telemetry()[0][0] # type: ignore

src/snowflake/snowpark/session.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
from snowflake.connector import ProgrammingError, SnowflakeConnection
4545
from snowflake.connector.options import installed_pandas, pandas, pyarrow
4646
from snowflake.connector.pandas_tools import write_pandas
47+
48+
from snowflake.snowpark import UDFProfiler
4749
from snowflake.snowpark._internal.analyzer import analyzer_utils
4850
from snowflake.snowpark._internal.analyzer.analyzer import Analyzer
4951
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
@@ -807,6 +809,7 @@ def __init__(
807809
self._runtime_version_from_requirement: str = None
808810
self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self)
809811
self._sp_profiler = StoredProcedureProfiler(session=self)
812+
self._udf_profiler = UDFProfiler(session=self)
810813
self._dataframe_profiler = DataframeProfiler(session=self)
811814
self._catalog = None
812815

@@ -4314,6 +4317,14 @@ def stored_procedure_profiler(self) -> StoredProcedureProfiler:
43144317
"""
43154318
return self._sp_profiler
43164319

4320+
@property
4321+
def udf_profiler(self) -> UDFProfiler:
4322+
"""
4323+
Returns a :class:`udf_profiler.UDFProfiler` object that you can use to profile UDFs.
4324+
See details of how to use this object in :class:`udf_profiler.UDFProfiler`.
4325+
"""
4326+
return self._udf_profiler
4327+
43174328
@property
43184329
def dataframe_profiler(self) -> DataframeProfiler:
43194330
"""

src/snowflake/snowpark/stored_procedure_profiler.py

Lines changed: 11 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,20 @@
22
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
33
#
44
import logging
5-
import threading
6-
from typing import List, Literal, Optional
5+
from typing import Literal
76

87
import snowflake.snowpark
8+
from snowflake.snowpark._internal.snowpark_profiler import SnowparkProfiler
99
from snowflake.snowpark._internal.utils import (
10-
SNOWFLAKE_ANONYMOUS_CALL_WITH_PATTERN,
1110
parse_table_name,
1211
strip_double_quotes_in_like_statement_in_table_name,
12+
SNOWFLAKE_ANONYMOUS_CALL_WITH_PATTERN,
1313
)
1414

1515
logger = logging.getLogger(__name__)
1616

1717

18-
class StoredProcedureProfiler:
18+
class StoredProcedureProfiler(SnowparkProfiler):
1919
"""
2020
Set up profiler to receive profiles of stored procedures. This feature cannot be used in owner's right stored
2121
procedure because owner's right stored procedure will not be able to set session-level parameters.
@@ -25,24 +25,11 @@ def __init__(
2525
self,
2626
session: "snowflake.snowpark.Session",
2727
) -> None:
28-
self._session = session
29-
self._query_history = None
30-
self._lock = threading.RLock()
31-
self._active_profiler_number = 0
32-
self._has_target_stage = False
33-
self._is_enabled = False
34-
35-
def register_modules(self, stored_procedures: Optional[List[str]] = None) -> None:
36-
"""
37-
Register stored procedures to generate profiles for them.
38-
39-
Args:
40-
stored_procedures: List of names of stored procedures. Registered modules will be overwritten by this input.
41-
Input None or an empty list will remove registered modules.
42-
"""
43-
sp_string = ",".join(stored_procedures) if stored_procedures is not None else ""
44-
sql_statement = f"alter session set python_profiler_modules='{sp_string}'"
45-
self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry()
28+
super().__init__(session)
29+
self._output_sql = (
30+
"select snowflake.core.get_python_profiler_output('{query_id}')"
31+
)
32+
self._profiler_module_name = "python_profiler_modules"
4633

4734
def set_target_stage(self, stage: str) -> None:
4835
"""
@@ -84,71 +71,11 @@ def set_active_profiler(
8471
logger.info(
8572
"Target stage for profiler not found, using default stage of current session."
8673
)
87-
if active_profiler_type.upper() not in ["LINE", "MEMORY"]:
88-
raise ValueError(
89-
f"active_profiler expect 'LINE', 'MEMORY', got {active_profiler_type} instead"
90-
)
91-
sql_statement = f"alter session set ACTIVE_PYTHON_PROFILER = '{active_profiler_type.upper()}'"
92-
try:
93-
self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry()
94-
except Exception as e:
95-
logger.warning(
96-
f"Set active profiler failed because of {e}. Active profiler is previously set value or default 'LINE' now."
97-
)
98-
with self._lock:
99-
self._active_profiler_number += 1
100-
if self._query_history is None:
101-
self._query_history = self._session.query_history(
102-
include_thread_id=True, include_error=True
103-
)
104-
self._is_enabled = True
105-
106-
def disable(self) -> None:
107-
"""
108-
Disable profiler.
109-
"""
110-
with self._lock:
111-
self._active_profiler_number -= 1
112-
if self._active_profiler_number == 0:
113-
self._session._conn.remove_query_listener(self._query_history) # type: ignore
114-
self._query_history = None
115-
self._is_enabled = False
116-
sql_statement = "alter session set ACTIVE_PYTHON_PROFILER = ''"
117-
self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry()
74+
super().set_active_profiler(active_profiler_type)
11875

11976
@staticmethod
120-
def _is_sp_call(query: str) -> bool:
77+
def _is_procedure_or_function_call(query: str) -> bool:
12178
query = query.upper().strip(" ")
12279
return SNOWFLAKE_ANONYMOUS_CALL_WITH_PATTERN.match(
12380
query
12481
) is not None or query.startswith("CALL")
125-
126-
def _get_last_query_id(self) -> Optional[str]:
127-
current_thread = threading.get_ident()
128-
for query in self._query_history.queries[::-1]: # type: ignore
129-
query_thread = getattr(query, "thread_id", None)
130-
if query_thread == current_thread and self._is_sp_call(query.sql_text):
131-
return query.query_id
132-
return None
133-
134-
def get_output(self) -> str:
135-
"""
136-
Return the profiles of last executed stored procedure in current thread. If there is no previous
137-
stored procedure call, an error will be raised.
138-
Please call this function right after the stored procedure you want to profile to avoid any error.
139-
140-
"""
141-
# return empty string when profiler is not enabled to not interrupt user's code
142-
if not self._is_enabled:
143-
logger.warning(
144-
"You are seeing this warning because you try to get profiler output while profiler is disabled. Please use profiler.set_active_profiler() to enable profiler."
145-
)
146-
return ""
147-
query_id = self._get_last_query_id()
148-
if query_id is None:
149-
logger.warning(
150-
"You are seeing this warning because last executed stored procedure does not exist. Please run the store procedure before get profiler output."
151-
)
152-
return ""
153-
sql = f"select snowflake.core.get_python_profiler_output('{query_id}')"
154-
return self._session.sql(sql)._internal_collect_with_tag_no_telemetry()[0][0] # type: ignore
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#
2+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
import logging
6+
import re
7+
8+
import snowflake.snowpark
9+
from snowflake.snowpark._internal.snowpark_profiler import SnowparkProfiler
10+
11+
12+
logger = logging.getLogger(__name__)
13+
14+
SNOWFLAKE_ANONYMOUS_FUNCTION_PATTERN = re.compile(
15+
r"^\s*WITH\s+\w+\s+AS\s+FUNCTION", re.IGNORECASE
16+
)
17+
18+
19+
class UDFProfiler(SnowparkProfiler):
20+
"""
21+
Set up profiler to receive profiles of UDFs.
22+
"""
23+
24+
def __init__(
25+
self,
26+
session: "snowflake.snowpark.Session",
27+
) -> None:
28+
super().__init__(session)
29+
30+
self._output_sql = "select * from table(SNOWFLAKE.LOCAL.GET_PYTHON_UDF_PROFILER_OUTPUT('{query_id}'));"
31+
self._profiler_module_name = "PYTHON_UDF_PROFILER_MODULES"
32+
33+
@staticmethod
34+
def _is_procedure_or_function_call(query: str) -> bool:
35+
query = query.upper().strip(" ")
36+
return SNOWFLAKE_ANONYMOUS_FUNCTION_PATTERN.match(
37+
query
38+
) is not None or query.startswith("SELECT")

tests/integ/test_stored_procedure_profiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def test_set_incorrect_active_profiler(
156156
)
157157
profiler_session.stored_procedure_profiler.set_active_profiler("LINE")
158158
profiler_session.stored_procedure_profiler.get_output()
159-
assert "last executed stored procedure does not exist" in caplog.text
159+
assert "last executed stored procedure or UDF does not exist" in caplog.text
160160

161161
with pytest.raises(ValueError) as e:
162162
profiler_session.stored_procedure_profiler.set_active_profiler(
@@ -188,7 +188,7 @@ def test_set_incorrect_active_profiler(
188188
def test_sp_call_match(profiler_session, sp_call_sql):
189189
pro = profiler_session.stored_procedure_profiler
190190

191-
assert pro._is_sp_call(sp_call_sql)
191+
assert pro._is_procedure_or_function_call(sp_call_sql)
192192

193193

194194
@pytest.mark.skipif(

0 commit comments

Comments
 (0)