Skip to content

Commit 9e04925

Browse files
sfc-gh-kwagnersfc-gh-alingsfc-gh-mkeller
authored
SNOW-630142 Custom Auth (#1215)
Co-authored-by: Adam Ling <[email protected]> Co-authored-by: Mark Keller <[email protected]>
1 parent bd39dc7 commit 9e04925

21 files changed

+948
-430
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#
2+
# Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
from __future__ import annotations
6+
7+
from ._auth import Auth, get_public_key_fingerprint, get_token_from_private_key
8+
from .by_plugin import AuthByPlugin, AuthType
9+
from .default import AuthByDefault
10+
from .keypair import AuthByKeyPair
11+
from .oauth import AuthByOAuth
12+
from .okta import AuthByOkta
13+
from .usrpwdmfa import AuthByUsrPwdMfa
14+
from .webbrowser import AuthByWebBrowser
15+
16+
FIRST_PARTY_AUTHENTICATORS = frozenset(
17+
(
18+
AuthByDefault,
19+
AuthByKeyPair,
20+
AuthByOAuth,
21+
AuthByOkta,
22+
AuthByUsrPwdMfa,
23+
AuthByWebBrowser,
24+
)
25+
)
26+
27+
__all__ = [
28+
"AuthByPlugin",
29+
"AuthByDefault",
30+
"AuthByKeyPair",
31+
"AuthByOAuth",
32+
"AuthByOkta",
33+
"AuthByUsrPwdMfa",
34+
"AuthByWebBrowser",
35+
"Auth",
36+
"AuthType",
37+
"FIRST_PARTY_AUTHENTICATORS",
38+
"get_public_key_fingerprint",
39+
"get_token_from_private_key",
40+
]

src/snowflake/connector/auth.py renamed to src/snowflake/connector/auth/_auth.py

Lines changed: 81 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python
21
#
32
# Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved.
43
#
@@ -16,6 +15,7 @@
1615
from os import getenv, makedirs, mkdir, path, remove, removedirs, rmdir
1716
from os.path import expanduser
1817
from threading import Lock, Thread
18+
from typing import TYPE_CHECKING, Any, Callable
1919

2020
from cryptography.hazmat.backends import default_backend
2121
from cryptography.hazmat.primitives.serialization import (
@@ -26,44 +26,45 @@
2626
load_pem_private_key,
2727
)
2828

29-
from .auth_keypair import AuthByKeyPair
30-
from .auth_usrpwdmfa import AuthByUsrPwdMfa
31-
from .compat import IS_LINUX, IS_MACOS, IS_WINDOWS, urlencode
32-
from .constants import (
29+
from ..compat import IS_LINUX, IS_MACOS, IS_WINDOWS, urlencode
30+
from ..constants import (
31+
DAY_IN_SECONDS,
3332
HTTP_HEADER_ACCEPT,
3433
HTTP_HEADER_CONTENT_TYPE,
3534
HTTP_HEADER_SERVICE_NAME,
3635
HTTP_HEADER_USER_AGENT,
3736
PARAMETER_CLIENT_REQUEST_MFA_TOKEN,
3837
PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL,
3938
)
40-
from .description import (
39+
from ..description import (
4140
COMPILER,
4241
IMPLEMENTATION,
4342
OPERATING_SYSTEM,
4443
PLATFORM,
4544
PYTHON_VERSION,
4645
)
47-
from .errorcode import ER_FAILED_TO_CONNECT_TO_DB
48-
from .errors import (
46+
from ..errorcode import ER_FAILED_TO_CONNECT_TO_DB
47+
from ..errors import (
4948
BadGatewayError,
5049
DatabaseError,
5150
Error,
5251
ForbiddenError,
5352
ProgrammingError,
5453
ServiceUnavailableError,
5554
)
56-
from .network import (
55+
from ..network import (
5756
ACCEPT_TYPE_APPLICATION_SNOWFLAKE,
5857
CONTENT_TYPE_APPLICATION_JSON,
5958
ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE,
60-
KEY_PAIR_AUTHENTICATOR,
6159
PYTHON_CONNECTOR_USER_AGENT,
6260
ReauthenticationRequest,
6361
)
64-
from .options import installed_keyring, keyring
65-
from .sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
66-
from .version import VERSION
62+
from ..options import installed_keyring, keyring
63+
from ..sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
64+
from ..version import VERSION
65+
66+
if TYPE_CHECKING:
67+
from . import AuthByPlugin
6768

6869
logger = logging.getLogger(__name__)
6970

@@ -153,19 +154,19 @@ def base_auth_data(
153154

154155
def authenticate(
155156
self,
156-
auth_instance,
157-
account,
158-
user,
159-
database=None,
160-
schema=None,
161-
warehouse=None,
162-
role=None,
163-
passcode=None,
164-
passcode_in_password=False,
165-
mfa_callback=None,
166-
password_callback=None,
167-
session_parameters=None,
168-
timeout=120,
157+
auth_instance: AuthByPlugin,
158+
account: str,
159+
user: str,
160+
database: str | None = None,
161+
schema: str | None = None,
162+
warehouse: str | None = None,
163+
role: str | None = None,
164+
passcode: str | None = None,
165+
passcode_in_password: bool = False,
166+
mfa_callback: Callable[[], None] | None = None,
167+
password_callback: Callable[[], str] | None = None,
168+
session_parameters: dict[Any, Any] | None = None,
169+
timeout: int = 120,
169170
) -> dict[str, str | int | bool]:
170171
logger.debug("authenticate")
171172

@@ -242,15 +243,7 @@ def authenticate(
242243
# login_timeout comes from user configuration.
243244
# Between login timeout and auth specific
244245
# timeout use whichever value is smaller
245-
if hasattr(auth_instance, "get_timeout"):
246-
logger.debug(
247-
f"Authenticator, {type(auth_instance).__name__}, implements get_timeout"
248-
)
249-
auth_timeout = min(
250-
self._rest._connection.login_timeout, auth_instance.get_timeout()
251-
)
252-
else:
253-
auth_timeout = self._rest._connection.login_timeout
246+
auth_timeout = min(self._rest._connection.login_timeout, auth_instance.timeout)
254247
logger.debug(f"Timeout set to {auth_timeout}")
255248

256249
try:
@@ -386,15 +379,19 @@ def post_request_wrapper(self, url, headers, body):
386379
)
387380
)
388381

389-
if type(auth_instance) is AuthByKeyPair:
382+
from . import AuthByKeyPair
383+
384+
if isinstance(auth_instance, AuthByKeyPair):
390385
logger.debug(
391386
"JWT Token authentication failed. "
392387
"Token expires at: %s. "
393388
"Current Time: %s",
394389
str(auth_instance._jwt_token_exp),
395390
str(datetime.utcnow()),
396391
)
397-
if type(auth_instance) is AuthByUsrPwdMfa:
392+
from . import AuthByUsrPwdMfa
393+
394+
if isinstance(auth_instance, AuthByUsrPwdMfa):
398395
delete_temporary_credential(self._rest._host, user, MFA_TOKEN)
399396
Error.errorhandler_wrapper(
400397
self._rest._connection,
@@ -483,16 +480,33 @@ def _read_temporary_credential(self, host, user, cred_type):
483480
logger.debug("OS not supported for Local Secure Storage")
484481
return cred
485482

486-
def read_temporary_credentials(self, host, user, session_parameters):
483+
def read_temporary_credentials(
484+
self,
485+
host: str,
486+
user: str,
487+
session_parameters: dict[str, Any],
488+
) -> None:
487489
if session_parameters.get(PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL, False):
488-
self._rest.id_token = self._read_temporary_credential(host, user, ID_TOKEN)
490+
self._rest.id_token = self._read_temporary_credential(
491+
host,
492+
user,
493+
ID_TOKEN,
494+
)
489495

490496
if session_parameters.get(PARAMETER_CLIENT_REQUEST_MFA_TOKEN, False):
491497
self._rest.mfa_token = self._read_temporary_credential(
492-
host, user, MFA_TOKEN
498+
host,
499+
user,
500+
MFA_TOKEN,
493501
)
494502

495-
def _write_temporary_credential(self, host, user, cred_type, cred):
503+
def _write_temporary_credential(
504+
self,
505+
host: str,
506+
user: str,
507+
cred_type: str,
508+
cred: str | None,
509+
) -> None:
496510
if not cred:
497511
logger.debug(
498512
"no credential is given when try to store temporary credential"
@@ -522,9 +536,18 @@ def _write_temporary_credential(self, host, user, cred_type, cred):
522536
else:
523537
logger.debug("OS not supported for Local Secure Storage")
524538

525-
def write_temporary_credentials(self, host, user, session_parameters, response):
526-
if self._rest._connection.consent_cache_id_token and session_parameters.get(
527-
PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL, False
539+
def write_temporary_credentials(
540+
self,
541+
host: str,
542+
user: str,
543+
session_parameters: dict[str, Any],
544+
response: dict[str, Any],
545+
) -> None:
546+
if (
547+
self._rest._connection.auth_class.consent_cache_id_token
548+
and session_parameters.get(
549+
PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL, False
550+
)
528551
):
529552
self._write_temporary_credential(
530553
host, user, ID_TOKEN, response["data"].get("idToken")
@@ -534,10 +557,9 @@ def write_temporary_credentials(self, host, user, session_parameters, response):
534557
self._write_temporary_credential(
535558
host, user, MFA_TOKEN, response["data"].get("mfaToken")
536559
)
537-
return
538560

539561

540-
def flush_temporary_credentials():
562+
def flush_temporary_credentials() -> None:
541563
"""Flush temporary credentials in memory into disk. Need to hold TEMPORARY_CREDENTIAL_LOCK."""
542564
global TEMPORARY_CREDENTIAL
543565
global TEMPORARY_CREDENTIAL_FILE
@@ -566,7 +588,7 @@ def flush_temporary_credentials():
566588
unlock_temporary_credential_file()
567589

568590

569-
def write_temporary_credential_file(host, cred_name, cred):
591+
def write_temporary_credential_file(host, cred_name, cred) -> None:
570592
"""Writes temporary credential file when OS is Linux."""
571593
if not CACHE_DIR:
572594
# no cache is enabled
@@ -581,7 +603,7 @@ def write_temporary_credential_file(host, cred_name, cred):
581603
flush_temporary_credentials()
582604

583605

584-
def read_temporary_credential_file():
606+
def read_temporary_credential_file() -> None:
585607
"""Reads temporary credential file when OS is Linux."""
586608
if not CACHE_DIR:
587609
# no cache is enabled
@@ -616,10 +638,9 @@ def read_temporary_credential_file():
616638
)
617639
finally:
618640
unlock_temporary_credential_file()
619-
return None
620641

621642

622-
def lock_temporary_credential_file():
643+
def lock_temporary_credential_file() -> bool:
623644
global TEMPORARY_CREDENTIAL_FILE_LOCK
624645
try:
625646
mkdir(TEMPORARY_CREDENTIAL_FILE_LOCK)
@@ -632,7 +653,7 @@ def lock_temporary_credential_file():
632653
return False
633654

634655

635-
def unlock_temporary_credential_file():
656+
def unlock_temporary_credential_file() -> bool:
636657
global TEMPORARY_CREDENTIAL_FILE_LOCK
637658
try:
638659
rmdir(TEMPORARY_CREDENTIAL_FILE_LOCK)
@@ -709,10 +730,13 @@ def get_token_from_private_key(
709730
format=PrivateFormat.PKCS8,
710731
encryption_algorithm=NoEncryption(),
711732
)
712-
auth_instance = AuthByKeyPair(private_key, 1440 * 60) # token valid for 24 hours
713-
return auth_instance.authenticate(
714-
KEY_PAIR_AUTHENTICATOR, None, account, user, key_password
715-
)
733+
from . import AuthByKeyPair
734+
735+
auth_instance = AuthByKeyPair(
736+
private_key,
737+
DAY_IN_SECONDS,
738+
) # token valid for 24 hours
739+
return auth_instance.prepare(account=account, user=user)
716740

717741

718742
def get_public_key_fingerprint(private_key_file: str, password: str) -> str:
@@ -729,4 +753,6 @@ def get_public_key_fingerprint(private_key_file: str, password: str) -> str:
729753
private_key = load_der_private_key(
730754
data=private_key, password=None, backend=default_backend()
731755
)
756+
from . import AuthByKeyPair
757+
732758
return AuthByKeyPair.calculate_public_key_fingerprint(private_key)

0 commit comments

Comments
 (0)