Skip to content

Commit ea4976f

Browse files
committed
test: Add more comprehensive DisconnectTest
Add more comprehensive disconnect test to test disconnect handling at more points during server IPC execution to catch more bugs and catch them more reliably in a way that does not rely on thread execution order.
1 parent fe1cd8c commit ea4976f

File tree

7 files changed

+225
-95
lines changed

7 files changed

+225
-95
lines changed

include/mp/proxy-io.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <capnp/rpc-twoparty.h>
1414

15+
#include <any>
1516
#include <assert.h>
1617
#include <condition_variable>
1718
#include <functional>
@@ -270,6 +271,9 @@ class EventLoop
270271

271272
//! External context pointer.
272273
void* m_context;
274+
275+
//! External callback for hooks / tests.
276+
std::function<void(std::any)> m_signal;
273277
};
274278

275279
//! Single element task queue used to handle recursive capnp calls. (If server

include/mp/type-context.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
#include <mp/util.h>
1010

1111
namespace mp {
12+
struct SignalCallStart{};
13+
struct SignalCallSetup{};
14+
struct SignalCallTeardown{};
15+
struct SignalCallEnd{};
16+
1217
template <typename Output>
1318
void CustomBuildField(TypeList<>,
1419
Priority<1>,
@@ -70,6 +75,7 @@ auto PassField(Priority<1>, TypeList<>, ServerContext& server_context, const Fn&
7075
Context::Reader context_arg = Accessor::get(params);
7176
ServerContext server_context{server, call_context, req};
7277
{
78+
if (server.m_context.loop->m_signal) server.m_context.loop->m_signal(SignalCallStart{});
7379
// Before invoking the function, store a reference to the
7480
// callbackThread provided by the client in the
7581
// thread_local.request_threads map. This way, if this
@@ -102,6 +108,7 @@ auto PassField(Priority<1>, TypeList<>, ServerContext& server_context, const Fn&
102108
const bool erase_thread{inserted};
103109
KJ_DEFER(if (erase_thread) {
104110
std::unique_lock<std::mutex> lock(thread_context.waiter->m_mutex);
111+
if (server.m_context.loop->m_signal) server.m_context.loop->m_signal(SignalCallTeardown{});
105112
// Call erase here with a Connection* argument instead
106113
// of an iterator argument, because the `request_thread`
107114
// iterator may be invalid if the connection is closed
@@ -111,8 +118,11 @@ auto PassField(Priority<1>, TypeList<>, ServerContext& server_context, const Fn&
111118
// erases the thread from the map, and also because the
112119
// ProxyServer<Thread> destructor calls
113120
// request_threads.clear().
121+
if (server.m_context.loop->m_signal) server.m_context.loop->m_signal(SignalCallTeardown{});
114122
request_threads.erase(server.m_context.connection);
115-
});
123+
}
124+
if (server.m_context.loop->m_signal) server.m_context.loop->m_signal(SignalCallEnd{});
125+
);
116126
fn.invoke(server_context, args...);
117127
}
118128
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {

src/mp/proxy.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <mp/proxy-io.h>
88
#include <mp/proxy-types.h>
99
#include <mp/proxy.capnp.h>
10+
#include <mp/type-context.h>
1011
#include <mp/type-threadmap.h>
1112
#include <mp/util.h>
1213

@@ -308,6 +309,7 @@ bool EventLoop::done() const
308309
std::tuple<ConnThread, bool> SetThread(ConnThreads& threads, std::mutex& mutex, Connection* connection, const std::function<Thread::Client()>& make_thread)
309310
{
310311
const std::unique_lock<std::mutex> lock(mutex);
312+
if (connection->m_loop->m_signal) connection->m_loop->m_signal(SignalCallSetup{});
311313
auto thread = threads.find(connection);
312314
if (thread != threads.end()) return {thread, false};
313315
thread = threads.emplace(

test/mp/test/foo-types.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,41 @@ inline void CustomPassMessage(InvokeContext& invoke_context,
9494
fn(mut);
9595
builder.setMessage(mut.message + " return");
9696
}
97+
98+
//! CustomBuildField for TestArg parameter. TestArg doesn't do anything special
99+
//! on the client side, so this does nothing. TestArg server-side behavior is
100+
//! implemented in CustomPassField below.
101+
template <typename Output>
102+
requires (std::is_same_v<decltype(std::declval<Output>().get()), test::messages::TestArg::Builder>)
103+
void CustomBuildField(TypeList<>,
104+
Priority<1>,
105+
ClientInvokeContext& invoke_context,
106+
Output&& output)
107+
{
108+
}
109+
110+
//! CustomPassField processing TestArg parameter by calling a start_hook()
111+
//! function which returns a bool promise, and continuing to execute the IPC if
112+
//! the promise value is true, aborting if it is false. It also calls an
113+
//! end_hook() function after the IPC call finishes, if it wasn't aborted.
114+
template <typename Accessor, typename ServerContext, typename Fn, typename... Args>
115+
requires (std::is_same_v<decltype(Accessor::get(std::declval<ServerContext>().call_context.getParams())), test::messages::TestArg::Reader>)
116+
auto CustomPassField(TypeList<>, ServerContext& server_context, const Fn& fn, Args&&... args)
117+
{
118+
const auto& start_hook = server_context.proxy_server.m_impl->m_start_hook;
119+
return start_hook().then([old=server_context, call_context=kj::mv(server_context.call_context), fn, args...](bool invoke_fn) mutable {
120+
// If start hook returns false, skip calling IPC function.
121+
if (!invoke_fn) return kj::Promise<typename ServerContext::CallContext>(kj::mv(call_context));
122+
// If start hook returns true, continue calling IPC function and
123+
// processing parameters/return values and calling end hook.
124+
ServerContext server_context{old.proxy_server, call_context, old.req};
125+
return fn.invoke(server_context, args...).then([old=server_context](ServerContext::CallContext call_context) {
126+
ServerContext server_context{old.proxy_server, call_context, old.req};
127+
const auto& end_hook = server_context.proxy_server.m_impl->m_end_hook;
128+
return end_hook().then([call_context = kj::mv(call_context)] { return call_context; });
129+
});
130+
});
131+
}
97132
} // namespace mp
98133

99134
#endif // MP_TEST_FOO_TYPES_H

test/mp/test/foo.capnp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ interface FooInterface $Proxy.wrap("mp::test::FooImplementation") {
3030
passEnum @15 (arg :Int32) -> (result :Int32);
3131
passFn @16 (context :Proxy.Context, fn :FooFn) -> (result :Int32);
3232
callFn @17 () -> ();
33-
callFnAsync @18 (context :Proxy.Context) -> ();
33+
callFnAsync @18 (testArg :TestArg, context :Proxy.Context) -> ();
3434
}
3535

3636
interface FooCallback $Proxy.wrap("mp::test::FooCallback") {
@@ -73,3 +73,7 @@ struct Pair(T1, T2) {
7373
first @0 :T1;
7474
second @1 :T2;
7575
}
76+
77+
# Special argument used in disconnect test to implement ASYNC_START / ASYNC_END hooks.
78+
struct TestArg $Proxy.count(0) {
79+
}

test/mp/test/foo.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ class FooImplementation
8282
void callFn() { assert(m_fn); m_fn(); }
8383
void callFnAsync() { assert(m_fn); m_fn(); }
8484
std::function<void()> m_fn;
85+
//! Hooks used by disconnect test ASYNC_START and ASYNC_END disconnects.
86+
//! Former returns false to able IPC call without executing it.
87+
std::function<kj::Promise<bool>()> m_start_hook;
88+
std::function<kj::Promise<void>()> m_end_hook;
8589
};
8690

8791
} // namespace test

0 commit comments

Comments
 (0)