Skip to content

Commit 7fe4c5f

Browse files
committed
Fix python code a lot
1 parent 1c40e74 commit 7fe4c5f

File tree

4 files changed

+204
-118
lines changed

4 files changed

+204
-118
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ node_modules
88
.tox
99
.eggs
1010
.coverage
11+
12+
/encrypt_data.json

nodejs/test.js

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@ var assert = require('assert');
1111
// If args contains 'verbose' show logs.
1212
// If args contains 'text=...' set the input string to the UTF-8 encoding of that string.
1313
// If args contains 'max=<n>' set the maximum input size to that value.
14+
// If args contains 'dump[=file]' log info to ../encrypt_data.json or the specified file.
1415
var args = process.argv.slice(2);
1516
var minLen = 3;
1617
var maxLen = 100;
1718
var plaintext;
19+
var dumpFile;
20+
var dumpData = [];
1821
var log = function() {};
1922
args.forEach(function(arg) {
2023
if (arg === 'verbose') {
@@ -26,16 +29,13 @@ args.forEach(function(arg) {
2629
if (!isNaN(v) && v > minLen) {
2730
maxLen = v;
2831
}
32+
} else if (arg === 'dump') {
33+
dumpFile = '../encrypt_data.json';
34+
} else if (arg.substring(0, 5) === 'dump=') {
35+
dumpFile = arg.substring(5);
2936
}
3037
});
3138

32-
if (process.argv.length >= 3) {
33-
if (!isNaN(parseInt(process.argv[2], 10))) {
34-
maxLen = parseInt(process.argv[2], 10);
35-
} else {
36-
plaintext = new Buffer(process.argv[2], 'ascii');
37-
}
38-
}
3939
function filterTests(fullList) {
4040
var filtered = fullList.filter(function(t) {
4141
return args.some(function(f) {
@@ -59,12 +59,22 @@ function logbuf(msg, buf) {
5959
}
6060
}
6161

62+
function saveDump(data){
63+
if (dumpFile && data.version) {
64+
dumpData.push(data);
65+
}
66+
}
67+
6268
function validate() {
6369
['hello', null, 1, NaN, [], {}].forEach(function(v) {
6470
try {
65-
encrypt('hello', {});
71+
ece.encrypt('hello', {});
6672
throw new Error('should insist on a buffer');
67-
} catch (e) {}
73+
} catch (e) {
74+
if (e.toString() != "Error: buffer argument must be a Buffer") {
75+
throw new Error("encrypt failed to reject " + JSON.stringify(v));
76+
}
77+
}
6878
});
6979
}
7080

@@ -81,7 +91,7 @@ function generateInput(len) {
8191
return input;
8292
}
8393

84-
function encryptDecrypt(input, encryptParams, decryptParams) {
94+
function encryptDecrypt(input, encryptParams, decryptParams, keys) {
8595
// Fill out a default rs.
8696
encryptParams.rs = encryptParams.rs || (input.length + minLen);
8797
if (decryptParams.version === 'aes128gcm') {
@@ -106,6 +116,17 @@ function encryptDecrypt(input, encryptParams, decryptParams) {
106116
logbuf('Decrypted', decrypted);
107117
assert.equal(Buffer.compare(input, decrypted), 0);
108118
log('----- OK');
119+
120+
saveDump({
121+
version: encryptParams.version,
122+
input: base64.encode(input),
123+
encrypted: base64.encode(encrypted),
124+
params: {
125+
encrypt: encryptParams,
126+
decrypt: decryptParams,
127+
},
128+
keys: keys
129+
});
109130
}
110131

111132
function useExplicitKey(version) {
@@ -175,7 +196,12 @@ function useKeyId(version) {
175196
keyid: keyid,
176197
keymap: keymap
177198
};
178-
encryptDecrypt(input, params, params);
199+
200+
var keyData = {
201+
keyid: keyid,
202+
key: base64.encode(keyid)
203+
}
204+
encryptDecrypt(input, params, params, keyData);
179205
}
180206

181207
function useDH(version) {
@@ -218,7 +244,20 @@ function useDH(version) {
218244
decryptParams.keymap = { k: staticKey };
219245
decryptParams.keylabels = encryptParams.keylabels;
220246
}
221-
encryptDecrypt(input, encryptParams, decryptParams);
247+
248+
249+
// keyData is used for cross library verification dumps
250+
var keyData = {
251+
sender: {
252+
private: base64.encode(ephemeralKey.getPrivateKey()),
253+
public: base64.encode(ephemeralKey.getPublicKey())
254+
},
255+
receiver: {
256+
private: base64.encode(staticKey.getPrivateKey()),
257+
public: base64.encode(staticKey.getPublicKey())
258+
}
259+
};
260+
encryptDecrypt(input, encryptParams, decryptParams, keyData);
222261
}
223262

224263
// Use the examples from the draft as a sanity check.
@@ -283,3 +322,8 @@ filterTests([ 'aesgcm128', 'aesgcm', 'aes128gcm' ])
283322
checkExamples();
284323

285324
log('All tests passed.');
325+
326+
327+
if (dumpFile) {
328+
require('fs').writeFileSync(dumpFile, JSON.stringify(dumpData, undefined, ' '));
329+
}

python/http_ece/__init__.py

Lines changed: 70 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020

2121
# Valid content types (ordered from newest, to most obsolete)
2222
versions = {
23-
"aes128gcm": {"padding": 2},
24-
"aesgcm": {"padding": 2},
25-
"aesgcm128": {"padding": 1}
23+
"aes128gcm": {"pad": 2},
24+
"aesgcm": {"pad": 2},
25+
"aesgcm128": {"pad": 1},
2626
}
2727

2828

@@ -34,8 +34,8 @@ def __init__(self, message):
3434
# TODO: turn this into a class so that we don't grow/stomp keys.
3535

3636

37-
def derive_key(mode, salt=None, key=None, dh=None, keyid=None,
38-
auth_secret=None, version="aesgcm", **kwargs):
37+
def derive_key(mode, version, salt=None, key=None, dh=None, auth_secret=None,
38+
keyid=None, keymap=None, keylabels=None):
3939
"""Derive the encryption key
4040
4141
:param mode: operational mode (encrypt or decrypt)
@@ -48,6 +48,10 @@ def derive_key(mode, salt=None, key=None, dh=None, keyid=None,
4848
:type dh: str
4949
:param keyid: key identifier label
5050
:type keyid: str
51+
:param keymap: map of keyids to keys
52+
:type keymap: map
53+
:param keylabels: map of keyids to labels
54+
:type keylabels: map
5155
:param auth_secret: authorization secret
5256
:type auth_secret: str
5357
:param version: Content Type identifier
@@ -61,69 +65,47 @@ def derive_key(mode, salt=None, key=None, dh=None, keyid=None,
6165
def build_info(base, info_context):
6266
return b"Content-Encoding: " + base + b"\0" + info_context
6367

64-
def derive_dh(mode, keyid, dh, version="aesgcm"):
65-
68+
def derive_dh(mode, version, dh, keyid, keymap, keylabels):
6669
def length_prefix(key):
6770
return struct.pack("!H", len(key)) + key
6871

6972
if keyid is None:
7073
raise ECEException(u"'keyid' is not specified with 'dh'")
71-
if keyid not in keys:
74+
if keyid not in keymap:
7275
raise ECEException(u"'keyid' doesn't identify a key: " + keyid)
7376
if mode == "encrypt":
74-
sender_pub_key = key or keys[keyid].get_pubkey()
77+
sender_pub_key = key or keymap[keyid].get_pubkey()
7578
receiver_pub_key = dh
7679
elif mode == "decrypt":
7780
sender_pub_key = dh
78-
receiver_pub_key = key or keys[keyid].get_pubkey()
81+
receiver_pub_key = key or keymap[keyid].get_pubkey()
7982
else:
8083
raise ECEException(u"unknown 'mode' specified: " + mode)
8184
if version == "aes128gcm":
8285
context = b"WebPush: info\x00" + receiver_pub_key + sender_pub_key
8386
else:
84-
label = labels.get(keyid, 'P-256').encode('utf-8')
87+
label = keylabels.get(keyid, 'P-256').encode('utf-8')
8588
context = (label + b"\0" + length_prefix(receiver_pub_key) +
8689
length_prefix(sender_pub_key))
8790

88-
return keys[keyid].get_ecdh_key(dh), context
89-
90-
# handle the older, ill formatted args.
91-
pad_size = kwargs.get('padSize', 2)
92-
auth_secret = kwargs.get('authSecret', auth_secret)
93-
secret = key
94-
95-
# handle old cases where version is explicitly None.
96-
if not version:
97-
if pad_size == 1:
98-
version = "aesgcm128"
99-
else:
100-
version = "aesgcm"
91+
return keymap[keyid].get_ecdh_key(dh), context
10192

10293
if version not in versions:
103-
raise ECEException(u"invalid version specified")
94+
raise ECEException(u"Invalid version")
10495
if salt is None or len(salt) != 16:
10596
raise ECEException(u"'salt' must be a 16 octet value")
10697
if dh is not None:
107-
(secret, context) = derive_dh(mode=mode, keyid=keyid, dh=dh,
108-
version=version)
109-
elif keyid in keys:
110-
if isinstance(keys[keyid], ecc.ECC):
111-
secret = keys[keyid].get_privkey()
112-
else:
113-
secret = keys[keyid]
98+
(secret, context) = derive_dh(mode=mode, version=version, dh=dh,
99+
keyid=keyid, keymap=keymap,
100+
keylabels=keylabels)
101+
elif keyid in keymap:
102+
secret = keymap[keyid]
103+
else:
104+
secret = key
105+
114106
if secret is None:
115107
raise ECEException(u"unable to determine the secret")
116108

117-
if auth_secret is not None:
118-
hkdf_auth = HKDF(
119-
algorithm=hashes.SHA256(),
120-
length=32,
121-
salt=auth_secret,
122-
info=build_info(b"auth", b""),
123-
backend=default_backend()
124-
)
125-
secret = hkdf_auth.derive(secret)
126-
127109
if version == "aesgcm":
128110
keyinfo = build_info(b"aesgcm", context)
129111
nonceinfo = build_info(b"nonce", context)
@@ -134,6 +116,20 @@ def length_prefix(key):
134116
keyinfo = b"Content-Encoding: aes128gcm\x00"
135117
nonceinfo = b"Content-Encoding: nonce\x00"
136118

119+
if auth_secret is not None:
120+
if version == "aes128gcm":
121+
info = context
122+
else:
123+
info = build_info(b'auth', b'')
124+
hkdf_auth = HKDF(
125+
algorithm=hashes.SHA256(),
126+
length=32,
127+
salt=auth_secret,
128+
info=info,
129+
backend=default_backend()
130+
)
131+
secret = hkdf_auth.derive(secret)
132+
137133
hkdf_key = HKDF(
138134
algorithm=hashes.SHA256(),
139135
length=16,
@@ -161,8 +157,8 @@ def iv(base, counter):
161157
return base[:4] + struct.pack("!Q", counter ^ mask)
162158

163159

164-
def decrypt(content, salt, key=None, keyid=None, dh=None, rs=4096,
165-
auth_secret=None, version="aesgcm", **kwargs):
160+
def decrypt(content, salt, key=None, keyid=None, keymap=None, keylabels=None,
161+
dh=None, rs=4096, auth_secret=None, version="aesgcm", **kwargs):
166162
"""
167163
Decrypt a data block
168164
@@ -218,12 +214,17 @@ def decrypt_record(key, nonce, counter, content):
218214
data = data[pad_size + pad:]
219215
return data
220216

217+
if version not in versions:
218+
raise ECEException(u"Invalid version")
219+
221220
# handle old, malformed args
222-
pad_size = kwargs.get('padSize', 2)
221+
pad_size = kwargs.get('padSize', versions[version]['pad'])
223222
auth_secret = kwargs.get('authSecret', auth_secret)
223+
if keymap is None:
224+
keymap = keys
225+
if keylabels is None:
226+
keylabels = labels
224227

225-
if version not in versions:
226-
raise ECEException(u"Invalid version")
227228
if version == "aes128gcm":
228229
try:
229230
content_header = parse_content_header(content)
@@ -232,14 +233,12 @@ def decrypt_record(key, nonce, counter, content):
232233
ex.message)
233234
salt = content_header['salt']
234235
keyid = content_header['key_id'] or '' if keyid is None else keyid
235-
pad_size = 2
236236
content = content_header['content']
237237

238-
(key_, nonce_) = derive_key(mode="decrypt", salt=salt,
239-
key=key, keyid=keyid, dh=dh,
240-
auth_secret=auth_secret,
241-
padSize=pad_size,
242-
version=version)
238+
(key_, nonce_) = derive_key(mode="decrypt", version=version,
239+
salt=salt, key=key,
240+
dh=dh, auth_secret=auth_secret,
241+
keyid=keyid, keymap=keymap, keylabels=keylabels)
243242
if rs <= pad_size:
244243
raise ECEException(u"Record size too small")
245244
rs += 16 # account for tags
@@ -257,8 +256,8 @@ def decrypt_record(key, nonce, counter, content):
257256
return result
258257

259258

260-
def encrypt(content, salt=None, key=None, keyid=None, dh=None, rs=4096,
261-
auth_secret=None, pad_size=2, version="aesgcm", **kwargs):
259+
def encrypt(content, salt=None, key=None, keyid=None, keymap=None, keylabels=None,
260+
dh=None, rs=4096, auth_secret=None, version="aesgcm", **kwargs):
262261
"""
263262
Encrypt a data block
264263
@@ -288,7 +287,7 @@ def encrypt_record(key, nonce, counter, buf):
288287
modes.GCM(iv(nonce, counter)),
289288
backend=default_backend()
290289
).encryptor()
291-
data = encryptor.update(b"\0\0" + buf)
290+
data = encryptor.update((b"\0" * pad_size) + buf)
292291
data += encryptor.finalize()
293292
data += encryptor.tag
294293
return data
@@ -324,26 +323,33 @@ def compose_aes128gcm(salt, content, rs=4096, key_id=""):
324323
header += key_id.encode('utf-8')
325324
return header + content
326325

326+
if version not in versions:
327+
raise ECEException(u"Invalid version")
328+
327329
# handle the older, ill formatted args.
328-
pad_size = kwargs.get('padSize', pad_size)
330+
pad_size = kwargs.get('padSize', versions[version]['pad'])
329331
auth_secret = kwargs.get('authSecret', auth_secret)
332+
if keymap is None:
333+
keymap = keys
334+
if keylabels is None:
335+
keylabels = labels
330336
if salt is None:
331337
salt = os.urandom(16)
332338
version = "aes128gcm"
333339

334-
(key_, nonce_) = derive_key(mode="encrypt", salt=salt,
335-
key=key, keyid=keyid, dh=dh,
336-
auth_secret=auth_secret, padSize=pad_size,
337-
version=version)
340+
(key_, nonce_) = derive_key(mode="encrypt", version=version,
341+
salt=salt, key=key,
342+
dh=dh, auth_secret=auth_secret,
343+
keyid=keyid, keymap=keymap, keylabels=keylabels)
338344
if rs <= pad_size:
339345
raise ECEException(u"Record size too small")
340346
rs -= pad_size # account for padding
341-
347+
342348
result = b""
343349
counter = 0
344350

345-
# the extra padSize on the loop ensures that we produce a padding only
346-
# record if the data length is an exact multiple of rs-padSize
351+
# the extra pad_size on the loop ensures that we produce a padding only
352+
# record if the data length is an exact multiple of rs-pad_size
347353
for i in list(range(0, len(content) + pad_size, rs)):
348354
result += encrypt_record(key_, nonce_, counter, content[i:i + rs])
349355
counter += 1

0 commit comments

Comments
 (0)