diff --git a/src/tpm2_cryptocb.c b/src/tpm2_cryptocb.c index c3543267..510b0d1e 100644 --- a/src/tpm2_cryptocb.c +++ b/src/tpm2_cryptocb.c @@ -193,9 +193,10 @@ int wolfTPM2_CryptoDevCb(int devId, wc_CryptoInfo* info, void* ctx) && tlsCtx->ecdhKey == NULL ) { #ifdef DEBUG_WOLFTPM - printf("No crypto callback key pointer set!\n"); + printf("No crypto callback TPM key set, " + "fallback to software crypto\n"); #endif - return BAD_FUNC_ARG; + return exit_rc; } /* Make sure an ECDH key has been set and curve is supported */ @@ -205,6 +206,7 @@ int wolfTPM2_CryptoDevCb(int devId, wc_CryptoInfo* info, void* ctx) } rc = TPM2_GetTpmCurve(curve_id); if (rc < 0) { + /* curve not available, so fallback to sw crypto */ return exit_rc; } curve_id = rc; @@ -215,9 +217,14 @@ int wolfTPM2_CryptoDevCb(int devId, wc_CryptoInfo* info, void* ctx) if (tlsCtx->ecdhKey == NULL) #endif { - /* Create an ECC key for ECDSA - if one isn't already created */ key = (tlsCtx->ecdsaKey != NULL) ? (WOLFTPM2_KEY*)tlsCtx->ecdsaKey : tlsCtx->eccKey; + if (key == NULL) { + /* fallback to software crypto */ + return exit_rc; + } + + /* Create an ECC key for ECDSA - if one isn't already created */ if (key->handle.hndl == 0 || key->handle.hndl == TPM_RH_NULL ) { diff --git a/src/tpm2_wrap.c b/src/tpm2_wrap.c index 04e39b39..08f517b4 100644 --- a/src/tpm2_wrap.c +++ b/src/tpm2_wrap.c @@ -4132,11 +4132,13 @@ int wolfTPM2_SignHash(WOLFTPM2_DEV* dev, WOLFTPM2_KEY* key, sigAlg = TPM_ALG_ECDSA; } if (hashAlg == 0 || hashAlg == TPM_ALG_NULL) { - if (digestSz == 64) + /* determine hash type based on curve */ + int curve_id = pub->parameters.eccDetail.curveID; + if (curve_id == TPM_ECC_NIST_P521) hashAlg = TPM_ALG_SHA512; - else if (digestSz == 48) + else if (curve_id == TPM_ECC_NIST_P384) hashAlg = TPM_ALG_SHA384; - else if (digestSz == 32) + else hashAlg = TPM_ALG_SHA256; } } @@ -4273,12 +4275,16 @@ int wolfTPM2_VerifyHash_ex(WOLFTPM2_DEV* dev, WOLFTPM2_KEY* key, int wolfTPM2_VerifyHash(WOLFTPM2_DEV* dev, WOLFTPM2_KEY* key, const byte* sig, int sigSz, const byte* digest, int digestSz) { + int curve_id = 0; int hashAlg = TPM_ALG_NULL; - /* detect hash algorithm based on digest size */ - if (digestSz >= TPM_SHA512_DIGEST_SIZE) + /* detect hash algorithm based on key curve */ + if (key != NULL) { + curve_id = key->pub.publicArea.parameters.eccDetail.curveID; + } + if (curve_id == TPM_ECC_NIST_P521) hashAlg = TPM_ALG_SHA512; - else if (digestSz >= TPM_SHA384_DIGEST_SIZE) + else if (curve_id == TPM_ECC_NIST_P384) hashAlg = TPM_ALG_SHA384; else hashAlg = TPM_ALG_SHA256;