Skip to content

Commit 9ca88cc

Browse files
[async] Fixed #2429 and #2568:
conn-rest-conn -> conn and made http_config passed from async to base_auth_data. Fix - limited outgoing requests from tests Fixed not awaited coroutine error Fixed old mistakes in _okta.py - conn.rest.connection -> conn Fixed errors with wif tests Fixed errors with okta auth step4_negative test case Fixed errors with sessionManager runtime import and None value Fixed errors with no session manager in workload_identity in tests Fixed Connection closed issue
1 parent cf2a731 commit 9ca88cc

18 files changed

+143
-86
lines changed

src/snowflake/connector/aio/_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ async def _request_exec(
849849
) from err
850850

851851
@contextlib.asynccontextmanager
852-
async def _use_session(
852+
async def use_session(
853853
self, url: str | None = None
854854
) -> AsyncGenerator[aiohttp.ClientSession]:
855855
async with self._session_manager.use_session(url) as session:

src/snowflake/connector/aio/_s3_storage_client.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,27 @@ async def _has_expired_token(self, response: aiohttp.ClientResponse) -> bool:
425425
"""
426426
if response.status != 400:
427427
return False
428-
message = await response.text()
428+
# Read body once; avoid a second read which can raise RuntimeError("Connection closed.")
429+
try:
430+
message = await response.text()
431+
except RuntimeError as e:
432+
logger.debug(
433+
"S3 token-expiry check: failed to read error body, treating as not expired. error=%s",
434+
type(e),
435+
)
436+
return False
429437
if not message:
438+
logger.debug(
439+
"S3 token-expiry check: empty error body, treating as not expired"
440+
)
441+
return False
442+
try:
443+
err = ET.fromstring(message)
444+
except ET.ParseError:
445+
logger.debug(
446+
"S3 token-expiry check: non-XML error body (len=%d), treating as not expired.",
447+
len(message),
448+
)
430449
return False
431-
err = ET.fromstring(await response.read())
432-
return err.find("Code").text == EXPIRED_TOKEN
450+
code = err.find("Code")
451+
return code is not None and code.text == EXPIRED_TOKEN

src/snowflake/connector/aio/_wif_util.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import logging
55
import os
66
from base64 import b64encode
7-
from typing import TYPE_CHECKING
87

98
import aioboto3
109
from aiobotocore.utils import AioInstanceMetadataRegionFetcher
@@ -22,9 +21,7 @@
2221
extract_iss_and_sub_without_signature_verification,
2322
get_aws_sts_hostname,
2423
)
25-
26-
if TYPE_CHECKING:
27-
from ._session_manager import SessionManager
24+
from ._session_manager import SessionManager
2825

2926
logger = logging.getLogger(__name__)
3027

@@ -175,6 +172,9 @@ async def create_attestation(
175172
If an explicit entra_resource was provided to the connector, this will be used. Otherwise, the default Snowflake Entra resource will be used.
176173
"""
177174
entra_resource = entra_resource or DEFAULT_ENTRA_SNOWFLAKE_RESOURCE
175+
session_manager = (
176+
session_manager.clone() if session_manager else SessionManager(use_pooling=True)
177+
)
178178

179179
if provider == AttestationProvider.AWS:
180180
return await create_aws_attestation()

src/snowflake/connector/aio/auth/_auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ async def authenticate(
103103
self._rest._connection._network_timeout,
104104
self._rest._connection._socket_timeout,
105105
self._rest._connection._platform_detection_timeout_seconds,
106-
session_manager=self._rest.session_manager.clone(use_pooling=False),
106+
http_config=self._rest.session_manager.config, # AioHttpConfig extends BaseHttpConfig
107107
)
108108

109109
body = copy.deepcopy(body_template)

src/snowflake/connector/aio/auth/_okta.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ async def _step1(
123123
conn._ocsp_mode(),
124124
conn.login_timeout,
125125
conn._network_timeout,
126+
http_config=conn._session_manager.config, # AioHttpConfig extends BaseHttpConfig
126127
)
127128

128129
body["data"]["AUTHENTICATOR"] = authenticator
@@ -131,12 +132,12 @@ async def _step1(
131132
account,
132133
authenticator,
133134
)
134-
ret = await conn._rest._post_request(
135+
ret = await conn.rest._post_request(
135136
url,
136137
headers,
137138
json.dumps(body),
138-
timeout=conn._rest._connection.login_timeout,
139-
socket_timeout=conn._rest._connection.login_timeout,
139+
timeout=conn.login_timeout,
140+
socket_timeout=conn.login_timeout,
140141
)
141142

142143
if not ret["success"]:
@@ -171,19 +172,19 @@ async def _step3(
171172
"username": user,
172173
"password": password,
173174
}
174-
ret = await conn._rest.fetch(
175+
ret = await conn.rest.fetch(
175176
"post",
176177
token_url,
177178
headers,
178179
data=json.dumps(data),
179-
timeout=conn._rest._connection.login_timeout,
180-
socket_timeout=conn._rest._connection.login_timeout,
180+
timeout=conn.login_timeout,
181+
socket_timeout=conn.login_timeout,
181182
catch_okta_unauthorized_error=True,
182183
)
183184
one_time_token = ret.get("sessionToken", ret.get("cookieToken"))
184185
if not one_time_token:
185186
Error.errorhandler_wrapper(
186-
conn._rest._connection,
187+
conn,
187188
None,
188189
DatabaseError,
189190
{
@@ -221,7 +222,7 @@ async def _step4(
221222
HTTP_HEADER_ACCEPT: "*/*",
222223
}
223224
remaining_timeout = timeout_time - time.time() if timeout_time else None
224-
response_html = await conn._rest.fetch(
225+
response_html = await conn.rest.fetch(
225226
"get",
226227
sso_url,
227228
headers,

src/snowflake/connector/aio/auth/_webbrowser.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -365,13 +365,13 @@ async def _get_sso_url(
365365
body = Auth.base_auth_data(
366366
user,
367367
account,
368-
conn._rest._connection.application,
369-
conn._rest._connection._internal_application_name,
370-
conn._rest._connection._internal_application_version,
371-
conn._rest._connection._ocsp_mode(),
372-
conn._rest._connection.login_timeout,
373-
conn._rest._connection._network_timeout,
374-
session_manager=conn.rest.session_manager.clone(use_pooling=False),
368+
conn.application,
369+
conn._internal_application_name,
370+
conn._internal_application_version,
371+
conn._ocsp_mode(),
372+
conn.login_timeout,
373+
conn._network_timeout,
374+
http_config=conn._session_manager.config, # AioHttpConfig extends BaseHttpConfig
375375
)
376376

377377
body["data"]["AUTHENTICATOR"] = authenticator
@@ -383,8 +383,8 @@ async def _get_sso_url(
383383
url,
384384
headers,
385385
json.dumps(body),
386-
timeout=conn._rest._connection.login_timeout,
387-
socket_timeout=conn._rest._connection.login_timeout,
386+
timeout=conn.login_timeout,
387+
socket_timeout=conn.login_timeout,
388388
)
389389
if not ret["success"]:
390390
await self._handle_failure(conn=conn, ret=ret)

src/snowflake/connector/auth/_auth.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@
5353
ReauthenticationRequest,
5454
)
5555
from ..platform_detection import detect_platforms
56-
from ..session_manager import SessionManager
56+
from ..session_manager import BaseHttpConfig, HttpConfig
57+
from ..session_manager import SessionManager as SyncSessionManager
5758
from ..sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
5859
from ..token_cache import TokenCache, TokenKey, TokenType
5960
from ..version import VERSION
@@ -104,8 +105,17 @@ def base_auth_data(
104105
network_timeout: int | None = None,
105106
socket_timeout: int | None = None,
106107
platform_detection_timeout_seconds: float | None = None,
107-
session_manager: SessionManager | None = None,
108+
session_manager: SyncSessionManager | None = None,
109+
http_config: BaseHttpConfig | None = None,
108110
):
111+
# Create sync SessionManager for platform detection if config is provided
112+
# Platform detection runs in threads and uses sync SessionManager
113+
if http_config is not None and session_manager is None:
114+
# Extract base fields (automatically excludes subclass-specific fields)
115+
# Note: It won't be possible to pass adapter_factory from outer async-code to this part of code
116+
sync_config = HttpConfig(**http_config.to_base_dict())
117+
session_manager = SyncSessionManager(config=sync_config)
118+
109119
return {
110120
"data": {
111121
"CLIENT_APP_ID": internal_application_name,

src/snowflake/connector/platform_detection.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
import os
45
import re
56
from concurrent.futures.thread import ThreadPoolExecutor
@@ -13,6 +14,8 @@
1314
from .session_manager import SessionManager
1415
from .vendored.requests import RequestException, Timeout
1516

17+
logger = logging.getLogger(__name__)
18+
1619

1720
class _DetectionState(Enum):
1821
"""Internal enum to represent the detection state of a platform."""
@@ -399,6 +402,9 @@ def detect_platforms(
399402

400403
if session_manager is None:
401404
# This should never happen - we expect session manager to be passed from the outer scope
405+
logger.debug(
406+
"No session manager provided. HTTP settings may not be preserved. Using default."
407+
)
402408
session_manager = SessionManager(use_pooling=False)
403409

404410
# Run environment-only checks synchronously (no network calls, no threading overhead)

src/snowflake/connector/session_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import functools
77
import itertools
88
import logging
9-
from dataclasses import dataclass, field, replace
9+
from dataclasses import asdict, dataclass, field, fields, replace
1010
from typing import TYPE_CHECKING, Any, Callable, Generator, Generic, Mapping, TypeVar
1111

1212
from .compat import urlparse
@@ -117,6 +117,11 @@ def copy_with(self, **overrides: Any) -> BaseHttpConfig:
117117
"""Return a new config with overrides applied."""
118118
return replace(self, **overrides)
119119

120+
def to_base_dict(self) -> dict[str, Any]:
121+
"""Extract only BaseHttpConfig fields as a dict, excluding subclass-specific fields."""
122+
base_field_names = {f.name for f in fields(BaseHttpConfig)}
123+
return {k: v for k, v in asdict(self).items() if k in base_field_names}
124+
120125

121126
@dataclass(frozen=True)
122127
class HttpConfig(BaseHttpConfig):

test/integ/aio_it/pandas_it/test_arrow_pandas_async.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,8 +1402,8 @@ async def test_sessions_used(conn_cnx, fetch_fn_name, pass_connection):
14021402

14031403
# check that sessions are used when connection is supplied
14041404
with mock.patch(
1405-
"snowflake.connector.aio._network.SnowflakeRestful._use_session",
1406-
side_effect=cnx._rest._use_session,
1405+
"snowflake.connector.aio._network.SnowflakeRestful.use_session",
1406+
side_effect=cnx._rest.use_session,
14071407
) as get_session_mock:
14081408
await fetch_fn(connection=connection)
14091409
assert get_session_mock.call_count == (1 if pass_connection else 0)

0 commit comments

Comments
 (0)