Skip to content

Commit f059f15

Browse files
author
liuhuiqi.7
committed
feat(ark e2e): update utils to achieve key agreement
Change-Id: I11bea480e941be2cad4d585d398fd765e4fa9574
1 parent 28e8263 commit f059f15

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from __future__ import annotations
2+
3+
import os
4+
import base64
5+
from cryptography import x509
6+
from cryptography.hazmat.primitives import hashes
7+
from cryptography.hazmat.primitives.asymmetric import ec
8+
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
9+
from cryptography.hazmat.primitives.ciphers import (
10+
Cipher, algorithms, modes
11+
)
12+
13+
14+
def aes_gcm_encrypt_bytes(key: bytes, iv: bytes, plain_bytes: bytes, associated_data: bytes = b"") -> bytes:
15+
# aes_gcm_encrypt_bytes encrypt message using AES-GCM
16+
encryptor = Cipher(
17+
algorithms.AES(key),
18+
modes.GCM(iv),
19+
).encryptor()
20+
# associated_data will be authenticated but not encrypted,
21+
# it must also be passed in on decryption.
22+
encryptor.authenticate_additional_data(associated_data)
23+
# Encrypt the plaintext and get the associated ciphertext.
24+
# GCM does not require padding.
25+
ciphertext = encryptor.update(plain_bytes) + encryptor.finalize()
26+
return ciphertext + encryptor.tag
27+
28+
29+
def aes_gcm_encrypt_base64_string(key: bytes, nonce: bytes, plaintext: str) -> str:
30+
"""aes_gcm_encrypt_base64_string Encrypt message from base64 string to string using AES-GCM
31+
"""
32+
plain_bytes = plaintext.encode()
33+
# Encrypt message to string using AES-GCM
34+
c = aes_gcm_encrypt_bytes(key, nonce, plain_bytes)
35+
return base64.b64encode(c).decode()
36+
37+
38+
def aes_gcm_decrypt_bytes(key: bytes, iv: bytes, cipher_bytes: bytes, associated_data: bytes = b"") -> bytes:
39+
"""aes_gcm_decrypt_bytes Decrypt message from bytes to bytes using AES-GCM
40+
"""
41+
tag_length = 16 # default aes gcm tag length
42+
cipher = cipher_bytes[:-tag_length]
43+
tag = cipher_bytes[-tag_length:]
44+
# Construct a Cipher object, with the key, iv, and additionally the
45+
# GCM tag used for authenticating the message.
46+
decryptor = Cipher(
47+
algorithms.AES(key),
48+
modes.GCM(iv, tag),
49+
).decryptor()
50+
# We put associated_data back in or the tag will fail to verify
51+
# when we finalize the decryptor.
52+
decryptor.authenticate_additional_data(associated_data)
53+
# Decryption gets us the authenticated plaintext.
54+
# If the tag does not match an InvalidTag exception will be raised.
55+
return decryptor.update(cipher) + decryptor.finalize()
56+
57+
58+
def aes_gcm_decrypt_base64_string(key: bytes, nonce: bytes, ciphertext: str) -> str:
59+
# Decrypt message(base64.std.string) using AES-GCM
60+
cipher_bytes = base64.decodebytes(ciphertext.encode())
61+
return aes_gcm_decrypt_bytes(key, nonce, cipher_bytes).decode()
62+
63+
64+
def marshal_cryptography_pub_key(key: ec.EllipticCurvePublicNumbers) -> bytes:
65+
# python version of crypto/elliptic/elliptic.go Marshal
66+
# without point on curve check
67+
return bytes([4]) + key.x.to_bytes(32, 'big') + key.y.to_bytes(32, 'big')
68+
69+
70+
class key_agreement_client():
71+
def __init__(self, certificate_pem_string: str) -> None:
72+
""" Load cert and extract public key
73+
"""
74+
pem_data = certificate_pem_string.encode()
75+
self._cert = x509.load_pem_x509_certificate(pem_data)
76+
cert_pub = self._cert.public_key().public_numbers()
77+
self._curve = ec._CURVE_TYPES[self._cert.public_key().curve.name]()
78+
self._public_key = ec.EllipticCurvePublicNumbers(
79+
cert_pub.x, cert_pub.y, self._curve).public_key()
80+
81+
def encrypt_string(self, plaintext: str) -> tuple[bytes, bytes, str, str]:
82+
"""encrypt_string encrypt plaintext with ECIES DH protocol
83+
"""
84+
key, nonce, token = self.generate_ecies_key_pair()
85+
# Encrypt message using AES-GCM
86+
ciphertext = aes_gcm_encrypt_base64_string(key, nonce, plaintext)
87+
return key, nonce, token, ciphertext
88+
89+
def generate_ecies_key_pair(self) -> tuple[bytes, bytes, str]:
90+
"""generate_ecies_key_pair generate ECIES key pair
91+
"""
92+
# Generate an ephemeral elliptic curve scalar and point
93+
peer_private_key = ec.generate_private_key(self._curve)
94+
dh = peer_private_key.exchange(ec.ECDH(), self._public_key)
95+
R = peer_private_key.public_key().public_numbers()
96+
97+
# Derive symmetric key and nonce via HKDF
98+
length = 32 + 12
99+
buf = HKDF(
100+
algorithm=hashes.SHA256(),
101+
length=length,
102+
salt=None,
103+
info=None,
104+
).derive(dh)
105+
key = buf[:32]
106+
nonce = buf[32:length]
107+
108+
token = marshal_cryptography_pub_key(R)
109+
return key, nonce, base64.b64encode(token).decode()

0 commit comments

Comments
 (0)