24
24
#include "pycore_typeobject.h" // _PyType_GetModuleState()
25
25
#include "hashlib.h"
26
26
27
+ /*
28
+ * Assert that 'LEN' can be safely casted to uint32_t.
29
+ *
30
+ * The 'LEN' parameter should be convertible to Py_ssize_t.
31
+ */
32
+ #if !defined(NDEBUG ) && (PY_SSIZE_T_MAX > UINT32_MAX )
33
+ #define CHECK_HACL_UINT32_T_LENGTH (LEN ) assert((LEN) < (Py_ssize_t)UINT32_MAX)
34
+ #else
35
+ #define CHECK_HACL_UINT32_T_LENGTH (LEN )
36
+ #endif
37
+
27
38
#define SHA3_MAX_DIGESTSIZE 64 /* 64 Bytes (512 Bits) for 224 to 512 */
28
39
29
40
typedef struct {
@@ -472,50 +483,23 @@ SHA3_TYPE_SPEC(sha3_384_spec, "sha3_384", sha3_384_slots);
472
483
SHA3_TYPE_SLOTS (sha3_512_slots , sha3_512__doc__ , SHA3_methods , SHA3_getseters );
473
484
SHA3_TYPE_SPEC (sha3_512_spec , "sha3_512" , sha3_512_slots );
474
485
475
- static PyObject *
476
- _SHAKE_digest ( SHA3object * self , Py_ssize_t digestlen , int hex )
486
+ static int
487
+ sha3_shake_check_digest_length ( Py_ssize_t length )
477
488
{
478
- unsigned char * digest = NULL ;
479
- PyObject * result = NULL ;
480
-
481
- if (digestlen < 0 ) {
489
+ if (length < 0 ) {
482
490
PyErr_SetString (PyExc_ValueError , "negative digest length" );
483
- return NULL ;
491
+ return -1 ;
484
492
}
485
- if ((size_t )digestlen >= (1 << 29 )) {
493
+ if ((size_t )length >= (1 << 29 )) {
486
494
/*
487
495
* Raise OverflowError to match the semantics of OpenSSL SHAKE
488
496
* when the digest length exceeds the range of a 'Py_ssize_t';
489
497
* the exception message will however be different in this case.
490
498
*/
491
499
PyErr_SetString (PyExc_OverflowError , "digest length is too large" );
492
- return NULL ;
493
- }
494
-
495
- digest = (unsigned char * )PyMem_Malloc (digestlen );
496
- if (digest == NULL ) {
497
- return PyErr_NoMemory ();
498
- }
499
-
500
- /* Get the raw (binary) digest value. The HACL functions errors out if:
501
- * - the algorithm is not shake -- not the case here
502
- * - the output length is zero -- we follow the existing behavior and return
503
- * an empty digest, without raising an error */
504
- if (digestlen > 0 ) {
505
- #if PY_SSIZE_T_MAX > UINT32_MAX
506
- assert (digestlen <= (Py_ssize_t )UINT32_MAX );
507
- #endif
508
- (void )Hacl_Hash_SHA3_squeeze (self -> hash_state , digest ,
509
- (uint32_t )digestlen );
510
- }
511
- if (hex ) {
512
- result = _Py_strhex ((const char * )digest , digestlen );
513
- }
514
- else {
515
- result = PyBytes_FromStringAndSize ((const char * )digest , digestlen );
500
+ return -1 ;
516
501
}
517
- PyMem_Free (digest );
518
- return result ;
502
+ return 0 ;
519
503
}
520
504
521
505
@@ -531,7 +515,26 @@ static PyObject *
531
515
_sha3_shake_128_digest_impl (SHA3object * self , Py_ssize_t length )
532
516
/*[clinic end generated code: output=6c53fb71a6cff0a0 input=be03ade4b31dd54c]*/
533
517
{
534
- return _SHAKE_digest (self , length , 0 );
518
+ if (sha3_shake_check_digest_length (length ) < 0 ) {
519
+ return NULL ;
520
+ }
521
+
522
+ /*
523
+ * Hacl_Hash_SHA3_squeeze() fails if the algorithm is not SHAKE,
524
+ * or if the length is 0. In the latter case, we follow OpenSSL's
525
+ * behavior and return an empty digest, without raising an error.
526
+ */
527
+ if (length == 0 ) {
528
+ return Py_GetConstant (Py_CONSTANT_EMPTY_BYTES );
529
+ }
530
+
531
+ CHECK_HACL_UINT32_T_LENGTH (length );
532
+ PyObject * digest = PyBytes_FromStringAndSize (NULL , length );
533
+ uint8_t * buffer = (uint8_t * )PyBytes_AS_STRING (digest );
534
+ ENTER_HASHLIB (self );
535
+ (void )Hacl_Hash_SHA3_squeeze (self -> hash_state , buffer , (uint32_t )length );
536
+ LEAVE_HASHLIB (self );
537
+ return digest ;
535
538
}
536
539
537
540
@@ -547,7 +550,27 @@ static PyObject *
547
550
_sha3_shake_128_hexdigest_impl (SHA3object * self , Py_ssize_t length )
548
551
/*[clinic end generated code: output=a27412d404f64512 input=0d84d05d7a8ccd37]*/
549
552
{
550
- return _SHAKE_digest (self , length , 1 );
553
+ if (sha3_shake_check_digest_length (length ) < 0 ) {
554
+ return NULL ;
555
+ }
556
+
557
+ /* See _sha3_shake_128_digest_impl() for the fast path rationale. */
558
+ if (length == 0 ) {
559
+ return Py_GetConstant (Py_CONSTANT_EMPTY_STR );
560
+ }
561
+
562
+ CHECK_HACL_UINT32_T_LENGTH (length );
563
+ uint8_t * buffer = PyMem_Malloc (length );
564
+ if (buffer == NULL ) {
565
+ return PyErr_NoMemory ();
566
+ }
567
+
568
+ ENTER_HASHLIB (self );
569
+ (void )Hacl_Hash_SHA3_squeeze (self -> hash_state , buffer , (uint32_t )length );
570
+ LEAVE_HASHLIB (self );
571
+ PyObject * digest = _Py_strhex ((const char * )buffer , length );
572
+ PyMem_Free (buffer );
573
+ return digest ;
551
574
}
552
575
553
576
static PyObject *
0 commit comments