diff --git a/rmw_implementation/src/functions.cpp b/rmw_implementation/src/functions.cpp index 6201e037b..884954ae9 100644 --- a/rmw_implementation/src/functions.cpp +++ b/rmw_implementation/src/functions.cpp @@ -136,6 +136,29 @@ get_library() return g_rmw_lib; } +std::shared_ptr +get_wrap_library() +{ + static bool tried_load = false; + static std::shared_ptr wrap_lib = nullptr; + if (!wrap_lib && !tried_load) { + tried_load = true; + std::string wrapper_var; + try { + wrapper_var = rcpputils::get_env_var("RMW_IMPLEMENTATION_WRAPPER"); + } catch (const std::exception & e) { + RMW_SET_ERROR_MSG_WITH_FORMAT_STRING( + "failed to fetch RMW_IMPLEMENTATION_WRAPPER " + "from environment due to %s", e.what()); + return nullptr; + } + if (!wrapper_var.empty()) { + wrap_lib = attempt_to_load_one_rmw(wrapper_var); + } + } + return wrap_lib; +} + void * lookup_symbol(std::shared_ptr lib, const std::string & symbol_name) { @@ -175,6 +198,25 @@ get_symbol(const char * symbol_name) } } +void * get_wrap_symbol(const char * symbol_name) +{ + try { + auto lib = get_wrap_library(); + if (!lib) { + return nullptr; + } + if (!lib->has_symbol(symbol_name)) { + return nullptr; + } + return lib->get_symbol(symbol_name); + } catch (const std::exception & e) { + RMW_SET_ERROR_MSG_WITH_FORMAT_STRING( + "failed to get wrapper symbol '%s' due to %s", + symbol_name, e.what()); + return nullptr; + } +} + #ifdef __cplusplus extern "C" { @@ -202,6 +244,27 @@ extern "C" #define ARGS_6(t6, ...) t6 v6, EXPAND(ARGS_5(__VA_ARGS__)) #define ARGS_7(t7, ...) t7 v7, EXPAND(ARGS_6(__VA_ARGS__)) +// Macros for "wrapped function" args, allow us to prepend one extra argument more than _NR value +// "tX" params are "type" for the type declaration of the argument +// "vX" params are "variable" for the variable name declaration +#define WARGS_0(t0, tvoid) t0 v0 +#define WARGS_1(t1, t0) t1 v1, t0 v0 +#define WARGS_2(t2, ...) t2 v2, EXPAND(WARGS_1(__VA_ARGS__)) +#define WARGS_3(t3, ...) t3 v3, EXPAND(WARGS_2(__VA_ARGS__)) +#define WARGS_4(t4, ...) t4 v4, EXPAND(WARGS_3(__VA_ARGS__)) +#define WARGS_5(t5, ...) t5 v5, EXPAND(WARGS_4(__VA_ARGS__)) +#define WARGS_6(t6, ...) t6 v6, EXPAND(WARGS_5(__VA_ARGS__)) +#define WARGS_7(t7, ...) t7 v7, EXPAND(WARGS_6(__VA_ARGS__)) + +#define WARG_VALUES_0(impl, ...) impl +#define WARG_VALUES_1(impl, ...) impl, EXPAND(ARG_VALUES_1(__VA_ARGS__)) +#define WARG_VALUES_2(impl, ...) impl, EXPAND(ARG_VALUES_2(__VA_ARGS__)) +#define WARG_VALUES_3(impl, ...) impl, EXPAND(ARG_VALUES_3(__VA_ARGS__)) +#define WARG_VALUES_4(impl, ...) impl, EXPAND(ARG_VALUES_4(__VA_ARGS__)) +#define WARG_VALUES_5(impl, ...) impl, EXPAND(ARG_VALUES_5(__VA_ARGS__)) +#define WARG_VALUES_6(impl, ...) impl, EXPAND(ARG_VALUES_6(__VA_ARGS__)) +#define WARG_VALUES_7(impl, ...) impl, EXPAND(ARG_VALUES_7(__VA_ARGS__)) + #define CALL_SYMBOL(symbol_name, ReturnType, error_value, ArgTypes, arg_values) \ if (!symbol_ ## symbol_name) { \ /* only necessary for functions called before rmw_init */ \ @@ -218,8 +281,14 @@ extern "C" // cppcheck-suppress preprocessorErrorDirective #define RMW_INTERFACE_FN(name, ReturnType, error_value, _NR, ...) \ void * symbol_ ## name = nullptr; \ + void * symbol_wrap_ ## name = nullptr; \ ReturnType name(EXPAND(ARGS_ ## _NR(__VA_ARGS__))) \ { \ + if (symbol_wrap_ ## name) { \ + typedef ReturnType (* WrapFunctionSignature)(EXPAND (WARGS_ ## _NR(void *, __VA_ARGS__))); \ + auto wrap_func = reinterpret_cast(symbol_wrap_ ## name); \ + return wrap_func(EXPAND(WARG_VALUES_ ## _NR(symbol_ ## name, __VA_ARGS__))); \ + } \ CALL_SYMBOL( \ name, ReturnType, error_value, ARG_TYPES(__VA_ARGS__), \ EXPAND(ARG_VALUES_ ## _NR(__VA_ARGS__))); \ @@ -789,7 +858,9 @@ RMW_INTERFACE_FN( const char *, rcutils_allocator_t *, rosidl_dynamic_typesupport_serialization_support_t *)) -#define GET_SYMBOL(x) symbol_ ## x = get_symbol(#x); +#define GET_SYMBOL(x) \ + symbol_ ## x = get_symbol(#x); \ + symbol_wrap_ ## x = get_wrap_symbol(STRINGIFY(wrap_ ## x)); void prefetch_symbols(void) { @@ -891,6 +962,7 @@ void prefetch_symbols(void) } void * symbol_rmw_init = nullptr; +void * symbol_wrap_rmw_init = nullptr; rmw_ret_t rmw_init(const rmw_init_options_t * options, rmw_context_t * context) @@ -902,10 +974,19 @@ rmw_init(const rmw_init_options_t * options, rmw_context_t * context) if (!symbol_rmw_init) { return RMW_RET_ERROR; } + if (!symbol_wrap_rmw_init) { + symbol_wrap_rmw_init = get_wrap_symbol("wrap_rmw_init"); + } + rmw_ret_t ret = RMW_RET_OK; + if (symbol_wrap_rmw_init) { + typedef rmw_ret_t (* WrapperInitSignature)(void *, const rmw_init_options_t *, rmw_context_t *); + auto wrap_init_func = reinterpret_cast(symbol_wrap_rmw_init); + ret = wrap_init_func(get_library().get(), options, context); + } typedef rmw_ret_t (* FunctionSignature)(const rmw_init_options_t *, rmw_context_t *); FunctionSignature func = reinterpret_cast(symbol_rmw_init); - return func(options, context); + return ret != RMW_RET_OK ? ret : func(options, context); } #ifdef __cplusplus