Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ repos:
hooks:
- id: pyupgrade
args: [--py38-plus]
language_version: python3.13
- repo: local
hooks:
- id: check-no-native-http
Expand Down
115 changes: 86 additions & 29 deletions src/snowflake/connector/aio/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
SQLSTATE_CONNECTION_REJECTED,
SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
)
from ..time_util import TimeoutBackoffCtx
from ..time_util import DEFAULT_MASTER_VALIDITY_IN_SECONDS, TimeoutBackoffCtx
from ._description import CLIENT_NAME
from ._session_manager import (
SessionManager,
Expand Down Expand Up @@ -143,7 +143,7 @@ def __init__(
session_manager: SessionManager | None = None,
):
super().__init__(host, port, protocol, inject_client_pause, connection)
self._lock_token = asyncio.Lock()
self._token_async_lock = asyncio.Lock()

if session_manager is None:
session_manager = (
Expand All @@ -155,16 +155,53 @@ def __init__(
)
self._session_manager = session_manager

async def close(self) -> None:
if hasattr(self, "_token"):
del self._token
if hasattr(self, "_master_token"):
del self._master_token
if hasattr(self, "_id_token"):
del self._id_token
if hasattr(self, "_mfa_token"):
del self._mfa_token
@property
def id_token(self):
return super().id_token

@id_token.setter
def id_token(self, value) -> None:
raise TypeError("Use set_id_token_async() in async connections.")

@property
def mfa_token(self) -> str | None:
return super().mfa_token

@mfa_token.setter
def mfa_token(self, value: str) -> None:
raise TypeError("Use set_mfa_token_async() in async connections.")

@property
def master_validity_in_seconds(self) -> int:
return super().master_validity_in_seconds

@master_validity_in_seconds.setter
def master_validity_in_seconds(self, value) -> None:
raise TypeError(
"Use set_master_validity_in_seconds_async() in async connections."
)

async def set_id_token_async(self, value) -> None:
async with self._token_async_lock:
with self._lock_token:
self._token_state = self._get_token_state().copy(id_token=value)

async def set_mfa_token_async(self, value: str) -> None:
async with self._token_async_lock:
with self._lock_token:
self._token_state = self._get_token_state().copy(mfa_token=value)

async def set_master_validity_in_seconds_async(self, value) -> None:
async with self._token_async_lock:
with self._lock_token:
target = value if value else DEFAULT_MASTER_VALIDITY_IN_SECONDS
self._token_state = self._get_token_state().copy(
master_validity_in_seconds=target
)

async def close(self) -> None:
async with self._token_async_lock:
self._remove_token_state()
await self._session_manager.close()

async def request(
Expand All @@ -182,7 +219,8 @@ async def request(
logger.debug("%s %s", method.upper(), url)
if body is None:
body = {}
if self.master_token is None and self.token is None:
state = self._get_token_state()
if state.master_token is None and state.session_token is None:
Error.errorhandler_wrapper(
self._connection,
None,
Expand Down Expand Up @@ -225,7 +263,7 @@ async def request(
url,
headers,
json.dumps(body, cls=SnowflakeRestfulJsonEncoder),
token=self.token,
token=state.session_token,
_no_results=_no_results,
timeout=timeout,
_include_retry_params=_include_retry_params,
Expand All @@ -235,10 +273,15 @@ async def request(
return await self._get_request(
url,
headers,
token=self.token,
token=state.session_token,
timeout=timeout,
)

# TODO(future): Decide legacy vs new token flow and serialization model. Current gaps:
# - Legacy consumers/tests still read connection._token (and friends) which are set at init, but not kept in sync after renewals; either mirror updates here (and delete on close) or audit/remove those reads to rely solely on _TokenState-backed properties.
# - Mutations are serialized via _token_async_lock + _lock_token; sync setters are disabled to avoid blocking the loop, but if we want to keep them, we need a coherent locking story.
# - The race we aim to avoid: mixed token snapshots across concurrent requests/renew/close; consider passing token snapshots via a context var instead of shared mutable state.
# - Post-close: we currently allow recreating state via _get_token_state(); if we prefer hard-fail after close, add an explicit closed guard instead of AttributeError.
async def update_tokens(
self,
session_token,
Expand All @@ -248,21 +291,26 @@ async def update_tokens(
mfa_token=None,
) -> None:
"""Updates session and master tokens and optionally temporary credential."""
async with self._lock_token:
self._token = session_token
self._master_token = master_token
self._id_token = id_token
self._mfa_token = mfa_token
self._master_validity_in_seconds = master_validity_in_seconds
async with self._token_async_lock:
with self._lock_token:
new_state = self._get_token_state().copy(
session_token=session_token,
master_token=master_token,
master_validity_in_seconds=master_validity_in_seconds,
id_token=id_token,
mfa_token=mfa_token,
)
self._token_state = new_state

async def _renew_session(self):
"""Renew a session and master token."""
return await self._token_request(REQUEST_TYPE_RENEW)

async def _token_request(self, request_type):
state = self._get_token_state()
logger.debug(
"updating session. master_token: {}".format(
"****" if self.master_token else None
"****" if state.master_token else None
)
)
headers = {
Expand All @@ -278,9 +326,9 @@ async def _token_request(self, request_type):

# NOTE: ensure an empty key if master token is not set.
# This avoids HTTP 400.
header_token = self.master_token or ""
header_token = state.master_token or ""
body = {
"oldSessionToken": self.token,
"oldSessionToken": state.session_token,
"requestType": request_type,
}
ret = await self._post_request(
Expand Down Expand Up @@ -331,6 +379,7 @@ async def _token_request(self, request_type):
)

async def _heartbeat(self) -> Any | dict[Any, Any] | None:
state = self._get_token_state()
headers = {
HTTP_HEADER_CONTENT_TYPE: CONTENT_TYPE_APPLICATION_JSON,
HTTP_HEADER_ACCEPT: CONTENT_TYPE_APPLICATION_JSON,
Expand All @@ -345,15 +394,16 @@ async def _heartbeat(self) -> Any | dict[Any, Any] | None:
url,
headers,
None,
token=self.token,
token=state.session_token,
)
if not ret.get("success"):
logger.error("Failed to heartbeat. code: %s, url: %s", ret.get("code"), url)
return ret

async def delete_session(self, retry: bool = False) -> None:
"""Deletes the session."""
if self.master_token is None:
state = self._get_token_state()
if state.master_token is None:
Error.errorhandler_wrapper(
self._connection,
None,
Expand Down Expand Up @@ -385,7 +435,7 @@ async def delete_session(self, retry: bool = False) -> None:
url,
headers,
json.dumps(body, cls=SnowflakeRestfulJsonEncoder),
token=self.token,
token=state.session_token,
timeout=5,
no_retry=True,
)
Expand Down Expand Up @@ -441,10 +491,11 @@ async def _get_request(
)
)
if ret.get("success"):
refreshed_state = self._get_token_state()
return await self._get_request(
url,
headers,
token=self.token,
token=refreshed_state.session_token,
is_fetch_query_status=is_fetch_query_status,
)

Expand Down Expand Up @@ -499,8 +550,13 @@ async def _post_request(
)
)
if ret.get("success"):
refreshed_state = self._get_token_state()
return await self._post_request(
url, headers, body, token=self.token, timeout=timeout
url,
headers,
body,
token=refreshed_state.session_token,
timeout=timeout,
)

if isinstance(ret.get("data"), dict) and ret["data"].get("queryId"):
Expand All @@ -516,10 +572,11 @@ async def _post_request(
# ping pong
result_url = ret["data"]["getResultUrl"]
logger.debug("ping pong starting...")
refreshed_state = self._get_token_state()
ret = await self._get_request(
result_url,
headers,
token=self.token,
token=refreshed_state.session_token,
timeout=timeout,
is_fetch_query_status=bool(
re.match(r"^/queries/.+/result$", result_url)
Expand Down
Loading
Loading