|
5 | 5 |
|
6 | 6 | from __future__ import annotations |
7 | 7 |
|
| 8 | +import datetime |
8 | 9 | import logging |
9 | 10 | import os |
10 | 11 | import platform |
|
13 | 14 | from os import environ, path |
14 | 15 | from unittest import mock |
15 | 16 |
|
| 17 | +from asn1crypto import x509 as asn1crypto509 |
| 18 | +from cryptography import x509 |
| 19 | +from cryptography.hazmat.backends import default_backend |
| 20 | +from cryptography.hazmat.primitives import hashes |
| 21 | +from cryptography.hazmat.primitives.asymmetric import rsa |
| 22 | +from cryptography.hazmat.primitives.serialization import Encoding |
| 23 | + |
16 | 24 | try: |
17 | 25 | from snowflake.connector.util_text import random_string |
18 | 26 | except ImportError: |
@@ -546,3 +554,68 @@ def test_building_new_retry(): |
546 | 554 | ) |
547 | 555 |
|
548 | 556 | del os.environ["SF_OCSP_ACTIVATE_NEW_ENDPOINT"] |
| 557 | + |
| 558 | + |
| 559 | +@pytest.mark.parametrize( |
| 560 | + "hash_algorithm", |
| 561 | + [ |
| 562 | + hashes.SHA256(), |
| 563 | + hashes.SHA384(), |
| 564 | + hashes.SHA512(), |
| 565 | + hashes.SHA3_256(), |
| 566 | + hashes.SHA3_384(), |
| 567 | + hashes.SHA3_512(), |
| 568 | + ], |
| 569 | +) |
| 570 | +def test_signature_verification(hash_algorithm): |
| 571 | + # Generate a private key |
| 572 | + private_key = rsa.generate_private_key( |
| 573 | + public_exponent=65537, key_size=1024, backend=default_backend() |
| 574 | + ) |
| 575 | + |
| 576 | + # Generate a public key |
| 577 | + public_key = private_key.public_key() |
| 578 | + |
| 579 | + # Create a certificate |
| 580 | + subject = x509.Name( |
| 581 | + [ |
| 582 | + x509.NameAttribute(x509.NameOID.COUNTRY_NAME, "US"), |
| 583 | + ] |
| 584 | + ) |
| 585 | + |
| 586 | + issuer = subject |
| 587 | + |
| 588 | + cert = ( |
| 589 | + x509.CertificateBuilder() |
| 590 | + .subject_name(subject) |
| 591 | + .issuer_name(issuer) |
| 592 | + .public_key(public_key) |
| 593 | + .serial_number(x509.random_serial_number()) |
| 594 | + .not_valid_before(datetime.datetime.now()) |
| 595 | + .not_valid_after(datetime.datetime.now() + datetime.timedelta(days=365)) |
| 596 | + .add_extension( |
| 597 | + x509.SubjectAlternativeName([x509.DNSName("example.com")]), |
| 598 | + critical=False, |
| 599 | + ) |
| 600 | + .sign(private_key, hash_algorithm, default_backend()) |
| 601 | + ) |
| 602 | + |
| 603 | + # in snowflake, we use lib asn1crypto to load certificate, not using lib cryptography |
| 604 | + asy1_509_cert = asn1crypto509.Certificate.load(cert.public_bytes(Encoding.DER)) |
| 605 | + |
| 606 | + # sha3 family is not recognized by asn1crypto library |
| 607 | + if hash_algorithm.name.startswith("sha3-"): |
| 608 | + with pytest.raises(ValueError): |
| 609 | + SFOCSP().verify_signature( |
| 610 | + asy1_509_cert.hash_algo, |
| 611 | + cert.signature, |
| 612 | + asy1_509_cert, |
| 613 | + asy1_509_cert["tbs_certificate"], |
| 614 | + ) |
| 615 | + else: |
| 616 | + SFOCSP().verify_signature( |
| 617 | + asy1_509_cert.hash_algo, |
| 618 | + cert.signature, |
| 619 | + asy1_509_cert, |
| 620 | + asy1_509_cert["tbs_certificate"], |
| 621 | + ) |
0 commit comments