|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import os |
| 4 | +import base64 |
| 5 | +from cryptography import x509 |
| 6 | +from cryptography.hazmat.primitives import hashes |
| 7 | +from cryptography.hazmat.primitives.asymmetric import ec |
| 8 | +from cryptography.hazmat.primitives.kdf.hkdf import HKDF |
| 9 | +from cryptography.hazmat.primitives.ciphers import ( |
| 10 | + Cipher, algorithms, modes |
| 11 | +) |
| 12 | + |
| 13 | + |
| 14 | +def aes_gcm_encrypt_bytes(key: bytes, iv: bytes, plain_bytes: bytes, associated_data: bytes = b"") -> bytes: |
| 15 | + # aes_gcm_encrypt_bytes encrypt message using AES-GCM |
| 16 | + encryptor = Cipher( |
| 17 | + algorithms.AES(key), |
| 18 | + modes.GCM(iv), |
| 19 | + ).encryptor() |
| 20 | + # associated_data will be authenticated but not encrypted, |
| 21 | + # it must also be passed in on decryption. |
| 22 | + encryptor.authenticate_additional_data(associated_data) |
| 23 | + # Encrypt the plaintext and get the associated ciphertext. |
| 24 | + # GCM does not require padding. |
| 25 | + ciphertext = encryptor.update(plain_bytes) + encryptor.finalize() |
| 26 | + return ciphertext + encryptor.tag |
| 27 | + |
| 28 | + |
| 29 | +def aes_gcm_encrypt_base64_string(key: bytes, nonce: bytes, plaintext: str) -> str: |
| 30 | + """aes_gcm_encrypt_base64_string Encrypt message from base64 string to string using AES-GCM |
| 31 | + """ |
| 32 | + plain_bytes = plaintext.encode() |
| 33 | + # Encrypt message to string using AES-GCM |
| 34 | + c = aes_gcm_encrypt_bytes(key, nonce, plain_bytes) |
| 35 | + return base64.b64encode(c).decode() |
| 36 | + |
| 37 | + |
| 38 | +def aes_gcm_decrypt_bytes(key: bytes, iv: bytes, cipher_bytes: bytes, associated_data: bytes = b"") -> bytes: |
| 39 | + """aes_gcm_decrypt_bytes Decrypt message from bytes to bytes using AES-GCM |
| 40 | + """ |
| 41 | + tag_length = 16 # default aes gcm tag length |
| 42 | + cipher = cipher_bytes[:-tag_length] |
| 43 | + tag = cipher_bytes[-tag_length:] |
| 44 | + # Construct a Cipher object, with the key, iv, and additionally the |
| 45 | + # GCM tag used for authenticating the message. |
| 46 | + decryptor = Cipher( |
| 47 | + algorithms.AES(key), |
| 48 | + modes.GCM(iv, tag), |
| 49 | + ).decryptor() |
| 50 | + # We put associated_data back in or the tag will fail to verify |
| 51 | + # when we finalize the decryptor. |
| 52 | + decryptor.authenticate_additional_data(associated_data) |
| 53 | + # Decryption gets us the authenticated plaintext. |
| 54 | + # If the tag does not match an InvalidTag exception will be raised. |
| 55 | + return decryptor.update(cipher) + decryptor.finalize() |
| 56 | + |
| 57 | + |
| 58 | +def aes_gcm_decrypt_base64_string(key: bytes, nonce: bytes, ciphertext: str) -> str: |
| 59 | + # Decrypt message(base64.std.string) using AES-GCM |
| 60 | + cipher_bytes = base64.decodebytes(ciphertext.encode()) |
| 61 | + return aes_gcm_decrypt_bytes(key, nonce, cipher_bytes).decode() |
| 62 | + |
| 63 | + |
| 64 | +def marshal_cryptography_pub_key(key: ec.EllipticCurvePublicNumbers) -> bytes: |
| 65 | + # python version of crypto/elliptic/elliptic.go Marshal |
| 66 | + # without point on curve check |
| 67 | + return bytes([4]) + key.x.to_bytes(32, 'big') + key.y.to_bytes(32, 'big') |
| 68 | + |
| 69 | + |
| 70 | +class key_agreement_client(): |
| 71 | + def __init__(self, certificate_pem_string: str) -> None: |
| 72 | + """ Load cert and extract public key |
| 73 | + """ |
| 74 | + pem_data = certificate_pem_string.encode() |
| 75 | + self._cert = x509.load_pem_x509_certificate(pem_data) |
| 76 | + cert_pub = self._cert.public_key().public_numbers() |
| 77 | + self._curve = ec._CURVE_TYPES[self._cert.public_key().curve.name]() |
| 78 | + self._public_key = ec.EllipticCurvePublicNumbers( |
| 79 | + cert_pub.x, cert_pub.y, self._curve).public_key() |
| 80 | + |
| 81 | + def encrypt_string(self, plaintext: str) -> tuple[bytes, bytes, str, str]: |
| 82 | + """encrypt_string encrypt plaintext with ECIES DH protocol |
| 83 | + """ |
| 84 | + key, nonce, token = self.generate_ecies_key_pair() |
| 85 | + # Encrypt message using AES-GCM |
| 86 | + ciphertext = aes_gcm_encrypt_base64_string(key, nonce, plaintext) |
| 87 | + return key, nonce, token, ciphertext |
| 88 | + |
| 89 | + def generate_ecies_key_pair(self) -> tuple[bytes, bytes, str]: |
| 90 | + """generate_ecies_key_pair generate ECIES key pair |
| 91 | + """ |
| 92 | + # Generate an ephemeral elliptic curve scalar and point |
| 93 | + peer_private_key = ec.generate_private_key(self._curve) |
| 94 | + dh = peer_private_key.exchange(ec.ECDH(), self._public_key) |
| 95 | + R = peer_private_key.public_key().public_numbers() |
| 96 | + |
| 97 | + # Derive symmetric key and nonce via HKDF |
| 98 | + length = 32 + 12 |
| 99 | + buf = HKDF( |
| 100 | + algorithm=hashes.SHA256(), |
| 101 | + length=length, |
| 102 | + salt=None, |
| 103 | + info=None, |
| 104 | + ).derive(dh) |
| 105 | + key = buf[:32] |
| 106 | + nonce = buf[32:length] |
| 107 | + |
| 108 | + token = marshal_cryptography_pub_key(R) |
| 109 | + return key, nonce, base64.b64encode(token).decode() |
0 commit comments