Skip to content

Commit 2668e81

Browse files
authored
Merge pull request #68 from wolfSSL/devin/support_mldsa
Support ML-DSA
2 parents fa4142d + 24993ad commit 2668e81

File tree

4 files changed

+499
-5
lines changed

4 files changed

+499
-5
lines changed

docs/asymmetric.rst

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,54 @@ ML-KEM
112112
>>>
113113
>>> ss_recv = mlkem_priv.decapsulate(ct)
114114
>>> ss_send == ss_recv
115+
True
116+
117+
ML-DSA
118+
------
119+
120+
.. autoclass:: MlDsaType
121+
:show-inheritance:
122+
123+
.. autoclass:: MlDsaPublic
124+
:private-members:
125+
:members:
126+
:inherited-members:
127+
128+
.. autoclass:: MlDsaPrivate
129+
:members:
130+
:inherited-members:
131+
132+
**Example:**
133+
134+
>>> ######## Simple Usage
135+
>>> from wolfcrypt.ciphers import MlDsaType, MlDsaPrivate, MlDsaPublic
136+
>>>
137+
>>> mldsa_type = MlDsaType.ML_DSA_44
138+
>>>
139+
>>> mldsa_priv = MlDsaPrivate.make_key(mldsa_type)
140+
>>> pub_key = mldsa_priv.encode_pub_key()
141+
>>>
142+
>>> mldsa_pub = MlDsaPublic(mldsa_type)
143+
>>> mldsa_pub.decode_key(pub_key)
144+
>>>
145+
>>> msg = b"This is an example message"
146+
>>>
147+
>>> sig = mldsa_priv.sign(msg)
148+
>>> mldsa_pub.verify(sig, msg)
149+
True
150+
>>>
151+
>>> ######## Export and Import Keys
152+
>>> exported_key_pair = mldsa_priv.encode_priv_key(), mldsa_priv.encode_pub_key()
153+
>>> exported_pub_key = mldsa_pub.encode_key()
154+
>>> exported_key_pair[1] == exported_pub_key
155+
True
156+
>>>
157+
>>> mldsa_priv2 = MlDsaPrivate(mldsa_type)
158+
>>> mldsa_priv2.decode_key(exported_key_pair[0], exported_key_pair[1])
159+
>>>
160+
>>> mldsa_pub2 = MlDsaPublic(mldsa_type)
161+
>>> mldsa_pub2.decode_key(exported_pub_key)
162+
>>>
163+
>>> sig2 = mldsa_priv2.sign(msg)
164+
>>> mldsa_pub2.verify(sig2, msg)
115165
True

scripts/build_ffi.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,9 @@ def make_flags(prefix, fips):
235235
# ML-KEM
236236
flags.append("--enable-kyber")
237237

238+
# ML-DSA
239+
flags.append("--enable-dilithium")
240+
238241
# disabling other configs enabled by default
239242
flags.append("--disable-oldtls")
240243
flags.append("--disable-oldnames")
@@ -371,6 +374,7 @@ def get_features(local_wolfssl, features):
371374
features["AESGCM_STREAM"] = 1 if '#define WOLFSSL_AESGCM_STREAM' in defines else 0
372375
features["RSA_PSS"] = 1 if '#define WC_RSA_PSS' in defines else 0
373376
features["CHACHA20_POLY1305"] = 1 if '#define HAVE_CHACHA' and '#define HAVE_POLY1305' in defines else 0
377+
features["ML_DSA"] = 1 if '#define HAVE_DILITHIUM' in defines else 0
374378

375379
if '#define HAVE_FIPS' in defines:
376380
if not fips:
@@ -447,6 +451,7 @@ def build_ffi(local_wolfssl, features):
447451
#include <wolfssl/wolfcrypt/chacha20_poly1305.h>
448452
#include <wolfssl/wolfcrypt/kyber.h>
449453
#include <wolfssl/wolfcrypt/wc_kyber.h>
454+
#include <wolfssl/wolfcrypt/dilithium.h>
450455
"""
451456

452457
init_source_string = """
@@ -484,6 +489,7 @@ def build_ffi(local_wolfssl, features):
484489
int RSA_PSS_ENABLED = """ + str(features["RSA_PSS"]) + """;
485490
int CHACHA20_POLY1305_ENABLED = """ + str(features["CHACHA20_POLY1305"]) + """;
486491
int ML_KEM_ENABLED = """ + str(features["ML_KEM"]) + """;
492+
int ML_DSA_ENABLED = """ + str(features["ML_DSA"]) + """;
487493
"""
488494

489495
ffibuilder.set_source( "wolfcrypt._ffi", init_source_string,
@@ -520,6 +526,7 @@ def build_ffi(local_wolfssl, features):
520526
extern int RSA_PSS_ENABLED;
521527
extern int CHACHA20_POLY1305_ENABLED;
522528
extern int ML_KEM_ENABLED;
529+
extern int ML_DSA_ENABLED;
523530
524531
typedef unsigned char byte;
525532
typedef unsigned int word32;
@@ -929,12 +936,16 @@ def build_ffi(local_wolfssl, features):
929936
int wolfCrypt_GetPrivateKeyReadEnable_fips(enum wc_KeyType);
930937
"""
931938

939+
if features["ML_KEM"] or features["ML_DSA"]:
940+
cdef += """
941+
static const int INVALID_DEVID;
942+
"""
943+
932944
if features["ML_KEM"]:
933945
cdef += """
934946
static const int WC_ML_KEM_512;
935947
static const int WC_ML_KEM_768;
936948
static const int WC_ML_KEM_1024;
937-
static const int INVALID_DEVID;
938949
typedef struct {...; } KyberKey;
939950
int wc_KyberKey_CipherTextSize(KyberKey* key, word32* len);
940951
int wc_KyberKey_SharedSecretSize(KyberKey* key, word32* len);
@@ -950,7 +961,29 @@ def build_ffi(local_wolfssl, features):
950961
int wc_KyberKey_EncapsulateWithRandom(KyberKey* key, unsigned char* ct, unsigned char* ss, const unsigned char* rand, int len);
951962
int wc_KyberKey_Decapsulate(KyberKey* key, unsigned char* ss, const unsigned char* ct, word32 len);
952963
int wc_KyberKey_EncodePrivateKey(KyberKey* key, unsigned char* out, word32 len);
953-
int wc_KyberKey_DecodePrivateKey(KyberKey* key, const unsigned char* in, word32 len);
964+
int wc_KyberKey_DecodePrivateKey(KyberKey* key, const unsigned char* in, word32 len);
965+
"""
966+
967+
if features["ML_DSA"]:
968+
cdef += """
969+
static const int WC_ML_DSA_44;
970+
static const int WC_ML_DSA_65;
971+
static const int WC_ML_DSA_87;
972+
typedef struct {...; } dilithium_key;
973+
int wc_dilithium_init_ex(dilithium_key* key, void* heap, int devId);
974+
int wc_dilithium_set_level(dilithium_key* key, byte level);
975+
void wc_dilithium_free(dilithium_key* key);
976+
int wc_dilithium_make_key(dilithium_key* key, WC_RNG* rng);
977+
int wc_dilithium_export_private(dilithium_key* key, byte* out, word32* outLen);
978+
int wc_dilithium_import_private(const byte* priv, word32 privSz, dilithium_key* key);
979+
int wc_dilithium_export_public(dilithium_key* key, byte* out, word32* outLen);
980+
int wc_dilithium_import_public(const byte* in, word32 inLen, dilithium_key* key);
981+
int wc_dilithium_sign_msg(const byte* msg, word32 msgLen, byte* sig, word32* sigLen, dilithium_key* key, WC_RNG* rng);
982+
int wc_dilithium_verify_msg(const byte* sig, word32 sigLen, const byte* msg, word32 msgLen, int* res, dilithium_key* key);
983+
typedef dilithium_key MlDsaKey;
984+
int wc_MlDsaKey_GetPrivLen(MlDsaKey* key, int* len);
985+
int wc_MlDsaKey_GetPubLen(MlDsaKey* key, int* len);
986+
int wc_MlDsaKey_GetSigLen(MlDsaKey* key, int* len);
954987
"""
955988

956989
ffibuilder.cdef(cdef)
@@ -983,7 +1016,8 @@ def main(ffibuilder):
9831016
"AESGCM_STREAM": 1,
9841017
"RSA_PSS": 1,
9851018
"CHACHA20_POLY1305": 1,
986-
"ML_KEM": 1
1019+
"ML_KEM": 1,
1020+
"ML_DSA": 1
9871021
}
9881022

9891023
# Ed448 requires SHAKE256, which isn't part of the Windows build, yet.

tests/test_mldsa.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# test_mldsa.py
2+
#
3+
# Copyright (C) 2025 wolfSSL Inc.
4+
#
5+
# This file is part of wolfSSL. (formerly known as CyaSSL)
6+
#
7+
# wolfSSL is free software; you can redistribute it and/or modify
8+
# it under the terms of the GNU General Public License as published by
9+
# the Free Software Foundation; either version 2 of the License, or
10+
# (at your option) any later version.
11+
#
12+
# wolfSSL is distributed in the hope that it will be useful,
13+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
# GNU General Public License for more details.
16+
#
17+
# You should have received a copy of the GNU General Public License
18+
# along with this program; if not, write to the Free Software
19+
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
20+
21+
# pylint: disable=redefined-outer-name
22+
23+
from wolfcrypt._ffi import lib as _lib
24+
25+
if _lib.ML_DSA_ENABLED:
26+
import pytest
27+
28+
from wolfcrypt.ciphers import MlDsaPrivate, MlDsaPublic, MlDsaType
29+
from wolfcrypt.random import Random
30+
31+
@pytest.fixture
32+
def rng():
33+
return Random()
34+
35+
@pytest.fixture(
36+
params=[MlDsaType.ML_DSA_44, MlDsaType.ML_DSA_65, MlDsaType.ML_DSA_87]
37+
)
38+
def mldsa_type(request):
39+
return request.param
40+
41+
def test_init_base(mldsa_type):
42+
mldsa_priv = MlDsaPrivate(mldsa_type)
43+
assert isinstance(mldsa_priv, MlDsaPrivate)
44+
45+
mldsa_pub = MlDsaPublic(mldsa_type)
46+
assert isinstance(mldsa_pub, MlDsaPublic)
47+
48+
def test_size_properties(mldsa_type):
49+
refvals = {
50+
MlDsaType.ML_DSA_44: {
51+
"sig_size": 2420,
52+
"pub_key_size": 1312,
53+
"priv_key_size": 2560,
54+
},
55+
MlDsaType.ML_DSA_65: {
56+
"sig_size": 3309,
57+
"pub_key_size": 1952,
58+
"priv_key_size": 4032,
59+
},
60+
MlDsaType.ML_DSA_87: {
61+
"sig_size": 4627,
62+
"pub_key_size": 2592,
63+
"priv_key_size": 4896,
64+
},
65+
}
66+
67+
mldsa_pub = MlDsaPublic(mldsa_type)
68+
assert mldsa_pub.sig_size == refvals[mldsa_type]["sig_size"]
69+
assert mldsa_pub.key_size == refvals[mldsa_type]["pub_key_size"]
70+
71+
mldsa_priv = MlDsaPrivate(mldsa_type)
72+
assert mldsa_priv.sig_size == refvals[mldsa_type]["sig_size"]
73+
assert mldsa_priv.pub_key_size == refvals[mldsa_type]["pub_key_size"]
74+
assert mldsa_priv.priv_key_size == refvals[mldsa_type]["priv_key_size"]
75+
76+
def test_initializations(mldsa_type, rng):
77+
mldsa_priv = MlDsaPrivate.make_key(mldsa_type, rng)
78+
assert type(mldsa_priv) is MlDsaPrivate
79+
80+
mldsa_priv2 = MlDsaPrivate(mldsa_type)
81+
assert type(mldsa_priv2) is MlDsaPrivate
82+
83+
mldsa_pub = MlDsaPublic(mldsa_type)
84+
assert type(mldsa_pub) is MlDsaPublic
85+
86+
def test_key_import_export(mldsa_type, rng):
87+
# Generate key pair and export keys
88+
mldsa_priv = MlDsaPrivate.make_key(mldsa_type, rng)
89+
priv_key = mldsa_priv.encode_priv_key()
90+
pub_key = mldsa_priv.encode_pub_key()
91+
assert len(priv_key) == mldsa_priv.priv_key_size
92+
assert len(pub_key) == mldsa_priv.pub_key_size
93+
94+
# Export key pair from imported one
95+
mldsa_priv2 = MlDsaPrivate(mldsa_type)
96+
mldsa_priv2.decode_key(priv_key, pub_key)
97+
priv_key2 = mldsa_priv2.encode_priv_key()
98+
pub_key2 = mldsa_priv2.encode_pub_key()
99+
assert priv_key == priv_key2
100+
assert pub_key == pub_key2
101+
102+
# Export private key from imported one
103+
mldsa_priv3 = MlDsaPrivate(mldsa_type)
104+
mldsa_priv3.decode_key(priv_key)
105+
priv_key3 = mldsa_priv3.encode_priv_key()
106+
assert priv_key == priv_key3
107+
108+
# Export public key from imported one
109+
mldsa_pub = MlDsaPublic(mldsa_type)
110+
mldsa_pub.decode_key(pub_key)
111+
pub_key3 = mldsa_pub.encode_key()
112+
assert pub_key == pub_key3
113+
114+
def test_sign_verify(mldsa_type, rng):
115+
# Generate a key pair and export public key
116+
mldsa_priv = MlDsaPrivate.make_key(mldsa_type, rng)
117+
pub_key = mldsa_priv.encode_pub_key()
118+
119+
# Import public key
120+
mldsa_pub = MlDsaPublic(mldsa_type)
121+
mldsa_pub.decode_key(pub_key)
122+
123+
# Sign a message
124+
message = b"This is a test message for ML-DSA signature"
125+
signature = mldsa_priv.sign(message, rng)
126+
assert len(signature) == mldsa_priv.sig_size
127+
128+
# Verify the signature by MlDsaPrivate
129+
assert mldsa_priv.verify(signature, message)
130+
131+
# Verify the signature by MlDsaPublic
132+
assert mldsa_pub.verify(signature, message)
133+
134+
# Verify with wrong message
135+
wrong_message = b"This is a wrong message for ML-DSA signature"
136+
assert not mldsa_pub.verify(signature, wrong_message)

0 commit comments

Comments
 (0)