Skip to content

Commit 9b6b0a6

Browse files
Test 1731 (#1808)
Co-authored-by: Benny Lu <[email protected]> Co-authored-by: Benny Lu <[email protected]>
1 parent 7b4b708 commit 9b6b0a6

File tree

3 files changed

+108
-2
lines changed

3 files changed

+108
-2
lines changed

DESCRIPTION.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
1313

1414
- Added support for Vector types
1515
- Changed urllib3 version pin to only affect Python versions < 3.10.
16+
- Support for `private_key_file` and `private_key_file_pwd` connection parameters
1617

1718
- v3.5.0(November 13,2023)
1819

src/snowflake/connector/connection.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence
2727
from uuid import UUID
2828

29+
from cryptography.hazmat.backends import default_backend
30+
from cryptography.hazmat.primitives import serialization
2931
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
3032

3133
from . import errors, proxy
@@ -120,6 +122,26 @@ def DefaultConverterClass() -> type:
120122
return SnowflakeConverter
121123

122124

125+
def _get_private_bytes_from_file(
126+
private_key_file: str | bytes | os.PathLike[str] | os.PathLike[bytes],
127+
private_key_file_pwd: bytes | None = None,
128+
) -> bytes:
129+
with open(private_key_file, "rb") as key:
130+
private_key = serialization.load_pem_private_key(
131+
key.read(),
132+
password=private_key_file_pwd,
133+
backend=default_backend(),
134+
)
135+
136+
pkb = private_key.private_bytes(
137+
encoding=serialization.Encoding.DER,
138+
format=serialization.PrivateFormat.PKCS8,
139+
encryption_algorithm=serialization.NoEncryption(),
140+
)
141+
142+
return pkb
143+
144+
123145
SUPPORTED_PARAMSTYLES = {
124146
"qmark",
125147
"numeric",
@@ -155,6 +177,8 @@ def DefaultConverterClass() -> type:
155177
"passcode_in_password": (False, bool), # Snowflake MFA
156178
"passcode": (None, (type(None), str)), # Snowflake MFA
157179
"private_key": (None, (type(None), str, RSAPrivateKey)),
180+
"private_key_file": (None, (type(None), str)),
181+
"private_key_file_pwd": (None, (type(None), str)),
158182
"token": (None, (type(None), str)), # OAuth or JWT Token
159183
"authenticator": (DEFAULT_AUTHENTICATOR, (type(None), str)),
160184
"mfa_callback": (None, (type(None), Callable)),
@@ -935,8 +959,16 @@ def __open_connection(self):
935959
)
936960

937961
elif self._authenticator == KEY_PAIR_AUTHENTICATOR:
962+
private_key = self._private_key
963+
964+
if self._private_key_file:
965+
private_key = _get_private_bytes_from_file(
966+
self._private_key_file,
967+
self._private_key_file_pwd,
968+
)
969+
938970
self.auth_class = AuthByKeyPair(
939-
private_key=self._private_key,
971+
private_key=private_key,
940972
timeout=self._login_timeout,
941973
backoff_generator=self._backoff_generator,
942974
)
@@ -1091,7 +1123,7 @@ def __config(self, **kwargs):
10911123
{"msg": "User is empty", "errno": ER_NO_USER},
10921124
)
10931125

1094-
if self._private_key:
1126+
if self._private_key or self._private_key_file:
10951127
self._authenticator = KEY_PAIR_AUTHENTICATOR
10961128

10971129
if (

test/unit/test_connection.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,15 @@
88
import json
99
import os
1010
import sys
11+
from pathlib import Path
1112
from textwrap import dedent
13+
from unittest import mock
1214
from unittest.mock import MagicMock, patch
1315

1416
import pytest
17+
from cryptography.hazmat.backends import default_backend
18+
from cryptography.hazmat.primitives import serialization
19+
from cryptography.hazmat.primitives.asymmetric import rsa
1520

1621
import snowflake.connector
1722
from snowflake.connector.errors import (
@@ -353,3 +358,71 @@ def test_handle_timeout(mockSessionRequest, next_action):
353358
# 9 seconds should be enough for authenticator to attempt twice
354359
# however, loosen restrictions to avoid thread scheduling causing failure
355360
assert 1 < mockSessionRequest.call_count < 4
361+
362+
363+
def test__get_private_bytes_from_file(tmp_path: Path):
364+
private_key_file = tmp_path / "key.pem"
365+
366+
private_key = rsa.generate_private_key(
367+
backend=default_backend(), public_exponent=65537, key_size=2048
368+
)
369+
370+
private_key_pem = private_key.private_bytes(
371+
encoding=serialization.Encoding.PEM,
372+
format=serialization.PrivateFormat.PKCS8,
373+
encryption_algorithm=serialization.NoEncryption(),
374+
)
375+
376+
pkb = private_key.private_bytes(
377+
encoding=serialization.Encoding.DER,
378+
format=serialization.PrivateFormat.PKCS8,
379+
encryption_algorithm=serialization.NoEncryption(),
380+
)
381+
382+
private_key_file.write_bytes(private_key_pem)
383+
384+
private_key = snowflake.connector.connection._get_private_bytes_from_file(
385+
private_key_file=str(private_key_file)
386+
)
387+
388+
assert pkb == private_key
389+
390+
391+
def test_private_key_file_reading(tmp_path: Path):
392+
key_file = tmp_path / "key.pem"
393+
394+
private_key = rsa.generate_private_key(
395+
backend=default_backend(), public_exponent=65537, key_size=2048
396+
)
397+
398+
private_key_pem = private_key.private_bytes(
399+
encoding=serialization.Encoding.PEM,
400+
format=serialization.PrivateFormat.PKCS8,
401+
encryption_algorithm=serialization.NoEncryption(),
402+
)
403+
404+
key_file.write_bytes(private_key_pem)
405+
406+
pkb = private_key.private_bytes(
407+
encoding=serialization.Encoding.DER,
408+
format=serialization.PrivateFormat.PKCS8,
409+
encryption_algorithm=serialization.NoEncryption(),
410+
)
411+
412+
exc_msg = "stop execution"
413+
414+
with mock.patch(
415+
"snowflake.connector.auth.keypair.AuthByKeyPair.__init__",
416+
side_effect=Exception(exc_msg),
417+
) as m:
418+
with pytest.raises(
419+
Exception,
420+
match=exc_msg,
421+
):
422+
snowflake.connector.connect(
423+
account="test_account",
424+
user="test_user",
425+
private_key_file=str(key_file),
426+
)
427+
assert m.call_count == 1
428+
assert m.call_args_list[0].kwargs["private_key"] == pkb

0 commit comments

Comments
 (0)