1414#include <zephyr/ztest.h>
1515
1616#include <mbedtls/x509.h>
17+ #include <mbedtls/x509_crt.h>
1718
1819LOG_MODULE_REGISTER (tls_test , CONFIG_NET_SOCKETS_LOG_LEVEL );
1920
@@ -150,6 +151,7 @@ static void server_thread_fn(void *arg0, void *arg1, void *arg2)
150151{
151152 const int server_fd = POINTER_TO_INT (arg0 );
152153 const int echo = POINTER_TO_INT (arg1 );
154+ const int expect_failure = POINTER_TO_INT (arg2 );
153155
154156 int r ;
155157 int client_fd ;
@@ -168,6 +170,10 @@ static void server_thread_fn(void *arg0, void *arg1, void *arg2)
168170 NET_DBG ("Accepting client connection.." );
169171 k_sem_give (& server_sem );
170172 r = accept (server_fd , (struct sockaddr * )& sa , & addrlen );
173+ if (expect_failure ) {
174+ zassert_equal (r , -1 , "accept() should've failed" );
175+ return ;
176+ }
171177 zassert_not_equal (r , -1 , "accept() failed (%d)" , r );
172178 client_fd = r ;
173179
@@ -199,7 +205,7 @@ static void server_thread_fn(void *arg0, void *arg1, void *arg2)
199205}
200206
201207static int test_configure_server (k_tid_t * server_thread_id , int peer_verify ,
202- int echo )
208+ int echo , int expect_failure )
203209{
204210 static const sec_tag_t server_tag_list_verify_none [] = {
205211 SERVER_CERTIFICATE_TAG ,
@@ -282,7 +288,8 @@ static int test_configure_server(k_tid_t *server_thread_id, int peer_verify,
282288 * server_thread_id = k_thread_create (& server_thread , server_stack ,
283289 STACK_SIZE , server_thread_fn ,
284290 INT_TO_POINTER (server_fd ),
285- INT_TO_POINTER (echo ), NULL ,
291+ INT_TO_POINTER (echo ),
292+ INT_TO_POINTER (expect_failure ),
286293 K_PRIO_PREEMPT (8 ), 0 , K_NO_WAIT );
287294
288295 r = k_sem_take (& server_sem , K_MSEC (TIMEOUT ));
@@ -380,7 +387,8 @@ static void test_common(int peer_verify)
380387 /*
381388 * Server socket setup
382389 */
383- server_fd = test_configure_server (& server_thread_id , peer_verify , true);
390+ server_fd = test_configure_server (& server_thread_id , peer_verify , true,
391+ false);
384392
385393 /*
386394 * Client socket setup
@@ -444,7 +452,7 @@ static void test_tls_cert_verify_result_opt_common(uint32_t expect)
444452 }
445453
446454 server_fd = test_configure_server (& server_thread_id , TLS_PEER_VERIFY_NONE ,
447- false);
455+ false, false );
448456 client_fd = test_configure_client (& sa , false, hostname );
449457
450458 ret = zsock_setsockopt (client_fd , SOL_TLS , TLS_PEER_VERIFY ,
@@ -473,6 +481,71 @@ ZTEST(net_socket_tls_api_extension, test_tls_cert_verify_result_opt_bad_cn)
473481 test_tls_cert_verify_result_opt_common (MBEDTLS_X509_BADCERT_CN_MISMATCH );
474482}
475483
484+ struct test_cert_verify_ctx {
485+ bool cb_called ;
486+ int result ;
487+ };
488+
489+ static int cert_verify_cb (void * ctx , mbedtls_x509_crt * crt , int depth ,
490+ uint32_t * flags )
491+ {
492+ struct test_cert_verify_ctx * test_ctx = (struct test_cert_verify_ctx * )ctx ;
493+
494+ test_ctx -> cb_called = true;
495+
496+ if (test_ctx -> result == 0 ) {
497+ * flags = 0 ;
498+ } else {
499+ * flags |= MBEDTLS_X509_BADCERT_NOT_TRUSTED ;
500+ }
501+
502+ return test_ctx -> result ;
503+ }
504+
505+ static void test_tls_cert_verify_cb_opt_common (int result )
506+ {
507+ int server_fd , client_fd , ret ;
508+ k_tid_t server_thread_id ;
509+ struct sockaddr_in sa ;
510+ struct test_cert_verify_ctx ctx = {
511+ .cb_called = false,
512+ .result = result ,
513+ };
514+ struct tls_cert_verify_cb cb = {
515+ .cb = cert_verify_cb ,
516+ .ctx = & ctx ,
517+ };
518+
519+ server_fd = test_configure_server (& server_thread_id , TLS_PEER_VERIFY_NONE ,
520+ false, result == 0 ? false : true);
521+ client_fd = test_configure_client (& sa , false, "localhost" );
522+
523+ ret = zsock_setsockopt (client_fd , SOL_TLS , TLS_CERT_VERIFY_CALLBACK ,
524+ & cb , sizeof (cb ));
525+ zassert_ok (ret , "failed to set TLS_CERT_VERIFY_CALLBACK (%d)" , errno );
526+
527+ ret = zsock_connect (client_fd , (struct sockaddr * )& sa , sizeof (sa ));
528+ zassert_true (ctx .cb_called , "callback not called" );
529+ if (result == 0 ) {
530+ zassert_equal (ret , 0 , "failed to connect (%d)" , errno );
531+ } else {
532+ zassert_equal (ret , -1 , "connect() should fail" );
533+ zassert_equal (errno , ECONNABORTED , "invalid errno" );
534+ }
535+
536+ test_shutdown (client_fd , server_fd , server_thread_id );
537+ }
538+
539+ ZTEST (net_socket_tls_api_extension , test_tls_cert_verify_cb_opt_ok )
540+ {
541+ test_tls_cert_verify_cb_opt_common (0 );
542+ }
543+
544+ ZTEST (net_socket_tls_api_extension , test_tls_cert_verify_cb_opt_bad_cert )
545+ {
546+ test_tls_cert_verify_cb_opt_common (MBEDTLS_ERR_X509_CERT_VERIFY_FAILED );
547+ }
548+
476549static void * setup (void )
477550{
478551 int r ;
0 commit comments