Skip to content

Commit d7caffe

Browse files
SNOW-2430625:encapsulate local ingestion and udtf ingestion, get rid of driver reference on top level (#3897)
1 parent 7d53a8f commit d7caffe

File tree

3 files changed

+75
-26
lines changed

3 files changed

+75
-26
lines changed

src/snowflake/snowpark/_internal/data_source/datasource_partitioner.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import decimal
77
from collections import defaultdict
88
from functools import cached_property
9-
from typing import Optional, Union, List, Callable
9+
from typing import Optional, Union, List, Callable, Dict
1010
import logging
1111
import pytz
1212
from dateutil import parser
@@ -30,6 +30,10 @@
3030
DateType,
3131
DataType,
3232
)
33+
from typing import TYPE_CHECKING
34+
35+
if TYPE_CHECKING:
36+
import snowflake.snowpark
3337

3438
logger = logging.getLogger(__name__)
3539

@@ -171,6 +175,34 @@ def partitions(self) -> List[str]:
171175
self.num_partitions,
172176
)
173177

178+
def _udtf_ingestion(
179+
self,
180+
session: "snowflake.snowpark.Session",
181+
schema: StructType,
182+
partition_table: str,
183+
external_access_integrations: str,
184+
fetch_size: int = 1000,
185+
imports: Optional[List[str]] = None,
186+
packages: Optional[List[str]] = None,
187+
session_init_statement: Optional[List[str]] = None,
188+
query_timeout: Optional[int] = 0,
189+
statement_params: Optional[Dict[str, str]] = None,
190+
_emit_ast: bool = True,
191+
) -> "snowflake.snowpark.DataFrame":
192+
return self.driver.udtf_ingestion(
193+
session,
194+
schema,
195+
partition_table,
196+
external_access_integrations,
197+
fetch_size,
198+
imports,
199+
packages,
200+
session_init_statement,
201+
query_timeout,
202+
statement_params,
203+
_emit_ast,
204+
)
205+
174206
@staticmethod
175207
def generate_partitions(
176208
select_query: str,

src/snowflake/snowpark/_internal/data_source/utils.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@
3232
)
3333
import snowflake
3434
from snowflake.snowpark._internal.data_source import DataSourceReader
35+
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
36+
from snowflake.snowpark._internal.utils import get_temp_type_for_object
3537
from snowflake.snowpark.exceptions import SnowparkDataframeReaderException
36-
38+
from snowflake.snowpark.types import StructType
3739

3840
logger = logging.getLogger(__name__)
3941

@@ -551,6 +553,37 @@ def process_parquet_queue_with_threads(
551553
return fetch_to_local_end_time, upload_to_sf_start_time, upload_to_sf_end_time
552554

553555

556+
def create_data_source_table_and_stage(
557+
session: "snowflake.snowpark.Session",
558+
schema: StructType,
559+
snowflake_table_name: str,
560+
snowflake_stage_name: str,
561+
statements_params_for_telemetry: dict,
562+
) -> None:
563+
snowflake_table_type = "TEMPORARY"
564+
create_table_sql = (
565+
"CREATE "
566+
f"{snowflake_table_type} "
567+
"TABLE "
568+
f"identifier(?) "
569+
f"""({" , ".join([f'{field.name} {convert_sp_to_sf_type(field.datatype)} {"NOT NULL" if not field.nullable else ""}' for field in schema.fields])})"""
570+
f"""{DATA_SOURCE_SQL_COMMENT}"""
571+
)
572+
params = (snowflake_table_name,)
573+
logger.debug(f"Creating temporary Snowflake table: {snowflake_table_name}")
574+
session.sql(create_table_sql, params=params, _emit_ast=False).collect(
575+
statement_params=statements_params_for_telemetry, _emit_ast=False
576+
)
577+
# create temp stage
578+
sql_create_temp_stage = (
579+
f"create {get_temp_type_for_object(session._use_scoped_temp_objects, True)} stage"
580+
f" if not exists {snowflake_stage_name} {DATA_SOURCE_SQL_COMMENT}"
581+
)
582+
session.sql(sql_create_temp_stage, _emit_ast=False).collect(
583+
statement_params=statements_params_for_telemetry, _emit_ast=False
584+
)
585+
586+
554587
def track_data_source_statement_params(
555588
dataframe, statement_params: Optional[Dict] = None
556589
) -> Optional[Dict]:

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@
4242
worker_process,
4343
process_parquet_queue_with_threads,
4444
STATEMENT_PARAMS_DATA_SOURCE,
45-
DATA_SOURCE_SQL_COMMENT,
4645
DATA_SOURCE_DBAPI_SIGNATURE,
4746
_MAX_WORKER_SCALE,
47+
create_data_source_table_and_stage,
4848
)
4949
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
5050
from snowflake.snowpark._internal.telemetry import set_api_call_source
@@ -69,7 +69,6 @@
6969
get_aliased_option_name,
7070
get_copy_into_table_options,
7171
get_stage_parts,
72-
get_temp_type_for_object,
7372
is_in_stored_procedure,
7473
parse_positional_args_to_list_variadic,
7574
private_preview,
@@ -1992,7 +1991,7 @@ def create_oracledb_connection():
19921991
table_type="temp",
19931992
statement_params=statements_params_for_telemetry,
19941993
)
1995-
df = partitioner.driver.udtf_ingestion(
1994+
df = partitioner._udtf_ingestion(
19961995
self._session,
19971996
struct_schema,
19981997
partitions_table,
@@ -2018,29 +2017,14 @@ def create_oracledb_connection():
20182017
return df
20192018

20202019
# parquet ingestion
2021-
snowflake_table_type = "TEMPORARY"
20222020
snowflake_table_name = random_name_for_temp_object(TempObjectType.TABLE)
2023-
create_table_sql = (
2024-
"CREATE "
2025-
f"{snowflake_table_type} "
2026-
"TABLE "
2027-
f"identifier(?) "
2028-
f"""({" , ".join([f'{field.name} {convert_sp_to_sf_type(field.datatype)} {"NOT NULL" if not field.nullable else ""}' for field in struct_schema.fields])})"""
2029-
f"""{DATA_SOURCE_SQL_COMMENT}"""
2030-
)
2031-
params = (snowflake_table_name,)
2032-
logger.debug(f"Creating temporary Snowflake table: {snowflake_table_name}")
2033-
self._session.sql(create_table_sql, params=params, _emit_ast=False).collect(
2034-
statement_params=statements_params_for_telemetry, _emit_ast=False
2035-
)
2036-
# create temp stage
20372021
snowflake_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
2038-
sql_create_temp_stage = (
2039-
f"create {get_temp_type_for_object(self._session._use_scoped_temp_objects, True)} stage"
2040-
f" if not exists {snowflake_stage_name} {DATA_SOURCE_SQL_COMMENT}"
2041-
)
2042-
self._session.sql(sql_create_temp_stage, _emit_ast=False).collect(
2043-
statement_params=statements_params_for_telemetry, _emit_ast=False
2022+
create_data_source_table_and_stage(
2023+
session=self._session,
2024+
schema=struct_schema,
2025+
snowflake_table_name=snowflake_table_name,
2026+
snowflake_stage_name=snowflake_stage_name,
2027+
statements_params_for_telemetry=statements_params_for_telemetry,
20442028
)
20452029

20462030
data_fetching_thread_pool_executor = None

0 commit comments

Comments
 (0)