Skip to content

Commit 642f7a8

Browse files
authored
SNOW-2398158: dbapi add usage telemetry to udtf (#3867)
1 parent bfe436f commit 642f7a8

File tree

6 files changed

+94
-26
lines changed

6 files changed

+94
-26
lines changed

src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#
44
from enum import Enum
55
import datetime
6-
from typing import List, Callable, Any, Optional, TYPE_CHECKING
6+
from typing import Dict, List, Callable, Any, Optional, TYPE_CHECKING
77
from snowflake.connector.options import pandas as pd
88

99
from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
@@ -152,6 +152,7 @@ def udtf_ingestion(
152152
packages: Optional[List[str]] = None,
153153
session_init_statement: Optional[List[str]] = None,
154154
query_timeout: Optional[int] = 0,
155+
statement_params: Optional[Dict[str, str]] = None,
155156
_emit_ast: bool = True,
156157
) -> "snowflake.snowpark.DataFrame":
157158
from snowflake.snowpark._internal.data_source.utils import UDTF_PACKAGE_MAP
@@ -175,6 +176,7 @@ def udtf_ingestion(
175176
external_access_integrations=[external_access_integrations],
176177
packages=packages or UDTF_PACKAGE_MAP.get(self.dbms_type),
177178
imports=imports,
179+
statement_params=statement_params,
178180
)
179181
logger.debug(f"register ingestion udtf takes: {udtf_register_time()} seconds")
180182
call_udtf_sql = f"""

src/snowflake/snowpark/_internal/data_source/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,3 +549,21 @@ def process_parquet_queue_with_threads(
549549
)
550550

551551
return fetch_to_local_end_time, upload_to_sf_start_time, upload_to_sf_end_time
552+
553+
554+
def track_data_source_statement_params(
555+
dataframe, statement_params: Optional[Dict] = None
556+
) -> Optional[Dict]:
557+
"""
558+
Helper method to initialize and update data source tracking statement_params based on dataframe attributes.
559+
"""
560+
statement_params = statement_params or {}
561+
if (
562+
dataframe._plan
563+
and dataframe._plan.api_calls
564+
and dataframe._plan.api_calls[0].get("name") == DATA_SOURCE_DBAPI_SIGNATURE
565+
):
566+
# Track data source ingestion
567+
statement_params[STATEMENT_PARAMS_DATA_SOURCE] = "1"
568+
569+
return statement_params if statement_params else None

src/snowflake/snowpark/dataframe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@
172172
string_half_width,
173173
warning,
174174
)
175+
from snowflake.snowpark._internal.data_source.utils import (
176+
track_data_source_statement_params,
177+
)
175178
from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType
176179
from snowflake.snowpark.column import Column, _to_col_if_sql_expr, _to_col_if_str
177180
from snowflake.snowpark.dataframe_ai_functions import DataFrameAIFunctions
@@ -836,6 +839,9 @@ def _internal_collect_with_tag_no_telemetry(
836839
# When executing a DataFrame in any method of snowpark (either public or private),
837840
# we should always call this method instead of collect(), to make sure the
838841
# query tag is set properly.
842+
statement_params = track_data_source_statement_params(
843+
self, statement_params or self._statement_params
844+
)
839845
return self._session._conn.execute(
840846
self._plan,
841847
block=block,

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1876,7 +1876,11 @@ def create_sqlite_connection(timeout=5.0, isolation_level=None, **kwargs):
18761876
partitions_table = random_name_for_temp_object(TempObjectType.TABLE)
18771877
self._session.create_dataframe(
18781878
[[query] for query in partitioned_queries], schema=["partition"]
1879-
).write.save_as_table(partitions_table, table_type="temp")
1879+
).write.save_as_table(
1880+
partitions_table,
1881+
table_type="temp",
1882+
statement_params=statements_params_for_telemetry,
1883+
)
18801884
df = partitioner.driver.udtf_ingestion(
18811885
self._session,
18821886
struct_schema,
@@ -1887,7 +1891,8 @@ def create_sqlite_connection(timeout=5.0, isolation_level=None, **kwargs):
18871891
packages=udtf_configs.get("packages", None),
18881892
session_init_statement=session_init_statement,
18891893
query_timeout=query_timeout,
1890-
_emit_ast=_emit_ast,
1894+
statement_params=statements_params_for_telemetry,
1895+
_emit_ast=False, # internal API, no need to emit AST
18911896
)
18921897
end_time = time.perf_counter()
18931898
telemetry_json_string["end_to_end_duration"] = end_time - start_time

src/snowflake/snowpark/dataframe_writer.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
build_table_name,
2828
)
2929
from snowflake.snowpark._internal.data_source.utils import (
30-
STATEMENT_PARAMS_DATA_SOURCE,
31-
DATA_SOURCE_DBAPI_SIGNATURE,
30+
track_data_source_statement_params,
3231
)
3332
from snowflake.snowpark._internal.open_telemetry import open_telemetry_context_manager
3433
from snowflake.snowpark._internal.telemetry import (
@@ -109,24 +108,6 @@ def __init__(
109108
self._ast = writer
110109
dataframe._set_ast_ref(self._ast.dataframe_writer.df)
111110

112-
@staticmethod
113-
def _track_data_source_statement_params(
114-
dataframe, statement_params: Optional[Dict] = None
115-
) -> Optional[Dict]:
116-
"""
117-
Helper method to initialize and update data source tracking statement_params based on dataframe attributes.
118-
"""
119-
statement_params = statement_params or {}
120-
if (
121-
dataframe._plan
122-
and dataframe._plan.api_calls
123-
and dataframe._plan.api_calls[0].get("name") == DATA_SOURCE_DBAPI_SIGNATURE
124-
):
125-
# Track data source ingestion
126-
statement_params[STATEMENT_PARAMS_DATA_SOURCE] = "1"
127-
128-
return statement_params if statement_params else None
129-
130111
@publicapi
131112
def mode(self, save_mode: str, _emit_ast: bool = True) -> "DataFrameWriter":
132113
"""Set the save mode of this :class:`DataFrameWriter`.
@@ -372,7 +353,7 @@ def save_as_table(
372353
>>> df.write.mode("overwrite").save_as_table("my_table", iceberg_config=iceberg_config) # doctest: +SKIP
373354
"""
374355

375-
statement_params = self._track_data_source_statement_params(
356+
statement_params = track_data_source_statement_params(
376357
self._dataframe, statement_params or self._dataframe._statement_params
377358
)
378359
if _emit_ast and self._ast is not None:
@@ -688,7 +669,7 @@ def _internal_copy_into_location(
688669
# This method is not intended to be used directly by users.
689670
# AST.
690671
kwargs = {}
691-
statement_params = self._track_data_source_statement_params(
672+
statement_params = track_data_source_statement_params(
692673
self._dataframe, statement_params or self._dataframe._statement_params
693674
)
694675
if _emit_ast and self._ast is not None:

tests/integ/test_data_source_api.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def test_telemetry_tracking(caplog, session, fetch_with_process):
444444
called, comment_showed = 0, 0
445445

446446
def assert_datasource_statement_params_run_query(*args, **kwargs):
447-
# assert we set statement_parameters to track datasourcee api usage
447+
# assert we set statement_parameters to track datasource api usage
448448
nonlocal comment_showed
449449
statement_parameters = kwargs.get("_statement_params")
450450
query = args[0]
@@ -496,6 +496,62 @@ def assert_datasource_statement_params_run_query(*args, **kwargs):
496496
assert called == 2
497497

498498

499+
def test_telemetry_tracking_for_udtf(caplog, session):
500+
501+
original_func = session._conn.run_query
502+
called = 0
503+
504+
def assert_datasource_statement_params_run_query(*args, **kwargs):
505+
# assert we set statement_parameters to track datasource udtf api usage
506+
query = args[0]
507+
if not query.lower().startswith("put"):
508+
assert kwargs.get("_statement_params")[STATEMENT_PARAMS_DATA_SOURCE] == "1"
509+
nonlocal called
510+
called += 1
511+
return original_func(*args, **kwargs)
512+
513+
def create_connection():
514+
class FakeConnection:
515+
def cursor(self):
516+
class FakeCursor:
517+
def execute(self, query):
518+
pass
519+
520+
@property
521+
def description(self):
522+
return [("c1", int, None, None, None, None, None)]
523+
524+
def fetchmany(self, *args, **kwargs):
525+
return None
526+
527+
return FakeCursor()
528+
529+
return FakeConnection()
530+
531+
with mock.patch(
532+
"snowflake.snowpark._internal.server_connection.ServerConnection.run_query",
533+
side_effect=assert_datasource_statement_params_run_query,
534+
):
535+
df = session.read.dbapi(
536+
create_connection,
537+
table="Fake",
538+
custom_schema="c1 INT",
539+
udtf_configs={
540+
"external_access_integration": ORACLEDB_TEST_EXTERNAL_ACCESS_INTEGRATION,
541+
"packages": ["snowflake-snowpark-python"],
542+
},
543+
)
544+
df.select("*").collect()
545+
# called 7/8 times coming from:
546+
# 1. dbapi internal save empty table: 1 time
547+
# 2. dbapi register UDTF: 5/6 times depending on python versions (operations entailing stage, package, registration)
548+
# the delta is due to different python versions will do registration differently: inline or upload to stage
549+
# 3. select: 1 time
550+
assert (
551+
"'name': 'DataFrameReader.dbapi'" in str(df._plan.api_calls[0]) and called >= 7
552+
)
553+
554+
499555
@pytest.mark.skipif(
500556
IS_WINDOWS,
501557
reason="sqlite3 file can not be shared accorss processes on windows",

0 commit comments

Comments
 (0)