|
9 | 9 | import os |
10 | 10 | import sys |
11 | 11 | from pathlib import Path |
| 12 | +from secrets import token_urlsafe |
12 | 13 | from textwrap import dedent |
13 | 14 | from unittest import mock |
14 | 15 | from unittest.mock import MagicMock, patch |
@@ -428,6 +429,49 @@ def test_private_key_file_reading(tmp_path: Path): |
428 | 429 | assert m.call_args_list[0].kwargs["private_key"] == pkb |
429 | 430 |
|
430 | 431 |
|
| 432 | +def test_encrypted_private_key_file_reading(tmp_path: Path): |
| 433 | + key_file = tmp_path / "key.pem" |
| 434 | + private_key_password = token_urlsafe(25) |
| 435 | + private_key = rsa.generate_private_key( |
| 436 | + backend=default_backend(), public_exponent=65537, key_size=2048 |
| 437 | + ) |
| 438 | + |
| 439 | + private_key_pem = private_key.private_bytes( |
| 440 | + encoding=serialization.Encoding.PEM, |
| 441 | + format=serialization.PrivateFormat.PKCS8, |
| 442 | + encryption_algorithm=serialization.BestAvailableEncryption( |
| 443 | + private_key_password.encode("utf-8") |
| 444 | + ), |
| 445 | + ) |
| 446 | + |
| 447 | + key_file.write_bytes(private_key_pem) |
| 448 | + |
| 449 | + pkb = private_key.private_bytes( |
| 450 | + encoding=serialization.Encoding.DER, |
| 451 | + format=serialization.PrivateFormat.PKCS8, |
| 452 | + encryption_algorithm=serialization.NoEncryption(), |
| 453 | + ) |
| 454 | + |
| 455 | + exc_msg = "stop execution" |
| 456 | + |
| 457 | + with mock.patch( |
| 458 | + "snowflake.connector.auth.keypair.AuthByKeyPair.__init__", |
| 459 | + side_effect=Exception(exc_msg), |
| 460 | + ) as m: |
| 461 | + with pytest.raises( |
| 462 | + Exception, |
| 463 | + match=exc_msg, |
| 464 | + ): |
| 465 | + snowflake.connector.connect( |
| 466 | + account="test_account", |
| 467 | + user="test_user", |
| 468 | + private_key_file=str(key_file), |
| 469 | + private_key_file_pwd=private_key_password, |
| 470 | + ) |
| 471 | + assert m.call_count == 1 |
| 472 | + assert m.call_args_list[0].kwargs["private_key"] == pkb |
| 473 | + |
| 474 | + |
431 | 475 | def test_expired_detection(): |
432 | 476 | with mock.patch( |
433 | 477 | "snowflake.connector.network.SnowflakeRestful._post_request", |
|
0 commit comments