2222
2323#include "mbedtls/version.h"
2424
25+ #define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)
26+
2527#if defined(MBEDTLS_ERROR_C )
2628#include "../../lib/mbedtls_errors/mp_mbedtls_errors.c"
2729#endif
@@ -220,6 +222,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
220222 o -> base .type = & ssl_sslsocket_type ;
221223 o -> ssl_context = self ;
222224 o -> sock_obj = socket ;
225+ o -> poll_mask = 0 ;
223226
224227 mp_load_method (socket , MP_QSTR_accept , o -> accept_args );
225228 mp_load_method (socket , MP_QSTR_bind , o -> bind_args );
@@ -330,7 +333,8 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
330333 }
331334}
332335
333- mp_uint_t common_hal_ssl_sslsocket_recv_into (ssl_sslsocket_obj_t * self , uint8_t * buf , uint32_t len ) {
336+ mp_uint_t common_hal_ssl_sslsocket_recv_into (ssl_sslsocket_obj_t * self , uint8_t * buf , mp_uint_t len ) {
337+ self -> poll_mask = 0 ;
334338 int ret = mbedtls_ssl_read (& self -> ssl , buf , len );
335339 DEBUG_PRINT ("recv_into mbedtls_ssl_read() -> %d\n" , ret );
336340 if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY ) {
@@ -342,17 +346,24 @@ mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t
342346 DEBUG_PRINT ("returning %d\n" , ret );
343347 return ret ;
344348 }
349+ if (ret == MBEDTLS_ERR_SSL_WANT_WRITE ) {
350+ self -> poll_mask = MP_STREAM_POLL_WR ;
351+ }
345352 DEBUG_PRINT ("raising errno [error case] %d\n" , ret );
346353 mbedtls_raise_error (ret );
347354}
348355
349- mp_uint_t common_hal_ssl_sslsocket_send (ssl_sslsocket_obj_t * self , const uint8_t * buf , uint32_t len ) {
356+ mp_uint_t common_hal_ssl_sslsocket_send (ssl_sslsocket_obj_t * self , const uint8_t * buf , mp_uint_t len ) {
357+ self -> poll_mask = 0 ;
350358 int ret = mbedtls_ssl_write (& self -> ssl , buf , len );
351359 DEBUG_PRINT ("send mbedtls_ssl_write() -> %d\n" , ret );
352360 if (ret >= 0 ) {
353361 DEBUG_PRINT ("returning %d\n" , ret );
354362 return ret ;
355363 }
364+ if (ret == MBEDTLS_ERR_SSL_WANT_READ ) {
365+ self -> poll_mask = MP_STREAM_POLL_RD ;
366+ }
356367 DEBUG_PRINT ("raising errno [error case] %d\n" , ret );
357368 mbedtls_raise_error (ret );
358369}
@@ -448,3 +459,37 @@ void common_hal_ssl_sslsocket_setsockopt(ssl_sslsocket_obj_t *self, mp_obj_t lev
448459void common_hal_ssl_sslsocket_settimeout (ssl_sslsocket_obj_t * self , mp_obj_t timeout_obj ) {
449460 ssl_socket_settimeout (self , timeout_obj );
450461}
462+
463+ static bool poll_common (ssl_sslsocket_obj_t * self , uintptr_t arg ) {
464+ // Take into account that the library might have buffered data already
465+ int has_pending = 0 ;
466+ if (arg & MP_STREAM_POLL_RD ) {
467+ has_pending = mbedtls_ssl_check_pending (& self -> ssl );
468+ if (has_pending ) {
469+ // Shortcut if we only need to read and we have buffered data, no need to go to the underlying socket
470+ return true;
471+ }
472+ }
473+
474+ // If the library signaled us that it needs reading or writing, only
475+ // check that direction
476+ if (self -> poll_mask && (arg & MP_STREAM_POLL_RDWR )) {
477+ arg = (arg & ~MP_STREAM_POLL_RDWR ) | self -> poll_mask ;
478+ }
479+
480+ // If direction the library needed is available, return a fake
481+ // result to the caller so that it reenters a read or a write to
482+ // allow the handshake to progress
483+ const mp_stream_p_t * stream_p = mp_get_stream_raise (self -> sock_obj , MP_STREAM_OP_IOCTL );
484+ int errcode ;
485+ mp_int_t ret = stream_p -> ioctl (self -> sock_obj , MP_STREAM_POLL , arg , & errcode );
486+ return ret != 0 ;
487+ }
488+
489+ bool common_hal_ssl_sslsocket_readable (ssl_sslsocket_obj_t * self ) {
490+ return poll_common (self , MP_STREAM_POLL_RD );
491+ }
492+
493+ bool common_hal_ssl_sslsocket_writable (ssl_sslsocket_obj_t * self ) {
494+ return poll_common (self , MP_STREAM_POLL_WR );
495+ }
0 commit comments