diff --git a/subsys/net/lib/sockets/sockets_tls.c b/subsys/net/lib/sockets/sockets_tls.c index 6f4919dc99e72..695af67164685 100644 --- a/subsys/net/lib/sockets/sockets_tls.c +++ b/subsys/net/lib/sockets/sockets_tls.c @@ -1464,10 +1464,162 @@ static int tls_mbedtls_init(struct tls_context *context, bool is_server) return 0; } +static int tls_check_cert(struct tls_credential *cert) +{ +#if defined(MBEDTLS_X509_CRT_PARSE_C) + mbedtls_x509_crt cert_ctx; + int err; + + mbedtls_x509_crt_init(&cert_ctx); + + if (crt_is_pem(cert->buf, cert->len)) { + err = mbedtls_x509_crt_parse(&cert_ctx, cert->buf, cert->len); + } else { + /* For DER case, use the no copy version of the parsing function + * to avoid unnecessary heap allocations. + */ + err = mbedtls_x509_crt_parse_der_nocopy(&cert_ctx, cert->buf, + cert->len); + } + + if (err != 0) { + NET_ERR("Failed to parse %s on tag %d, err: -0x%x", + "certificate", cert->tag, -err); + return -EINVAL; + } + + mbedtls_x509_crt_free(&cert_ctx); + + return err; +#else + NET_ERR("TLS with certificates disabled. " + "Reconfigure mbed TLS to support certificate based key exchange."); + + return -ENOTSUP; +#endif /* MBEDTLS_X509_CRT_PARSE_C */ +} + +static int tls_check_priv_key(struct tls_credential *priv_key) +{ +#if defined(MBEDTLS_X509_CRT_PARSE_C) + mbedtls_pk_context key_ctx; + int err; + + mbedtls_pk_init(&key_ctx); + + err = mbedtls_pk_parse_key(&key_ctx, priv_key->buf, + priv_key->len, NULL, 0, + tls_ctr_drbg_random, NULL); + if (err != 0) { + NET_ERR("Failed to parse %s on tag %d, err: -0x%x", + "private key", priv_key->tag, -err); + err = -EINVAL; + } + + mbedtls_pk_free(&key_ctx); + + return err; +#else + NET_ERR("TLS with certificates disabled. " + "Reconfigure mbed TLS to support certificate based key exchange."); + + return -ENOTSUP; +#endif /* MBEDTLS_X509_CRT_PARSE_C */ +} + +static int tls_check_psk(struct tls_credential *psk) +{ +#if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED) + struct tls_credential *psk_id; + + psk_id = credential_get(psk->tag, TLS_CREDENTIAL_PSK_ID); + if (psk_id == NULL) { + NET_ERR("No matching PSK ID found for tag %d", psk->tag); + return -EINVAL; + } + + if (psk->len == 0 || psk_id->len == 0) { + NET_ERR("PSK or PSK ID empty on tag %d", psk->tag); + return -EINVAL; + } + + return 0; +#else + NET_ERR("TLS with PSK disabled. " + "Reconfigure mbed TLS to support PSK based key exchange."); + + return -ENOTSUP; +#endif +} + +/* TODO add decent logs */ +static int tls_check_credentials(const sec_tag_t *sec_tags, int sec_tag_count) +{ + int err = 0; + + credentials_lock(); + + for (int i = 0; i < sec_tag_count; i++) { + sec_tag_t tag = sec_tags[i]; + struct tls_credential *cred = NULL; + bool tag_found = false; + + while ((cred = credential_next_get(tag, cred)) != NULL) { + tag_found = true; + + switch (cred->type) { + case TLS_CREDENTIAL_CA_CERTIFICATE: + __fallthrough; + case TLS_CREDENTIAL_PUBLIC_CERTIFICATE: + err = tls_check_cert(cred); + if (err != 0) { + goto exit; + } + + break; + case TLS_CREDENTIAL_PRIVATE_KEY: + err = tls_check_priv_key(cred); + if (err != 0) { + goto exit; + } + + break; + case TLS_CREDENTIAL_PSK: + err = tls_check_psk(cred); + if (err != 0) { + goto exit; + } + + break; + case TLS_CREDENTIAL_PSK_ID: + /* Ignore PSK ID - it will be verified together + * with PSK. + */ + break; + default: + return -EINVAL; + } + } + + /* If no credential is found with such a tag, report an error. */ + if (!tag_found) { + NET_ERR("No TLS credential found with tag %d", tag); + err = -ENOENT; + goto exit; + } + } + +exit: + credentials_unlock(); + + return err; +} + static int tls_opt_sec_tag_list_set(struct tls_context *context, const void *optval, socklen_t optlen) { int sec_tag_cnt; + int ret; if (!optval) { return -EINVAL; @@ -1483,6 +1635,11 @@ static int tls_opt_sec_tag_list_set(struct tls_context *context, return -EINVAL; } + ret = tls_check_credentials((const sec_tag_t *)optval, sec_tag_cnt); + if (ret < 0) { + return ret; + } + memcpy(context->options.sec_tag_list.sec_tags, optval, optlen); context->options.sec_tag_list.sec_tag_count = sec_tag_cnt; diff --git a/tests/net/socket/tls/prj.conf b/tests/net/socket/tls/prj.conf index 3bbb28310f66f..6c4e061e808c0 100644 --- a/tests/net/socket/tls/prj.conf +++ b/tests/net/socket/tls/prj.conf @@ -16,6 +16,7 @@ CONFIG_NET_SOCKETS_SOCKOPT_TLS=y CONFIG_NET_SOCKETS_ENABLE_DTLS=y CONFIG_NET_SOCKETS_DTLS_SENDMSG_BUF_SIZE=128 CONFIG_NET_SOCKETS_TLS_MAX_CONTEXTS=4 +CONFIG_TLS_MAX_CREDENTIALS_NUMBER=10 CONFIG_NET_CONTEXT_RCVTIMEO=y CONFIG_NET_CONTEXT_SNDTIMEO=y CONFIG_NET_CONTEXT_RCVBUF=y diff --git a/tests/net/socket/tls/src/main.c b/tests/net/socket/tls/src/main.c index 44254e2f2975f..8186ba7dbe913 100644 --- a/tests/net/socket/tls/src/main.c +++ b/tests/net/socket/tls/src/main.c @@ -1814,6 +1814,88 @@ ZTEST(net_socket_tls, test_poll_dtls_pollerr) k_msleep(10); } +#define BAD_CA_CERT_TAG 11 +#define BAD_OWN_CERT_TAG 12 +#define BAD_PRIV_KEY_TAG 13 +#define BAD_PSK_TAG 14 +#define BAD_NO_CRED_TAG 15 + +static void remove_bad_cred(void) +{ + (void)tls_credential_delete(BAD_CA_CERT_TAG, TLS_CREDENTIAL_CA_CERTIFICATE); + (void)tls_credential_delete(BAD_OWN_CERT_TAG, TLS_CREDENTIAL_PUBLIC_CERTIFICATE); + (void)tls_credential_delete(BAD_PRIV_KEY_TAG, TLS_CREDENTIAL_PRIVATE_KEY); + (void)tls_credential_delete(BAD_PSK_TAG, TLS_CREDENTIAL_PSK); + (void)tls_credential_delete(BAD_PSK_TAG, TLS_CREDENTIAL_PSK_ID); +} + +static void test_bad_cred_common(bool test_dtls) +{ + static uint8_t bad_ca_cert[] = "bad ca cert"; + static uint8_t bad_own_cert[] = "bad own cert"; + static uint8_t bad_priv_key[] = "bad priv key"; + static uint8_t bad_psk[] = "bad psk"; /* PSK is not bad per se, but will + * try to use it without matching PSK ID. + */ + sec_tag_t bad_tags[] = { + BAD_CA_CERT_TAG, + BAD_OWN_CERT_TAG, + BAD_PRIV_KEY_TAG, + BAD_PSK_TAG, + BAD_NO_CRED_TAG, + }; + + /* Preconfigure "bad" credentials */ + remove_bad_cred(); + + zassert_ok(tls_credential_add(BAD_CA_CERT_TAG, TLS_CREDENTIAL_CA_CERTIFICATE, + bad_ca_cert, sizeof(bad_ca_cert)), + "Failed to add ca cert"); + zassert_ok(tls_credential_add(BAD_OWN_CERT_TAG, TLS_CREDENTIAL_PUBLIC_CERTIFICATE, + bad_own_cert, sizeof(bad_own_cert)), + "Failed to add own cert"); + zassert_ok(tls_credential_add(BAD_PRIV_KEY_TAG, TLS_CREDENTIAL_PRIVATE_KEY, + bad_priv_key, sizeof(bad_priv_key)), + "Failed to add priv key"); + zassert_ok(tls_credential_add(BAD_PSK_TAG, TLS_CREDENTIAL_PSK, bad_psk, + sizeof(bad_psk)), "Failed to add psk"); + + if (test_dtls) { + s_sock = zsock_socket(AF_INET, SOCK_DGRAM, IPPROTO_DTLS_1_2); + } else { + s_sock = zsock_socket(AF_INET, SOCK_STREAM, IPPROTO_TLS_1_2); + } + + zassert_true(s_sock >= 0, "socket open failed"); + + for (int i = 0; i < ARRAY_SIZE(bad_tags); i++) { + sec_tag_t test_tag = bad_tags[i]; + int ret; + + ret = zsock_setsockopt(s_sock, SOL_TLS, TLS_SEC_TAG_LIST, + &test_tag, sizeof(test_tag)); + zassert_equal(ret, -1, "zsock_setsockopt should've failed with invalid credential"); + if (test_tag == BAD_NO_CRED_TAG) { + zassert_equal(errno, ENOENT, "Unfound credential should fail with ENOENT"); + } else { + zassert_equal(errno, EINVAL, "Bad credential should fail with EINVAL"); + } + } + + test_sockets_close(); + remove_bad_cred(); +} + +ZTEST(net_socket_tls, test_tls_bad_cred) +{ + test_bad_cred_common(false); +} + +ZTEST(net_socket_tls, test_dtls_bad_cred) +{ + test_bad_cred_common(true); +} + static void *tls_tests_setup(void) { k_work_queue_init(&tls_test_work_queue);