Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/snowflake/snowpark/_internal/data_source/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import os
import queue
import re
import time
import traceback
import threading
Expand Down Expand Up @@ -46,10 +47,15 @@
_MAX_RETRY_TIME = 3
_MAX_WORKER_SCALE = 2 # 2 * max_workers
STATEMENT_PARAMS_DATA_SOURCE = "SNOWPARK_PYTHON_DATASOURCE"
STATEMENT_PARAMS_DATA_SOURCE_JDBC = "SNOWPARK_PYTHON_DATASOURCE_JDBC"
DATA_SOURCE_DBAPI_SIGNATURE = "DataFrameReader.dbapi"
DATA_SOURCE_JDBC_SIGNATURE = "DataFrameReader.jdbc"
DATA_SOURCE_SQL_COMMENT = (
f"/* Python:snowflake.snowpark.{DATA_SOURCE_DBAPI_SIGNATURE} */"
)
DATA_SOURCE_JDBC_SQL_COMMENT = (
f"/* Python:snowflake.snowpark.{DATA_SOURCE_JDBC_SIGNATURE} */"
)
PARTITION_TASK_COMPLETE_SIGNAL_PREFIX = "PARTITION_COMPLETE_"
PARTITION_TASK_ERROR_SIGNAL = "ERROR"

Expand Down Expand Up @@ -109,6 +115,25 @@ class DRIVER_TYPE(str, Enum):
}


def get_jdbc_dbms(jdbc_url: str) -> str:
"""
Extract the DBMS name from a JDBC connection URL.
"""
if not jdbc_url.startswith("jdbc:"):
return "connection url does not start with jdbc"

# Extract the DBMS type (first component after "jdbc:")
match = re.match(r"^jdbc:([^:]+):", jdbc_url)
return match.group(1).lower() if match else None


def get_jdbc_jar_file(imports: List[str]) -> str:
"""
Extract the JDBC Jar used to establish jdbc connection.
"""
pass


def detect_dbms(dbapi2_conn) -> Tuple[DBMS_TYPE, DRIVER_TYPE]:
"""Detects the DBMS type from a DBAPI2 connection."""

Expand Down
29 changes: 27 additions & 2 deletions src/snowflake/snowpark/dataframe_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
DATA_SOURCE_DBAPI_SIGNATURE,
create_data_source_table_and_stage,
local_ingestion,
STATEMENT_PARAMS_DATA_SOURCE_JDBC,
DATA_SOURCE_JDBC_SIGNATURE,
get_jdbc_dbms,
)
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.telemetry import set_api_call_source
Expand Down Expand Up @@ -1629,11 +1632,21 @@ def jdbc(
java_version = udtf_configs.get("java_version", 17)
secret = udtf_configs.get("secret", None)

telemetry_json_string = defaultdict()
telemetry_json_string["function_name"] = DATA_SOURCE_JDBC_SIGNATURE
telemetry_json_string["ingestion_mode"] = "udtf_ingestion"
telemetry_json_string["dbms_type"] = get_jdbc_dbms(url)
telemetry_json_string["imports"] = udtf_configs["imports"]
statements_params_for_telemetry = {STATEMENT_PARAMS_DATA_SOURCE_JDBC: "1"}

if external_access_integration is None or imports is None or secret is None:
raise ValueError(
"external_access_integration, secret and imports must be specified in udtf configs"
)

start_time = time.perf_counter()
logger.debug(f"ingestion start at: {start_time}")

if session_init_statement and isinstance(session_init_statement, str):
session_init_statement = [session_init_statement]

Expand Down Expand Up @@ -1665,12 +1678,24 @@ def jdbc(
partitions_table = random_name_for_temp_object(TempObjectType.TABLE)
self._session.create_dataframe(
[[query] for query in partitions], schema=["partition"]
).write.save_as_table(partitions_table, table_type="temp")
).write.save_as_table(
partitions_table,
table_type="temp",
statement_params=statements_params_for_telemetry,
)

df = jdbc_client.read(partitions_table)
return jdbc_client.to_result_snowpark_df(
res_df = jdbc_client.to_result_snowpark_df(
df, jdbc_client.schema, _emit_ast=_emit_ast
)
end_time = time.perf_counter()
telemetry_json_string["end_to_end_duration"] = end_time - start_time
telemetry_json_string["schema"] = res_df.schema.simple_string()
self._session._conn._telemetry_client.send_data_source_perf_telemetry(
telemetry_json_string
)
set_api_call_source(res_df, DATA_SOURCE_JDBC_SIGNATURE)
return res_df

@private_preview(version="1.29.0")
@publicapi
Expand Down
15 changes: 15 additions & 0 deletions tests/integ/datasource/test_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#
import re
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -340,3 +341,17 @@ def test_postgres_session_init_statement(
query_timeout=1,
session_init_statement=session_init_statement,
).collect()


def test_telemetry(session, udtf_configs):
with patch(
"snowflake.snowpark._internal.telemetry.TelemetryClient.send_data_source_perf_telemetry"
) as mock_telemetry:
df = session.read.jdbc(url=URL, udtf_configs=udtf_configs, query=SELECT_QUERY)
telemetry_json = mock_telemetry.call_args[0][0]
assert telemetry_json["function_name"] == "DataFrameReader.jdbc"
assert telemetry_json["ingestion_mode"] == "udtf_ingestion"
assert telemetry_json["dbms_type"] == "oracle"
assert "ojdbc17-23.9.0.25.07.jar" in telemetry_json["imports"][0]
assert telemetry_json["end_to_end_duration"] > 0
assert telemetry_json["schema"] == df.schema.simple_string()
Loading