Skip to content

Commit 977cd9a

Browse files
committed
refactor grpc: refactor ugrpc::impl::AsyncMethodInvocation
commit_hash:3c4d7444deb7d7fc547142eedf37a42f0c0719f4
1 parent 6d78710 commit 977cd9a

File tree

10 files changed

+76
-86
lines changed

10 files changed

+76
-86
lines changed

grpc/include/userver/ugrpc/client/impl/async_stream_methods.hpp

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@ namespace ugrpc::client::impl {
1818

1919
ugrpc::impl::AsyncMethodInvocation::WaitStatus WaitAndTryCancelIfNeeded(
2020
ugrpc::impl::AsyncMethodInvocation& invocation,
21-
grpc::ClientContext& context
21+
grpc::ClientContext& client_context
2222
) noexcept;
2323

24-
void CheckOk(StreamingCallState& state, ugrpc::impl::AsyncMethodInvocation::WaitStatus status, std::string_view stage);
24+
void CheckOk(
25+
StreamingCallState& state,
26+
ugrpc::impl::AsyncMethodInvocation::WaitStatus wait_status,
27+
std::string_view stage
28+
);
2529

2630
void CheckFinishStatus(CallState& state);
2731

@@ -33,13 +37,13 @@ void ProcessCancelled(CallState& state, std::string_view stage) noexcept;
3337

3438
void ProcessNetworkError(CallState& state, std::string_view stage) noexcept;
3539

36-
void ThrowIfDeadlineIsExceeded(grpc::ClientContext& context, std::string_view call_name);
40+
void ThrowIfDeadlineIsExceeded(grpc::ClientContext& client_context, std::string_view call_name);
3741

3842
template <typename GrpcStream>
3943
void StartCall(GrpcStream& stream, StreamingCallState& state) {
40-
ugrpc::impl::AsyncMethodInvocation start_call;
41-
stream.StartCall(start_call.GetCompletionTag());
42-
CheckOk(state, WaitAndTryCancelIfNeeded(start_call, state.GetClientContext()), "StartCall");
44+
ugrpc::impl::AsyncMethodInvocation invocation;
45+
stream.StartCall(invocation.GetCompletionTag());
46+
CheckOk(state, WaitAndTryCancelIfNeeded(invocation, state.GetClientContext()), "StartCall");
4347
}
4448

4549
template <typename GrpcStream>
@@ -53,14 +57,13 @@ void Finish(
5357

5458
state.SetFinished();
5559

56-
FinishAsyncMethodInvocation finish;
57-
auto& status = state.GetStatus();
58-
stream.Finish(&status, finish.GetCompletionTag());
60+
FinishAsyncMethodInvocation invocation;
61+
stream.Finish(&state.GetStatus(), invocation.GetCompletionTag());
62+
const auto wait_status = WaitAndTryCancelIfNeeded(invocation, state.GetClientContext());
5963

60-
const auto wait_status = WaitAndTryCancelIfNeeded(finish, state.GetClientContext());
6164
switch (wait_status) {
6265
case ugrpc::impl::AsyncMethodInvocation::WaitStatus::kOk:
63-
state.GetStatsScope().SetFinishTime(finish.GetFinishTime());
66+
state.GetStatsScope().SetFinishTime(invocation.GetFinishTime());
6467
try {
6568
ProcessFinish(state, final_response);
6669
} catch (const std::exception& ex) {
@@ -76,7 +79,7 @@ void Finish(
7679
break;
7780

7881
case ugrpc::impl::AsyncMethodInvocation::WaitStatus::kError:
79-
state.GetStatsScope().SetFinishTime(finish.GetFinishTime());
82+
state.GetStatsScope().SetFinishTime(invocation.GetFinishTime());
8083
ProcessNetworkError(state, "Finish");
8184
if (throw_on_error) {
8285
ThrowIfDeadlineIsExceeded(state.GetClientContext(), state.GetCallName());
@@ -107,24 +110,16 @@ void FinishAbandoned(GrpcStream& stream, StreamingCallState& state) noexcept try
107110

108111
state.GetClientContext().TryCancel();
109112

110-
FinishAsyncMethodInvocation finish;
111-
stream.Finish(&state.GetStatus(), finish.GetCompletionTag());
112-
113-
const engine::TaskCancellationBlocker cancel_blocker;
114-
const auto wait_status = finish.Wait();
113+
FinishAsyncMethodInvocation invocation;
114+
stream.Finish(&state.GetStatus(), invocation.GetCompletionTag());
115+
const auto ok = invocation.WaitNonCancellable();
115116

116-
state.GetStatsScope().SetFinishTime(finish.GetFinishTime());
117+
state.GetStatsScope().SetFinishTime(invocation.GetFinishTime());
117118

118-
switch (wait_status) {
119-
case ugrpc::impl::AsyncMethodInvocation::WaitStatus::kOk:
120-
ProcessFinishAbandoned(state);
121-
break;
122-
case ugrpc::impl::AsyncMethodInvocation::WaitStatus::kError:
123-
ProcessNetworkError(state, "Finish");
124-
break;
125-
case ugrpc::impl::AsyncMethodInvocation::WaitStatus::kCancelled:
126-
case ugrpc::impl::AsyncMethodInvocation::WaitStatus::kDeadline:
127-
UINVARIANT(false, "unreachable");
119+
if (ok) {
120+
ProcessFinishAbandoned(state);
121+
} else {
122+
ProcessNetworkError(state, "Finish");
128123
}
129124
} catch (const std::exception& ex) {
130125
LOG_WARNING() << "There is a caught exception in 'FinishAbandoned': " << ex;
@@ -133,9 +128,9 @@ void FinishAbandoned(GrpcStream& stream, StreamingCallState& state) noexcept try
133128
template <typename GrpcStream, typename Response>
134129
[[nodiscard]] bool Read(GrpcStream& stream, Response& response, StreamingCallState& state) {
135130
UINVARIANT(IsReadAvailable(state), "'impl::Read' called on a finished call");
136-
ugrpc::impl::AsyncMethodInvocation read;
137-
stream.Read(&response, read.GetCompletionTag());
138-
const auto wait_status = WaitAndTryCancelIfNeeded(read, state.GetClientContext());
131+
ugrpc::impl::AsyncMethodInvocation invocation;
132+
stream.Read(&response, invocation.GetCompletionTag());
133+
const auto wait_status = WaitAndTryCancelIfNeeded(invocation, state.GetClientContext());
139134
if (wait_status == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kCancelled) {
140135
state.GetStatsScope().OnCancelled();
141136
}
@@ -146,8 +141,8 @@ template <typename GrpcStream, typename Response>
146141
void ReadAsync(GrpcStream& stream, Response& response, StreamingCallState& state) {
147142
UINVARIANT(IsReadAvailable(state), "'impl::Read' called on a finished call");
148143
state.EmplaceAsyncMethodInvocation();
149-
auto& read = state.GetAsyncMethodInvocation();
150-
stream.Read(&response, read.GetCompletionTag());
144+
auto& invocation = state.GetAsyncMethodInvocation();
145+
stream.Read(&response, invocation.GetCompletionTag());
151146
}
152147

153148
template <typename GrpcStream, typename Request>
@@ -160,16 +155,16 @@ bool Write(GrpcStream& stream, const Request& request, grpc::WriteOptions option
160155

161156
UINVARIANT(IsWriteAvailable(state), "'impl::Write' called on a stream that is closed for writes");
162157

163-
ugrpc::impl::AsyncMethodInvocation write;
164-
stream.Write(request, options, write.GetCompletionTag());
165-
const auto result = WaitAndTryCancelIfNeeded(write, state.GetClientContext());
166-
if (result == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kCancelled) {
158+
ugrpc::impl::AsyncMethodInvocation invocation;
159+
stream.Write(request, options, invocation.GetCompletionTag());
160+
const auto wait_status = WaitAndTryCancelIfNeeded(invocation, state.GetClientContext());
161+
if (wait_status == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kCancelled) {
167162
state.GetStatsScope().OnCancelled();
168163
}
169-
if (result != ugrpc::impl::AsyncMethodInvocation::WaitStatus::kOk) {
164+
if (wait_status != ugrpc::impl::AsyncMethodInvocation::WaitStatus::kOk) {
170165
state.SetWritesFinished();
171166
}
172-
return result == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kOk;
167+
return wait_status == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kOk;
173168
}
174169

175170
template <typename GrpcStream, typename Request>
@@ -182,9 +177,9 @@ void WriteAndCheck(GrpcStream& stream, const Request& request, grpc::WriteOption
182177

183178
UINVARIANT(IsWriteAndCheckAvailable(state), "'impl::WriteAndCheck' called on a finished or closed stream");
184179

185-
ugrpc::impl::AsyncMethodInvocation write;
186-
stream.Write(request, options, write.GetCompletionTag());
187-
CheckOk(state, WaitAndTryCancelIfNeeded(write, state.GetClientContext()), "WriteAndCheck");
180+
ugrpc::impl::AsyncMethodInvocation invocation;
181+
stream.Write(request, options, invocation.GetCompletionTag());
182+
CheckOk(state, WaitAndTryCancelIfNeeded(invocation, state.GetClientContext()), "WriteAndCheck");
188183
}
189184

190185
template <typename GrpcStream>
@@ -197,9 +192,9 @@ bool WritesDone(GrpcStream& stream, StreamingCallState& state) {
197192

198193
UINVARIANT(IsWriteAvailable(state), "'impl::WritesDone' called on a stream that is closed for writes");
199194
state.SetWritesFinished();
200-
ugrpc::impl::AsyncMethodInvocation writes_done;
201-
stream.WritesDone(writes_done.GetCompletionTag());
202-
const auto wait_status = WaitAndTryCancelIfNeeded(writes_done, state.GetClientContext());
195+
ugrpc::impl::AsyncMethodInvocation invocation;
196+
stream.WritesDone(invocation.GetCompletionTag());
197+
const auto wait_status = WaitAndTryCancelIfNeeded(invocation, state.GetClientContext());
203198
if (wait_status == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kCancelled) {
204199
state.GetStatsScope().OnCancelled();
205200
}

grpc/include/userver/ugrpc/client/impl/prepare_call.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ class PrepareUnaryCallProxy<grpc::GenericStub, grpc::ByteBuffer, grpc::ByteBuffe
6767

6868
decltype(auto) operator()(
6969
StubHandle& stub_handle,
70-
grpc::ClientContext* context,
70+
grpc::ClientContext* client_context,
7171
const grpc::ByteBuffer& request,
7272
grpc::CompletionQueue* cq
7373
) const {
74-
return impl::PrepareCall(prepare_async_method_, stub_handle, context, method_name_, request, cq);
74+
return impl::PrepareCall(prepare_async_method_, stub_handle, client_context, method_name_, request, cq);
7575
}
7676

7777
private:

grpc/include/userver/ugrpc/client/stream_read_future.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,12 @@ bool StreamReadFuture<RawStream>::Get() {
9797
UINVARIANT(state_, "'Get' must be called only once");
9898
const impl::StreamingCallState::AsyncMethodInvocationGuard guard(*state_);
9999
auto* const state = std::exchange(state_, nullptr);
100-
const auto result = impl::WaitAndTryCancelIfNeeded(state->GetAsyncMethodInvocation(), state->GetClientContext());
101-
if (result == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kCancelled) {
100+
const auto
101+
wait_status = impl::WaitAndTryCancelIfNeeded(state->GetAsyncMethodInvocation(), state->GetClientContext());
102+
if (wait_status == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kCancelled) {
102103
state->GetStatsScope().OnCancelled();
103104
state->GetStatsScope().Flush();
104-
} else if (result == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kError) {
105+
} else if (wait_status == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kError) {
105106
// Finish can only be called once all the data is read, otherwise the
106107
// underlying gRPC driver hangs.
107108
impl::Finish(*stream_, *state, /*final_response=*/nullptr, /*throw_on_error=*/true);
@@ -110,7 +111,7 @@ bool StreamReadFuture<RawStream>::Get() {
110111
RunMiddlewarePipeline(*state, impl::MiddlewareHooks::RecvMessageHooks(*recv_message_));
111112
}
112113
}
113-
return result == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kOk;
114+
return wait_status == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kOk;
114115
}
115116

116117
template <typename RawStream>

grpc/include/userver/ugrpc/impl/async_method_invocation.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ class AsyncMethodInvocation : public EventBase {
2020
/// @see EventBase::Notify
2121
void Notify(bool ok) noexcept override;
2222

23-
bool IsBusy() const noexcept;
24-
2523
enum class WaitStatus {
2624
kOk,
2725
kError,
@@ -52,12 +50,10 @@ class AsyncMethodInvocation : public EventBase {
5250
// For internal use only.
5351
engine::impl::ContextAccessor* TryGetContextAccessor() noexcept;
5452
/// @endcond
55-
protected:
56-
void WaitWhileBusy() noexcept;
5753

5854
private:
55+
bool enqueued_{false};
5956
bool ok_{false};
60-
bool busy_{false};
6157
engine::SingleUseEvent event_;
6258
};
6359

grpc/include/userver/ugrpc/server/impl/service_worker_impl.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class CallData final {
107107

108108
auto& queue = method_data_.service_data.internals.completion_queues.GetQueue(method_data_.queue_id);
109109

110+
ugrpc::impl::AsyncMethodInvocation request_call_invocation;
110111
// the request for an incoming RPC must be performed synchronously
111112
method_data_.service_data.async_service.template RequestCall<CallTraits>(
112113
method_data_.method_id,
@@ -115,14 +116,14 @@ class CallData final {
115116
raw_responder_,
116117
queue,
117118
queue,
118-
request_call_.GetCompletionTag()
119+
request_call_invocation.GetCompletionTag()
119120
);
120121

121122
// Note: we ignore task cancellations here. Even if notify_when_done has
122123
// already cancelled this RPC, we want to:
123124
// 1. listen to further RPCs for the same method
124125
// 2. handle this RPC correctly, including metrics, logs, etc.
125-
if (!request_call_.WaitNonCancellable()) {
126+
if (!request_call_invocation.WaitNonCancellable()) {
126127
// the CompletionQueue is shutting down
127128

128129
// Do not wait for notify_when_done. When queue is shutting down, it will
@@ -197,7 +198,6 @@ class CallData final {
197198
typename CallTraits::RawContext context_{};
198199
InitialRequest initial_request_{};
199200
RawResponder raw_responder_{&context_};
200-
ugrpc::impl::AsyncMethodInvocation request_call_;
201201
std::optional<tracing::InPlaceSpan> span_storage_{};
202202
};
203203

grpc/src/ugrpc/client/channels.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ namespace {
3636
return false;
3737
}
3838

39-
ugrpc::impl::AsyncMethodInvocation operation;
40-
channel.NotifyOnStateChange(state, deadline, &queue, operation.GetCompletionTag());
41-
if (operation.Wait() != ugrpc::impl::AsyncMethodInvocation::WaitStatus::kOk) {
39+
ugrpc::impl::AsyncMethodInvocation invocation;
40+
channel.NotifyOnStateChange(state, deadline, &queue, invocation.GetCompletionTag());
41+
if (invocation.Wait() != ugrpc::impl::AsyncMethodInvocation::WaitStatus::kOk) {
4242
return false;
4343
}
4444
}

grpc/src/ugrpc/client/impl/async_stream_methods.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ void SetErrorAndResetSpan(CallState& state, std::string_view error_message) noex
2626

2727
} // namespace
2828

29-
void ThrowIfDeadlineIsExceeded(grpc::ClientContext& context, std::string_view call_name) {
30-
const auto raw_deadline = context.raw_deadline();
29+
void ThrowIfDeadlineIsExceeded(grpc::ClientContext& client_context, std::string_view call_name) {
30+
const auto raw_deadline = client_context.raw_deadline();
3131
const auto deadline = ugrpc::TimespecToDeadline(raw_deadline);
3232
if (deadline.IsReached()) {
3333
grpc::Status deadline_status(grpc::StatusCode::DEADLINE_EXCEEDED, "Deadline exceeded");
@@ -37,22 +37,26 @@ void ThrowIfDeadlineIsExceeded(grpc::ClientContext& context, std::string_view ca
3737

3838
ugrpc::impl::AsyncMethodInvocation::WaitStatus WaitAndTryCancelIfNeeded(
3939
ugrpc::impl::AsyncMethodInvocation& invocation,
40-
grpc::ClientContext& context
40+
grpc::ClientContext& client_context
4141
) noexcept {
4242
const auto wait_status = invocation.Wait();
4343
if (ugrpc::impl::AsyncMethodInvocation::WaitStatus::kCancelled == wait_status) {
44-
context.TryCancel();
44+
client_context.TryCancel();
4545
}
4646
return wait_status;
4747
}
4848

49-
void CheckOk(StreamingCallState& state, ugrpc::impl::AsyncMethodInvocation::WaitStatus status, std::string_view stage) {
50-
if (status == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kError) {
49+
void CheckOk(
50+
StreamingCallState& state,
51+
ugrpc::impl::AsyncMethodInvocation::WaitStatus wait_status,
52+
std::string_view stage
53+
) {
54+
if (wait_status == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kError) {
5155
state.SetFinished();
5256
ThrowIfDeadlineIsExceeded(state.GetClientContext(), state.GetCallName());
5357
ProcessNetworkError(state, stage);
5458
throw RpcInterruptedError(state.GetCallName(), stage);
55-
} else if (status == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kCancelled) {
59+
} else if (wait_status == ugrpc::impl::AsyncMethodInvocation::WaitStatus::kCancelled) {
5660
state.SetFinished();
5761
ProcessCancelled(state, stage);
5862
throw RpcCancelledError(state.GetCallName(), stage);

grpc/src/ugrpc/impl/async_method_invocation.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,20 @@ USERVER_NAMESPACE_BEGIN
88

99
namespace ugrpc::impl {
1010

11-
AsyncMethodInvocation::~AsyncMethodInvocation() { WaitWhileBusy(); }
11+
AsyncMethodInvocation::~AsyncMethodInvocation() {
12+
if (enqueued_) {
13+
event_.WaitNonCancellable();
14+
}
15+
}
1216

1317
void AsyncMethodInvocation::Notify(bool ok) noexcept {
1418
ok_ = ok;
1519
event_.Send();
1620
}
1721

18-
bool AsyncMethodInvocation::IsBusy() const noexcept { return busy_; }
19-
2022
void* AsyncMethodInvocation::GetCompletionTag() noexcept {
21-
UASSERT(!busy_);
22-
busy_ = true;
23+
UASSERT(!enqueued_);
24+
enqueued_ = true;
2325
return static_cast<EventBase*>(this);
2426
}
2527

@@ -40,7 +42,6 @@ AsyncMethodInvocation::WaitStatus AsyncMethodInvocation::WaitUntil(engine::Deadl
4042
return WaitStatus::kDeadline;
4143
}
4244
case engine::FutureStatus::kReady: {
43-
busy_ = false;
4445
return ok_ ? WaitStatus::kOk : WaitStatus::kError;
4546
}
4647
}
@@ -59,13 +60,6 @@ engine::impl::ContextAccessor* AsyncMethodInvocation::TryGetContextAccessor() no
5960

6061
bool AsyncMethodInvocation::IsReady() const noexcept { return event_.IsReady(); }
6162

62-
void AsyncMethodInvocation::WaitWhileBusy() noexcept {
63-
if (busy_) {
64-
event_.WaitNonCancellable();
65-
}
66-
busy_ = false;
67-
}
68-
6963
} // namespace ugrpc::impl
7064

7165
USERVER_NAMESPACE_END

grpc/tests/base_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ ugrpc::client::CallOptions PrepareCallOptions() {
115115
return call_options;
116116
}
117117

118-
void CheckClientContext(const grpc::ClientContext& context) {
119-
const auto& metadata = context.GetServerTrailingMetadata();
118+
void CheckClientContext(const grpc::ClientContext& client_context) {
119+
const auto& metadata = client_context.GetServerTrailingMetadata();
120120
const auto iter = metadata.find("resp_header");
121121
ASSERT_NE(iter, metadata.end());
122122
EXPECT_EQ(iter->second, "value");

grpc/tests/tracing_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ class GrpcTracing : public ugrpc::tests::ServiceFixture<UnitTestServiceWithTraci
7777
logging::DefaultLoggerLevelScope log_level_scope_{logging::Level::kInfo};
7878
};
7979

80-
void CheckMetadata(const grpc::ClientContext& context) {
81-
const auto& metadata = context.GetServerInitialMetadata();
80+
void CheckMetadata(const grpc::ClientContext& client_context) {
81+
const auto& metadata = client_context.GetServerInitialMetadata();
8282
const auto& span = tracing::Span::CurrentSpan();
8383

8484
// - TraceId should propagate both to sub-spans within a single service,

0 commit comments

Comments
 (0)