1212from logging import getLogger
1313from typing import IO , TYPE_CHECKING
1414
15+ from Cryptodome .Cipher import AES
1516from cryptography .hazmat .backends import default_backend
1617from cryptography .hazmat .primitives .ciphers import Cipher , algorithms , modes
1718
@@ -68,6 +69,7 @@ def encrypt_stream(
6869 The encryption metadata.
6970 """
7071 logger = getLogger (__name__ )
72+ use_openssl_only = os .getenv ("SF_USE_OPENSSL_ONLY" , "False" ) == "True"
7173 decoded_key = base64 .standard_b64decode (
7274 encryption_material .query_stage_master_key
7375 )
@@ -77,9 +79,14 @@ def encrypt_stream(
7779 # Generate key for data encryption
7880 iv_data = SnowflakeEncryptionUtil .get_secure_random (block_size )
7981 file_key = SnowflakeEncryptionUtil .get_secure_random (key_size )
80- backend = default_backend ()
81- cipher = Cipher (algorithms .AES (file_key ), modes .CBC (iv_data ), backend = backend )
82- encryptor = cipher .encryptor ()
82+ if not use_openssl_only :
83+ data_cipher = AES .new (key = file_key , mode = AES .MODE_CBC , IV = iv_data )
84+ else :
85+ backend = default_backend ()
86+ cipher = Cipher (
87+ algorithms .AES (file_key ), modes .CBC (iv_data ), backend = backend
88+ )
89+ encryptor = cipher .encryptor ()
8390
8491 padded = False
8592 while True :
@@ -89,17 +96,30 @@ def encrypt_stream(
8996 elif len (chunk ) % block_size != 0 :
9097 chunk = PKCS5_PAD (chunk , block_size )
9198 padded = True
92- out .write (encryptor .update (chunk ))
99+ if not use_openssl_only :
100+ out .write (data_cipher .encrypt (chunk ))
101+ else :
102+ out .write (encryptor .update (chunk ))
93103 if not padded :
94- out .write (encryptor .update (block_size * chr (block_size ).encode (UTF8 )))
95- out .write (encryptor .finalize ())
104+ if not use_openssl_only :
105+ out .write (
106+ data_cipher .encrypt (block_size * chr (block_size ).encode (UTF8 ))
107+ )
108+ else :
109+ out .write (encryptor .update (block_size * chr (block_size ).encode (UTF8 )))
110+ if use_openssl_only :
111+ out .write (encryptor .finalize ())
96112
97113 # encrypt key with QRMK
98- cipher = Cipher (algorithms .AES (decoded_key ), modes .ECB (), backend = backend )
99- encryptor = cipher .encryptor ()
100- enc_kek = (
101- encryptor .update (PKCS5_PAD (file_key , block_size )) + encryptor .finalize ()
102- )
114+ if not use_openssl_only :
115+ key_cipher = AES .new (key = decoded_key , mode = AES .MODE_ECB )
116+ enc_kek = key_cipher .encrypt (PKCS5_PAD (file_key , block_size ))
117+ else :
118+ cipher = Cipher (algorithms .AES (decoded_key ), modes .ECB (), backend = backend )
119+ encryptor = cipher .encryptor ()
120+ enc_kek = (
121+ encryptor .update (PKCS5_PAD (file_key , block_size )) + encryptor .finalize ()
122+ )
103123
104124 mat_desc = MaterialDescriptor (
105125 smk_id = encryption_material .smk_id ,
@@ -158,6 +178,7 @@ def decrypt_stream(
158178 ) -> None :
159179 """To read from `src` stream then decrypt to `out` stream."""
160180
181+ use_openssl_only = os .getenv ("SF_USE_OPENSSL_ONLY" , "False" ) == "True"
161182 key_base64 = metadata .key
162183 iv_base64 = metadata .iv
163184 decoded_key = base64 .standard_b64decode (
@@ -166,26 +187,37 @@ def decrypt_stream(
166187 key_bytes = base64 .standard_b64decode (key_base64 )
167188 iv_bytes = base64 .standard_b64decode (iv_base64 )
168189
169- backend = default_backend ()
170- cipher = Cipher (algorithms .AES (decoded_key ), modes .ECB (), backend = backend )
171- decryptor = cipher .decryptor ()
172- file_key = PKCS5_UNPAD (decryptor .update (key_bytes ) + decryptor .finalize ())
173- cipher = Cipher (algorithms .AES (file_key ), modes .CBC (iv_bytes ), backend = backend )
174- decryptor = cipher .decryptor ()
190+ if not use_openssl_only :
191+ key_cipher = AES .new (key = decoded_key , mode = AES .MODE_ECB )
192+ file_key = PKCS5_UNPAD (key_cipher .decrypt (key_bytes ))
193+ data_cipher = AES .new (key = file_key , mode = AES .MODE_CBC , IV = iv_bytes )
194+ else :
195+ backend = default_backend ()
196+ cipher = Cipher (algorithms .AES (decoded_key ), modes .ECB (), backend = backend )
197+ decryptor = cipher .decryptor ()
198+ file_key = PKCS5_UNPAD (decryptor .update (key_bytes ) + decryptor .finalize ())
199+ cipher = Cipher (
200+ algorithms .AES (file_key ), modes .CBC (iv_bytes ), backend = backend
201+ )
202+ decryptor = cipher .decryptor ()
175203
176204 last_decrypted_chunk = None
177205 chunk = src .read (chunk_size )
178206 while len (chunk ) != 0 :
179207 if last_decrypted_chunk is not None :
180208 out .write (last_decrypted_chunk )
181- d = decryptor .update (chunk )
209+ if not use_openssl_only :
210+ d = data_cipher .decrypt (chunk )
211+ else :
212+ d = decryptor .update (chunk )
182213 last_decrypted_chunk = d
183214 chunk = src .read (chunk_size )
184215
185216 if last_decrypted_chunk is not None :
186217 offset = PKCS5_OFFSET (last_decrypted_chunk )
187218 out .write (last_decrypted_chunk [:- offset ])
188- out .write (decryptor .finalize ())
219+ if use_openssl_only :
220+ out .write (decryptor .finalize ())
189221
190222 @staticmethod
191223 def decrypt_file (
0 commit comments