3737#include "py/stream.h"
3838#include "py/objstr.h"
3939#include "py/reader.h"
40+ #include "py/mphal.h"
4041#include "py/gc.h"
4142#include "extmod/vfs.h"
4243
4748#include "mbedtls/pk.h"
4849#include "mbedtls/entropy.h"
4950#include "mbedtls/ctr_drbg.h"
51+ #ifdef MBEDTLS_SSL_PROTO_DTLS
52+ #include "mbedtls/timing.h"
53+ #endif
5054#include "mbedtls/debug.h"
5155#include "mbedtls/error.h"
5256#if MBEDTLS_VERSION_NUMBER >= 0x03000000
6569
6670#define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)
6771
72+ #define MP_ENDPOINT_IS_SERVER (1 << 0)
73+ #define MP_TRANSPORT_IS_DTLS (1 << 1)
74+
75+ #define MP_PROTOCOL_TLS_CLIENT 0
76+ #define MP_PROTOCOL_TLS_SERVER MP_ENDPOINT_IS_SERVER
77+ #define MP_PROTOCOL_DTLS_CLIENT MP_TRANSPORT_IS_DTLS
78+ #define MP_PROTOCOL_DTLS_SERVER MP_ENDPOINT_IS_SERVER | MP_TRANSPORT_IS_DTLS
79+
6880// This corresponds to an SSLContext object.
6981typedef struct _mp_obj_ssl_context_t {
7082 mp_obj_base_t base ;
@@ -91,6 +103,12 @@ typedef struct _mp_obj_ssl_socket_t {
91103
92104 uintptr_t poll_mask ; // Indicates which read or write operations the protocol needs next
93105 int last_error ; // The last error code, if any
106+
107+ #ifdef MBEDTLS_SSL_PROTO_DTLS
108+ mp_uint_t timer_start_ms ;
109+ mp_uint_t timer_fin_ms ;
110+ mp_uint_t timer_int_ms ;
111+ #endif
94112} mp_obj_ssl_socket_t ;
95113
96114static const mp_obj_type_t ssl_context_type ;
@@ -242,7 +260,10 @@ static mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args
242260 mp_arg_check_num (n_args , n_kw , 1 , 1 , false);
243261
244262 // This is the "protocol" argument.
245- mp_int_t endpoint = mp_obj_get_int (args [0 ]);
263+ mp_int_t protocol = mp_obj_get_int (args [0 ]);
264+
265+ int endpoint = (protocol & MP_ENDPOINT_IS_SERVER ) ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT ;
266+ int transport = (protocol & MP_TRANSPORT_IS_DTLS ) ? MBEDTLS_SSL_TRANSPORT_DATAGRAM : MBEDTLS_SSL_TRANSPORT_STREAM ;
246267
247268 // Create SSLContext object.
248269 #if MICROPY_PY_SSL_FINALISER
@@ -282,7 +303,7 @@ static mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args
282303 }
283304
284305 ret = mbedtls_ssl_config_defaults (& self -> conf , endpoint ,
285- MBEDTLS_SSL_TRANSPORT_STREAM , MBEDTLS_SSL_PRESET_DEFAULT );
306+ transport , MBEDTLS_SSL_PRESET_DEFAULT );
286307 if (ret != 0 ) {
287308 mbedtls_raise_error (ret );
288309 }
@@ -525,6 +546,39 @@ static int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
525546 }
526547}
527548
549+ #ifdef MBEDTLS_SSL_PROTO_DTLS
550+ static void _mbedtls_timing_set_delay (void * ctx , uint32_t int_ms , uint32_t fin_ms ) {
551+ mp_obj_ssl_socket_t * o = (mp_obj_ssl_socket_t * )ctx ;
552+
553+ o -> timer_int_ms = int_ms ;
554+ o -> timer_fin_ms = fin_ms ;
555+
556+ if (fin_ms != 0 ) {
557+ o -> timer_start_ms = mp_hal_ticks_ms ();
558+ }
559+ }
560+
561+ static int _mbedtls_timing_get_delay (void * ctx ) {
562+ mp_obj_ssl_socket_t * o = (mp_obj_ssl_socket_t * )ctx ;
563+
564+ if (o -> timer_fin_ms == 0 ) {
565+ return -1 ;
566+ }
567+
568+ mp_uint_t elapsed_ms = mp_hal_ticks_ms () - o -> timer_start_ms ;
569+
570+ if (elapsed_ms >= o -> timer_fin_ms ) {
571+ return 2 ;
572+ }
573+
574+ if (elapsed_ms >= o -> timer_int_ms ) {
575+ return 1 ;
576+ }
577+
578+ return 0 ;
579+ }
580+ #endif
581+
528582static mp_obj_t ssl_socket_make_new (mp_obj_ssl_context_t * ssl_context , mp_obj_t sock ,
529583 bool server_side , bool do_handshake_on_connect , mp_obj_t server_hostname ) {
530584
@@ -577,6 +631,10 @@ static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t
577631 mp_raise_ValueError (MP_ERROR_TEXT ("CERT_REQUIRED requires server_hostname" ));
578632 }
579633
634+ #ifdef MBEDTLS_SSL_PROTO_DTLS
635+ mbedtls_ssl_set_timer_cb (& o -> ssl , o , _mbedtls_timing_set_delay , _mbedtls_timing_get_delay );
636+ #endif
637+
580638 mbedtls_ssl_set_bio (& o -> ssl , & o -> sock , _mbedtls_ssl_send , _mbedtls_ssl_recv , NULL );
581639
582640 if (do_handshake_on_connect ) {
@@ -788,6 +846,12 @@ static const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = {
788846 { MP_ROM_QSTR (MP_QSTR_readinto ), MP_ROM_PTR (& mp_stream_readinto_obj ) },
789847 { MP_ROM_QSTR (MP_QSTR_readline ), MP_ROM_PTR (& mp_stream_unbuffered_readline_obj ) },
790848 { MP_ROM_QSTR (MP_QSTR_write ), MP_ROM_PTR (& mp_stream_write_obj ) },
849+ #ifdef MBEDTLS_SSL_PROTO_DTLS
850+ { MP_ROM_QSTR (MP_QSTR_recv ), MP_ROM_PTR (& mp_stream_read1_obj ) },
851+ { MP_ROM_QSTR (MP_QSTR_recv_into ), MP_ROM_PTR (& mp_stream_readinto_obj ) },
852+ { MP_ROM_QSTR (MP_QSTR_send ), MP_ROM_PTR (& mp_stream_write1_obj ) },
853+ { MP_ROM_QSTR (MP_QSTR_sendall ), MP_ROM_PTR (& mp_stream_write_obj ) },
854+ #endif
791855 { MP_ROM_QSTR (MP_QSTR_setblocking ), MP_ROM_PTR (& socket_setblocking_obj ) },
792856 { MP_ROM_QSTR (MP_QSTR_close ), MP_ROM_PTR (& mp_stream_close_obj ) },
793857 #if MICROPY_PY_SSL_FINALISER
@@ -879,8 +943,12 @@ static const mp_rom_map_elem_t mp_module_tls_globals_table[] = {
879943
880944 // Constants.
881945 { MP_ROM_QSTR (MP_QSTR_MBEDTLS_VERSION ), MP_ROM_PTR (& mbedtls_version_obj )},
882- { MP_ROM_QSTR (MP_QSTR_PROTOCOL_TLS_CLIENT ), MP_ROM_INT (MBEDTLS_SSL_IS_CLIENT ) },
883- { MP_ROM_QSTR (MP_QSTR_PROTOCOL_TLS_SERVER ), MP_ROM_INT (MBEDTLS_SSL_IS_SERVER ) },
946+ { MP_ROM_QSTR (MP_QSTR_PROTOCOL_TLS_CLIENT ), MP_ROM_INT (MP_PROTOCOL_TLS_CLIENT ) },
947+ { MP_ROM_QSTR (MP_QSTR_PROTOCOL_TLS_SERVER ), MP_ROM_INT (MP_PROTOCOL_TLS_SERVER ) },
948+ #ifdef MBEDTLS_SSL_PROTO_DTLS
949+ { MP_ROM_QSTR (MP_QSTR_PROTOCOL_DTLS_CLIENT ), MP_ROM_INT (MP_PROTOCOL_DTLS_CLIENT ) },
950+ { MP_ROM_QSTR (MP_QSTR_PROTOCOL_DTLS_SERVER ), MP_ROM_INT (MP_PROTOCOL_DTLS_SERVER ) },
951+ #endif
884952 { MP_ROM_QSTR (MP_QSTR_CERT_NONE ), MP_ROM_INT (MBEDTLS_SSL_VERIFY_NONE ) },
885953 { MP_ROM_QSTR (MP_QSTR_CERT_OPTIONAL ), MP_ROM_INT (MBEDTLS_SSL_VERIFY_OPTIONAL ) },
886954 { MP_ROM_QSTR (MP_QSTR_CERT_REQUIRED ), MP_ROM_INT (MBEDTLS_SSL_VERIFY_REQUIRED ) },
0 commit comments