Skip to content
47 changes: 25 additions & 22 deletions src/snowflake/connector/aio/_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,11 +664,8 @@ async def execute(
)
logger.debug("PUT OR GET: %s", self.is_file_transfer)
if self.is_file_transfer:
from ._file_transfer_agent import SnowflakeFileTransferAgent

# Decide whether to use the old, or new code path
sf_file_transfer_agent = SnowflakeFileTransferAgent(
self,
sf_file_transfer_agent = self._create_file_transfer_agent(
query,
ret,
put_callback=_put_callback,
Expand All @@ -684,9 +681,6 @@ async def execute(
skip_upload_on_content_match=_skip_upload_on_content_match,
source_from_stream=file_stream,
multipart_threshold=data.get("threshold"),
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
unsafe_file_write=self._connection.unsafe_file_write,
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
)
await sf_file_transfer_agent.execute()
data = sf_file_transfer_agent.result()
Expand Down Expand Up @@ -1082,8 +1076,6 @@ async def _download(
_do_reset (bool, optional): Whether to reset the cursor before
downloading, by default we will reset the cursor.
"""
from ._file_transfer_agent import SnowflakeFileTransferAgent

if _do_reset:
self.reset()

Expand All @@ -1097,11 +1089,9 @@ async def _download(
)

# Execute the file operation based on the interpretation above.
file_transfer_agent = SnowflakeFileTransferAgent(
self,
file_transfer_agent = self._create_file_transfer_agent(
"", # empty command because it is triggered by directly calling this util not by a SQL query
ret,
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
)
await file_transfer_agent.execute()
await self._init_result_and_meta(file_transfer_agent.result())
Expand All @@ -1122,8 +1112,6 @@ async def _upload(
_do_reset (bool, optional): Whether to reset the cursor before
uploading, by default we will reset the cursor.
"""
from ._file_transfer_agent import SnowflakeFileTransferAgent

if _do_reset:
self.reset()

Expand All @@ -1137,12 +1125,10 @@ async def _upload(
)

# Execute the file operation based on the interpretation above.
file_transfer_agent = SnowflakeFileTransferAgent(
self,
file_transfer_agent = self._create_file_transfer_agent(
"", # empty command because it is triggered by directly calling this util not by a SQL query
ret,
force_put_overwrite=False, # _upload should respect user decision on overwriting
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
)
await file_transfer_agent.execute()
await self._init_result_and_meta(file_transfer_agent.result())
Expand Down Expand Up @@ -1191,8 +1177,6 @@ async def _upload_stream(
_do_reset (bool, optional): Whether to reset the cursor before
uploading, by default we will reset the cursor.
"""
from ._file_transfer_agent import SnowflakeFileTransferAgent

if _do_reset:
self.reset()

Expand All @@ -1207,13 +1191,11 @@ async def _upload_stream(
)

# Execute the file operation based on the interpretation above.
file_transfer_agent = SnowflakeFileTransferAgent(
self,
file_transfer_agent = self._create_file_transfer_agent(
"", # empty command because it is triggered by directly calling this util not by a SQL query
ret,
source_from_stream=input_stream,
force_put_overwrite=False, # _upload should respect user decision on overwriting
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
)
await file_transfer_agent.execute()
await self._init_result_and_meta(file_transfer_agent.result())
Expand Down Expand Up @@ -1320,6 +1302,27 @@ async def query_result(self, qid: str) -> SnowflakeCursor:
)
return self

def _create_file_transfer_agent(
self,
command: str,
ret: dict[str, Any],
/,
**kwargs,
) -> SnowflakeFileTransferAgent:
from snowflake.connector.aio._file_transfer_agent import (
SnowflakeFileTransferAgent,
)

return SnowflakeFileTransferAgent(
self,
command,
ret,
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
unsafe_file_write=self._connection.unsafe_file_write,
reraise_error_in_file_transfer_work_function=self._connection._reraise_error_in_file_transfer_work_function,
**kwargs,
)


class DictCursor(DictCursorSync, SnowflakeCursor):
pass
55 changes: 35 additions & 20 deletions src/snowflake/connector/aio/_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ..session_manager import BaseHttpConfig
from ..session_manager import SessionManager as SessionManagerSync
from ..session_manager import SessionPool as SessionPoolSync
from ..session_manager import _ConfigDirectAccessMixin

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -328,7 +329,29 @@ async def delete(
)


class SessionManager(_RequestVerbsUsingSessionMixin, SessionManagerSync):
class _AsyncHttpConfigDirectAccessMixin(_ConfigDirectAccessMixin, abc.ABC):
@property
@abc.abstractmethod
def config(self) -> AioHttpConfig: ...

@config.setter
@abc.abstractmethod
def config(self, value) -> AioHttpConfig: ...

@property
def connector_factory(self) -> Callable[..., aiohttp.BaseConnector]:
return self.config.connector_factory

@connector_factory.setter
def connector_factory(self, value: Callable[..., aiohttp.BaseConnector]) -> None:
self.config: AioHttpConfig = self.config.copy_with(connector_factory=value)


class SessionManager(
_RequestVerbsUsingSessionMixin,
SessionManagerSync,
_AsyncHttpConfigDirectAccessMixin,
):
"""
Async HTTP session manager for aiohttp.ClientSession instances.

Expand Down Expand Up @@ -363,14 +386,6 @@ def from_config(cls, cfg: AioHttpConfig, **overrides: Any) -> SessionManager:
cfg = cfg.copy_with(**overrides)
return cls(config=cfg)

@property
def connector_factory(self) -> Callable[..., aiohttp.BaseConnector]:
return self._cfg.connector_factory

@connector_factory.setter
def connector_factory(self, value: Callable[..., aiohttp.BaseConnector]) -> None:
self._cfg: AioHttpConfig = self._cfg.copy_with(connector_factory=value)

def make_session(self) -> aiohttp.ClientSession:
"""Create a new aiohttp.ClientSession with configured connector."""
connector = self._cfg.get_connector(
Expand Down Expand Up @@ -432,18 +447,18 @@ async def close(self):

def clone(
self,
*,
use_pooling: bool | None = None,
connector_factory: ConnectorFactory | None = None,
**http_config_overrides,
) -> SessionManager:
"""Return a new async SessionManager sharing this instance's config."""
overrides: dict[str, Any] = {}
if use_pooling is not None:
overrides["use_pooling"] = use_pooling
if connector_factory is not None:
overrides["connector_factory"] = connector_factory

return self.from_config(self._cfg, **overrides)
"""Return a new *stateless* SessionManager sharing this instance’s config.

"Shallow clone" - the configuration object (HttpConfig) is reused as-is,
while *stateful* aspects such as the per-host SessionPool mapping are
reset, so the two managers do not share live `requests.Session`
objects.
Optional kwargs (e.g. *use_pooling* / *adapter_factory* / max_retries etc.) - overrides to create a modified
copy of the HttpConfig before instantiation.
"""
return self.from_config(self._cfg, **http_config_overrides)


async def request(
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/connector/auth/_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def base_auth_data(
"SOCKET_TIMEOUT": socket_timeout,
"PLATFORM": detect_platforms(
platform_detection_timeout_seconds=platform_detection_timeout_seconds,
session_manager=session_manager,
session_manager=session_manager.clone(max_retries=0),
),
},
},
Expand Down
64 changes: 30 additions & 34 deletions src/snowflake/connector/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@
from pyarrow import Table

from .connection import SnowflakeConnection
from .file_transfer_agent import SnowflakeProgressPercentage
from .file_transfer_agent import (
SnowflakeFileTransferAgent,
SnowflakeProgressPercentage,
)
from .result_batch import ResultBatch

T = TypeVar("T", bound=collections.abc.Sequence)
Expand Down Expand Up @@ -1064,11 +1067,7 @@ def execute(
)
logger.debug("PUT OR GET: %s", self.is_file_transfer)
if self.is_file_transfer:
from .file_transfer_agent import SnowflakeFileTransferAgent

# Decide whether to use the old, or new code path
sf_file_transfer_agent = SnowflakeFileTransferAgent(
self,
sf_file_transfer_agent = self._create_file_transfer_agent(
query,
ret,
put_callback=_put_callback,
Expand All @@ -1084,13 +1083,6 @@ def execute(
skip_upload_on_content_match=_skip_upload_on_content_match,
source_from_stream=file_stream,
multipart_threshold=data.get("threshold"),
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
iobound_tpe_limit=self._connection.iobound_tpe_limit,
unsafe_file_write=self._connection.unsafe_file_write,
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
self._connection
),
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
)
sf_file_transfer_agent.execute()
data = sf_file_transfer_agent.result()
Expand Down Expand Up @@ -1785,8 +1777,6 @@ def _download(
_do_reset (bool, optional): Whether to reset the cursor before
downloading, by default we will reset the cursor.
"""
from .file_transfer_agent import SnowflakeFileTransferAgent

if _do_reset:
self.reset()

Expand All @@ -1800,14 +1790,9 @@ def _download(
)

# Execute the file operation based on the interpretation above.
file_transfer_agent = SnowflakeFileTransferAgent(
self,
file_transfer_agent = self._create_file_transfer_agent(
"", # empty command because it is triggered by directly calling this util not by a SQL query
ret,
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
self._connection
),
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
)
file_transfer_agent.execute()
self._init_result_and_meta(file_transfer_agent.result())
Expand All @@ -1828,7 +1813,6 @@ def _upload(
_do_reset (bool, optional): Whether to reset the cursor before
uploading, by default we will reset the cursor.
"""
from .file_transfer_agent import SnowflakeFileTransferAgent

if _do_reset:
self.reset()
Expand All @@ -1843,15 +1827,10 @@ def _upload(
)

# Execute the file operation based on the interpretation above.
file_transfer_agent = SnowflakeFileTransferAgent(
self,
file_transfer_agent = self._create_file_transfer_agent(
"", # empty command because it is triggered by directly calling this util not by a SQL query
ret,
force_put_overwrite=False, # _upload should respect user decision on overwriting
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
self._connection
),
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
)
file_transfer_agent.execute()
self._init_result_and_meta(file_transfer_agent.result())
Expand Down Expand Up @@ -1898,7 +1877,6 @@ def _upload_stream(
_do_reset (bool, optional): Whether to reset the cursor before
uploading, by default we will reset the cursor.
"""
from .file_transfer_agent import SnowflakeFileTransferAgent

if _do_reset:
self.reset()
Expand All @@ -1914,19 +1892,37 @@ def _upload_stream(
)

# Execute the file operation based on the interpretation above.
file_transfer_agent = SnowflakeFileTransferAgent(
self,
file_transfer_agent = self._create_file_transfer_agent(
"", # empty command because it is triggered by directly calling this util not by a SQL query
ret,
source_from_stream=input_stream,
force_put_overwrite=False, # _upload_stream should respect user decision on overwriting
)
file_transfer_agent.execute()
self._init_result_and_meta(file_transfer_agent.result())

def _create_file_transfer_agent(
self,
command: str,
ret: dict[str, Any],
/,
**kwargs,
) -> SnowflakeFileTransferAgent:
from .file_transfer_agent import SnowflakeFileTransferAgent

return SnowflakeFileTransferAgent(
self,
command,
ret,
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
iobound_tpe_limit=self._connection.iobound_tpe_limit,
unsafe_file_write=self._connection.unsafe_file_write,
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
self._connection
),
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
reraise_error_in_file_transfer_work_function=self._connection._reraise_error_in_file_transfer_work_function,
**kwargs,
)
file_transfer_agent.execute()
self._init_result_and_meta(file_transfer_agent.result())


class DictCursor(SnowflakeCursor):
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/connector/platform_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def detect_platforms(
logger.debug(
"No session manager provided. HTTP settings may not be preserved. Using default."
)
session_manager = SessionManager(use_pooling=False)
session_manager = SessionManager(use_pooling=False, max_retries=0)

# Run environment-only checks synchronously (no network calls, no threading overhead)
platforms = {
Expand Down
Loading
Loading