Skip to content

Commit 074a3db

Browse files
committed
ML-KEM: derive secret fix
Fixes for deriving secret for ML-KEM.
1 parent 5c421a1 commit 074a3db

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

wolfcrypt/src/wc_mlkem.c

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,42 @@
105105

106106
#ifdef WOLFSSL_WC_MLKEM
107107

108+
#ifdef DEBUG_MLKEM
109+
void print_polys(const char* name, const sword16* a, int d1, int d2);
110+
void print_polys(const char* name, const sword16* a, int d1, int d2)
111+
{
112+
int i;
113+
int j;
114+
int k;
115+
116+
fprintf(stderr, "%s: %d %d\n", name, d1, d2);
117+
for (i = 0; i < d1; i++) {
118+
for (j = 0; j < d2; j++) {
119+
for (k = 0; k < 256; k++) {
120+
fprintf(stderr, "%9d,", a[(i*d2*256) + (j*256) + k]);
121+
if ((k % 8) == 7) fprintf(stderr, "\n");
122+
}
123+
fprintf(stderr, "\n");
124+
}
125+
}
126+
}
127+
#endif
128+
129+
#ifdef DEBUG_MLKEM
130+
void print_data(const char* name, const byte* d, int len);
131+
void print_data(const char* name, const byte* d, int len)
132+
{
133+
int i;
134+
135+
fprintf(stderr, "%s\n", name);
136+
for (i = 0; i < len; i++) {
137+
fprintf(stderr, "0x%02x,", d[i]);
138+
if ((i % 16) == 15) fprintf(stderr, "\n");
139+
}
140+
fprintf(stderr, "\n");
141+
}
142+
#endif
143+
108144
/******************************************************************************/
109145

110146
/* Use SHA3-256 to generate 32-bytes of hash. */

wolfcrypt/src/wc_mlkem_poly.c

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3184,16 +3184,20 @@ int mlkem_derive_secret(wc_Shake* shake256, const byte* z, const byte* ct,
31843184

31853185
#ifdef USE_INTEL_SPEEDUP
31863186
XMEMCPY(shake256->t, z, WC_ML_KEM_SYM_SZ);
3187-
XMEMCPY(shake256->t, ct, WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ);
3188-
shake256->i = WC_ML_KEM_SYM_SZ;
3187+
XMEMCPY(shake256->t + WC_ML_KEM_SYM_SZ, ct,
3188+
WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ);
3189+
shake256->i = WC_ML_KEM_SYM_SZ + WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ;
31893190
ct += WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ;
31903191
ctSz -= WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ;
31913192
ret = wc_Shake256_Update(shake256, ct, ctSz);
31923193
if (ret == 0) {
31933194
ret = wc_Shake256_Final(shake256, ss, WC_ML_KEM_SS_SZ);
31943195
}
31953196
#else
3196-
ret = wc_Shake256_Update(shake256, z, WC_ML_KEM_SYM_SZ);
3197+
ret = wc_InitShake256(shake256, NULL, INVALID_DEVID);
3198+
if (ret == 0) {
3199+
ret = wc_Shake256_Update(shake256, z, WC_ML_KEM_SYM_SZ);
3200+
}
31973201
if (ret == 0) {
31983202
ret = wc_Shake256_Update(shake256, ct, ctSz);
31993203
}

0 commit comments

Comments
 (0)