Skip to content

Commit 06d4492

Browse files
rosspb3sfc-gh-stan
andauthored
Fix for an issue where specifying a private_key_file_pwd in a connections.toml file causes an error. (#1878)
* Fix for issue SNOW-1045815 where specifying a private_key_file_pwd in a connections.toml file causes an error. * Fix lint --------- Co-authored-by: Sophie Tan <[email protected]>
1 parent 066f1ec commit 06d4492

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

src/snowflake/connector/connection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,10 @@ def DefaultConverterClass() -> type:
124124

125125
def _get_private_bytes_from_file(
126126
private_key_file: str | bytes | os.PathLike[str] | os.PathLike[bytes],
127-
private_key_file_pwd: bytes | None = None,
127+
private_key_file_pwd: bytes | str | None = None,
128128
) -> bytes:
129+
if private_key_file_pwd is not None and isinstance(private_key_file_pwd, str):
130+
private_key_file_pwd = private_key_file_pwd.encode("utf-8")
129131
with open(private_key_file, "rb") as key:
130132
private_key = serialization.load_pem_private_key(
131133
key.read(),
@@ -178,7 +180,7 @@ def _get_private_bytes_from_file(
178180
"passcode": (None, (type(None), str)), # Snowflake MFA
179181
"private_key": (None, (type(None), str, RSAPrivateKey)),
180182
"private_key_file": (None, (type(None), str)),
181-
"private_key_file_pwd": (None, (type(None), str)),
183+
"private_key_file_pwd": (None, (type(None), str, bytes)),
182184
"token": (None, (type(None), str)), # OAuth or JWT Token
183185
"authenticator": (DEFAULT_AUTHENTICATOR, (type(None), str)),
184186
"mfa_callback": (None, (type(None), Callable)),

test/unit/test_connection.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010
import sys
1111
from pathlib import Path
12+
from secrets import token_urlsafe
1213
from textwrap import dedent
1314
from unittest import mock
1415
from unittest.mock import MagicMock, patch
@@ -428,6 +429,49 @@ def test_private_key_file_reading(tmp_path: Path):
428429
assert m.call_args_list[0].kwargs["private_key"] == pkb
429430

430431

432+
def test_encrypted_private_key_file_reading(tmp_path: Path):
433+
key_file = tmp_path / "key.pem"
434+
private_key_password = token_urlsafe(25)
435+
private_key = rsa.generate_private_key(
436+
backend=default_backend(), public_exponent=65537, key_size=2048
437+
)
438+
439+
private_key_pem = private_key.private_bytes(
440+
encoding=serialization.Encoding.PEM,
441+
format=serialization.PrivateFormat.PKCS8,
442+
encryption_algorithm=serialization.BestAvailableEncryption(
443+
private_key_password.encode("utf-8")
444+
),
445+
)
446+
447+
key_file.write_bytes(private_key_pem)
448+
449+
pkb = private_key.private_bytes(
450+
encoding=serialization.Encoding.DER,
451+
format=serialization.PrivateFormat.PKCS8,
452+
encryption_algorithm=serialization.NoEncryption(),
453+
)
454+
455+
exc_msg = "stop execution"
456+
457+
with mock.patch(
458+
"snowflake.connector.auth.keypair.AuthByKeyPair.__init__",
459+
side_effect=Exception(exc_msg),
460+
) as m:
461+
with pytest.raises(
462+
Exception,
463+
match=exc_msg,
464+
):
465+
snowflake.connector.connect(
466+
account="test_account",
467+
user="test_user",
468+
private_key_file=str(key_file),
469+
private_key_file_pwd=private_key_password,
470+
)
471+
assert m.call_count == 1
472+
assert m.call_args_list[0].kwargs["private_key"] == pkb
473+
474+
431475
def test_expired_detection():
432476
with mock.patch(
433477
"snowflake.connector.network.SnowflakeRestful._post_request",

0 commit comments

Comments
 (0)