Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions subsys/net/lib/sockets/sockets_tls.c
Original file line number Diff line number Diff line change
Expand Up @@ -1464,10 +1464,162 @@ static int tls_mbedtls_init(struct tls_context *context, bool is_server)
return 0;
}

static int tls_check_cert(struct tls_credential *cert)
{
#if defined(MBEDTLS_X509_CRT_PARSE_C)
mbedtls_x509_crt cert_ctx;
int err;

mbedtls_x509_crt_init(&cert_ctx);

if (crt_is_pem(cert->buf, cert->len)) {
err = mbedtls_x509_crt_parse(&cert_ctx, cert->buf, cert->len);
} else {
/* For DER case, use the no copy version of the parsing function
* to avoid unnecessary heap allocations.
*/
err = mbedtls_x509_crt_parse_der_nocopy(&cert_ctx, cert->buf,
cert->len);
}

if (err != 0) {
NET_ERR("Failed to parse %s on tag %d, err: -0x%x",
"certificate", cert->tag, -err);
return -EINVAL;
}

mbedtls_x509_crt_free(&cert_ctx);

return err;
#else
NET_ERR("TLS with certificates disabled. "
"Reconfigure mbed TLS to support certificate based key exchange.");

return -ENOTSUP;
#endif /* MBEDTLS_X509_CRT_PARSE_C */
}

static int tls_check_priv_key(struct tls_credential *priv_key)
{
#if defined(MBEDTLS_X509_CRT_PARSE_C)
mbedtls_pk_context key_ctx;
int err;

mbedtls_pk_init(&key_ctx);

err = mbedtls_pk_parse_key(&key_ctx, priv_key->buf,
priv_key->len, NULL, 0,
tls_ctr_drbg_random, NULL);
if (err != 0) {
NET_ERR("Failed to parse %s on tag %d, err: -0x%x",
"private key", priv_key->tag, -err);
err = -EINVAL;
}

mbedtls_pk_free(&key_ctx);

return err;
#else
NET_ERR("TLS with certificates disabled. "
"Reconfigure mbed TLS to support certificate based key exchange.");

return -ENOTSUP;
#endif /* MBEDTLS_X509_CRT_PARSE_C */
}

static int tls_check_psk(struct tls_credential *psk)
{
#if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
struct tls_credential *psk_id;

psk_id = credential_get(psk->tag, TLS_CREDENTIAL_PSK_ID);
if (psk_id == NULL) {
NET_ERR("No matching PSK ID found for tag %d", psk->tag);
return -EINVAL;
}

if (psk->len == 0 || psk_id->len == 0) {
NET_ERR("PSK or PSK ID empty on tag %d", psk->tag);
return -EINVAL;
}

return 0;
#else
NET_ERR("TLS with PSK disabled. "
"Reconfigure mbed TLS to support PSK based key exchange.");

return -ENOTSUP;
#endif
}

/* TODO add decent logs */
static int tls_check_credentials(const sec_tag_t *sec_tags, int sec_tag_count)
{
int err = 0;

credentials_lock();

for (int i = 0; i < sec_tag_count; i++) {
sec_tag_t tag = sec_tags[i];
struct tls_credential *cred = NULL;
bool tag_found = false;

while ((cred = credential_next_get(tag, cred)) != NULL) {
tag_found = true;

switch (cred->type) {
case TLS_CREDENTIAL_CA_CERTIFICATE:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add here __fallthrough;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be, fixed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__fallthrough is needed when cases have a body, for combined cases it's not needed.

Both cases are fine here:

switch (some_var) {
case 1:
	/* some logic */
	__fallthrough;
case 2:
	/* ... */
}

switch (some_var) {
case 1:
case 2:
	/* ... */
}

__fallthrough;
case TLS_CREDENTIAL_PUBLIC_CERTIFICATE:
err = tls_check_cert(cred);
if (err != 0) {
goto exit;
}

break;
case TLS_CREDENTIAL_PRIVATE_KEY:
err = tls_check_priv_key(cred);
if (err != 0) {
goto exit;
}

break;
case TLS_CREDENTIAL_PSK:
err = tls_check_psk(cred);
if (err != 0) {
goto exit;
}

break;
case TLS_CREDENTIAL_PSK_ID:
/* Ignore PSK ID - it will be verified together
* with PSK.
*/
break;
default:
return -EINVAL;
}
}

/* If no credential is found with such a tag, report an error. */
if (!tag_found) {
NET_ERR("No TLS credential found with tag %d", tag);
err = -ENOENT;
goto exit;
}
}

exit:
credentials_unlock();

return err;
}

static int tls_opt_sec_tag_list_set(struct tls_context *context,
const void *optval, socklen_t optlen)
{
int sec_tag_cnt;
int ret;

if (!optval) {
return -EINVAL;
Expand All @@ -1483,6 +1635,11 @@ static int tls_opt_sec_tag_list_set(struct tls_context *context,
return -EINVAL;
}

ret = tls_check_credentials((const sec_tag_t *)optval, sec_tag_cnt);
if (ret < 0) {
return ret;
}

memcpy(context->options.sec_tag_list.sec_tags, optval, optlen);
context->options.sec_tag_list.sec_tag_count = sec_tag_cnt;

Expand Down
1 change: 1 addition & 0 deletions tests/net/socket/tls/prj.conf
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ CONFIG_NET_SOCKETS_SOCKOPT_TLS=y
CONFIG_NET_SOCKETS_ENABLE_DTLS=y
CONFIG_NET_SOCKETS_DTLS_SENDMSG_BUF_SIZE=128
CONFIG_NET_SOCKETS_TLS_MAX_CONTEXTS=4
CONFIG_TLS_MAX_CREDENTIALS_NUMBER=10
CONFIG_NET_CONTEXT_RCVTIMEO=y
CONFIG_NET_CONTEXT_SNDTIMEO=y
CONFIG_NET_CONTEXT_RCVBUF=y
Expand Down
82 changes: 82 additions & 0 deletions tests/net/socket/tls/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -1814,6 +1814,88 @@ ZTEST(net_socket_tls, test_poll_dtls_pollerr)
k_msleep(10);
}

#define BAD_CA_CERT_TAG 11
#define BAD_OWN_CERT_TAG 12
#define BAD_PRIV_KEY_TAG 13
#define BAD_PSK_TAG 14
#define BAD_NO_CRED_TAG 15

static void remove_bad_cred(void)
{
(void)tls_credential_delete(BAD_CA_CERT_TAG, TLS_CREDENTIAL_CA_CERTIFICATE);
(void)tls_credential_delete(BAD_OWN_CERT_TAG, TLS_CREDENTIAL_PUBLIC_CERTIFICATE);
(void)tls_credential_delete(BAD_PRIV_KEY_TAG, TLS_CREDENTIAL_PRIVATE_KEY);
(void)tls_credential_delete(BAD_PSK_TAG, TLS_CREDENTIAL_PSK);
(void)tls_credential_delete(BAD_PSK_TAG, TLS_CREDENTIAL_PSK_ID);
}

static void test_bad_cred_common(bool test_dtls)
{
static uint8_t bad_ca_cert[] = "bad ca cert";
static uint8_t bad_own_cert[] = "bad own cert";
static uint8_t bad_priv_key[] = "bad priv key";
static uint8_t bad_psk[] = "bad psk"; /* PSK is not bad per se, but will
* try to use it without matching PSK ID.
*/
sec_tag_t bad_tags[] = {
BAD_CA_CERT_TAG,
BAD_OWN_CERT_TAG,
BAD_PRIV_KEY_TAG,
BAD_PSK_TAG,
BAD_NO_CRED_TAG,
};

/* Preconfigure "bad" credentials */
remove_bad_cred();

zassert_ok(tls_credential_add(BAD_CA_CERT_TAG, TLS_CREDENTIAL_CA_CERTIFICATE,
bad_ca_cert, sizeof(bad_ca_cert)),
"Failed to add ca cert");
zassert_ok(tls_credential_add(BAD_OWN_CERT_TAG, TLS_CREDENTIAL_PUBLIC_CERTIFICATE,
bad_own_cert, sizeof(bad_own_cert)),
"Failed to add own cert");
zassert_ok(tls_credential_add(BAD_PRIV_KEY_TAG, TLS_CREDENTIAL_PRIVATE_KEY,
bad_priv_key, sizeof(bad_priv_key)),
"Failed to add priv key");
zassert_ok(tls_credential_add(BAD_PSK_TAG, TLS_CREDENTIAL_PSK, bad_psk,
sizeof(bad_psk)), "Failed to add psk");

if (test_dtls) {
s_sock = zsock_socket(AF_INET, SOCK_DGRAM, IPPROTO_DTLS_1_2);
} else {
s_sock = zsock_socket(AF_INET, SOCK_STREAM, IPPROTO_TLS_1_2);
}

zassert_true(s_sock >= 0, "socket open failed");

for (int i = 0; i < ARRAY_SIZE(bad_tags); i++) {
sec_tag_t test_tag = bad_tags[i];
int ret;

ret = zsock_setsockopt(s_sock, SOL_TLS, TLS_SEC_TAG_LIST,
&test_tag, sizeof(test_tag));
zassert_equal(ret, -1, "zsock_setsockopt should've failed with invalid credential");
if (test_tag == BAD_NO_CRED_TAG) {
zassert_equal(errno, ENOENT, "Unfound credential should fail with ENOENT");
} else {
zassert_equal(errno, EINVAL, "Bad credential should fail with EINVAL");
}
}

test_sockets_close();
remove_bad_cred();
}

ZTEST(net_socket_tls, test_tls_bad_cred)
{
test_bad_cred_common(false);
}

ZTEST(net_socket_tls, test_dtls_bad_cred)
{
test_bad_cred_common(true);
}

static void *tls_tests_setup(void)
{
k_work_queue_init(&tls_test_work_queue);
Expand Down