Skip to content

Commit db8e265

Browse files
malthesfc-gh-sfan
andauthored
Add support for specifying 'RSAPublicKey' instance instead of raw bytes (#1477)
* Add support for specifying 'RSAPublicKey' instance instead of raw bytes This can be used to externalize the JWT encoding process. * Add test for private key abstraction layer * Add 'isinstance' check to make sure private key has an expected type * Revert method signature change Note that while this method does not require a private key, the change is inconsequential because we're anyway expecting something that implements a private key at the class level (either bytes or an abstract implementation) * Be more specific in type error message * Add failing test for non-bytes, non-RSAPrivateKey value * Fix linting issues * add changelog * Move cases which are now handled by type testing over to unit test --------- Co-authored-by: sfc-gh-sfan <[email protected]>
1 parent 1380e41 commit db8e265

File tree

5 files changed

+93
-24
lines changed

5 files changed

+93
-24
lines changed

DESCRIPTION.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
88

99
# Release Notes
1010

11+
- v3.1.1(TBD)
12+
13+
- Support `RSAPublicKey` when constructing `AuthByKeyPair` in addition to raw bytes.
14+
1115
- v3.1.0(July 31,2023)
1216

1317
- Added a feature that lets you add connection definitions to the `connections.toml` configuration file. A connection definition refers to a collection of connection parameters, for example, if you wanted to define a connection named `prod``:

src/snowflake/connector/auth/keypair.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,18 @@ class AuthByKeyPair(AuthByPlugin):
4747

4848
def __init__(
4949
self,
50-
private_key: bytes,
50+
private_key: bytes | RSAPrivateKey,
5151
lifetime_in_seconds: int = LIFETIME,
5252
) -> None:
5353
"""Inits AuthByKeyPair class with private key.
5454
5555
Args:
56-
private_key: a byte array of der formats of private key
56+
private_key: a byte array of der formats of private key, or an
57+
object that implements the `RSAPrivateKey` interface.
5758
lifetime_in_seconds: number of seconds the JWT token will be valid
5859
"""
5960
super().__init__()
60-
self._private_key: bytes | None = private_key
61+
self._private_key: bytes | RSAPrivateKey | None = private_key
6162
self._jwt_token = ""
6263
self._jwt_token_exp = 0
6364
self._lifetime = timedelta(
@@ -102,25 +103,32 @@ def prepare(
102103

103104
now = datetime.utcnow()
104105

105-
try:
106-
private_key = load_der_private_key(
107-
data=self._private_key,
108-
password=None,
109-
backend=default_backend(),
110-
)
111-
except Exception as e:
112-
raise ProgrammingError(
113-
msg=f"Failed to load private key: {e}\nPlease provide a valid "
114-
"unencrypted rsa private key in DER format as bytes object",
115-
errno=ER_INVALID_PRIVATE_KEY,
116-
)
106+
if isinstance(self._private_key, bytes):
107+
try:
108+
private_key = load_der_private_key(
109+
data=self._private_key,
110+
password=None,
111+
backend=default_backend(),
112+
)
113+
except Exception as e:
114+
raise ProgrammingError(
115+
msg=f"Failed to load private key: {e}\nPlease provide a valid "
116+
"unencrypted rsa private key in DER format as bytes object",
117+
errno=ER_INVALID_PRIVATE_KEY,
118+
)
117119

118-
if not isinstance(private_key, RSAPrivateKey):
119-
raise ProgrammingError(
120-
msg=f"Private key type ({private_key.__class__.__name__}) not supported."
121-
"\nPlease provide a valid rsa private key in DER format as bytes "
122-
"object",
123-
errno=ER_INVALID_PRIVATE_KEY,
120+
if not isinstance(private_key, RSAPrivateKey):
121+
raise ProgrammingError(
122+
msg=f"Private key type ({private_key.__class__.__name__}) not supported."
123+
"\nPlease provide a valid rsa private key in DER format as bytes "
124+
"object",
125+
errno=ER_INVALID_PRIVATE_KEY,
126+
)
127+
elif isinstance(self._private_key, RSAPrivateKey):
128+
private_key = self._private_key
129+
else:
130+
raise TypeError(
131+
f"Expected bytes or RSAPrivateKey, got {type(self._private_key)}"
124132
)
125133

126134
public_key_fp = self.calculate_public_key_fingerprint(private_key)

src/snowflake/connector/connection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from typing import Any, Callable, Generator, Iterable, NamedTuple, Sequence
2626
from uuid import UUID
2727

28+
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
29+
2830
from . import errors, proxy
2931
from ._query_context_cache import QueryContextCache
3032
from .auth import (
@@ -146,7 +148,7 @@ def DefaultConverterClass() -> type:
146148
), # network timeout (infinite by default)
147149
"passcode_in_password": (False, bool), # Snowflake MFA
148150
"passcode": (None, (type(None), str)), # Snowflake MFA
149-
"private_key": (None, (type(None), str)),
151+
"private_key": (None, (type(None), str, RSAPrivateKey)),
150152
"token": (None, (type(None), str)), # OAuth or JWT Token
151153
"authenticator": (DEFAULT_AUTHENTICATOR, (type(None), str)),
152154
"mfa_callback": (None, (type(None), Callable)),

test/integ/test_key_pair_authentication.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,6 @@ def test_bad_private_key(db_parameters):
248248
)
249249

250250
bad_private_key_test_cases = [
251-
"abcd",
252-
1234,
253251
b"abcd",
254252
dsa_private_key_der,
255253
encrypted_rsa_private_key_der,

test/unit/test_auth_keypair.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from cryptography.hazmat.backends import default_backend
1111
from cryptography.hazmat.primitives import serialization
1212
from cryptography.hazmat.primitives.asymmetric import rsa
13+
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
14+
from cryptography.hazmat.primitives.serialization import load_der_private_key
15+
from pytest import raises
1316

1417
from snowflake.connector.auth import Auth
1518
from snowflake.connector.constants import OCSPMode
@@ -59,6 +62,60 @@ def test_auth_keypair():
5962
assert rest.master_token == "MASTER_TOKEN"
6063

6164

65+
def test_auth_keypair_abc():
66+
"""Simple Key Pair test using abstraction layer."""
67+
private_key_der, public_key_der_encoded = generate_key_pair(2048)
68+
application = "testapplication"
69+
account = "testaccount"
70+
user = "testuser"
71+
72+
private_key = load_der_private_key(
73+
data=private_key_der,
74+
password=None,
75+
backend=default_backend(),
76+
)
77+
78+
assert isinstance(private_key, RSAPrivateKey)
79+
80+
auth_instance = AuthByKeyPair(private_key=private_key)
81+
auth_instance.handle_timeout(
82+
authenticator="SNOWFLAKE_JWT",
83+
service_name=None,
84+
account=account,
85+
user=user,
86+
password=None,
87+
)
88+
89+
# success test case
90+
rest = _init_rest(application, _create_mock_auth_keypair_rest_response())
91+
auth = Auth(rest)
92+
auth.authenticate(auth_instance, account, user)
93+
assert not rest._connection.errorhandler.called # not error
94+
assert rest.token == "TOKEN"
95+
assert rest.master_token == "MASTER_TOKEN"
96+
97+
98+
def test_auth_keypair_bad_type():
99+
"""Simple Key Pair test using abstraction layer."""
100+
account = "testaccount"
101+
user = "testuser"
102+
103+
class Bad:
104+
pass
105+
106+
for bad_private_key in ("abcd", 1234, Bad()):
107+
auth_instance = AuthByKeyPair(private_key=bad_private_key)
108+
with raises(TypeError) as ex:
109+
auth_instance.handle_timeout(
110+
authenticator="SNOWFLAKE_JWT",
111+
service_name=None,
112+
account=account,
113+
user=user,
114+
password=None,
115+
)
116+
assert str(type(bad_private_key)) in str(ex)
117+
118+
62119
def _init_rest(application, post_requset):
63120
connection = MagicMock()
64121
connection._login_timeout = 120

0 commit comments

Comments
 (0)