Skip to content

Commit 212a4f9

Browse files
committed
custom data source place hoolder
1 parent 27c9645 commit 212a4f9

File tree

5 files changed

+139
-20
lines changed

5 files changed

+139
-20
lines changed

src/snowflake/snowpark/_internal/data_source/utils.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def detect_dbms_pyodbc(dbapi2_conn):
162162

163163
def _task_fetch_data_from_source(
164164
worker: DataSourceReader,
165-
partition: str,
165+
partition: Any,
166166
partition_idx: int,
167167
parquet_queue: Union[mp.Queue, queue.Queue],
168168
stop_event: threading.Event = None,
@@ -202,7 +202,7 @@ def convert_to_parquet_bytesio(fetched_data, fetch_idx):
202202

203203
def _task_fetch_data_from_source_with_retry(
204204
worker: DataSourceReader,
205-
partition: str,
205+
partition: Any,
206206
partition_idx: int,
207207
parquet_queue: Union[mp.Queue, queue.Queue],
208208
stop_event: threading.Event = None,
@@ -609,7 +609,7 @@ def track_data_source_statement_params(
609609
def local_ingestion(
610610
session: "snowflake.snowpark.Session",
611611
data_source: "snowflake.snowpark.DataSource",
612-
partitioned_queries: List[str],
612+
partitioned_queries: List[Any],
613613
schema: StructType,
614614
max_workers: int,
615615
snowflake_stage_name: str,
@@ -732,3 +732,30 @@ def local_ingestion(
732732
data_fetching_thread_pool_executor.shutdown(wait=True)
733733

734734
logger.debug("All data has been successfully loaded into the Snowflake table.")
735+
736+
737+
def custom_data_source_udtf_class_builder(
738+
data_source_class: type,
739+
schema: StructType,
740+
):
741+
import cloudpickle
742+
import pickle
743+
744+
pickled_schema = cloudpickle.dumps(schema, protocol=pickle.HIGHEST_PROTOCOL)
745+
746+
class UDTFIngestion:
747+
def process(self, pickled_partition: bytearray):
748+
import cloudpickle
749+
750+
data_source_instance = data_source_class()
751+
unpickled_schema = cloudpickle.loads(pickled_schema)
752+
partition = cloudpickle.loads(pickled_partition)
753+
reader = data_source_instance.reader(unpickled_schema)
754+
for result in reader.read(partition):
755+
if isinstance(result, list):
756+
yield from result
757+
else:
758+
yield from list(reader.read(partition))
759+
break
760+
761+
return UDTFIngestion

src/snowflake/snowpark/data_source.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,6 @@
55
from typing import Any, Tuple, Iterator, Union, List
66

77
from snowflake.snowpark.types import StructType
8-
from typing import TYPE_CHECKING
9-
10-
if TYPE_CHECKING:
11-
import snowflake.snowpark
128

139

1410
class InputPartition:
@@ -47,6 +43,3 @@ def _partitions(self) -> List[InputPartition]:
4743
if self._internal_partitions is None:
4844
self._internal_partitions = self.reader(self.schema()).partitions()
4945
return self._internal_partitions
50-
51-
def _udtf_ingestion(self) -> "snowflake.snowpark.DataFrame":
52-
pass

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
DATA_SOURCE_DBAPI_SIGNATURE,
4040
create_data_source_table_and_stage,
4141
local_ingestion,
42+
custom_data_source_udtf_class_builder,
4243
)
4344
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
4445
from snowflake.snowpark._internal.telemetry import set_api_call_source
@@ -2154,3 +2155,97 @@ def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame:
21542155
def register_custom_data_source(self, data_source: DataSource):
21552156
self._data_source_format.append(data_source.name())
21562157
self._custom_data_source_format[data_source.name()] = data_source
2158+
2159+
def _custom_data_source(
2160+
self,
2161+
data_source_name: str,
2162+
max_workers: Optional[int] = None,
2163+
udtf_configs: Optional[dict] = None,
2164+
fetch_with_process: bool = False,
2165+
_emit_ast: bool = True,
2166+
) -> DataFrame:
2167+
2168+
data_source_class = self._custom_data_source_format[data_source_name]
2169+
data_source_instance = data_source_class()
2170+
partitions = data_source_instance._partitions()
2171+
if partitions is None:
2172+
partitions = [None]
2173+
schema = DataSourcePartitioner.formatting_custom_schema(
2174+
data_source_instance.schema()
2175+
)
2176+
statements_params_for_telemetry = {STATEMENT_PARAMS_DATA_SOURCE: "1"}
2177+
telemetry_json_string = {}
2178+
2179+
# udtf ingestion
2180+
2181+
if udtf_configs is not None:
2182+
import cloudpickle
2183+
import pickle
2184+
2185+
pickled_partitions = [
2186+
cloudpickle.dumps(par, protocol=pickle.HIGHEST_PROTOCOL)
2187+
for par in partitions
2188+
]
2189+
partitions_table = random_name_for_temp_object(TempObjectType.TABLE)
2190+
self._session.create_dataframe(
2191+
[[par] for par in pickled_partitions], schema=["partition"]
2192+
).write.save_as_table(partitions_table, table_type="temp")
2193+
2194+
udtf_name = random_name_for_temp_object(TempObjectType.TABLE_FUNCTION)
2195+
self._session.udtf.register(
2196+
custom_data_source_udtf_class_builder(data_source_class, schema),
2197+
name=udtf_name,
2198+
output_schema=StructType(
2199+
[
2200+
StructField(field.name, VariantType(), field.nullable)
2201+
for field in schema.fields
2202+
]
2203+
),
2204+
external_access_integrations=[
2205+
udtf_configs["external_access_integrations"]
2206+
],
2207+
packages=udtf_configs["packages"],
2208+
imports=udtf_configs["imports"],
2209+
)
2210+
call_udtf_sql = f"""
2211+
select * from {partitions_table}, table({udtf_name}(partition))
2212+
"""
2213+
res = self._session.sql(call_udtf_sql, _emit_ast=_emit_ast)
2214+
cols = [
2215+
res[field.name].cast(field.datatype).alias(field.name)
2216+
for field in schema.fields
2217+
]
2218+
return res.select(cols, _emit_ast=True)
2219+
2220+
# parquet ingestion
2221+
2222+
snowflake_table_name = random_name_for_temp_object(TempObjectType.TABLE)
2223+
snowflake_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
2224+
2225+
create_data_source_table_and_stage(
2226+
session=self._session,
2227+
schema=schema,
2228+
snowflake_table_name=snowflake_table_name,
2229+
snowflake_stage_name=snowflake_stage_name,
2230+
statements_params_for_telemetry=statements_params_for_telemetry,
2231+
)
2232+
2233+
local_ingestion(
2234+
session=self._session,
2235+
schema=schema,
2236+
data_source=data_source_instance,
2237+
partitioned_queries=partitions,
2238+
max_workers=max_workers,
2239+
fetch_with_process=fetch_with_process,
2240+
snowflake_stage_name=snowflake_stage_name,
2241+
snowflake_table_name=snowflake_table_name,
2242+
statements_params_for_telemetry=statements_params_for_telemetry,
2243+
telemetry_json_string=telemetry_json_string,
2244+
_emit_ast=_emit_ast,
2245+
)
2246+
2247+
res_df = self._session.table(snowflake_table_name, _emit_ast=_emit_ast)
2248+
2249+
set_api_call_source(res_df, DATA_SOURCE_DBAPI_SIGNATURE)
2250+
2251+
return res_df

tests/integ/test_data_source_api.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
from snowflake.snowpark._internal.data_source.datasource_partitioner import (
2626
DataSourcePartitioner,
2727
)
28-
from snowflake.snowpark._internal.data_source.datasource_reader import DataSourceReader
28+
from snowflake.snowpark._internal.data_source.datasource_reader import (
29+
DbapiDataSourceReader,
30+
)
2931
from snowflake.snowpark._internal.data_source.drivers import (
3032
PyodbcDriver,
3133
SqliteDriver,
@@ -177,7 +179,7 @@ def test_dbapi_retry(session, fetch_with_process):
177179
SnowparkDataframeReaderException, match="\\[RuntimeError\\] Test error"
178180
):
179181
_task_fetch_data_from_source_with_retry(
180-
worker=DataSourceReader(
182+
worker=DbapiDataSourceReader(
181183
PyodbcDriver,
182184
sql_server_create_connection,
183185
StructType([StructField("col1", IntegerType(), False)]),
@@ -220,7 +222,7 @@ def test_dbapi_non_retryable_error(session, fetch_with_process):
220222
parquet_queue = multiprocessing.Queue() if fetch_with_process else queue.Queue()
221223
with pytest.raises(SnowparkDataframeReaderException, match="mock error"):
222224
_task_fetch_data_from_source_with_retry(
223-
worker=DataSourceReader(
225+
worker=DbapiDataSourceReader(
224226
PyodbcDriver,
225227
sql_server_create_connection,
226228
StructType([StructField("col1", IntegerType(), False)]),
@@ -712,7 +714,7 @@ def test_task_fetch_from_data_source_with_fetch_size(
712714
parquet_queue = multiprocessing.Queue() if fetch_with_process else queue.Queue()
713715

714716
params = {
715-
"worker": DataSourceReader(
717+
"worker": DbapiDataSourceReader(
716718
PyodbcDriver,
717719
sql_server_create_connection_small_data,
718720
schema=schema,
@@ -1147,7 +1149,7 @@ def test_fetch_merge_count_unit(fetch_size, fetch_merge_count, expected_batch_cn
11471149
with tempfile.TemporaryDirectory() as temp_dir:
11481150
dbpath = os.path.join(temp_dir, "testsqlite3.db")
11491151
table_name, columns, example_data, _ = sqlite3_db(dbpath)
1150-
reader = DataSourceReader(
1152+
reader = DbapiDataSourceReader(
11511153
SqliteDriver,
11521154
functools.partial(create_connection_to_sqlite3_db, dbpath),
11531155
schema=SQLITE3_DB_CUSTOM_SCHEMA_STRUCT_TYPE,
@@ -1268,8 +1270,8 @@ def test_worker_process_unit(fetch_with_process):
12681270
dbpath = os.path.join(temp_dir, "testsqlite3.db")
12691271
table_name, columns, example_data, _ = sqlite3_db(dbpath)
12701272

1271-
# Create DataSourceReader for sqlite3
1272-
reader = DataSourceReader(
1273+
# Create DbapiDataSourceReader for sqlite3
1274+
reader = DbapiDataSourceReader(
12731275
SqliteDriver,
12741276
functools.partial(create_connection_to_sqlite3_db, dbpath),
12751277
schema=SQLITE3_DB_CUSTOM_SCHEMA_STRUCT_TYPE,

tests/unit/test_data_source.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import re
77
from unittest.mock import Mock, patch
88
from snowflake.snowpark._internal.data_source.drivers.base_driver import BaseDriver
9-
from snowflake.snowpark._internal.data_source.datasource_reader import DataSourceReader
9+
from snowflake.snowpark._internal.data_source.datasource_reader import (
10+
DbapiDataSourceReader,
11+
)
1012
from snowflake.snowpark._internal.data_source.utils import DBMS_TYPE
1113
from snowflake.snowpark.types import StructType, StructField, StringType
1214

@@ -67,7 +69,7 @@ def test_close_error_handling(cursor_fails, conn_fails):
6769
],
6870
)
6971
def test_datasource_reader_close_error_handling(cursor_fails, conn_fails):
70-
"""Test that DataSourceReader handles cursor/connection close errors gracefully."""
72+
"""Test that DbapiDataSourceReader handles cursor/connection close errors gracefully."""
7173
# Setup mocks
7274
mock_create_connection = Mock()
7375
expected_schema = StructType([StructField("test_col", StringType())])
@@ -94,7 +96,7 @@ def test_datasource_reader_close_error_handling(cursor_fails, conn_fails):
9496
mock_driver_class = Mock(return_value=mock_driver)
9597

9698
# Create reader with the mock driver class
97-
reader = DataSourceReader(
99+
reader = DbapiDataSourceReader(
98100
driver_class=mock_driver_class,
99101
create_connection=mock_create_connection,
100102
schema=expected_schema,

0 commit comments

Comments
 (0)