20
20
21
21
# Valid content types (ordered from newest, to most obsolete)
22
22
versions = {
23
- "aes128gcm" : {"padding " : 2 },
24
- "aesgcm" : {"padding " : 2 },
25
- "aesgcm128" : {"padding " : 1 }
23
+ "aes128gcm" : {"pad " : 2 },
24
+ "aesgcm" : {"pad " : 2 },
25
+ "aesgcm128" : {"pad " : 1 },
26
26
}
27
27
28
28
@@ -34,8 +34,8 @@ def __init__(self, message):
34
34
# TODO: turn this into a class so that we don't grow/stomp keys.
35
35
36
36
37
- def derive_key (mode , salt = None , key = None , dh = None , keyid = None ,
38
- auth_secret = None , version = "aesgcm" , ** kwargs ):
37
+ def derive_key (mode , version , salt = None , key = None , dh = None , auth_secret = None ,
38
+ keyid = None , keymap = None , keylabels = None ):
39
39
"""Derive the encryption key
40
40
41
41
:param mode: operational mode (encrypt or decrypt)
@@ -48,6 +48,10 @@ def derive_key(mode, salt=None, key=None, dh=None, keyid=None,
48
48
:type dh: str
49
49
:param keyid: key identifier label
50
50
:type keyid: str
51
+ :param keymap: map of keyids to keys
52
+ :type keymap: map
53
+ :param keylabels: map of keyids to labels
54
+ :type keylabels: map
51
55
:param auth_secret: authorization secret
52
56
:type auth_secret: str
53
57
:param version: Content Type identifier
@@ -61,69 +65,47 @@ def derive_key(mode, salt=None, key=None, dh=None, keyid=None,
61
65
def build_info (base , info_context ):
62
66
return b"Content-Encoding: " + base + b"\0 " + info_context
63
67
64
- def derive_dh (mode , keyid , dh , version = "aesgcm" ):
65
-
68
+ def derive_dh (mode , version , dh , keyid , keymap , keylabels ):
66
69
def length_prefix (key ):
67
70
return struct .pack ("!H" , len (key )) + key
68
71
69
72
if keyid is None :
70
73
raise ECEException (u"'keyid' is not specified with 'dh'" )
71
- if keyid not in keys :
74
+ if keyid not in keymap :
72
75
raise ECEException (u"'keyid' doesn't identify a key: " + keyid )
73
76
if mode == "encrypt" :
74
- sender_pub_key = key or keys [keyid ].get_pubkey ()
77
+ sender_pub_key = key or keymap [keyid ].get_pubkey ()
75
78
receiver_pub_key = dh
76
79
elif mode == "decrypt" :
77
80
sender_pub_key = dh
78
- receiver_pub_key = key or keys [keyid ].get_pubkey ()
81
+ receiver_pub_key = key or keymap [keyid ].get_pubkey ()
79
82
else :
80
83
raise ECEException (u"unknown 'mode' specified: " + mode )
81
84
if version == "aes128gcm" :
82
85
context = b"WebPush: info\x00 " + receiver_pub_key + sender_pub_key
83
86
else :
84
- label = labels .get (keyid , 'P-256' ).encode ('utf-8' )
87
+ label = keylabels .get (keyid , 'P-256' ).encode ('utf-8' )
85
88
context = (label + b"\0 " + length_prefix (receiver_pub_key ) +
86
89
length_prefix (sender_pub_key ))
87
90
88
- return keys [keyid ].get_ecdh_key (dh ), context
89
-
90
- # handle the older, ill formatted args.
91
- pad_size = kwargs .get ('padSize' , 2 )
92
- auth_secret = kwargs .get ('authSecret' , auth_secret )
93
- secret = key
94
-
95
- # handle old cases where version is explicitly None.
96
- if not version :
97
- if pad_size == 1 :
98
- version = "aesgcm128"
99
- else :
100
- version = "aesgcm"
91
+ return keymap [keyid ].get_ecdh_key (dh ), context
101
92
102
93
if version not in versions :
103
- raise ECEException (u"invalid version specified " )
94
+ raise ECEException (u"Invalid version" )
104
95
if salt is None or len (salt ) != 16 :
105
96
raise ECEException (u"'salt' must be a 16 octet value" )
106
97
if dh is not None :
107
- (secret , context ) = derive_dh (mode = mode , keyid = keyid , dh = dh ,
108
- version = version )
109
- elif keyid in keys :
110
- if isinstance (keys [keyid ], ecc .ECC ):
111
- secret = keys [keyid ].get_privkey ()
112
- else :
113
- secret = keys [keyid ]
98
+ (secret , context ) = derive_dh (mode = mode , version = version , dh = dh ,
99
+ keyid = keyid , keymap = keymap ,
100
+ keylabels = keylabels )
101
+ elif keyid in keymap :
102
+ secret = keymap [keyid ]
103
+ else :
104
+ secret = key
105
+
114
106
if secret is None :
115
107
raise ECEException (u"unable to determine the secret" )
116
108
117
- if auth_secret is not None :
118
- hkdf_auth = HKDF (
119
- algorithm = hashes .SHA256 (),
120
- length = 32 ,
121
- salt = auth_secret ,
122
- info = build_info (b"auth" , b"" ),
123
- backend = default_backend ()
124
- )
125
- secret = hkdf_auth .derive (secret )
126
-
127
109
if version == "aesgcm" :
128
110
keyinfo = build_info (b"aesgcm" , context )
129
111
nonceinfo = build_info (b"nonce" , context )
@@ -134,6 +116,20 @@ def length_prefix(key):
134
116
keyinfo = b"Content-Encoding: aes128gcm\x00 "
135
117
nonceinfo = b"Content-Encoding: nonce\x00 "
136
118
119
+ if auth_secret is not None :
120
+ if version == "aes128gcm" :
121
+ info = context
122
+ else :
123
+ info = build_info (b'auth' , b'' )
124
+ hkdf_auth = HKDF (
125
+ algorithm = hashes .SHA256 (),
126
+ length = 32 ,
127
+ salt = auth_secret ,
128
+ info = info ,
129
+ backend = default_backend ()
130
+ )
131
+ secret = hkdf_auth .derive (secret )
132
+
137
133
hkdf_key = HKDF (
138
134
algorithm = hashes .SHA256 (),
139
135
length = 16 ,
@@ -161,8 +157,8 @@ def iv(base, counter):
161
157
return base [:4 ] + struct .pack ("!Q" , counter ^ mask )
162
158
163
159
164
- def decrypt (content , salt , key = None , keyid = None , dh = None , rs = 4096 ,
165
- auth_secret = None , version = "aesgcm" , ** kwargs ):
160
+ def decrypt (content , salt , key = None , keyid = None , keymap = None , keylabels = None ,
161
+ dh = None , rs = 4096 , auth_secret = None , version = "aesgcm" , ** kwargs ):
166
162
"""
167
163
Decrypt a data block
168
164
@@ -218,12 +214,17 @@ def decrypt_record(key, nonce, counter, content):
218
214
data = data [pad_size + pad :]
219
215
return data
220
216
217
+ if version not in versions :
218
+ raise ECEException (u"Invalid version" )
219
+
221
220
# handle old, malformed args
222
- pad_size = kwargs .get ('padSize' , 2 )
221
+ pad_size = kwargs .get ('padSize' , versions [ version ][ 'pad' ] )
223
222
auth_secret = kwargs .get ('authSecret' , auth_secret )
223
+ if keymap is None :
224
+ keymap = keys
225
+ if keylabels is None :
226
+ keylabels = labels
224
227
225
- if version not in versions :
226
- raise ECEException (u"Invalid version" )
227
228
if version == "aes128gcm" :
228
229
try :
229
230
content_header = parse_content_header (content )
@@ -232,14 +233,12 @@ def decrypt_record(key, nonce, counter, content):
232
233
ex .message )
233
234
salt = content_header ['salt' ]
234
235
keyid = content_header ['key_id' ] or '' if keyid is None else keyid
235
- pad_size = 2
236
236
content = content_header ['content' ]
237
237
238
- (key_ , nonce_ ) = derive_key (mode = "decrypt" , salt = salt ,
239
- key = key , keyid = keyid , dh = dh ,
240
- auth_secret = auth_secret ,
241
- padSize = pad_size ,
242
- version = version )
238
+ (key_ , nonce_ ) = derive_key (mode = "decrypt" , version = version ,
239
+ salt = salt , key = key ,
240
+ dh = dh , auth_secret = auth_secret ,
241
+ keyid = keyid , keymap = keymap , keylabels = keylabels )
243
242
if rs <= pad_size :
244
243
raise ECEException (u"Record size too small" )
245
244
rs += 16 # account for tags
@@ -257,8 +256,8 @@ def decrypt_record(key, nonce, counter, content):
257
256
return result
258
257
259
258
260
- def encrypt (content , salt = None , key = None , keyid = None , dh = None , rs = 4096 ,
261
- auth_secret = None , pad_size = 2 , version = "aesgcm" , ** kwargs ):
259
+ def encrypt (content , salt = None , key = None , keyid = None , keymap = None , keylabels = None ,
260
+ dh = None , rs = 4096 , auth_secret = None , version = "aesgcm" , ** kwargs ):
262
261
"""
263
262
Encrypt a data block
264
263
@@ -288,7 +287,7 @@ def encrypt_record(key, nonce, counter, buf):
288
287
modes .GCM (iv (nonce , counter )),
289
288
backend = default_backend ()
290
289
).encryptor ()
291
- data = encryptor .update (b"\0 \0 " + buf )
290
+ data = encryptor .update (( b"\0 " * pad_size ) + buf )
292
291
data += encryptor .finalize ()
293
292
data += encryptor .tag
294
293
return data
@@ -324,26 +323,33 @@ def compose_aes128gcm(salt, content, rs=4096, key_id=""):
324
323
header += key_id .encode ('utf-8' )
325
324
return header + content
326
325
326
+ if version not in versions :
327
+ raise ECEException (u"Invalid version" )
328
+
327
329
# handle the older, ill formatted args.
328
- pad_size = kwargs .get ('padSize' , pad_size )
330
+ pad_size = kwargs .get ('padSize' , versions [ version ][ 'pad' ] )
329
331
auth_secret = kwargs .get ('authSecret' , auth_secret )
332
+ if keymap is None :
333
+ keymap = keys
334
+ if keylabels is None :
335
+ keylabels = labels
330
336
if salt is None :
331
337
salt = os .urandom (16 )
332
338
version = "aes128gcm"
333
339
334
- (key_ , nonce_ ) = derive_key (mode = "encrypt" , salt = salt ,
335
- key = key , keyid = keyid , dh = dh ,
336
- auth_secret = auth_secret , padSize = pad_size ,
337
- version = version )
340
+ (key_ , nonce_ ) = derive_key (mode = "encrypt" , version = version ,
341
+ salt = salt , key = key ,
342
+ dh = dh , auth_secret = auth_secret ,
343
+ keyid = keyid , keymap = keymap , keylabels = keylabels )
338
344
if rs <= pad_size :
339
345
raise ECEException (u"Record size too small" )
340
346
rs -= pad_size # account for padding
341
-
347
+
342
348
result = b""
343
349
counter = 0
344
350
345
- # the extra padSize on the loop ensures that we produce a padding only
346
- # record if the data length is an exact multiple of rs-padSize
351
+ # the extra pad_size on the loop ensures that we produce a padding only
352
+ # record if the data length is an exact multiple of rs-pad_size
347
353
for i in list (range (0 , len (content ) + pad_size , rs )):
348
354
result += encrypt_record (key_ , nonce_ , counter , content [i :i + rs ])
349
355
counter += 1
0 commit comments