diff --git a/src/wh_client_crypto.c b/src/wh_client_crypto.c index f7bb8a79..f1106be6 100644 --- a/src/wh_client_crypto.c +++ b/src/wh_client_crypto.c @@ -1901,7 +1901,8 @@ int wh_Client_RsaFunction(whClientContext* ctx, RsaKey* key, int rsa_type, if ((in != NULL) && (in_len > 0)) { memcpy(req_in, in, in_len); } - req->outLen = *inout_out_len; + /* Set output length only when provided to avoid NULL dereference */ + req->outLen = (inout_out_len != NULL) ? *inout_out_len : 0; /* Send Request */ ret = wh_Client_SendRequest(ctx, group, action, req_len, diff --git a/src/wh_client_nvm.c b/src/wh_client_nvm.c index 0f6a4d86..5742b97c 100644 --- a/src/wh_client_nvm.c +++ b/src/wh_client_nvm.c @@ -626,7 +626,7 @@ int wh_Client_NvmReadRequest(whClientContext* c, int wh_Client_NvmReadResponse(whClientContext* c, int32_t *out_rc, whNvmSize *out_len, uint8_t* data) { - uint8_t buffer[WH_MESSAGE_NVM_MAX_READ_LEN] = {0}; + uint8_t buffer[WOLFHSM_CFG_COMM_DATA_LEN] = {0}; whMessageNvm_ReadResponse* msg = (whMessageNvm_ReadResponse*)buffer; uint16_t hdr_len = sizeof(*msg); uint8_t* payload = buffer + hdr_len; @@ -645,13 +645,14 @@ int wh_Client_NvmReadResponse(whClientContext* c, int32_t *out_rc, &resp_size, buffer); if (rc == 0) { /* Validate response */ - if ( (resp_group != WH_MESSAGE_GROUP_NVM) || - (resp_action != WH_MESSAGE_NVM_ACTION_READ) || - (resp_size < hdr_len) || - (resp_size - hdr_len > WH_MESSAGE_NVM_MAX_READ_LEN) ){ + if ((resp_group != WH_MESSAGE_GROUP_NVM) || + (resp_action != WH_MESSAGE_NVM_ACTION_READ) || + (resp_size < hdr_len) || (resp_size > sizeof(buffer)) || + (resp_size - hdr_len > WH_MESSAGE_NVM_MAX_READ_LEN)) { /* Invalid message */ rc = WH_ERROR_ABORTED; - } else { + } + else { /* Valid message */ if (out_rc != NULL) { *out_rc = msg->rc; diff --git a/src/wh_nvm_flash.c b/src/wh_nvm_flash.c index 2df53e72..9ab85e2b 100644 --- a/src/wh_nvm_flash.c +++ b/src/wh_nvm_flash.c @@ -1190,9 +1190,13 @@ int wh_NvmFlash_DestroyObjects(void* c, whNvmId list_count, /* Write each used object to new partition */ for (entry = 0; entry < WOLFHSM_CFG_NVM_OBJECT_COUNT; entry++) { if (d->objects[entry].state.status == NF_STATUS_USED) { - /* TODO: Handle errors here better. Break out of loop? */ ret = nfObject_Copy(context, entry, dest_part, &dest_object, &dest_data); + if (ret != WH_ERROR_OK) { + /* Abort reclaim to avoid activating a partially copied + * partition */ + return ret; + } } } diff --git a/src/wh_server_counter.c b/src/wh_server_counter.c index 00a34c86..04717705 100644 --- a/src/wh_server_counter.c +++ b/src/wh_server_counter.c @@ -56,8 +56,8 @@ int wh_Server_HandleCounter(whServerContext* server, uint16_t magic, switch (action) { case WH_COUNTER_INIT: { - whMessageCounter_InitRequest req; - whMessageCounter_InitResponse resp; + whMessageCounter_InitRequest req = {0}; + whMessageCounter_InitResponse resp = {0}; /* translate request */ (void)wh_MessageCounter_TranslateInitRequest( @@ -65,7 +65,8 @@ int wh_Server_HandleCounter(whServerContext* server, uint16_t magic, /* write 0 to nvm with the supplied id and user_id */ meta->id = WH_MAKE_KEYID(WH_KEYTYPE_COUNTER, - server->comm->client_id, req.counterId); + (uint16_t)server->comm->client_id, + (uint16_t)req.counterId); /* use the label buffer to hold the counter value */ *counter = req.counter; ret = wh_Nvm_AddObjectWithReclaim(server->nvm, meta, 0, NULL); @@ -83,19 +84,20 @@ int wh_Server_HandleCounter(whServerContext* server, uint16_t magic, } break; case WH_COUNTER_INCREMENT: { - whMessageCounter_IncrementRequest req; - whMessageCounter_IncrementResponse resp; + whMessageCounter_IncrementRequest req = {0}; + whMessageCounter_IncrementResponse resp = {0}; /* translate request */ (void)wh_MessageCounter_TranslateIncrementRequest( magic, (whMessageCounter_IncrementRequest*)req_packet, &req); /* read the counter, stored in the metadata label */ - ret = wh_Nvm_GetMetadata(server->nvm, - WH_MAKE_KEYID(WH_KEYTYPE_COUNTER, - server->comm->client_id, - req.counterId), - meta); + ret = wh_Nvm_GetMetadata( + server->nvm, + WH_MAKE_KEYID(WH_KEYTYPE_COUNTER, + (uint16_t)server->comm->client_id, + (uint16_t)req.counterId), + meta); resp.rc = ret; /* increment and write the counter back */ @@ -128,19 +130,20 @@ int wh_Server_HandleCounter(whServerContext* server, uint16_t magic, } break; case WH_COUNTER_READ: { - whMessageCounter_ReadRequest req; - whMessageCounter_ReadResponse resp; + whMessageCounter_ReadRequest req = {0}; + whMessageCounter_ReadResponse resp = {0}; /* translate request */ (void)wh_MessageCounter_TranslateReadRequest( magic, (whMessageCounter_ReadRequest*)req_packet, &req); /* read the counter, stored in the metadata label */ - ret = wh_Nvm_GetMetadata(server->nvm, - WH_MAKE_KEYID(WH_KEYTYPE_COUNTER, - server->comm->client_id, - req.counterId), - meta); + ret = wh_Nvm_GetMetadata( + server->nvm, + WH_MAKE_KEYID(WH_KEYTYPE_COUNTER, + (uint16_t)server->comm->client_id, + (uint16_t)req.counterId), + meta); resp.rc = ret; /* return counter to the caller */ @@ -158,15 +161,16 @@ int wh_Server_HandleCounter(whServerContext* server, uint16_t magic, } break; case WH_COUNTER_DESTROY: { - whMessageCounter_DestroyRequest req; - whMessageCounter_DestroyResponse resp; + whMessageCounter_DestroyRequest req = {0}; + whMessageCounter_DestroyResponse resp = {0}; /* translate request */ (void)wh_MessageCounter_TranslateDestroyRequest( magic, (whMessageCounter_DestroyRequest*)req_packet, &req); counterId = WH_MAKE_KEYID(WH_KEYTYPE_COUNTER, - server->comm->client_id, req.counterId); + (uint16_t)server->comm->client_id, + (uint16_t)req.counterId); ret = wh_Nvm_DestroyObjects(server->nvm, 1, &counterId); resp.rc = ret; diff --git a/src/wh_server_crypto.c b/src/wh_server_crypto.c index c6e37c69..30b75a72 100644 --- a/src/wh_server_crypto.c +++ b/src/wh_server_crypto.c @@ -289,16 +289,25 @@ static int _HandleRsaKeyGen(whServerContext* ctx, uint16_t magic, printf("[server] RsaKeyGen UniqueId: keyId:%u, ret:%d\n", key_id, ret); #endif + if (ret != WH_ERROR_OK) { + /* Early return on unique ID generation failure */ + wc_FreeRsaKey(rsa); + return ret; + } } - ret = wh_Server_CacheImportRsaKey(ctx, rsa, key_id, flags, - label_size, label); + if (ret == 0) { + ret = wh_Server_CacheImportRsaKey(ctx, rsa, key_id, flags, + label_size, label); + } #ifdef DEBUG_CRYPTOCB_VERBOSE printf("[server] RsaKeyGen CacheKeyRsa: keyId:%u, ret:%d\n", key_id, ret); #endif - res.keyId = WH_KEYID_ID(key_id); - res.len = 0; + if (ret == 0) { + res.keyId = WH_KEYID_ID(key_id); + res.len = 0; + } } } wc_FreeRsaKey(rsa); @@ -729,9 +738,16 @@ static int _HandleEccKeyGen(whServerContext* ctx, uint16_t magic, printf("[server] %s UniqueId: keyId:%u, ret:%d\n", __func__, key_id, ret); #endif + if (ret != WH_ERROR_OK) { + /* Early return on unique ID generation failure */ + wc_ecc_free(key); + return ret; + } + } + if (ret == 0) { + ret = wh_Server_EccKeyCacheImport(ctx, key, key_id, flags, + label_size, label); } - ret = wh_Server_EccKeyCacheImport(ctx, key, key_id, flags, - label_size, label); #ifdef DEBUG_CRYPTOCB printf("[server] %s CacheImport: keyId:%u, ret:%d\n", __func__, key_id, ret); @@ -1146,10 +1162,17 @@ static int _HandleCurve25519KeyGen(whServerContext* ctx, uint16_t magic, printf("[server] %s UniqueId: keyId:%u, ret:%d\n", __func__, key_id, ret); #endif + if (ret != WH_ERROR_OK) { + /* Early return on unique ID generation failure */ + wc_curve25519_free(key); + return ret; + } } - ret = wh_Server_CacheImportCurve25519Key( - ctx, key, key_id, flags, label_size, label); + if (ret == 0) { + ret = wh_Server_CacheImportCurve25519Key( + ctx, key, key_id, flags, label_size, label); + } #ifdef DEBUG_CRYPTOCB printf("[server] %s CacheImport: keyId:%u, ret:%d\n", __func__, key_id, ret); @@ -1693,12 +1716,16 @@ static int _HandleCmac(whServerContext* ctx, uint16_t magic, uint16_t seq, if (moveToBigCache == 1) { ret = wh_Server_KeystoreEvictKey(ctx, keyId); } - meta->id = keyId; - meta->len = sizeof(ctx->crypto->algoCtx.cmac); - ret = wh_Server_KeystoreCacheKey( - ctx, meta, (uint8_t*)ctx->crypto->algoCtx.cmac); - res.keyId = WH_KEYID_ID(keyId); - res.outSz = 0; + if (ret == 0) { + meta->id = keyId; + meta->len = sizeof(ctx->crypto->algoCtx.cmac); + ret = wh_Server_KeystoreCacheKey( + ctx, meta, (uint8_t*)ctx->crypto->algoCtx.cmac); + if (ret == 0) { + res.keyId = WH_KEYID_ID(keyId); + res.outSz = 0; + } + } #ifdef DEBUG_CRYPTOCB_VERBOSE printf("[server] cmac saved state in keyid:%x %x len:%u ret:%d type:%d\n", keyId, WH_KEYID_ID(keyId), meta->len, ret, ctx->crypto->algoCtx.cmac->type); @@ -1734,7 +1761,8 @@ static int _HandleSha256(whServerContext* ctx, uint16_t magic, int ret = 0; wc_Sha256 sha256[1]; whMessageCrypto_Sha256Request req; - whMessageCrypto_Sha2Response res; + whMessageCrypto_Sha2Response res = {0}; + /* Translate the request */ ret = wh_MessageCrypto_TranslateSha256Request(magic, cryptoDataIn, &req); if (ret != 0) { @@ -1751,6 +1779,10 @@ static int _HandleSha256(whServerContext* ctx, uint16_t magic, sha256->hiLen = req.resumeState.hiLen; if (req.isLastBlock) { + /* Validate lastBlockLen to prevent potential buffer overread */ + if ((unsigned int)req.lastBlockLen > WC_SHA256_BLOCK_SIZE) { + return WH_ERROR_BADARGS; + } /* wolfCrypt (or cryptoCb) is responsible for last block padding */ if (ret == 0) { ret = wc_Sha256Update(sha256, req.inBlock, req.lastBlockLen); @@ -2113,9 +2145,17 @@ static int _HandleMlDsaKeyGen(whServerContext* ctx, uint16_t magic, printf("[server] %s UniqueId: keyId:%u, ret:%d\n", __func__, key_id, ret); #endif + if (ret != WH_ERROR_OK) { + /* Early return on unique ID generation failure + */ + wc_MlDsaKey_Free(key); + return ret; + } + } + if (ret == 0) { + ret = wh_Server_MlDsaKeyCacheImport( + ctx, key, key_id, flags, label_size, label); } - ret = wh_Server_MlDsaKeyCacheImport( - ctx, key, key_id, flags, label_size, label); #ifdef DEBUG_CRYPTOCB printf("[server] %s CacheImport: keyId:%u, ret:%d\n", __func__, key_id, ret); @@ -2175,6 +2215,16 @@ static int _HandleMlDsaSign(whServerContext* ctx, uint16_t magic, uint32_t options = req.options; int evict = !!(options & WH_MESSAGE_CRYPTO_MLDSA_SIGN_OPTIONS_EVICT); + /* Validate input length against available data to prevent buffer overread + */ + if (inSize < sizeof(whMessageCrypto_MlDsaSignRequest)) { + return WH_ERROR_BADARGS; + } + word32 available_data = inSize - sizeof(whMessageCrypto_MlDsaSignRequest); + if (in_len > available_data) { + return WH_ERROR_BADARGS; + } + /* Response message */ byte* res_out = (uint8_t*)(cryptoDataOut) + sizeof(whMessageCrypto_MlDsaSignResponse); @@ -2244,6 +2294,17 @@ static int _HandleMlDsaVerify(whServerContext* ctx, uint16_t magic, uint32_t sig_len = req.sigSz; byte* req_sig = (uint8_t*)(cryptoDataIn) + sizeof(whMessageCrypto_MlDsaVerifyRequest); + + /* Validate lengths against available payload (overflow-safe) */ + if (inSize < sizeof(whMessageCrypto_MlDsaVerifyRequest)) { + return WH_ERROR_BADARGS; + } + uint32_t available = inSize - sizeof(whMessageCrypto_MlDsaVerifyRequest); + if ((sig_len > available) || (hash_len > available) || + (sig_len > (available - hash_len))) { + return WH_ERROR_BADARGS; + } + byte* req_hash = req_sig + sig_len; int evict = !!(options & WH_MESSAGE_CRYPTO_MLDSA_VERIFY_OPTIONS_EVICT); @@ -3216,6 +3277,12 @@ static int _HandleMlDsaKeyGenDma(whServerContext* ctx, uint16_t magic, printf("[server] %s UniqueId: keyId:%u, ret:%d\n", __func__, keyId, ret); #endif + if (ret != WH_ERROR_OK) { + /* Early return on unique ID generation failure + */ + wc_MlDsaKey_Free(key); + return ret; + } } if (ret == 0) { diff --git a/src/wh_server_keystore.c b/src/wh_server_keystore.c index 61d0391d..eca690df 100644 --- a/src/wh_server_keystore.c +++ b/src/wh_server_keystore.c @@ -448,7 +448,7 @@ int wh_Server_KeystoreReadKey(whServerContext* server, whKeyId keyId, } /* cache key if free slot, will only kick out other commited keys */ if (ret == 0 && out != NULL) { - wh_Server_KeystoreCacheKey(server, meta, out); + (void)wh_Server_KeystoreCacheKey(server, meta, out); } #ifdef WOLFHSM_CFG_SHE_EXTENSION /* use empty key of zeros if we couldn't find the master ecu key */ @@ -690,7 +690,7 @@ int wh_Server_HandleKeyRequest(whServerContext* server, uint16_t magic, case WH_KEY_EVICT: { whMessageKeystore_EvictRequest req; - whMessageKeystore_EvictResponse resp; + whMessageKeystore_EvictResponse resp = {0}; (void)wh_MessageKeystore_TranslateEvictRequest( magic, (whMessageKeystore_EvictRequest*)req_packet, &req); @@ -702,9 +702,14 @@ int wh_Server_HandleKeyRequest(whServerContext* server, uint16_t magic, /* TODO: Are there any fatal server errors? */ ret = WH_ERROR_OK; - (void)wh_MessageKeystore_TranslateEvictResponse( - magic, &resp, (whMessageKeystore_EvictResponse*)resp_packet); - *out_resp_size = sizeof(resp); + if (ret == WH_ERROR_OK) { + resp.ok = 0; + + (void)wh_MessageKeystore_TranslateEvictResponse( + magic, &resp, + (whMessageKeystore_EvictResponse*)resp_packet); + *out_resp_size = sizeof(resp); + } } break; case WH_KEY_EXPORT: { diff --git a/wolfhsm/wh_message_counter.h b/wolfhsm/wh_message_counter.h index a2f7679e..b8732305 100644 --- a/wolfhsm/wh_message_counter.h +++ b/wolfhsm/wh_message_counter.h @@ -41,7 +41,7 @@ typedef struct { /* Counter Init Response */ typedef struct { - uint32_t rc; + int32_t rc; uint32_t counter; } whMessageCounter_InitResponse; @@ -62,7 +62,7 @@ typedef struct { /* Counter Increment Response */ typedef struct { - uint32_t rc; + int32_t rc; uint32_t counter; } whMessageCounter_IncrementResponse; @@ -83,7 +83,7 @@ typedef struct { /* Counter Read Response */ typedef struct { - uint32_t rc; + int32_t rc; uint32_t counter; } whMessageCounter_ReadResponse; @@ -104,7 +104,7 @@ typedef struct { /* Counter Destroy Response */ typedef struct { - uint32_t rc; + int32_t rc; uint8_t WH_PAD[4]; } whMessageCounter_DestroyResponse;