Skip to content

Commit 27c9645

Browse files
committed
migrate data source partitioner to inheriant from DataSource
1 parent eb4f3cd commit 27c9645

File tree

6 files changed

+27
-12
lines changed

6 files changed

+27
-12
lines changed

src/snowflake/snowpark/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
"AsyncJob",
4242
"StoredProcedureProfiler",
4343
"UDFProfiler",
44+
"DataSource",
45+
"DataSourceReader",
46+
"InputPartition",
4447
]
4548

4649

@@ -85,6 +88,7 @@
8588
WhenNotMatchedClause,
8689
)
8790
from snowflake.snowpark.window import Window, WindowSpec
91+
from snowflake.snowpark.data_source import DataSourceReader, DataSource, InputPartition
8892

8993
_deprecation_warning_msg = (
9094
"Python Runtime 3.8 reached its End-Of-Life (EOL) on October 14, 2024, there will be no further bug fixes "

src/snowflake/snowpark/_internal/data_source/datasource_partitioner.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
DRIVER_MAP,
2020
)
2121

22-
from snowflake.snowpark._internal.data_source.datasource_reader import DataSourceReader
22+
from snowflake.snowpark._internal.data_source.datasource_reader import (
23+
DbapiDataSourceReader,
24+
)
25+
from snowflake.snowpark.data_source import DataSource
2326
from snowflake.snowpark._internal.type_utils import type_string_to_type_object
2427
from snowflake.snowpark._internal.data_source.datasource_typing import Connection
2528
from snowflake.snowpark._internal.utils import generate_random_alphanumeric
@@ -38,7 +41,7 @@
3841
logger = logging.getLogger(__name__)
3942

4043

41-
class DataSourcePartitioner:
44+
class DataSourcePartitioner(DataSource):
4245
def __init__(
4346
self,
4447
create_connection: Callable[..., "Connection"],
@@ -56,6 +59,7 @@ def __init__(
5659
fetch_merge_count: Optional[int] = 1,
5760
connection_parameters: Optional[dict] = None,
5861
) -> None:
62+
super().__init__()
5963
self.create_connection = create_connection
6064
self.table_or_query = table_or_query
6165
self.is_query = is_query
@@ -91,11 +95,11 @@ def __init__(
9195
else None
9296
)
9397

94-
def reader(self) -> DataSourceReader:
95-
return DataSourceReader(
98+
def reader(self, schema) -> DbapiDataSourceReader:
99+
return DbapiDataSourceReader(
96100
self.driver_class,
97101
self.create_connection,
98-
self.schema,
102+
schema,
99103
self.dbms_type,
100104
self.fetch_size,
101105
self.query_timeout,

src/snowflake/snowpark/_internal/data_source/datasource_reader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from typing import List, Any, Iterator, Type, Callable, Optional
99

10+
from snowflake.snowpark.data_source import DataSourceReader
1011
from snowflake.snowpark._internal.data_source.datasource_typing import Connection
1112
from snowflake.snowpark._internal.data_source.drivers.base_driver import BaseDriver
1213
from snowflake.snowpark.exceptions import SnowparkDataframeReaderException
@@ -17,7 +18,7 @@
1718
logger = logging.getLogger(__name__)
1819

1920

20-
class DataSourceReader:
21+
class DbapiDataSourceReader(DataSourceReader):
2122
def __init__(
2223
self,
2324
driver_class: Type[BaseDriver],
@@ -30,6 +31,7 @@ def __init__(
3031
fetch_merge_count: Optional[int] = 1,
3132
connection_parameters: Optional[dict] = None,
3233
) -> None:
34+
super().__init__(schema)
3335
# we use cloudpickle to pickle the callback function so that local function and function defined in
3436
# __main__ can be pickled and unpickled in subprocess
3537
self.pickled_create_connection_callback = cloudpickle.dumps(

src/snowflake/snowpark/_internal/data_source/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -608,8 +608,9 @@ def track_data_source_statement_params(
608608

609609
def local_ingestion(
610610
session: "snowflake.snowpark.Session",
611-
partitioner: "snowflake.snowpark._internal.data_source.datasource_partitioner.DataSourcePartitioner",
611+
data_source: "snowflake.snowpark.DataSource",
612612
partitioned_queries: List[str],
613+
schema: StructType,
613614
max_workers: int,
614615
snowflake_stage_name: str,
615616
snowflake_table_name: str,
@@ -651,7 +652,7 @@ def local_ingestion(
651652
partition_queue,
652653
parquet_queue,
653654
process_or_thread_error_indicator,
654-
partitioner.reader(),
655+
data_source.reader(schema),
655656
),
656657
)
657658
process.start()
@@ -667,7 +668,7 @@ def local_ingestion(
667668
partition_queue,
668669
parquet_queue,
669670
process_or_thread_error_indicator,
670-
partitioner.reader(),
671+
data_source.reader(schema),
671672
data_fetching_thread_stop_event,
672673
)
673674
for _worker_id in range(max_workers)

src/snowflake/snowpark/data_source.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
from typing import Any, Tuple, Iterator, Union, List
66

77
from snowflake.snowpark.types import StructType
8-
import snowflake
8+
from typing import TYPE_CHECKING
9+
10+
if TYPE_CHECKING:
11+
import snowflake.snowpark
912

1013

1114
class InputPartition:
@@ -45,5 +48,5 @@ def _partitions(self) -> List[InputPartition]:
4548
self._internal_partitions = self.reader(self.schema()).partitions()
4649
return self._internal_partitions
4750

48-
def udtf_ingestion(self) -> "snowflake.snowpark.DataFrame":
51+
def _udtf_ingestion(self) -> "snowflake.snowpark.DataFrame":
4952
pass

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2027,7 +2027,8 @@ def create_oracledb_connection():
20272027

20282028
local_ingestion(
20292029
session=self._session,
2030-
partitioner=partitioner,
2030+
schema=struct_schema,
2031+
data_source=partitioner,
20312032
partitioned_queries=partitioned_queries,
20322033
max_workers=max_workers,
20332034
fetch_with_process=fetch_with_process,

0 commit comments

Comments
 (0)