Skip to content

Commit b7f382f

Browse files
committed
fix grpc: fix client outbound middleware run order
commit_hash:c24fdd3c5dab22e80aaaabc07f7fc023615c6c71
1 parent 89295ee commit b7f382f

File tree

7 files changed

+93
-75
lines changed

7 files changed

+93
-75
lines changed

grpc/include/userver/ugrpc/client/impl/middleware_hooks.hpp

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#pragma once
22

3+
#include <variant>
4+
35
#include <google/protobuf/message.h>
46
#include <grpcpp/support/status.h>
57

@@ -14,24 +16,36 @@ namespace ugrpc::client::impl {
1416

1517
class MiddlewareHooks {
1618
public:
17-
void SetStartCall() noexcept;
18-
void SetSendMessage(const google::protobuf::Message& send_message) noexcept;
19-
void SetRecvMessage(const google::protobuf::Message& recv_message) noexcept;
20-
void SetStatus(const grpc::Status& status) noexcept;
19+
static MiddlewareHooks StartCallHooks(const google::protobuf::Message* request = nullptr) noexcept;
20+
21+
static MiddlewareHooks SendMessageHooks(const google::protobuf::Message& send_message) noexcept;
22+
23+
static MiddlewareHooks RecvMessageHooks(const google::protobuf::Message& recv_message) noexcept;
24+
25+
static MiddlewareHooks FinishHooks(const grpc::Status& status, const google::protobuf::Message* response = nullptr)
26+
noexcept;
2127

2228
void Run(const MiddlewareBase& middleware, MiddlewareCallContext& context) const;
2329

30+
bool Reverse() const noexcept;
31+
2432
private:
25-
bool start_call_{false};
26-
const google::protobuf::Message* send_message_{};
27-
const google::protobuf::Message* recv_message_{};
28-
const grpc::Status* status_{};
29-
};
33+
struct Inbound {
34+
bool start_call{false};
35+
const google::protobuf::Message* send_message{};
36+
};
37+
38+
struct Outbound {
39+
const grpc::Status* status{};
40+
const google::protobuf::Message* recv_message{};
41+
};
42+
43+
using Params = std::variant<Inbound, Outbound>;
3044

31-
MiddlewareHooks StartCallHooks(const google::protobuf::Message* request = nullptr) noexcept;
32-
MiddlewareHooks SendMessageHooks(const google::protobuf::Message& send_message) noexcept;
33-
MiddlewareHooks RecvMessageHooks(const google::protobuf::Message& recv_message) noexcept;
34-
MiddlewareHooks FinishHooks(const grpc::Status& status, const google::protobuf::Message* response = nullptr) noexcept;
45+
explicit MiddlewareHooks(Params&& params) noexcept;
46+
47+
Params params_;
48+
};
3549

3650
} // namespace ugrpc::client::impl
3751

grpc/include/userver/ugrpc/client/impl/rpc.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ InputStream<Response>::InputStream(
257257
: state_{std::move(params), CallKind::kInputStream},
258258
context_{utils::impl::InternalTag{}, state_}
259259
{
260-
RunMiddlewarePipeline(state_, StartCallHooks(ToBaseMessage(&request)));
260+
RunMiddlewarePipeline(state_, MiddlewareHooks::StartCallHooks(ToBaseMessage(&request)));
261261

262262
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
263263
stream_ = impl::PrepareCall(
@@ -286,7 +286,7 @@ bool InputStream<Response>::Read(Response& response) {
286286
}
287287

288288
if (impl::Read(*stream_, response, state_)) {
289-
RunMiddlewarePipeline(state_, RecvMessageHooks(response));
289+
RunMiddlewarePipeline(state_, MiddlewareHooks::RecvMessageHooks(response));
290290
return true;
291291
} else {
292292
// Finish can only be called once all the data is read, otherwise the
@@ -305,7 +305,7 @@ OutputStream<Request, Response>::OutputStream(
305305
: state_{std::move(params), CallKind::kOutputStream},
306306
context_{utils::impl::InternalTag{}, state_}
307307
{
308-
RunMiddlewarePipeline(state_, StartCallHooks());
308+
RunMiddlewarePipeline(state_, MiddlewareHooks::StartCallHooks());
309309

310310
// 'response_' will be filled upon successful 'Finish' async call
311311
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
@@ -332,7 +332,7 @@ bool OutputStream<Request, Response>::Write(const Request& request) {
332332
return false;
333333
}
334334

335-
RunMiddlewarePipeline(state_, SendMessageHooks(request));
335+
RunMiddlewarePipeline(state_, MiddlewareHooks::SendMessageHooks(request));
336336

337337
// Don't buffer writes, otherwise in an event subscription scenario, events
338338
// may never actually be delivered
@@ -348,7 +348,7 @@ void OutputStream<Request, Response>::WriteAndCheck(const Request& request) {
348348
throw RpcError(state_.GetCallName(), "'WriteAndCheck' called on a finished or closed stream");
349349
}
350350

351-
RunMiddlewarePipeline(state_, SendMessageHooks(request));
351+
RunMiddlewarePipeline(state_, MiddlewareHooks::SendMessageHooks(request));
352352

353353
// Don't buffer writes, otherwise in an event subscription scenario, events
354354
// may never actually be delivered
@@ -381,7 +381,7 @@ BidirectionalStream<Request, Response>::BidirectionalStream(
381381
: state_{std::move(params), CallKind::kBidirectionalStream},
382382
context_{utils::impl::InternalTag{}, state_}
383383
{
384-
RunMiddlewarePipeline(state_, StartCallHooks());
384+
RunMiddlewarePipeline(state_, MiddlewareHooks::StartCallHooks());
385385

386386
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
387387
stream_ = impl::PrepareCall(prepare_async_method, state_.GetStub(), &state_.GetClientContext(), &state_.GetQueue());
@@ -427,7 +427,7 @@ bool BidirectionalStream<Request, Response>::Write(const Request& request) {
427427

428428
{
429429
const auto lock = state_.TakeMutexIfBidirectional();
430-
RunMiddlewarePipeline(state_, SendMessageHooks(request));
430+
RunMiddlewarePipeline(state_, MiddlewareHooks::SendMessageHooks(request));
431431
}
432432

433433
// Don't buffer writes, optimize for ping-pong-style interaction
@@ -445,7 +445,7 @@ void BidirectionalStream<Request, Response>::WriteAndCheck(const Request& reques
445445

446446
{
447447
const auto lock = state_.TakeMutexIfBidirectional();
448-
RunMiddlewarePipeline(state_, SendMessageHooks(request));
448+
RunMiddlewarePipeline(state_, MiddlewareHooks::SendMessageHooks(request));
449449
}
450450

451451
// Don't buffer writes, optimize for ping-pong-style interaction

grpc/include/userver/ugrpc/client/impl/unary_call.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,15 @@ class UnaryCall final {
163163
return AttemptCompletionStatus::kOk;
164164
}
165165

166-
void RunStartCallHooks() { impl::RunMiddlewarePipeline(state_, StartCallHooks(ToBaseMessage(&request_))); }
166+
void RunStartCallHooks() {
167+
impl::RunMiddlewarePipeline(state_, MiddlewareHooks::StartCallHooks(ToBaseMessage(&request_)));
168+
}
167169

168170
void RunFinishHooks(const grpc::Status& status) {
169-
impl::RunMiddlewarePipeline(state_, FinishHooks(status, ToBaseMessage(&response_)));
171+
impl::RunMiddlewarePipeline(
172+
state_,
173+
MiddlewareHooks::FinishHooks(status, status.ok() ? ToBaseMessage(&response_) : nullptr)
174+
);
170175
}
171176

172177
void OnDone(const grpc::Status& status) {

grpc/include/userver/ugrpc/client/stream_read_future.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ bool StreamReadFuture<RawStream>::Get() {
107107
impl::Finish(*stream_, *state, /*final_response=*/nullptr, /*throw_on_error=*/true);
108108
} else {
109109
if (recv_message_) {
110-
RunMiddlewarePipeline(*state, impl::RecvMessageHooks(*recv_message_));
110+
RunMiddlewarePipeline(*state, impl::MiddlewareHooks::RecvMessageHooks(*recv_message_));
111111
}
112112
}
113113
return result == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kOk;

grpc/src/ugrpc/client/impl/async_stream_methods.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ void ProcessFinish(CallState& state, const google::protobuf::Message* final_resp
7171

7272
HandleCallStatistics(state, status);
7373

74-
RunMiddlewarePipeline(state, FinishHooks(status, final_response));
74+
RunMiddlewarePipeline(state, MiddlewareHooks::FinishHooks(status, status.ok() ? final_response : nullptr));
7575

7676
SetStatusAndResetSpan(state, status);
7777
}

grpc/src/ugrpc/client/impl/middleware_hooks.cpp

Lines changed: 39 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,61 @@
11
#include <userver/ugrpc/client/impl/middleware_hooks.hpp>
22

3+
#include <userver/utils/assert.hpp>
4+
#include <userver/utils/overloaded.hpp>
5+
36
#include <userver/ugrpc/client/middlewares/base.hpp>
47

58
USERVER_NAMESPACE_BEGIN
69

710
namespace ugrpc::client::impl {
811

9-
void MiddlewareHooks::SetStartCall() noexcept { start_call_ = true; }
10-
11-
void MiddlewareHooks::SetSendMessage(const google::protobuf::Message& send_message) noexcept {
12-
send_message_ = &send_message;
12+
MiddlewareHooks MiddlewareHooks::StartCallHooks(const google::protobuf::Message* request) noexcept {
13+
return MiddlewareHooks{Inbound{true, request}};
1314
}
1415

15-
void MiddlewareHooks::SetRecvMessage(const google::protobuf::Message& recv_message) noexcept {
16-
recv_message_ = &recv_message;
16+
MiddlewareHooks MiddlewareHooks::SendMessageHooks(const google::protobuf::Message& send_message) noexcept {
17+
return MiddlewareHooks{Inbound{false, &send_message}};
1718
}
1819

19-
void MiddlewareHooks::SetStatus(const grpc::Status& status) noexcept { status_ = &status; }
20-
21-
void MiddlewareHooks::Run(const MiddlewareBase& middleware, MiddlewareCallContext& context) const {
22-
if (start_call_) {
23-
middleware.PreStartCall(context);
24-
}
25-
26-
if (send_message_) {
27-
middleware.PreSendMessage(context, *send_message_);
28-
}
29-
30-
if (recv_message_) {
31-
middleware.PostRecvMessage(context, *recv_message_);
32-
}
33-
34-
if (status_) {
35-
middleware.PostFinish(context, *status_);
36-
}
20+
MiddlewareHooks MiddlewareHooks::FinishHooks(const grpc::Status& status, const google::protobuf::Message* response)
21+
noexcept {
22+
UASSERT(status.ok() || nullptr == response);
23+
return MiddlewareHooks{Outbound{&status, response}};
3724
}
3825

39-
MiddlewareHooks StartCallHooks(const google::protobuf::Message* request) noexcept {
40-
MiddlewareHooks hooks;
41-
hooks.SetStartCall();
42-
if (request) {
43-
hooks.SetSendMessage(*request);
44-
}
45-
return hooks;
26+
MiddlewareHooks MiddlewareHooks::RecvMessageHooks(const google::protobuf::Message& recv_message) noexcept {
27+
return MiddlewareHooks{Outbound{nullptr, &recv_message}};
4628
}
4729

48-
MiddlewareHooks SendMessageHooks(const google::protobuf::Message& send_message) noexcept {
49-
MiddlewareHooks hooks;
50-
hooks.SetSendMessage(send_message);
51-
return hooks;
30+
void MiddlewareHooks::Run(const MiddlewareBase& middleware, MiddlewareCallContext& context) const {
31+
std::visit(
32+
utils::Overloaded{
33+
[&middleware, &context](Inbound params) {
34+
if (params.start_call) {
35+
middleware.PreStartCall(context);
36+
}
37+
if (params.send_message) {
38+
middleware.PreSendMessage(context, *params.send_message);
39+
}
40+
},
41+
[&middleware, &context](Outbound params) {
42+
if (params.recv_message) {
43+
middleware.PostRecvMessage(context, *params.recv_message);
44+
}
45+
if (params.status) {
46+
middleware.PostFinish(context, *params.status);
47+
}
48+
},
49+
},
50+
params_
51+
);
5252
}
5353

54-
MiddlewareHooks RecvMessageHooks(const google::protobuf::Message& recv_message) noexcept {
55-
MiddlewareHooks hooks;
56-
hooks.SetRecvMessage(recv_message);
57-
return hooks;
58-
}
54+
bool MiddlewareHooks::Reverse() const noexcept { return std::holds_alternative<Outbound>(params_); }
5955

60-
MiddlewareHooks FinishHooks(const grpc::Status& status, const google::protobuf::Message* response) noexcept {
61-
MiddlewareHooks hooks;
62-
if (status.ok() && response) {
63-
hooks.SetRecvMessage(*response);
64-
}
65-
hooks.SetStatus(status);
66-
return hooks;
67-
}
56+
MiddlewareHooks::MiddlewareHooks(Params&& params) noexcept
57+
: params_{std::move(params)}
58+
{}
6859

6960
} // namespace ugrpc::client::impl
7061

grpc/src/ugrpc/client/impl/middleware_pipeline.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <userver/ugrpc/client/impl/middleware_pipeline.hpp>
22

3+
#include <boost/range/adaptor/reversed.hpp>
4+
35
#include <userver/logging/log.hpp>
46

57
USERVER_NAMESPACE_BEGIN
@@ -8,8 +10,14 @@ namespace ugrpc::client::impl {
810

911
void MiddlewarePipeline::Run(const MiddlewareHooks& hooks, MiddlewareCallContext& context) const {
1012
try {
11-
for (const auto& m : middlewares_) {
12-
hooks.Run(*m, context);
13+
if (!hooks.Reverse()) {
14+
for (const auto& m : middlewares_) {
15+
hooks.Run(*m, context);
16+
}
17+
} else {
18+
for (const auto& m : boost::adaptors::reverse(middlewares_)) {
19+
hooks.Run(*m, context);
20+
}
1321
}
1422
} catch (const std::exception& ex) {
1523
LOG_WARNING() << "Run middlewares failed: " << ex;

0 commit comments

Comments
 (0)