Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions include/picotls.h
Original file line number Diff line number Diff line change
Expand Up @@ -1405,21 +1405,39 @@ uint64_t ptls_decode_quicint(const uint8_t **src, const uint8_t *end);
ptls_decode_assert_block_close((src), end); \
} while (0)

typedef struct st_ptls_log_getsni_t {
const char *(*cb)(void *arg);
void *arg;
} ptls_log_getsni_t;

/**
* Creates a lazy callback object for obtaining SNI. The object is used to delay materialization of SNI to only when it is needed.
*/
#define PTLS_LOG_DEFINE_GETSNI(suffix, type, body) \
static inline const char *ptls_log_getsni_cb_##suffix(void *_arg) \
{ \
type arg = (type)_arg; \
body \
} \
static inline ptls_log_getsni_t ptls_log_getsni_##suffix(type arg) \
{ \
return (ptls_log_getsni_t){ptls_log_getsni_cb_##suffix, arg}; \
}

#if PTLS_HAVE_LOG
#define PTLS_LOG__DO_LOG(module, name, conn_state, get_sni, get_sni_arg, add_time, block) \
#define PTLS_LOG__DO_LOG(module, name, conn_state, get_sni, add_time, block) \
do { \
int ptlslog_include_appdata = 0; \
do { \
ptls_log__do_write_start(&logpoint, (add_time)); \
do { \
block \
} while (0); \
ptlslog_include_appdata = \
ptls_log__do_write_end(&logpoint, (conn_state), (get_sni), (get_sni_arg), ptlslog_include_appdata); \
ptlslog_include_appdata = ptls_log__do_write_end(&logpoint, (conn_state), (get_sni), ptlslog_include_appdata); \
} while (PTLS_UNLIKELY(ptlslog_include_appdata)); \
} while (0)
#else
#define PTLS_LOG__DO_LOG(module, name, conn_state, get_sni, get_sni_arg, add_time, block) /* don't generate code */
#define PTLS_LOG__DO_LOG(module, name, conn_state, get_sni, add_time, block) /* don't generate code */
#endif

#define PTLS_LOG_DEFINE_POINT(_module, _name, _var) \
Expand All @@ -1430,7 +1448,7 @@ uint64_t ptls_decode_quicint(const uint8_t **src, const uint8_t *end);
PTLS_LOG_DEFINE_POINT(module, name, logpoint); \
if (PTLS_LIKELY(ptls_log_point_maybe_active(&logpoint) == 0)) \
break; \
PTLS_LOG__DO_LOG(module, name, NULL, NULL, NULL, 1, {block}); \
PTLS_LOG__DO_LOG(module, name, NULL, (ptls_log_getsni_t){NULL}, 1, {block}); \
} while (0)

#define PTLS_LOG_CONN(name, tls, block) \
Expand All @@ -1441,10 +1459,10 @@ uint64_t ptls_decode_quicint(const uint8_t **src, const uint8_t *end);
break; \
ptls_t *_tls = (tls); \
ptls_log_conn_state_t *conn_state = ptls_get_log_state(_tls); \
active &= ptls_log_conn_maybe_active(conn_state, (const char *(*)(void *))ptls_get_server_name, _tls); \
active &= ptls_log_conn_maybe_active(conn_state, ptls_log_getsni_ptls(_tls)); \
if (PTLS_LIKELY(active == 0)) \
break; \
PTLS_LOG__DO_LOG(picotls, name, conn_state, (const char *(*)(void *))ptls_get_server_name, _tls, 1, { \
PTLS_LOG__DO_LOG(picotls, name, conn_state, ptls_log_getsni_ptls(_tls), 1, { \
PTLS_LOG_ELEMENT_PTR(tls, _tls); \
do { \
block \
Expand Down Expand Up @@ -1571,7 +1589,7 @@ static uint32_t ptls_log_point_maybe_active(struct st_ptls_log_point_t *point);
/**
* returns a bitmap indicating the loggers active for given connection
*/
static uint32_t ptls_log_conn_maybe_active(ptls_log_conn_state_t *conn, const char *(*get_sni)(void *), void *get_sni_arg);
static uint32_t ptls_log_conn_maybe_active(ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni);

/**
* Returns the number of log events that were unable to be emitted.
Expand All @@ -1588,8 +1606,7 @@ size_t ptls_log_num_lost(void);
int ptls_log_add_fd(int fd, float sample_ratio, const char *points, const char *snis, const char *addresses, int appdata);

void ptls_log__recalc_point(int caller_locked, struct st_ptls_log_point_t *point);
void ptls_log__recalc_conn(int caller_locked, struct st_ptls_log_conn_state_t *conn, const char *(*get_sni)(void *),
void *get_sni_arg);
void ptls_log__recalc_conn(int caller_locked, struct st_ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni);
void ptls_log__do_push_element_safestr(const char *prefix, size_t prefix_len, const char *s, size_t l);
void ptls_log__do_push_element_unsafestr(const char *prefix, size_t prefix_len, const char *s, size_t l);
void ptls_log__do_push_element_hexdump(const char *prefix, size_t prefix_len, const void *s, size_t l);
Expand All @@ -1603,8 +1620,8 @@ void ptls_log__do_push_appdata_element_unsafestr(int includes_appdata, const cha
void ptls_log__do_push_appdata_element_hexdump(int includes_appdata, const char *prefix, size_t prefix_len, const void *s,
size_t l);
void ptls_log__do_write_start(struct st_ptls_log_point_t *point, int add_time);
int ptls_log__do_write_end(struct st_ptls_log_point_t *point, struct st_ptls_log_conn_state_t *conn, const char *(*get_sni)(void *),
void *get_sni_arg, int includes_appdata);
int ptls_log__do_write_end(struct st_ptls_log_point_t *point, struct st_ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni,
int includes_appdata);

/**
* create a client object to handle new TLS connection
Expand Down Expand Up @@ -1963,6 +1980,8 @@ extern ptls_get_time_t ptls_get_time;
*/
static void ptls_hash_clone_memcpy(void *dst, const void *src, size_t size);

PTLS_LOG_DEFINE_GETSNI(ptls, ptls_t *, { return ptls_get_server_name(arg); })

/* inline functions */

inline uint32_t ptls_log_point_maybe_active(struct st_ptls_log_point_t *point)
Expand All @@ -1981,11 +2000,11 @@ inline void ptls_log_recalc_conn_state(ptls_log_conn_state_t *state)
state->state.generation = 0;
}

inline uint32_t ptls_log_conn_maybe_active(ptls_log_conn_state_t *conn, const char *(*get_sni)(void *), void *get_sni_arg)
inline uint32_t ptls_log_conn_maybe_active(ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni)
{
#if PTLS_HAVE_LOG
if (PTLS_UNLIKELY(conn->state.generation != ptls_log._generation))
ptls_log__recalc_conn(0, conn, get_sni, get_sni_arg);
ptls_log__recalc_conn(0, conn, getsni);
return conn->state.active_conns;
#else
return 0;
Expand Down
11 changes: 5 additions & 6 deletions lib/picotls.c
Original file line number Diff line number Diff line change
Expand Up @@ -6928,16 +6928,15 @@ void ptls_log__recalc_point(int caller_locked, struct st_ptls_log_point_t *point
pthread_mutex_unlock(&logctx.mutex);
}

void ptls_log__recalc_conn(int caller_locked, struct st_ptls_log_conn_state_t *conn, const char *(*get_sni)(void *),
void *get_sni_arg)
void ptls_log__recalc_conn(int caller_locked, struct st_ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni)
{
if (!caller_locked)
pthread_mutex_lock(&logctx.mutex);

if (conn->state.generation != ptls_log._generation) {
/* update active bitmap */
uint32_t new_active = 0;
const char *sni = get_sni != NULL ? get_sni(get_sni_arg) : NULL;
const char *sni = getsni.cb != NULL ? getsni.cb(getsni.arg) : NULL;
for (size_t slot = 0; slot < PTLS_ELEMENTSOF(logctx.conns); ++slot) {
if (logctx.conns[slot].points != NULL && conn->random_ < logctx.conns[slot].sample_ratio &&
is_in_stringlist(logctx.conns[slot].snis, sni) &&
Expand Down Expand Up @@ -7090,8 +7089,8 @@ void ptls_log__do_write_start(struct st_ptls_log_point_t *point, int add_time)
logbuf.buf.off = (size_t)written;
}

int ptls_log__do_write_end(struct st_ptls_log_point_t *point, struct st_ptls_log_conn_state_t *conn, const char *(*get_sni)(void *),
void *get_sni_arg, int includes_appdata)
int ptls_log__do_write_end(struct st_ptls_log_point_t *point, struct st_ptls_log_conn_state_t *conn, ptls_log_getsni_t getsni,
int includes_appdata)
{
if (!expand_logbuf_or_invalidate("}\n", 2, 0))
return 0;
Expand All @@ -7105,7 +7104,7 @@ int ptls_log__do_write_end(struct st_ptls_log_point_t *point, struct st_ptls_log
ptls_log__recalc_point(1, point);
uint32_t active = point->state.active_conns;
if (conn != NULL && conn->state.generation != ptls_log._generation) {
ptls_log__recalc_conn(1, conn, get_sni, get_sni_arg);
ptls_log__recalc_conn(1, conn, getsni);
active &= conn->state.active_conns;
}

Expand Down
Loading