Skip to content

Commit 3b3bab1

Browse files
SNOW-2691838: add telemetry tracking for snowpark datasource JDBC (#4002)
1 parent 23285ef commit 3b3bab1

File tree

3 files changed

+60
-2
lines changed

3 files changed

+60
-2
lines changed

src/snowflake/snowpark/_internal/data_source/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import math
55
import os
66
import queue
7+
import re
78
import time
89
import traceback
910
import threading
@@ -46,10 +47,15 @@
4647
_MAX_RETRY_TIME = 3
4748
_MAX_WORKER_SCALE = 2 # 2 * max_workers
4849
STATEMENT_PARAMS_DATA_SOURCE = "SNOWPARK_PYTHON_DATASOURCE"
50+
STATEMENT_PARAMS_DATA_SOURCE_JDBC = "SNOWPARK_PYTHON_DATASOURCE_JDBC"
4951
DATA_SOURCE_DBAPI_SIGNATURE = "DataFrameReader.dbapi"
52+
DATA_SOURCE_JDBC_SIGNATURE = "DataFrameReader.jdbc"
5053
DATA_SOURCE_SQL_COMMENT = (
5154
f"/* Python:snowflake.snowpark.{DATA_SOURCE_DBAPI_SIGNATURE} */"
5255
)
56+
DATA_SOURCE_JDBC_SQL_COMMENT = (
57+
f"/* Python:snowflake.snowpark.{DATA_SOURCE_JDBC_SIGNATURE} */"
58+
)
5359
PARTITION_TASK_COMPLETE_SIGNAL_PREFIX = "PARTITION_COMPLETE_"
5460
PARTITION_TASK_ERROR_SIGNAL = "ERROR"
5561

@@ -109,6 +115,18 @@ class DRIVER_TYPE(str, Enum):
109115
}
110116

111117

118+
def get_jdbc_dbms(jdbc_url: str) -> str:
119+
"""
120+
Extract the DBMS name from a JDBC connection URL.
121+
"""
122+
if not jdbc_url.startswith("jdbc:"):
123+
return "connection url does not start with jdbc"
124+
125+
# Extract the DBMS type (first component after "jdbc:")
126+
match = re.match(r"^jdbc:([^:]+):", jdbc_url)
127+
return match.group(1).lower() if match else "unrecognized DBMS"
128+
129+
112130
def detect_dbms(dbapi2_conn) -> Tuple[DBMS_TYPE, DRIVER_TYPE]:
113131
"""Detects the DBMS type from a DBAPI2 connection."""
114132

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
DATA_SOURCE_DBAPI_SIGNATURE,
4040
create_data_source_table_and_stage,
4141
local_ingestion,
42+
STATEMENT_PARAMS_DATA_SOURCE_JDBC,
43+
DATA_SOURCE_JDBC_SIGNATURE,
44+
get_jdbc_dbms,
4245
)
4346
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
4447
from snowflake.snowpark._internal.telemetry import set_api_call_source
@@ -1629,11 +1632,21 @@ def jdbc(
16291632
java_version = udtf_configs.get("java_version", 17)
16301633
secret = udtf_configs.get("secret", None)
16311634

1635+
telemetry_json_string = defaultdict()
1636+
telemetry_json_string["function_name"] = DATA_SOURCE_JDBC_SIGNATURE
1637+
telemetry_json_string["ingestion_mode"] = "udtf_ingestion"
1638+
telemetry_json_string["dbms_type"] = get_jdbc_dbms(url)
1639+
telemetry_json_string["imports"] = udtf_configs["imports"]
1640+
statements_params_for_telemetry = {STATEMENT_PARAMS_DATA_SOURCE_JDBC: "1"}
1641+
16321642
if external_access_integration is None or imports is None or secret is None:
16331643
raise ValueError(
16341644
"external_access_integration, secret and imports must be specified in udtf configs"
16351645
)
16361646

1647+
start_time = time.perf_counter()
1648+
logger.debug(f"ingestion start at: {start_time}")
1649+
16371650
if session_init_statement and isinstance(session_init_statement, str):
16381651
session_init_statement = [session_init_statement]
16391652

@@ -1665,12 +1678,24 @@ def jdbc(
16651678
partitions_table = random_name_for_temp_object(TempObjectType.TABLE)
16661679
self._session.create_dataframe(
16671680
[[query] for query in partitions], schema=["partition"]
1668-
).write.save_as_table(partitions_table, table_type="temp")
1681+
).write.save_as_table(
1682+
partitions_table,
1683+
table_type="temp",
1684+
statement_params=statements_params_for_telemetry,
1685+
)
16691686

16701687
df = jdbc_client.read(partitions_table)
1671-
return jdbc_client.to_result_snowpark_df(
1688+
res_df = jdbc_client.to_result_snowpark_df(
16721689
df, jdbc_client.schema, _emit_ast=_emit_ast
16731690
)
1691+
end_time = time.perf_counter()
1692+
telemetry_json_string["end_to_end_duration"] = end_time - start_time
1693+
telemetry_json_string["schema"] = res_df.schema.simple_string()
1694+
self._session._conn._telemetry_client.send_data_source_perf_telemetry(
1695+
telemetry_json_string
1696+
)
1697+
set_api_call_source(res_df, DATA_SOURCE_JDBC_SIGNATURE)
1698+
return res_df
16741699

16751700
@private_preview(version="1.29.0")
16761701
@publicapi

tests/integ/datasource/test_jdbc.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
33
#
44
import re
5+
from unittest.mock import patch
56

67
import pytest
78

@@ -395,6 +396,20 @@ def test_postgres_session_init_statement(
395396
).collect()
396397

397398

399+
def test_telemetry(session, udtf_configs):
400+
with patch(
401+
"snowflake.snowpark._internal.telemetry.TelemetryClient.send_data_source_perf_telemetry"
402+
) as mock_telemetry:
403+
df = session.read.jdbc(url=URL, udtf_configs=udtf_configs, query=SELECT_QUERY)
404+
telemetry_json = mock_telemetry.call_args[0][0]
405+
assert telemetry_json["function_name"] == "DataFrameReader.jdbc"
406+
assert telemetry_json["ingestion_mode"] == "udtf_ingestion"
407+
assert telemetry_json["dbms_type"] == "oracle"
408+
assert "ojdbc17-23.9.0.25.07.jar" in telemetry_json["imports"][0]
409+
assert telemetry_json["end_to_end_duration"] > 0
410+
assert telemetry_json["schema"] == df.schema.simple_string()
411+
412+
398413
def test_connect_mysql(session, mysql_udtf_configs):
399414
df = session.read.jdbc(
400415
url=MYSQL_URL,

0 commit comments

Comments
 (0)