diff --git a/include/picotls.h b/include/picotls.h index fef405e8..ecc394b3 100644 --- a/include/picotls.h +++ b/include/picotls.h @@ -1405,8 +1405,27 @@ uint64_t ptls_decode_quicint(const uint8_t **src, const uint8_t *end); ptls_decode_assert_block_close((src), end); \ } while (0) +typedef struct st_ptls_log_getsni_t { + const char *(*cb)(void *arg); + void *arg; +} ptls_log_getsni_t; + +/** + * Creates a lazy callback object for obtaining SNI. The object is used to delay materialization of SNI to only when it is needed. + */ +#define PTLS_LOG_DEFINE_GETSNI(suffix, type, body) \ + static inline const char *ptls_log_getsni_cb_##suffix(void *_arg) \ + { \ + type arg = (type)_arg; \ + body \ + } \ + static inline ptls_log_getsni_t ptls_log_getsni_##suffix(type arg) \ + { \ + return (ptls_log_getsni_t){ptls_log_getsni_cb_##suffix, arg}; \ + } + #if PTLS_HAVE_LOG -#define PTLS_LOG__DO_LOG(module, name, conn_state, get_sni, get_sni_arg, add_time, block) \ +#define PTLS_LOG__DO_LOG(module, name, conn_state, get_sni, add_time, block) \ do { \ int ptlslog_include_appdata = 0; \ do { \ @@ -1414,12 +1433,11 @@ uint64_t ptls_decode_quicint(const uint8_t **src, const uint8_t *end); do { \ block \ } while (0); \ - ptlslog_include_appdata = \ - ptls_log__do_write_end(&logpoint, (conn_state), (get_sni), (get_sni_arg), ptlslog_include_appdata); \ + ptlslog_include_appdata = ptls_log__do_write_end(&logpoint, (conn_state), (get_sni), ptlslog_include_appdata); \ } while (PTLS_UNLIKELY(ptlslog_include_appdata)); \ } while (0) #else -#define PTLS_LOG__DO_LOG(module, name, conn_state, get_sni, get_sni_arg, add_time, block) /* don't generate code */ +#define PTLS_LOG__DO_LOG(module, name, conn_state, get_sni, add_time, block) /* don't generate code */ #endif #define PTLS_LOG_DEFINE_POINT(_module, _name, _var) \ @@ -1430,7 +1448,7 @@ uint64_t ptls_decode_quicint(const uint8_t **src, const uint8_t *end); PTLS_LOG_DEFINE_POINT(module, name, logpoint); \ if (PTLS_LIKELY(ptls_log_point_maybe_active(&logpoint) == 0)) \ break; \ - PTLS_LOG__DO_LOG(module, name, NULL, NULL, NULL, 1, {block}); \ + PTLS_LOG__DO_LOG(module, name, NULL, (ptls_log_getsni_t){NULL}, 1, {block}); \ } while (0) #define PTLS_LOG_CONN(name, tls, block) \ @@ -1441,10 +1459,10 @@ uint64_t ptls_decode_quicint(const uint8_t **src, const uint8_t *end); break; \ ptls_t *_tls = (tls); \ ptls_log_conn_state_t *conn_state = ptls_get_log_state(_tls); \ - active &= ptls_log_conn_maybe_active(conn_state, (const char *(*)(void *))ptls_get_server_name, _tls); \ + active &= ptls_log_conn_maybe_active(conn_state, ptls_log_getsni_ptls(_tls)); \ if (PTLS_LIKELY(active == 0)) \ break; \ - PTLS_LOG__DO_LOG(picotls, name, conn_state, (const char *(*)(void *))ptls_get_server_name, _tls, 1, { \ + PTLS_LOG__DO_LOG(picotls, name, conn_state, ptls_log_getsni_ptls(_tls), 1, { \ PTLS_LOG_ELEMENT_PTR(tls, _tls); \ do { \ block \ @@ -1571,7 +1589,7 @@ static uint32_t ptls_log_point_maybe_active(struct st_ptls_log_point_t *point); /** * returns a bitmap indicating the loggers active for given connection */ -static uint32_t ptls_log_conn_maybe_active(ptls_log_conn_state_t *conn, const char *(*get_sni)(void *), void *get_sni_arg); +static uint32_t ptls_log_conn_maybe_active(ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni); /** * Returns the number of log events that were unable to be emitted. @@ -1588,8 +1606,7 @@ size_t ptls_log_num_lost(void); int ptls_log_add_fd(int fd, float sample_ratio, const char *points, const char *snis, const char *addresses, int appdata); void ptls_log__recalc_point(int caller_locked, struct st_ptls_log_point_t *point); -void ptls_log__recalc_conn(int caller_locked, struct st_ptls_log_conn_state_t *conn, const char *(*get_sni)(void *), - void *get_sni_arg); +void ptls_log__recalc_conn(int caller_locked, struct st_ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni); void ptls_log__do_push_element_safestr(const char *prefix, size_t prefix_len, const char *s, size_t l); void ptls_log__do_push_element_unsafestr(const char *prefix, size_t prefix_len, const char *s, size_t l); void ptls_log__do_push_element_hexdump(const char *prefix, size_t prefix_len, const void *s, size_t l); @@ -1603,8 +1620,8 @@ void ptls_log__do_push_appdata_element_unsafestr(int includes_appdata, const cha void ptls_log__do_push_appdata_element_hexdump(int includes_appdata, const char *prefix, size_t prefix_len, const void *s, size_t l); void ptls_log__do_write_start(struct st_ptls_log_point_t *point, int add_time); -int ptls_log__do_write_end(struct st_ptls_log_point_t *point, struct st_ptls_log_conn_state_t *conn, const char *(*get_sni)(void *), - void *get_sni_arg, int includes_appdata); +int ptls_log__do_write_end(struct st_ptls_log_point_t *point, struct st_ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni, + int includes_appdata); /** * create a client object to handle new TLS connection @@ -1963,6 +1980,8 @@ extern ptls_get_time_t ptls_get_time; */ static void ptls_hash_clone_memcpy(void *dst, const void *src, size_t size); +PTLS_LOG_DEFINE_GETSNI(ptls, ptls_t *, { return ptls_get_server_name(arg); }) + /* inline functions */ inline uint32_t ptls_log_point_maybe_active(struct st_ptls_log_point_t *point) @@ -1981,11 +2000,11 @@ inline void ptls_log_recalc_conn_state(ptls_log_conn_state_t *state) state->state.generation = 0; } -inline uint32_t ptls_log_conn_maybe_active(ptls_log_conn_state_t *conn, const char *(*get_sni)(void *), void *get_sni_arg) +inline uint32_t ptls_log_conn_maybe_active(ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni) { #if PTLS_HAVE_LOG if (PTLS_UNLIKELY(conn->state.generation != ptls_log._generation)) - ptls_log__recalc_conn(0, conn, get_sni, get_sni_arg); + ptls_log__recalc_conn(0, conn, getsni); return conn->state.active_conns; #else return 0; diff --git a/lib/picotls.c b/lib/picotls.c index 36899d40..813f7d13 100644 --- a/lib/picotls.c +++ b/lib/picotls.c @@ -6928,8 +6928,7 @@ void ptls_log__recalc_point(int caller_locked, struct st_ptls_log_point_t *point pthread_mutex_unlock(&logctx.mutex); } -void ptls_log__recalc_conn(int caller_locked, struct st_ptls_log_conn_state_t *conn, const char *(*get_sni)(void *), - void *get_sni_arg) +void ptls_log__recalc_conn(int caller_locked, struct st_ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni) { if (!caller_locked) pthread_mutex_lock(&logctx.mutex); @@ -6937,7 +6936,7 @@ void ptls_log__recalc_conn(int caller_locked, struct st_ptls_log_conn_state_t *c if (conn->state.generation != ptls_log._generation) { /* update active bitmap */ uint32_t new_active = 0; - const char *sni = get_sni != NULL ? get_sni(get_sni_arg) : NULL; + const char *sni = getsni.cb != NULL ? getsni.cb(getsni.arg) : NULL; for (size_t slot = 0; slot < PTLS_ELEMENTSOF(logctx.conns); ++slot) { if (logctx.conns[slot].points != NULL && conn->random_ < logctx.conns[slot].sample_ratio && is_in_stringlist(logctx.conns[slot].snis, sni) && @@ -7090,8 +7089,8 @@ void ptls_log__do_write_start(struct st_ptls_log_point_t *point, int add_time) logbuf.buf.off = (size_t)written; } -int ptls_log__do_write_end(struct st_ptls_log_point_t *point, struct st_ptls_log_conn_state_t *conn, const char *(*get_sni)(void *), - void *get_sni_arg, int includes_appdata) +int ptls_log__do_write_end(struct st_ptls_log_point_t *point, struct st_ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni, + int includes_appdata) { if (!expand_logbuf_or_invalidate("}\n", 2, 0)) return 0; @@ -7105,7 +7104,7 @@ int ptls_log__do_write_end(struct st_ptls_log_point_t *point, struct st_ptls_log ptls_log__recalc_point(1, point); uint32_t active = point->state.active_conns; if (conn != NULL && conn->state.generation != ptls_log._generation) { - ptls_log__recalc_conn(1, conn, get_sni, get_sni_arg); + ptls_log__recalc_conn(1, conn, getsni); active &= conn->state.active_conns; }