Skip to content

Commit 1e860df

Browse files
authored
TLS 1.3: add hybrid post-quantum key exchange support (X25519MLKEM768) (#472)
* Support of multiple key shares * Add MLKEM Support * gofmt * Fix for tests: handshake server and key agreement * Fix for tests: crypto/mlkem encapsulate return order * Fix for tests: unsupported hybrid group * do not advertise ML-KEM group by default to pass tests * keep single TLS1.3 key_share by default (to pass tests); add X25519 fallback only for ML-KEM
1 parent df961ee commit 1e860df

File tree

7 files changed

+252
-54
lines changed

7 files changed

+252
-54
lines changed

tls/common.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,19 @@ const (
123123
// CurveID is the type of a TLS identifier for an elliptic curve. See
124124
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8.
125125
//
126-
// In TLS 1.3, this type is called NamedGroup, but at this time this library
127-
// only supports Elliptic Curve based groups. See RFC 8446, Section 4.2.7.
126+
// In TLS 1.3, this type is called NamedGroup. This library historically used it
127+
// for elliptic curves, but it can represent any TLS 1.3 (EC / hybrid / PQ) group.
128128
type CurveID uint16
129129

130130
const (
131131
CurveP256 CurveID = 23
132132
CurveP384 CurveID = 24
133133
CurveP521 CurveID = 25
134134
X25519 CurveID = 29
135+
136+
// Hybrid PQ key exchange groups (TLS 1.3 NamedGroup)
137+
SecP256r1MLKEM768 CurveID = 4587
138+
X25519MLKEM768 CurveID = 4588
135139
)
136140

137141
func (curveID *CurveID) MarshalJSON() ([]byte, error) {

tls/handshake_client.go

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ func (c *ClientFingerprintConfiguration) marshal(config *Config) ([]byte, error)
222222
return hello, nil
223223
}
224224

225-
func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) {
225+
func (c *Conn) makeClientHello() (*clientHelloMsg, map[CurveID]tls13KeyShare, error) {
226226
config := c.config
227227
if len(config.ServerName) == 0 && !config.InsecureSkipVerify {
228228
return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
@@ -306,22 +306,54 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) {
306306
hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms
307307
}
308308

309-
var params ecdheParameters
309+
var keySharesByGroup map[CurveID]tls13KeyShare
310310
if hello.supportedVersions[0] == VersionTLS13 {
311311
hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13()...)
312312

313-
curveID := config.curvePreferences()[0]
314-
if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok {
315-
return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
313+
prefs := config.curvePreferences()
314+
if len(prefs) == 0 {
315+
return nil, nil, errors.New("tls: no supported key exchange mechanisms (no curve preferences)")
316316
}
317-
params, err = generateECDHEParameters(config.rand(), curveID)
318-
if err != nil {
319-
return nil, nil, err
317+
318+
// By default, send a single key_share.
319+
// If ML-KEM hybrid is explicitly enabled as the top preference, also send X25519 as fallback.
320+
shareGroups := []CurveID{prefs[0]}
321+
if prefs[0] == X25519MLKEM768 {
322+
// Ensure compatibility with servers that don't support the hybrid group.
323+
if prefs[0] != X25519 {
324+
shareGroups = append(shareGroups, X25519)
325+
}
326+
}
327+
328+
hello.keyShares = make([]keyShare, 0, len(shareGroups))
329+
keySharesByGroup = make(map[CurveID]tls13KeyShare, len(shareGroups))
330+
331+
seen := make(map[CurveID]struct{}, len(shareGroups))
332+
for _, group := range shareGroups {
333+
if _, ok := seen[group]; ok {
334+
continue
335+
}
336+
seen[group] = struct{}{}
337+
338+
ks, genErr := generateTLS13KeyShare(config.rand(), group)
339+
if genErr != nil {
340+
// If a group is not supported/implemented, skip it.
341+
continue
342+
}
343+
344+
hello.keyShares = append(hello.keyShares, keyShare{
345+
group: group,
346+
data: ks.PublicKey(),
347+
})
348+
keySharesByGroup[group] = ks
349+
}
350+
351+
if len(hello.keyShares) == 0 {
352+
return nil, nil, errors.New("tls: no supported key exchange mechanisms (no key shares)")
320353
}
321-
hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}}
322354
}
323355

324-
return hello, params, nil
356+
return hello, keySharesByGroup, nil
325357
}
326358

327359
func (c *Conn) clientHandshake() (err error) {
@@ -333,7 +365,7 @@ func (c *Conn) clientHandshake() (err error) {
333365
var session *ClientSessionState
334366
var sessionCache ClientSessionCache
335367
var cacheKey string
336-
var ecdheParams ecdheParameters
368+
var keySharesByGroup map[CurveID]tls13KeyShare
337369

338370
// This may be a renegotiation handshake, in which case some fields
339371
// need to be reset.
@@ -422,7 +454,7 @@ func (c *Conn) clientHandshake() (err error) {
422454
sessionCache = nil
423455
} else {
424456

425-
hello, ecdheParams, err = c.makeClientHello()
457+
hello, keySharesByGroup, err = c.makeClientHello()
426458
if err != nil {
427459
return err
428460
}
@@ -489,13 +521,13 @@ func (c *Conn) clientHandshake() (err error) {
489521

490522
if c.vers == VersionTLS13 {
491523
hs := &clientHandshakeStateTLS13{
492-
c: c,
493-
serverHello: serverHello,
494-
hello: hello,
495-
ecdheParams: ecdheParams,
496-
session: session,
497-
earlySecret: earlySecret,
498-
binderKey: binderKey,
524+
c: c,
525+
serverHello: serverHello,
526+
hello: hello,
527+
keySharesByGroup: keySharesByGroup,
528+
session: session,
529+
earlySecret: earlySecret,
530+
binderKey: binderKey,
499531
}
500532

501533
// In TLS 1.3, session tickets are delivered after the handshake.

tls/handshake_client_tls13.go

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ import (
1616
)
1717

1818
type clientHandshakeStateTLS13 struct {
19-
c *Conn
20-
serverHello *serverHelloMsg
21-
hello *clientHelloMsg
22-
ecdheParams ecdheParameters
19+
c *Conn
20+
serverHello *serverHelloMsg
21+
hello *clientHelloMsg
22+
keySharesByGroup map[CurveID]tls13KeyShare
2323

2424
session *ClientSessionState
2525
earlySecret []byte
@@ -34,7 +34,7 @@ type clientHandshakeStateTLS13 struct {
3434
trafficSecret []byte // client_application_traffic_secret_0
3535
}
3636

37-
// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and,
37+
// handshake requires hs.c, hs.hello, hs.serverHello, hs.keySharesByGroup, and,
3838
// optionally, hs.session, hs.earlySecret and hs.binderKey to be set.
3939
func (hs *clientHandshakeStateTLS13) handshake() error {
4040
// The server must not select TLS 1.3 in a renegotiation. See RFC 8446,
@@ -45,7 +45,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
4545
}
4646

4747
// Consistency check on the presence of a keyShare and its parameters.
48-
if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 {
48+
if len(hs.hello.keyShares) == 0 || hs.keySharesByGroup == nil {
4949
return hs.c.sendAlert(AlertInternalError)
5050
}
5151

@@ -219,21 +219,20 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
219219
c.sendAlert(AlertIllegalParameter)
220220
return errors.New("tls: server selected unsupported group")
221221
}
222-
if hs.ecdheParams.CurveID() == curveID {
222+
if _, ok := hs.keySharesByGroup[curveID]; ok {
223223
c.sendAlert(AlertIllegalParameter)
224224
return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share")
225225
}
226-
if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok {
227-
c.sendAlert(AlertInternalError)
228-
return errors.New("tls: CurvePreferences includes unsupported curve")
229-
}
230-
params, err := generateECDHEParameters(c.config.rand(), curveID)
226+
ks, err := generateTLS13KeyShare(c.config.rand(), curveID)
231227
if err != nil {
232228
c.sendAlert(AlertInternalError)
233229
return err
234230
}
235-
hs.ecdheParams = params
236-
hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}}
231+
if hs.keySharesByGroup == nil {
232+
hs.keySharesByGroup = make(map[CurveID]tls13KeyShare)
233+
}
234+
hs.keySharesByGroup[curveID] = ks
235+
hs.hello.keyShares = []keyShare{{group: curveID, data: ks.PublicKey()}}
237236
}
238237

239238
hs.hello.raw = nil
@@ -307,7 +306,9 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error {
307306
c.sendAlert(AlertIllegalParameter)
308307
return errors.New("tls: server did not send a key share")
309308
}
310-
if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() {
309+
310+
ks, ok := hs.keySharesByGroup[hs.serverHello.serverShare.group]
311+
if !ok || ks == nil {
311312
c.sendAlert(AlertIllegalParameter)
312313
return errors.New("tls: server selected unsupported group")
313314
}
@@ -345,10 +346,16 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error {
345346
func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
346347
c := hs.c
347348

348-
sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data)
349-
if sharedKey == nil {
349+
ks, ok := hs.keySharesByGroup[hs.serverHello.serverShare.group]
350+
if !ok || ks == nil {
351+
c.sendAlert(AlertIllegalParameter)
352+
return errors.New("tls: server selected unsupported group")
353+
}
354+
355+
sharedKey, err := ks.SharedKey(hs.serverHello.serverShare.data)
356+
if err != nil {
350357
c.sendAlert(AlertIllegalParameter)
351-
return errors.New("tls: invalid server key share")
358+
return err
352359
}
353360

354361
earlySecret := hs.earlySecret
@@ -365,7 +372,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
365372
serverHandshakeTrafficLabel, hs.transcript)
366373
c.in.setTrafficSecret(hs.suite, serverSecret)
367374

368-
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
375+
err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
369376
if err != nil {
370377
c.sendAlert(AlertInternalError)
371378
return err

tls/handshake_server_tls13.go

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -216,21 +216,13 @@ GroupSelection:
216216
clientKeyShare = &hs.clientHello.keyShares[0]
217217
}
218218

219-
if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && !ok {
220-
c.sendAlert(AlertInternalError)
221-
return errors.New("tls: CurvePreferences includes unsupported curve")
222-
}
223-
params, err := generateECDHEParameters(c.config.rand(), selectedGroup)
219+
serverShareData, sharedKey, err := generateTLS13ServerShareAndSharedKey(c.config.rand(), selectedGroup, clientKeyShare.data)
224220
if err != nil {
225-
c.sendAlert(AlertInternalError)
226-
return err
227-
}
228-
hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()}
229-
hs.sharedKey = params.SharedKey(clientKeyShare.data)
230-
if hs.sharedKey == nil {
231221
c.sendAlert(AlertIllegalParameter)
232-
return errors.New("tls: invalid client key share")
222+
return err
233223
}
224+
hs.hello.serverShare = keyShare{group: selectedGroup, data: serverShareData}
225+
hs.sharedKey = sharedKey
234226

235227
c.serverName = hs.clientHello.serverName
236228
return nil

tls/key_agreement.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,9 @@ type ecdheKeyAgreement struct {
388388
func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
389389
var curveID CurveID
390390
for _, c := range clientHello.supportedCurves {
391+
if c == X25519MLKEM768 {
392+
continue // ML-KEM hybrid group is TLS 1.3 (key_share) only.
393+
}
391394
if config.supportsCurve(c) {
392395
curveID = c
393396
break

0 commit comments

Comments
 (0)