Skip to content

Commit 8dfbf3e

Browse files
authored
Merge pull request #388 from tomato42/tls13-ecdh-fixes
TLS 1.3 ECDH and key_shares fixes
2 parents 23fdd47 + 5aa5a14 commit 8dfbf3e

File tree

3 files changed

+48
-17
lines changed

3 files changed

+48
-17
lines changed

tlslite/keyexchange.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .utils.x25519 import x25519, x448, X25519_G, X448_G, X25519_ORDER_SIZE, \
2424
X448_ORDER_SIZE
2525
from .utils.compat import int_types
26+
from .utils.codec import DecodeError
2627

2728

2829
class KeyExchange(object):
@@ -907,7 +908,7 @@ def calc_shared_key(self, private, peer_share):
907908
try:
908909
ecdhYc = decodeX962Point(peer_share,
909910
curve)
910-
except AssertionError:
911+
except (AssertionError, DecodeError):
911912
raise TLSIllegalParameterException("Invalid ECC point")
912913

913914
S = ecdhYc * private

tlslite/tlsconnection.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -691,20 +691,6 @@ def _clientSendClientHello(self, settings, session, srpUsername,
691691
extensions.append(TLSExtension().create(ExtensionType.
692692
extended_master_secret,
693693
bytearray(0)))
694-
groups = []
695-
#Send the ECC extensions only if we advertise ECC ciphers
696-
if next((cipher for cipher in cipherSuites \
697-
if cipher in CipherSuite.ecdhAllSuites), None) is not None:
698-
groups.extend(self._curveNamesToList(settings))
699-
extensions.append(ECPointFormatsExtension().\
700-
create([ECPointFormat.uncompressed]))
701-
# Advertise FFDHE groups if we have DHE ciphers
702-
if next((cipher for cipher in cipherSuites
703-
if cipher in CipherSuite.dhAllSuites), None) is not None:
704-
groups.extend(self._groupNamesToList(settings))
705-
# Send the extension only if it will be non empty
706-
if groups:
707-
extensions.append(SupportedGroupsExtension().create(groups))
708694
# In TLS1.2 advertise support for additional signature types
709695
if settings.maxVersion >= (3, 3):
710696
sigList = self._sigHashesToList(settings)
@@ -717,6 +703,7 @@ def _clientSendClientHello(self, settings, session, srpUsername,
717703

718704
session_id = bytearray()
719705
# when TLS 1.3 advertised, add key shares, set fake session_id
706+
shares = None
720707
if next((i for i in settings.versions if i > (3, 3)), None):
721708
# if we have a client cert configured, do indicate we're willing
722709
# to perform Post Handshake Authentication
@@ -746,6 +733,27 @@ def _clientSendClientHello(self, settings, session, srpUsername,
746733
[getattr(PskKeyExchangeMode, i) for i in settings.psk_modes])
747734
extensions.append(ext)
748735

736+
groups = []
737+
#Send the ECC extensions only if we advertise ECC ciphers
738+
if next((cipher for cipher in cipherSuites \
739+
if cipher in CipherSuite.ecdhAllSuites), None) is not None:
740+
groups.extend(self._curveNamesToList(settings))
741+
extensions.append(ECPointFormatsExtension().\
742+
create([ECPointFormat.uncompressed]))
743+
# Advertise FFDHE groups if we have DHE ciphers
744+
if next((cipher for cipher in cipherSuites
745+
if cipher in CipherSuite.dhAllSuites), None) is not None:
746+
groups.extend(self._groupNamesToList(settings))
747+
# Send the extension only if it will be non empty
748+
if groups:
749+
if shares:
750+
# put the groups used for key shares first, and in order
751+
# (req. from RFC 8446, section 4.2.8)
752+
share_ids = [i.group for i in shares]
753+
diff = set(groups) - set(share_ids)
754+
groups = share_ids + [i for i in groups if i in diff]
755+
extensions.append(SupportedGroupsExtension().create(groups))
756+
749757
if settings.use_heartbeat_extension:
750758
extensions.append(HeartbeatExtension().create(
751759
HeartbeatMode.PEER_ALLOWED_TO_SEND))
@@ -3119,6 +3127,23 @@ def _serverGetClientHello(self, settings, private_key, cert_chain,
31193127
.format(GroupName.toStr(mismatch))):
31203128
yield result
31213129

3130+
key_share_ids = [i.group for i in key_share.client_shares]
3131+
if len(set(key_share_ids)) != len(key_share_ids):
3132+
for result in self._sendError(
3133+
AlertDescription.illegal_parameter,
3134+
"Client sent multiple key shares for the same "
3135+
"group"):
3136+
yield result
3137+
3138+
group_ids = sup_groups.groups
3139+
diff = set(group_ids) - set(key_share_ids)
3140+
if key_share_ids != [i for i in group_ids if i not in diff]:
3141+
for result in self._sendError(
3142+
AlertDescription.illegal_parameter,
3143+
"Client sent key shares in different order than "
3144+
"the advertised groups."):
3145+
yield result
3146+
31223147
sig_algs = clientHello.getExtension(
31233148
ExtensionType.signature_algorithms)
31243149
if (not psk_modes or not psk) and sig_algs:

tlslite/utils/ecc.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,26 @@
33
# See the LICENSE file for legal information regarding use of this file.
44
"""Methods for dealing with ECC points"""
55

6-
from .codec import Parser, Writer
6+
from .codec import Parser, Writer, DecodeError
77
from .cryptomath import bytesToNumber, numberToByteArray, numBytes
88
from .compat import ecdsaAllCurves
99
import ecdsa
1010

11+
1112
def decodeX962Point(data, curve=ecdsa.NIST256p):
1213
"""Decode a point from a X9.62 encoding"""
1314
parser = Parser(data)
1415
encFormat = parser.get(1)
15-
assert encFormat == 4
16+
if encFormat != 4:
17+
raise DecodeError("Not an uncompressed point encoding")
1618
bytelength = getPointByteSize(curve)
1719
xCoord = bytesToNumber(parser.getFixBytes(bytelength))
1820
yCoord = bytesToNumber(parser.getFixBytes(bytelength))
21+
if parser.getRemainingLength():
22+
raise DecodeError("Invalid length of point encoding for curve")
1923
return ecdsa.ellipticcurve.Point(curve.curve, xCoord, yCoord)
2024

25+
2126
def encodeX962Point(point):
2227
"""Encode a point in X9.62 format"""
2328
bytelength = numBytes(point.curve().p())

0 commit comments

Comments
 (0)