From a22cafdc684c570dda13b0051bef8cb8b2022214 Mon Sep 17 00:00:00 2001 From: Kazuho Oku Date: Tue, 31 Mar 2026 10:03:00 +0900 Subject: [PATCH] Avoid UB by using a wrapper object to obtain SNI lazily. The cost of the wrapper is assumed to be negligible because, under many calling conventions, an object containing two pointers is passed using two registers, just like two separate pointer arguments. --- include/picotls.h | 47 +++++++++++++++++++++++++++++++++-------------- lib/picotls.c | 11 +++++------ 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/include/picotls.h b/include/picotls.h index fef405e8..ecc394b3 100644 --- a/include/picotls.h +++ b/include/picotls.h @@ -1405,8 +1405,27 @@ uint64_t ptls_decode_quicint(const uint8_t **src, const uint8_t *end); ptls_decode_assert_block_close((src), end); \ } while (0) +typedef struct st_ptls_log_getsni_t { + const char *(*cb)(void *arg); + void *arg; +} ptls_log_getsni_t; + +/** + * Creates a lazy callback object for obtaining SNI. The object is used to delay materialization of SNI to only when it is needed. + */ +#define PTLS_LOG_DEFINE_GETSNI(suffix, type, body) \ + static inline const char *ptls_log_getsni_cb_##suffix(void *_arg) \ + { \ + type arg = (type)_arg; \ + body \ + } \ + static inline ptls_log_getsni_t ptls_log_getsni_##suffix(type arg) \ + { \ + return (ptls_log_getsni_t){ptls_log_getsni_cb_##suffix, arg}; \ + } + #if PTLS_HAVE_LOG -#define PTLS_LOG__DO_LOG(module, name, conn_state, get_sni, get_sni_arg, add_time, block) \ +#define PTLS_LOG__DO_LOG(module, name, conn_state, get_sni, add_time, block) \ do { \ int ptlslog_include_appdata = 0; \ do { \ @@ -1414,12 +1433,11 @@ uint64_t ptls_decode_quicint(const uint8_t **src, const uint8_t *end); do { \ block \ } while (0); \ - ptlslog_include_appdata = \ - ptls_log__do_write_end(&logpoint, (conn_state), (get_sni), (get_sni_arg), ptlslog_include_appdata); \ + ptlslog_include_appdata = ptls_log__do_write_end(&logpoint, (conn_state), (get_sni), ptlslog_include_appdata); \ } while (PTLS_UNLIKELY(ptlslog_include_appdata)); \ } while (0) #else -#define PTLS_LOG__DO_LOG(module, name, conn_state, get_sni, get_sni_arg, add_time, block) /* don't generate code */ +#define PTLS_LOG__DO_LOG(module, name, conn_state, get_sni, add_time, block) /* don't generate code */ #endif #define PTLS_LOG_DEFINE_POINT(_module, _name, _var) \ @@ -1430,7 +1448,7 @@ uint64_t ptls_decode_quicint(const uint8_t **src, const uint8_t *end); PTLS_LOG_DEFINE_POINT(module, name, logpoint); \ if (PTLS_LIKELY(ptls_log_point_maybe_active(&logpoint) == 0)) \ break; \ - PTLS_LOG__DO_LOG(module, name, NULL, NULL, NULL, 1, {block}); \ + PTLS_LOG__DO_LOG(module, name, NULL, (ptls_log_getsni_t){NULL}, 1, {block}); \ } while (0) #define PTLS_LOG_CONN(name, tls, block) \ @@ -1441,10 +1459,10 @@ uint64_t ptls_decode_quicint(const uint8_t **src, const uint8_t *end); break; \ ptls_t *_tls = (tls); \ ptls_log_conn_state_t *conn_state = ptls_get_log_state(_tls); \ - active &= ptls_log_conn_maybe_active(conn_state, (const char *(*)(void *))ptls_get_server_name, _tls); \ + active &= ptls_log_conn_maybe_active(conn_state, ptls_log_getsni_ptls(_tls)); \ if (PTLS_LIKELY(active == 0)) \ break; \ - PTLS_LOG__DO_LOG(picotls, name, conn_state, (const char *(*)(void *))ptls_get_server_name, _tls, 1, { \ + PTLS_LOG__DO_LOG(picotls, name, conn_state, ptls_log_getsni_ptls(_tls), 1, { \ PTLS_LOG_ELEMENT_PTR(tls, _tls); \ do { \ block \ @@ -1571,7 +1589,7 @@ static uint32_t ptls_log_point_maybe_active(struct st_ptls_log_point_t *point); /** * returns a bitmap indicating the loggers active for given connection */ -static uint32_t ptls_log_conn_maybe_active(ptls_log_conn_state_t *conn, const char *(*get_sni)(void *), void *get_sni_arg); +static uint32_t ptls_log_conn_maybe_active(ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni); /** * Returns the number of log events that were unable to be emitted. @@ -1588,8 +1606,7 @@ size_t ptls_log_num_lost(void); int ptls_log_add_fd(int fd, float sample_ratio, const char *points, const char *snis, const char *addresses, int appdata); void ptls_log__recalc_point(int caller_locked, struct st_ptls_log_point_t *point); -void ptls_log__recalc_conn(int caller_locked, struct st_ptls_log_conn_state_t *conn, const char *(*get_sni)(void *), - void *get_sni_arg); +void ptls_log__recalc_conn(int caller_locked, struct st_ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni); void ptls_log__do_push_element_safestr(const char *prefix, size_t prefix_len, const char *s, size_t l); void ptls_log__do_push_element_unsafestr(const char *prefix, size_t prefix_len, const char *s, size_t l); void ptls_log__do_push_element_hexdump(const char *prefix, size_t prefix_len, const void *s, size_t l); @@ -1603,8 +1620,8 @@ void ptls_log__do_push_appdata_element_unsafestr(int includes_appdata, const cha void ptls_log__do_push_appdata_element_hexdump(int includes_appdata, const char *prefix, size_t prefix_len, const void *s, size_t l); void ptls_log__do_write_start(struct st_ptls_log_point_t *point, int add_time); -int ptls_log__do_write_end(struct st_ptls_log_point_t *point, struct st_ptls_log_conn_state_t *conn, const char *(*get_sni)(void *), - void *get_sni_arg, int includes_appdata); +int ptls_log__do_write_end(struct st_ptls_log_point_t *point, struct st_ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni, + int includes_appdata); /** * create a client object to handle new TLS connection @@ -1963,6 +1980,8 @@ extern ptls_get_time_t ptls_get_time; */ static void ptls_hash_clone_memcpy(void *dst, const void *src, size_t size); +PTLS_LOG_DEFINE_GETSNI(ptls, ptls_t *, { return ptls_get_server_name(arg); }) + /* inline functions */ inline uint32_t ptls_log_point_maybe_active(struct st_ptls_log_point_t *point) @@ -1981,11 +2000,11 @@ inline void ptls_log_recalc_conn_state(ptls_log_conn_state_t *state) state->state.generation = 0; } -inline uint32_t ptls_log_conn_maybe_active(ptls_log_conn_state_t *conn, const char *(*get_sni)(void *), void *get_sni_arg) +inline uint32_t ptls_log_conn_maybe_active(ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni) { #if PTLS_HAVE_LOG if (PTLS_UNLIKELY(conn->state.generation != ptls_log._generation)) - ptls_log__recalc_conn(0, conn, get_sni, get_sni_arg); + ptls_log__recalc_conn(0, conn, getsni); return conn->state.active_conns; #else return 0; diff --git a/lib/picotls.c b/lib/picotls.c index 36899d40..813f7d13 100644 --- a/lib/picotls.c +++ b/lib/picotls.c @@ -6928,8 +6928,7 @@ void ptls_log__recalc_point(int caller_locked, struct st_ptls_log_point_t *point pthread_mutex_unlock(&logctx.mutex); } -void ptls_log__recalc_conn(int caller_locked, struct st_ptls_log_conn_state_t *conn, const char *(*get_sni)(void *), - void *get_sni_arg) +void ptls_log__recalc_conn(int caller_locked, struct st_ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni) { if (!caller_locked) pthread_mutex_lock(&logctx.mutex); @@ -6937,7 +6936,7 @@ void ptls_log__recalc_conn(int caller_locked, struct st_ptls_log_conn_state_t *c if (conn->state.generation != ptls_log._generation) { /* update active bitmap */ uint32_t new_active = 0; - const char *sni = get_sni != NULL ? get_sni(get_sni_arg) : NULL; + const char *sni = getsni.cb != NULL ? getsni.cb(getsni.arg) : NULL; for (size_t slot = 0; slot < PTLS_ELEMENTSOF(logctx.conns); ++slot) { if (logctx.conns[slot].points != NULL && conn->random_ < logctx.conns[slot].sample_ratio && is_in_stringlist(logctx.conns[slot].snis, sni) && @@ -7090,8 +7089,8 @@ void ptls_log__do_write_start(struct st_ptls_log_point_t *point, int add_time) logbuf.buf.off = (size_t)written; } -int ptls_log__do_write_end(struct st_ptls_log_point_t *point, struct st_ptls_log_conn_state_t *conn, const char *(*get_sni)(void *), - void *get_sni_arg, int includes_appdata) +int ptls_log__do_write_end(struct st_ptls_log_point_t *point, struct st_ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni, + int includes_appdata) { if (!expand_logbuf_or_invalidate("}\n", 2, 0)) return 0; @@ -7105,7 +7104,7 @@ int ptls_log__do_write_end(struct st_ptls_log_point_t *point, struct st_ptls_log ptls_log__recalc_point(1, point); uint32_t active = point->state.active_conns; if (conn != NULL && conn->state.generation != ptls_log._generation) { - ptls_log__recalc_conn(1, conn, get_sni, get_sni_arg); + ptls_log__recalc_conn(1, conn, getsni); active &= conn->state.active_conns; }