Skip to content

Commit b1cecd2

Browse files
Add PAT to async authentication (#2122)
1 parent 584dc94 commit b1cecd2

File tree

4 files changed

+108
-0
lines changed

4 files changed

+108
-0
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
EXTERNAL_BROWSER_AUTHENTICATOR,
6262
KEY_PAIR_AUTHENTICATOR,
6363
OAUTH_AUTHENTICATOR,
64+
PROGRAMMATIC_ACCESS_TOKEN,
6465
REQUEST_ID,
6566
USR_PWD_MFA_AUTHENTICATOR,
6667
ReauthenticationRequest,
@@ -82,6 +83,7 @@
8283
AuthByKeyPair,
8384
AuthByOAuth,
8485
AuthByOkta,
86+
AuthByPAT,
8587
AuthByPlugin,
8688
AuthByUsrPwdMfa,
8789
AuthByWebBrowser,
@@ -298,6 +300,8 @@ async def __open_connection(self):
298300
timeout=self.login_timeout,
299301
backoff_generator=self._backoff_generator,
300302
)
303+
elif self._authenticator == PROGRAMMATIC_ACCESS_TOKEN:
304+
self.auth_class = AuthByPAT(self._token)
301305
elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR:
302306
self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN] = (
303307
self._client_request_mfa_token if IS_LINUX else True

src/snowflake/connector/aio/auth/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ._keypair import AuthByKeyPair
1313
from ._oauth import AuthByOAuth
1414
from ._okta import AuthByOkta
15+
from ._pat import AuthByPAT
1516
from ._usrpwdmfa import AuthByUsrPwdMfa
1617
from ._webbrowser import AuthByWebBrowser
1718

@@ -24,13 +25,15 @@
2425
AuthByUsrPwdMfa,
2526
AuthByWebBrowser,
2627
AuthByIdToken,
28+
AuthByPAT,
2729
)
2830
)
2931

3032
__all__ = [
3133
"AuthByPlugin",
3234
"AuthByDefault",
3335
"AuthByKeyPair",
36+
"AuthByPAT",
3437
"AuthByOAuth",
3538
"AuthByOkta",
3639
"AuthByUsrPwdMfa",
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
4+
#
5+
6+
from __future__ import annotations
7+
8+
from typing import Any
9+
10+
from ...auth.pat import AuthByPAT as AuthByPATSync
11+
from ._by_plugin import AuthByPlugin as AuthByPluginAsync
12+
13+
14+
class AuthByPAT(AuthByPluginAsync, AuthByPATSync):
15+
def __init__(self, pat_token: str, **kwargs) -> None:
16+
"""Initializes an instance with a PAT Token."""
17+
AuthByPATSync.__init__(self, pat_token, **kwargs)
18+
19+
async def reset_secrets(self) -> None:
20+
AuthByPATSync.reset_secrets(self)
21+
22+
async def prepare(self, **kwargs: Any) -> None:
23+
AuthByPATSync.prepare(self, **kwargs)
24+
25+
async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]:
26+
return AuthByPATSync.reauthenticate(self, **kwargs)
27+
28+
async def update_body(self, body: dict[Any, Any]) -> None:
29+
AuthByPATSync.update_body(self, body)

test/unit/aio/test_auth_pat_async.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
4+
#
5+
6+
from __future__ import annotations
7+
8+
from snowflake.connector.aio.auth import AuthByPAT
9+
from snowflake.connector.auth.by_plugin import AuthType
10+
from snowflake.connector.network import PROGRAMMATIC_ACCESS_TOKEN
11+
12+
13+
async def test_auth_pat():
14+
"""Simple test if AuthByPAT class."""
15+
token = "patToken"
16+
auth = AuthByPAT(token)
17+
assert auth.type_ == AuthType.PAT
18+
assert auth.assertion_content == token
19+
body = {"data": {}}
20+
await auth.update_body(body)
21+
assert body["data"]["TOKEN"] == token, body
22+
assert body["data"]["AUTHENTICATOR"] == PROGRAMMATIC_ACCESS_TOKEN, body
23+
24+
await auth.reset_secrets()
25+
assert auth.assertion_content is None
26+
27+
28+
async def test_auth_pat_reauthenticate():
29+
"""Test PAT reauthenticate."""
30+
token = "patToken"
31+
auth = AuthByPAT(token)
32+
result = await auth.reauthenticate()
33+
assert result == {"success": False}
34+
35+
36+
async def test_pat_authenticator_creates_auth_by_pat(monkeypatch):
37+
"""Test that using PROGRAMMATIC_ACCESS_TOKEN authenticator creates AuthByPAT instance."""
38+
import snowflake.connector.aio
39+
from snowflake.connector.aio._network import SnowflakeRestful
40+
41+
# Mock the network request - this prevents actual network calls and connection errors
42+
async def mock_post_request(request, url, headers, json_body, **kwargs):
43+
return {
44+
"success": True,
45+
"message": None,
46+
"data": {
47+
"token": "TOKEN",
48+
"masterToken": "MASTER_TOKEN",
49+
"idToken": None,
50+
"parameters": [{"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"}],
51+
},
52+
}
53+
54+
# Apply the mock using monkeypatch
55+
monkeypatch.setattr(SnowflakeRestful, "_post_request", mock_post_request)
56+
57+
# Create connection with PAT authenticator
58+
conn = snowflake.connector.aio.SnowflakeConnection(
59+
user="user",
60+
account="account",
61+
database="TESTDB",
62+
warehouse="TESTWH",
63+
authenticator=PROGRAMMATIC_ACCESS_TOKEN,
64+
token="test_pat_token",
65+
)
66+
67+
await conn.connect()
68+
69+
# Verify that the auth_class is an instance of AuthByPAT
70+
assert isinstance(conn.auth_class, AuthByPAT)
71+
72+
await conn.close()

0 commit comments

Comments
 (0)