Skip to content

Commit eb4f3cd

Browse files
Merge branch 'main' into yuwang-custom-data-source
2 parents b45a4f2 + 62acef0 commit eb4f3cd

File tree

8 files changed

+694
-125
lines changed

8 files changed

+694
-125
lines changed

CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,19 @@
144144
- `cumsum`
145145
- `cummin`
146146
- `cummax`
147+
- `groupby.groups`
148+
- `groupby.indices`
149+
- `groupby.first`
150+
- `groupby.last`
151+
- `groupby.rank`
152+
- `groupby.shift`
153+
- `groupby.cumcount`
154+
- `groupby.cumsum`
155+
- `groupby.cummin`
156+
- `groupby.cummax`
157+
- `groupby.any`
158+
- `groupby.all`
159+
- `groupby.unique`
147160
- Make faster pandas disabled by default (opt-in instead of opt-out).
148161
- Improve performance of `drop_duplicates` by avoiding joins when `keep!=False` in faster pandas.
149162

src/snowflake/snowpark/_internal/data_source/utils.py

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from threading import BoundedSemaphore
1313
from io import BytesIO
1414
from enum import Enum
15-
from typing import Any, Tuple, Optional, Callable, Dict, Union, Set
15+
from typing import Any, Tuple, Optional, Callable, Dict, Union, Set, List
1616
import logging
1717
from snowflake.snowpark._internal.data_source.dbms_dialects import (
1818
Sqlite3Dialect,
@@ -30,13 +30,17 @@
3030
Psycopg2Driver,
3131
PymysqlDriver,
3232
)
33-
import snowflake
3433
from snowflake.snowpark._internal.data_source import DataSourceReader
3534
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
3635
from snowflake.snowpark._internal.utils import get_temp_type_for_object
3736
from snowflake.snowpark.exceptions import SnowparkDataframeReaderException
3837
from snowflake.snowpark.types import StructType
3938

39+
from typing import TYPE_CHECKING
40+
41+
if TYPE_CHECKING:
42+
import snowflake.snowpark
43+
4044
logger = logging.getLogger(__name__)
4145

4246
_MAX_RETRY_TIME = 3
@@ -600,3 +604,130 @@ def track_data_source_statement_params(
600604
statement_params[STATEMENT_PARAMS_DATA_SOURCE] = "1"
601605

602606
return statement_params if statement_params else None
607+
608+
609+
def local_ingestion(
610+
session: "snowflake.snowpark.Session",
611+
partitioner: "snowflake.snowpark._internal.data_source.datasource_partitioner.DataSourcePartitioner",
612+
partitioned_queries: List[str],
613+
max_workers: int,
614+
snowflake_stage_name: str,
615+
snowflake_table_name: str,
616+
statements_params_for_telemetry: Dict,
617+
telemetry_json_string: Dict,
618+
fetch_with_process: bool = False,
619+
_emit_ast: bool = True,
620+
) -> None:
621+
data_fetching_thread_pool_executor = None
622+
data_fetching_thread_stop_event = None
623+
workers = []
624+
try:
625+
# Determine the number of processes or threads to use
626+
max_workers = max_workers or os.cpu_count()
627+
queue_class = mp.Queue if fetch_with_process else queue.Queue
628+
process_or_thread_error_indicator = queue_class()
629+
# a queue of partitions to be processed, this is filled by the partitioner before starting the workers
630+
partition_queue = queue_class()
631+
# a queue of parquet BytesIO objects to be uploaded
632+
# Set max size for parquet_queue to prevent overfilling when thread consumers are slower than process producers
633+
# process workers will block on this queue if it's full until the upload threads consume the BytesIO objects
634+
parquet_queue = queue_class(_MAX_WORKER_SCALE * max_workers)
635+
for partition_idx, query in enumerate(partitioned_queries):
636+
partition_queue.put((partition_idx, query))
637+
638+
# Start worker processes
639+
logger.debug(
640+
f"Starting {max_workers} worker processes to fetch data from the data source."
641+
)
642+
643+
fetch_to_local_start_time = time.perf_counter()
644+
logger.debug(f"fetch to local start at: {fetch_to_local_start_time}")
645+
646+
if fetch_with_process:
647+
for _worker_id in range(max_workers):
648+
process = mp.Process(
649+
target=worker_process,
650+
args=(
651+
partition_queue,
652+
parquet_queue,
653+
process_or_thread_error_indicator,
654+
partitioner.reader(),
655+
),
656+
)
657+
process.start()
658+
workers.append(process)
659+
else:
660+
data_fetching_thread_pool_executor = ThreadPoolExecutor(
661+
max_workers=max_workers
662+
)
663+
data_fetching_thread_stop_event = threading.Event()
664+
workers = [
665+
data_fetching_thread_pool_executor.submit(
666+
worker_process,
667+
partition_queue,
668+
parquet_queue,
669+
process_or_thread_error_indicator,
670+
partitioner.reader(),
671+
data_fetching_thread_stop_event,
672+
)
673+
for _worker_id in range(max_workers)
674+
]
675+
676+
# Process BytesIO objects from queue and upload them using utility method
677+
(
678+
fetch_to_local_end_time,
679+
upload_to_sf_start_time,
680+
upload_to_sf_end_time,
681+
) = process_parquet_queue_with_threads(
682+
session=session,
683+
parquet_queue=parquet_queue,
684+
process_or_thread_error_indicator=process_or_thread_error_indicator,
685+
workers=workers,
686+
total_partitions=len(partitioned_queries),
687+
snowflake_stage_name=snowflake_stage_name,
688+
snowflake_table_name=snowflake_table_name,
689+
max_workers=max_workers,
690+
statements_params=statements_params_for_telemetry,
691+
on_error="abort_statement",
692+
fetch_with_process=fetch_with_process,
693+
)
694+
logger.debug(f"upload and copy into start at: {upload_to_sf_start_time}")
695+
logger.debug(
696+
f"fetch to local total time: {fetch_to_local_end_time - fetch_to_local_start_time}"
697+
)
698+
699+
telemetry_json_string["fetch_to_local_duration"] = (
700+
fetch_to_local_end_time - fetch_to_local_start_time
701+
)
702+
telemetry_json_string["upload_and_copy_into_sf_table_duration"] = (
703+
upload_to_sf_end_time - upload_to_sf_start_time
704+
)
705+
706+
except BaseException as exc:
707+
if fetch_with_process:
708+
# Graceful shutdown - terminate all processes
709+
for process in workers:
710+
if process.is_alive():
711+
process.terminate()
712+
process.join(timeout=5)
713+
else:
714+
if data_fetching_thread_stop_event:
715+
data_fetching_thread_stop_event.set()
716+
for future in workers:
717+
if not future.done():
718+
future.cancel()
719+
logger.debug(
720+
f"Cancelled a remaining data fetching future {future} due to error in another thread."
721+
)
722+
723+
if isinstance(exc, SnowparkDataframeReaderException):
724+
raise exc
725+
726+
raise SnowparkDataframeReaderException(
727+
f"Error occurred while ingesting data from the data source: {exc!r}"
728+
)
729+
finally:
730+
if data_fetching_thread_pool_executor:
731+
data_fetching_thread_pool_executor.shutdown(wait=True)
732+
733+
logger.debug("All data has been successfully loaded into the Snowflake table.")

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 15 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,13 @@
22
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
33
#
44
import json
5-
import multiprocessing as mp
65
import os
76
import re
87
import sys
98
import time
10-
import queue
119
from collections import defaultdict
12-
from concurrent.futures import ThreadPoolExecutor
1310
from logging import getLogger
1411
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Callable
15-
import threading
1612
from datetime import datetime
1713

1814
import snowflake.snowpark
@@ -39,12 +35,10 @@
3935
)
4036
from snowflake.snowpark._internal.data_source.datasource_typing import Connection
4137
from snowflake.snowpark._internal.data_source.utils import (
42-
worker_process,
43-
process_parquet_queue_with_threads,
4438
STATEMENT_PARAMS_DATA_SOURCE,
4539
DATA_SOURCE_DBAPI_SIGNATURE,
46-
_MAX_WORKER_SCALE,
4740
create_data_source_table_and_stage,
41+
local_ingestion,
4842
)
4943
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
5044
from snowflake.snowpark._internal.telemetry import set_api_call_source
@@ -2018,7 +2012,7 @@ def create_oracledb_connection():
20182012
set_api_call_source(df, DATA_SOURCE_DBAPI_SIGNATURE)
20192013
return df
20202014

2021-
# parquet ingestion
2015+
# create table and stage for data source
20222016
snowflake_table_name = random_name_for_temp_object(TempObjectType.TABLE)
20232017
snowflake_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
20242018
create_data_source_table_and_stage(
@@ -2029,119 +2023,21 @@ def create_oracledb_connection():
20292023
statements_params_for_telemetry=statements_params_for_telemetry,
20302024
)
20312025

2032-
data_fetching_thread_pool_executor = None
2033-
data_fetching_thread_stop_event = None
2034-
workers = []
2035-
try:
2036-
# Determine the number of processes or threads to use
2037-
max_workers = max_workers or os.cpu_count()
2038-
queue_class = mp.Queue if fetch_with_process else queue.Queue
2039-
process_or_thread_error_indicator = queue_class()
2040-
# a queue of partitions to be processed, this is filled by the partitioner before starting the workers
2041-
partition_queue = queue_class()
2042-
# a queue of parquet BytesIO objects to be uploaded
2043-
# Set max size for parquet_queue to prevent overfilling when thread consumers are slower than process producers
2044-
# process workers will block on this queue if it's full until the upload threads consume the BytesIO objects
2045-
parquet_queue = queue_class(_MAX_WORKER_SCALE * max_workers)
2046-
for partition_idx, query in enumerate(partitioned_queries):
2047-
partition_queue.put((partition_idx, query))
2048-
2049-
# Start worker processes
2050-
logger.debug(
2051-
f"Starting {max_workers} worker processes to fetch data from the data source."
2052-
)
2053-
2054-
fetch_to_local_start_time = time.perf_counter()
2055-
logger.debug(f"fetch to local start at: {fetch_to_local_start_time}")
2056-
2057-
if fetch_with_process:
2058-
for _worker_id in range(max_workers):
2059-
process = mp.Process(
2060-
target=worker_process,
2061-
args=(
2062-
partition_queue,
2063-
parquet_queue,
2064-
process_or_thread_error_indicator,
2065-
partitioner.reader(),
2066-
),
2067-
)
2068-
process.start()
2069-
workers.append(process)
2070-
else:
2071-
data_fetching_thread_pool_executor = ThreadPoolExecutor(
2072-
max_workers=max_workers
2073-
)
2074-
data_fetching_thread_stop_event = threading.Event()
2075-
workers = [
2076-
data_fetching_thread_pool_executor.submit(
2077-
worker_process,
2078-
partition_queue,
2079-
parquet_queue,
2080-
process_or_thread_error_indicator,
2081-
partitioner.reader(),
2082-
data_fetching_thread_stop_event,
2083-
)
2084-
for _worker_id in range(max_workers)
2085-
]
2086-
2087-
# Process BytesIO objects from queue and upload them using utility method
2088-
(
2089-
fetch_to_local_end_time,
2090-
upload_to_sf_start_time,
2091-
upload_to_sf_end_time,
2092-
) = process_parquet_queue_with_threads(
2093-
session=self._session,
2094-
parquet_queue=parquet_queue,
2095-
process_or_thread_error_indicator=process_or_thread_error_indicator,
2096-
workers=workers,
2097-
total_partitions=len(partitioned_queries),
2098-
snowflake_stage_name=snowflake_stage_name,
2099-
snowflake_table_name=snowflake_table_name,
2100-
max_workers=max_workers,
2101-
statements_params=statements_params_for_telemetry,
2102-
on_error="abort_statement",
2103-
fetch_with_process=fetch_with_process,
2104-
)
2105-
logger.debug(f"upload and copy into start at: {upload_to_sf_start_time}")
2106-
logger.debug(
2107-
f"fetch to local total time: {fetch_to_local_end_time - fetch_to_local_start_time}"
2108-
)
2026+
# parquet ingestion, ingestion external data source into temporary table
21092027

2110-
telemetry_json_string["fetch_to_local_duration"] = (
2111-
fetch_to_local_end_time - fetch_to_local_start_time
2112-
)
2113-
telemetry_json_string["upload_and_copy_into_sf_table_duration"] = (
2114-
upload_to_sf_end_time - upload_to_sf_start_time
2115-
)
2116-
2117-
except BaseException as exc:
2118-
if fetch_with_process:
2119-
# Graceful shutdown - terminate all processes
2120-
for process in workers:
2121-
if process.is_alive():
2122-
process.terminate()
2123-
process.join(timeout=5)
2124-
else:
2125-
if data_fetching_thread_stop_event:
2126-
data_fetching_thread_stop_event.set()
2127-
for future in workers:
2128-
if not future.done():
2129-
future.cancel()
2130-
logger.debug(
2131-
f"Cancelled a remaining data fetching future {future} due to error in another thread."
2132-
)
2133-
2134-
if isinstance(exc, SnowparkDataframeReaderException):
2135-
raise exc
2136-
2137-
raise SnowparkDataframeReaderException(
2138-
f"Error occurred while ingesting data from the data source: {exc!r}"
2139-
)
2140-
finally:
2141-
if data_fetching_thread_pool_executor:
2142-
data_fetching_thread_pool_executor.shutdown(wait=True)
2028+
local_ingestion(
2029+
session=self._session,
2030+
partitioner=partitioner,
2031+
partitioned_queries=partitioned_queries,
2032+
max_workers=max_workers,
2033+
fetch_with_process=fetch_with_process,
2034+
snowflake_stage_name=snowflake_stage_name,
2035+
snowflake_table_name=snowflake_table_name,
2036+
statements_params_for_telemetry=statements_params_for_telemetry,
2037+
telemetry_json_string=telemetry_json_string,
2038+
_emit_ast=_emit_ast,
2039+
)
21432040

2144-
logger.debug("All data has been successfully loaded into the Snowflake table.")
21452041
end_time = time.perf_counter()
21462042
telemetry_json_string["end_to_end_duration"] = end_time - start_time
21472043

src/snowflake/snowpark/lineage.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class _SnowflakeDomain:
135135
TABLE = "TABLE"
136136
MODULE = "MODULE"
137137
DATASET = "DATASET"
138+
EXPERIMENT = "EXPERIMENT"
138139
VIEW = "VIEW"
139140
COLUMN = "COLUMN"
140141
SNOWSERVICE_INSTANCE = "SNOWSERVICE_INSTANCE"
@@ -274,6 +275,7 @@ def __init__(self, session: "snowflake.snowpark.session.Session") -> None:
274275
_UserDomain.FEATURE_VIEW,
275276
_UserDomain.MODEL,
276277
_SnowflakeDomain.DATASET,
278+
_SnowflakeDomain.EXPERIMENT,
277279
}
278280

279281
def _get_lineage(

0 commit comments

Comments
 (0)