Skip to content

Commit 56a5a20

Browse files
committed
correctly handle malformed key shares in TLS 1.3
1 parent 684f4af commit 56a5a20

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

tlslite/keyexchange.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .utils.x25519 import x25519, x448, X25519_G, X448_G, X25519_ORDER_SIZE, \
2424
X448_ORDER_SIZE
2525
from .utils.compat import int_types
26+
from .utils.codec import DecodeError
2627

2728

2829
class KeyExchange(object):
@@ -907,7 +908,7 @@ def calc_shared_key(self, private, peer_share):
907908
try:
908909
ecdhYc = decodeX962Point(peer_share,
909910
curve)
910-
except AssertionError:
911+
except (AssertionError, DecodeError):
911912
raise TLSIllegalParameterException("Invalid ECC point")
912913

913914
S = ecdhYc * private

tlslite/utils/ecc.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,26 @@
33
# See the LICENSE file for legal information regarding use of this file.
44
"""Methods for dealing with ECC points"""
55

6-
from .codec import Parser, Writer
6+
from .codec import Parser, Writer, DecodeError
77
from .cryptomath import bytesToNumber, numberToByteArray, numBytes
88
from .compat import ecdsaAllCurves
99
import ecdsa
1010

11+
1112
def decodeX962Point(data, curve=ecdsa.NIST256p):
1213
"""Decode a point from a X9.62 encoding"""
1314
parser = Parser(data)
1415
encFormat = parser.get(1)
15-
assert encFormat == 4
16+
if encFormat != 4:
17+
raise DecodeError("Not an uncompressed point encoding")
1618
bytelength = getPointByteSize(curve)
1719
xCoord = bytesToNumber(parser.getFixBytes(bytelength))
1820
yCoord = bytesToNumber(parser.getFixBytes(bytelength))
21+
if parser.getRemainingLength():
22+
raise DecodeError("Invalid length of point encoding for curve")
1923
return ecdsa.ellipticcurve.Point(curve.curve, xCoord, yCoord)
2024

25+
2126
def encodeX962Point(point):
2227
"""Encode a point in X9.62 format"""
2328
bytelength = numBytes(point.curve().p())

0 commit comments

Comments
 (0)