Skip to content

Commit bd46c66

Browse files
SNOW-1961756: Enable AST capture from Session.read.dbapi (#3134)
1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. [SNOW-1961756](https://snowflakecomputing.atlassian.net/browse/SNOW-1961756) 2. Fill out the following pre-review checklist: - [ ] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [x] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development) 3. Please describe how your code solves the related issue. Propagate `_emit_ast=False` across all internal calls to public APIs for new `dbapi` functionality. Importantly, we will capture uses of `session.read.dbapi` as calls to `session.table` since we only support this functionality on the client side. Attempting to capture the call would mean creating an AST entity for `session.read.dbapi` specifically just to have it show up in query history, and this does not seem valuable from an AST or server-side execution perspective. [SNOW-1961756]: https://snowflakecomputing.atlassian.net/browse/SNOW-1961756?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
1 parent 8398094 commit bd46c66

File tree

2 files changed

+23
-16
lines changed

2 files changed

+23
-16
lines changed

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def table(self, name: Union[str, Iterable[str]], _emit_ast: bool = True) -> Tabl
458458
ast.reader.CopyFrom(self._ast)
459459
build_table_name(ast.name, name)
460460

461-
table = self._session.table(name)
461+
table = self._session.table(name, _emit_ast=False)
462462

463463
if _emit_ast:
464464
table._ast_id = stmt.var_id.bitfield1
@@ -1087,6 +1087,7 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
10871087
return df
10881088

10891089
@private_preview(version="1.29.0")
1090+
@publicapi
10901091
def dbapi(
10911092
self,
10921093
create_connection: Callable[[], "Connection"],
@@ -1103,6 +1104,7 @@ def dbapi(
11031104
custom_schema: Optional[Union[str, StructType]] = None,
11041105
predicates: Optional[List[str]] = None,
11051106
session_init_statement: Optional[str] = None,
1107+
_emit_ast: bool = True,
11061108
) -> DataFrame:
11071109
"""Reads data from a database table using a DBAPI connection with optional partitioning, parallel processing, and query customization.
11081110
By default, the function reads the entire table at a time without a query timeout.
@@ -1222,17 +1224,17 @@ def dbapi(
12221224
)
12231225
params = (snowflake_table_name,)
12241226
logger.debug(f"Creating temporary Snowflake table: {snowflake_table_name}")
1225-
self._session.sql(create_table_sql, params=params).collect(
1226-
statement_params=statements_params_for_telemetry
1227+
self._session.sql(create_table_sql, params=params, _emit_ast=False).collect(
1228+
statement_params=statements_params_for_telemetry, _emit_ast=False
12271229
)
12281230
# create temp stage
12291231
snowflake_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
12301232
sql_create_temp_stage = (
12311233
f"create {get_temp_type_for_object(self._session._use_scoped_temp_objects, True)} stage"
12321234
f" if not exists {snowflake_stage_name} {DATA_SOURCE_SQL_COMMENT}"
12331235
)
1234-
self._session.sql(sql_create_temp_stage).collect(
1235-
statement_params=statements_params_for_telemetry
1236+
self._session.sql(sql_create_temp_stage, _emit_ast=False).collect(
1237+
statement_params=statements_params_for_telemetry, _emit_ast=False
12361238
)
12371239

12381240
try:
@@ -1327,7 +1329,11 @@ def ingestion_thread_cleanup_callback(parquet_file_path, _):
13271329
self._session._conn._telemetry_client.send_data_source_perf_telemetry(
13281330
DATA_SOURCE_DBAPI_SIGNATURE, time.perf_counter() - start_time
13291331
)
1330-
res_df = self.table(snowflake_table_name)
1332+
# Knowingly generating AST for `session.read.dbapi` calls as simply `session.read.table` calls
1333+
# with the new name for the temporary table into which the external db data was ingressed.
1334+
# Leaving this functionality as client-side only means capturing an AST specifically for
1335+
# this API in a new entity is not valuable from a server-side execution or AST perspective.
1336+
res_df = self.table(snowflake_table_name, _emit_ast=_emit_ast)
13311337
set_api_call_source(res_df, DATA_SOURCE_DBAPI_SIGNATURE)
13321338
return res_df
13331339

@@ -1476,9 +1482,11 @@ def _upload_and_copy_into_table(
14761482
ON_ERROR={on_error}
14771483
{DATA_SOURCE_SQL_COMMENT}
14781484
"""
1479-
self._session.sql(put_query).collect(statement_params=statements_params)
1480-
self._session.sql(copy_into_table_query).collect(
1481-
statement_params=statements_params
1485+
self._session.sql(put_query, _emit_ast=False).collect(
1486+
statement_params=statements_params, _emit_ast=False
1487+
)
1488+
self._session.sql(copy_into_table_query, _emit_ast=False).collect(
1489+
statement_params=statements_params, _emit_ast=False
14821490
)
14831491

14841492
def _upload_and_copy_into_table_with_retry(

src/snowflake/snowpark/dataframe_writer.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,9 @@ def save_as_table(
337337
"""
338338

339339
kwargs = {}
340+
statement_params = self._track_data_source_statement_params(
341+
self._dataframe, statement_params or self._dataframe._statement_params
342+
)
340343
if _emit_ast:
341344
# Add an Assign node that applies WriteTable() to the input, followed by its Eval.
342345
repr = self._dataframe._session._ast_batch.assign()
@@ -466,9 +469,6 @@ def save_as_table(
466469
else:
467470
table_exists = None
468471

469-
statement_params = self._track_data_source_statement_params(
470-
self._dataframe, statement_params or self._dataframe._statement_params
471-
)
472472
create_table_logic_plan = SnowflakeCreateTable(
473473
table_name,
474474
column_names,
@@ -590,6 +590,9 @@ def copy_into_location(
590590
"""
591591

592592
kwargs = {}
593+
statement_params = self._track_data_source_statement_params(
594+
self._dataframe, statement_params or self._dataframe._statement_params
595+
)
593596
if _emit_ast:
594597
# Add an Assign node that applies WriteCopyIntoLocation() to the input, followed by its Eval.
595598
repr = self._dataframe._session._ast_batch.assign()
@@ -649,10 +652,6 @@ def copy_into_location(
649652

650653
cur_format_type_options.update(format_type_aliased_options)
651654

652-
statement_params = self._track_data_source_statement_params(
653-
self._dataframe, statement_params or self._dataframe._statement_params
654-
)
655-
656655
df = self._dataframe._with_plan(
657656
CopyIntoLocationNode(
658657
self._dataframe._plan,

0 commit comments

Comments
 (0)