Skip to content

Commit 58c8644

Browse files
[async] apply #2417
1 parent a813cd3 commit 58c8644

File tree

4 files changed

+22
-68
lines changed

4 files changed

+22
-68
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -348,12 +348,6 @@ async def __open_connection(self):
348348
host=self.host, port=self.port
349349
),
350350
scope=self._oauth_scope,
351-
token_cache=(
352-
auth.get_token_cache()
353-
if self._client_store_temporary_credential
354-
else None
355-
),
356-
refresh_token_enabled=self._oauth_enable_refresh_tokens,
357351
connection=self,
358352
)
359353
elif self._authenticator == PROGRAMMATIC_ACCESS_TOKEN:

src/snowflake/connector/aio/auth/_oauth_credentials.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from ...auth.oauth_credentials import (
99
AuthByOauthCredentials as AuthByOauthCredentialsSync,
1010
)
11-
from ...token_cache import TokenCache
1211
from ._by_plugin import AuthByPlugin as AuthByPluginAsync
1312

1413
if TYPE_CHECKING:
@@ -27,8 +26,6 @@ def __init__(
2726
client_secret: str,
2827
token_request_url: str,
2928
scope: str,
30-
token_cache: TokenCache | None = None,
31-
refresh_token_enabled: bool = False,
3229
connection: SnowflakeConnection | None = None,
3330
**kwargs,
3431
) -> None:
@@ -43,8 +40,6 @@ def __init__(
4340
client_secret=client_secret,
4441
token_request_url=token_request_url,
4542
scope=scope,
46-
token_cache=token_cache,
47-
refresh_token_enabled=refresh_token_enabled,
4843
connection=connection,
4944
**kwargs,
5045
)

test/unit/aio/test_auth_oauth_credentials_async.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ async def test_auth_oauth_credentials():
2121
client_secret="test_client_secret",
2222
token_request_url="https://example.com/token",
2323
scope="session:role:test_role",
24-
refresh_token_enabled=False,
2524
)
2625

2726
body = {"data": {}}

test/unit/aio/test_oauth_token_async.py

Lines changed: 22 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ async def test_client_creds_successful_flow_async(
577577
wiremock_client: WiremockClient,
578578
wiremock_oauth_client_creds_dir,
579579
wiremock_generic_mappings_dir,
580+
temp_cache_async,
580581
) -> None:
581582
wiremock_client.import_mapping(
582583
wiremock_oauth_client_creds_dir / "successful_flow.json"
@@ -587,6 +588,15 @@ async def test_client_creds_successful_flow_async(
587588
wiremock_client.add_mapping(
588589
wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
589590
)
591+
user = "testUser"
592+
access_token_key = TokenKey(
593+
user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN
594+
)
595+
refresh_token_key = TokenKey(
596+
user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN
597+
)
598+
temp_cache_async.store(access_token_key, "unused-access-token-123")
599+
temp_cache_async.store(refresh_token_key, "unused-refresh-token-123")
590600
with mock.patch("secrets.token_urlsafe", return_value="abc123"):
591601
cnx = SnowflakeConnection(
592602
user="testUser",
@@ -599,10 +609,17 @@ async def test_client_creds_successful_flow_async(
599609
oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
600610
host=wiremock_client.wiremock_host,
601611
port=wiremock_client.wiremock_http_port,
612+
oauth_enable_refresh_tokens=True,
613+
client_store_temporary_credential=True,
602614
)
603615

604616
await cnx.connect()
605617
await cnx.close()
618+
# cached tokens are expected not to change since Client Credentials must not use token cache
619+
cached_access_token = temp_cache_async.retrieve(access_token_key)
620+
cached_refresh_token = temp_cache_async.retrieve(refresh_token_key)
621+
assert cached_access_token == "unused-access-token-123"
622+
assert cached_refresh_token == "unused-refresh-token-123"
606623

607624

608625
@pytest.mark.skipolddriver
@@ -643,57 +660,6 @@ async def test_client_creds_token_request_error_async(
643660
)
644661

645662

646-
@pytest.mark.skipolddriver
647-
async def test_client_creds_successful_refresh_token_flow_async(
648-
wiremock_client: WiremockClient,
649-
wiremock_oauth_refresh_token_dir,
650-
wiremock_generic_mappings_dir,
651-
temp_cache_async,
652-
) -> None:
653-
wiremock_client.import_mapping(
654-
wiremock_generic_mappings_dir / "snowflake_login_failed.json"
655-
)
656-
wiremock_client.add_mapping(
657-
wiremock_oauth_refresh_token_dir / "refresh_successful.json"
658-
)
659-
wiremock_client.add_mapping(
660-
wiremock_generic_mappings_dir / "snowflake_login_successful.json"
661-
)
662-
wiremock_client.add_mapping(
663-
wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
664-
)
665-
user = "testUser"
666-
access_token_key = TokenKey(
667-
user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN
668-
)
669-
refresh_token_key = TokenKey(
670-
user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN
671-
)
672-
temp_cache_async.store(access_token_key, "expired-access-token-123")
673-
temp_cache_async.store(refresh_token_key, "refresh-token-123")
674-
cnx = SnowflakeConnection(
675-
user=user,
676-
authenticator="OAUTH_CLIENT_CREDENTIALS",
677-
oauth_client_id="123",
678-
account="testAccount",
679-
protocol="http",
680-
role="ANALYST",
681-
oauth_client_secret="testClientSecret",
682-
oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
683-
host=wiremock_client.wiremock_host,
684-
port=wiremock_client.wiremock_http_port,
685-
oauth_enable_refresh_tokens=True,
686-
client_store_temporary_credential=True,
687-
)
688-
await cnx.connect()
689-
await cnx.close()
690-
691-
new_access_token = temp_cache_async.retrieve(access_token_key)
692-
new_refresh_token = temp_cache_async.retrieve(refresh_token_key)
693-
assert new_access_token == "access-token-123"
694-
assert new_refresh_token == "refresh-token-123"
695-
696-
697663
@pytest.mark.skipolddriver
698664
async def test_client_creds_expired_refresh_token_flow_async(
699665
wiremock_client: WiremockClient,
@@ -744,8 +710,8 @@ async def test_client_creds_expired_refresh_token_flow_async(
744710
)
745711
await cnx.connect()
746712
await cnx.close()
747-
748-
new_access_token = temp_cache_async.retrieve(access_token_key)
749-
new_refresh_token = temp_cache_async.retrieve(refresh_token_key)
750-
assert new_access_token == "access-token-123"
751-
assert new_refresh_token == "refresh-token-123"
713+
# the cache state is expected not to change, since Client Credentials must not use token caching
714+
cached_access_token = temp_cache_async.retrieve(access_token_key)
715+
cached_refresh_token = temp_cache_async.retrieve(refresh_token_key)
716+
assert cached_access_token == "expired-access-token-123"
717+
assert cached_refresh_token == "expired-refresh-token-123"

0 commit comments

Comments
 (0)