Skip to content

Commit 6fa227f

Browse files
committed
fix grpc: fix server outbound middleware run order
commit_hash:0b0bbaa4088d098e56acf04b1e9bd77b895e3fdf
1 parent f25ba49 commit 6fa227f

File tree

3 files changed

+33
-46
lines changed

3 files changed

+33
-46
lines changed

grpc/include/userver/ugrpc/server/impl/call_processor.hpp

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class CallProcessor final {
8484
)
8585
: state_(std::move(params), CallTraits::kCallKind),
8686
responder_(state_, raw_responder),
87-
context_(utils::impl::InternalTag{}, state_),
87+
middleware_call_context_(utils::impl::InternalTag{}, state_),
8888
initial_request_(initial_request),
8989
service_(service),
9090
service_method_(service_method)
@@ -105,15 +105,15 @@ class CallProcessor final {
105105
// Don't keep the config snapshot for too long, especially for streaming RPCs.
106106
state_.config_snapshot.reset();
107107

108+
// Final response is the response sent to the client together with status in the final batch.
109+
std::optional<Response> final_response;
110+
108111
if (!Status().ok()) {
109-
RunOnCallFinish();
112+
RunOnCallFinish(final_response);
110113
impl::ReportFinish(responder_.FinishWithError(Status()), Status(), state_);
111114
return;
112115
}
113116

114-
// Final response is the response sent to the client together with status in the final batch.
115-
std::optional<Response> final_response{};
116-
117117
RunWithCatch([this, &final_response] {
118118
auto result = CallHandler();
119119
impl::UnpackResult(std::move(result), final_response, Status());
@@ -128,15 +128,12 @@ class CallProcessor final {
128128
}
129129

130130
if (!Status().ok()) {
131-
RunOnCallFinish();
131+
RunOnCallFinish(final_response);
132132
impl::ReportFinish(responder_.FinishWithError(Status()), Status(), state_);
133133
return;
134134
}
135135

136-
if (final_response) {
137-
RunPreSendMessage(*final_response);
138-
}
139-
RunOnCallFinish();
136+
RunOnCallFinish(final_response);
140137

141138
if (!Status().ok()) {
142139
impl::ReportFinish(responder_.FinishWithError(Status()), Status(), state_);
@@ -171,44 +168,38 @@ class CallProcessor final {
171168
void RunOnCallStart() {
172169
UASSERT(success_pre_hooks_count_ == 0);
173170
for (const auto& m : state_.middlewares) {
174-
RunWithCatch([this, &m] { m->OnCallStart(context_); });
171+
RunWithCatch([this, &m] { m->OnCallStart(middleware_call_context_); });
175172
if (!Status().ok()) {
176173
return;
177174
}
178175
// On fail, we must call OnRpcFinish only for middlewares for which OnRpcStart has been called successfully.
179176
// So, we watch to count of these middlewares.
180177
++success_pre_hooks_count_;
181178
if constexpr (std::is_base_of_v<google::protobuf::Message, InitialRequest>) {
182-
RunWithCatch([this, &m] { m->PostRecvMessage(context_, initial_request_); });
179+
RunWithCatch([this, &m] { m->PostRecvMessage(middleware_call_context_, initial_request_); });
183180
if (!Status().ok()) {
184181
return;
185182
}
186183
}
187184
}
188185
}
189186

190-
void RunOnCallFinish() {
187+
void RunOnCallFinish(std::optional<Response>& final_response) {
191188
const auto& mids = state_.middlewares;
192189
const auto rbegin = mids.rbegin() + (mids.size() - success_pre_hooks_count_);
193190
for (auto it = rbegin; it != mids.rend(); ++it) {
194191
const auto& middleware = *it;
195-
// We must call all OnRpcFinish despite the failures. So, don't check the status.
196-
RunWithCatch([this, &middleware] { middleware->OnCallFinish(context_, Status()); });
197-
}
198-
}
199192

200-
void RunPreSendMessage(Response& response) {
201-
if constexpr (std::is_base_of_v<google::protobuf::Message, Response>) {
202-
const auto& mids = state_.middlewares;
203-
// We don't want to include a heavy boost header for reverse view.
204-
// NOLINTNEXTLINE(modernize-loop-convert)
205-
for (auto it = mids.rbegin(); it != mids.rend(); ++it) {
206-
const auto& middleware = *it;
207-
RunWithCatch([this, &response, &middleware] { middleware->PreSendMessage(context_, response); });
208-
if (!Status().ok()) {
209-
return;
193+
if constexpr (std::is_base_of_v<google::protobuf::Message, Response>) {
194+
if (Status().ok() && final_response.has_value()) {
195+
RunWithCatch([this, &middleware, &final_response] {
196+
middleware->PreSendMessage(middleware_call_context_, *final_response);
197+
});
210198
}
211199
}
200+
201+
// We must call all OnRpcFinish despite the failures. So, don't check the status.
202+
RunWithCatch([this, &middleware] { middleware->OnCallFinish(middleware_call_context_, Status()); });
212203
}
213204
}
214205

@@ -230,11 +221,11 @@ class CallProcessor final {
230221
}
231222
}
232223

233-
grpc::Status& Status() { return context_.GetStatus(utils::impl::InternalTag{}); }
224+
grpc::Status& Status() { return middleware_call_context_.GetStatus(utils::impl::InternalTag{}); }
234225

235226
CallState state_;
236227
Responder responder_;
237-
MiddlewareCallContext context_;
228+
MiddlewareCallContext middleware_call_context_;
238229
// Initial request is the request which is sent to the service together with RPC initiation.
239230
// Unary-request RPCs have an initial request, client-streaming RPCs don't.
240231
InitialRequest& initial_request_;

grpc/src/ugrpc/server/middlewares/base.cpp

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,6 @@ USERVER_NAMESPACE_BEGIN
1111

1212
namespace ugrpc::server {
1313

14-
MiddlewareBase::MiddlewareBase() = default;
15-
16-
MiddlewareBase::~MiddlewareBase() = default;
17-
18-
void MiddlewareBase::OnCallStart(MiddlewareCallContext&) const {}
19-
20-
void MiddlewareBase::OnCallFinish(MiddlewareCallContext& context, const grpc::Status& status) const {
21-
if (!status.ok()) {
22-
return context.SetError(grpc::Status{status});
23-
}
24-
}
25-
26-
void MiddlewareBase::PostRecvMessage(MiddlewareCallContext&, google::protobuf::Message&) const {}
27-
28-
void MiddlewareBase::PreSendMessage(MiddlewareCallContext&, google::protobuf::Message&) const {}
29-
3014
MiddlewareCallContext::MiddlewareCallContext(utils::impl::InternalTag, impl::CallState& state)
3115
: CallContextBase(utils::impl::InternalTag{}, state)
3216
{}
@@ -56,6 +40,18 @@ ugrpc::impl::RpcStatisticsScope& MiddlewareCallContext::GetStatistics(utils::imp
5640
return GetCallState(utils::impl::InternalTag{}).statistics_scope;
5741
}
5842

43+
MiddlewareBase::MiddlewareBase() = default;
44+
45+
MiddlewareBase::~MiddlewareBase() = default;
46+
47+
void MiddlewareBase::OnCallStart(MiddlewareCallContext&) const {}
48+
49+
void MiddlewareBase::PostRecvMessage(MiddlewareCallContext&, google::protobuf::Message&) const {}
50+
51+
void MiddlewareBase::PreSendMessage(MiddlewareCallContext&, google::protobuf::Message&) const {}
52+
53+
void MiddlewareBase::OnCallFinish(MiddlewareCallContext&, const grpc::Status&) const {}
54+
5955
} // namespace ugrpc::server
6056

6157
USERVER_NAMESPACE_END

grpc/tests/server_middleware_hooks_unary_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ UTEST_P(ServerMiddlewareHooksUnaryTest, ApplyTheLastErrorStatus) {
229229

230230
EXPECT_CALL(Middleware(1), OnCallStart).Times(1);
231231
EXPECT_CALL(Middleware(1), PostRecvMessage).Times(1);
232-
EXPECT_CALL(Middleware(1), PreSendMessage).Times(1);
232+
EXPECT_CALL(Middleware(1), PreSendMessage).Times(0);
233233
// OnCallStart of M1 is successfully => OnCallFinish must be called.
234234
EXPECT_CALL(Middleware(1), OnCallFinish).Times(1);
235235

0 commit comments

Comments
 (0)