Skip to content

Commit 3dbf09e

Browse files
authored
perf: Improve response throughput of a single gRPC stream (#7404)
1 parent d1780d1 commit 3dbf09e

File tree

3 files changed

+133
-71
lines changed

3 files changed

+133
-71
lines changed

src/grpc/infer_handler.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -979,6 +979,9 @@ class InferHandlerState {
979979
// Tracks all the states that have been created on this context.
980980
std::set<InferHandlerStateType*> all_states_;
981981

982+
// Ready to write queue for decoupled
983+
std::queue<InferHandlerStateType*> ready_to_write_states_;
984+
982985
// The step of the entire context.
983986
Steps step_;
984987

src/grpc/stream_infer_handler.cc

Lines changed: 127 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,38 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
359359
response->mutable_infer_response()->Clear();
360360
// repopulate the id so that client knows which request failed.
361361
response->mutable_infer_response()->set_id(request.id());
362-
state->step_ = Steps::WRITEREADY;
363362
if (!state->is_decoupled_) {
363+
state->step_ = Steps::WRITEREADY;
364364
state->context_->WriteResponseIfReady(state);
365365
} else {
366-
state->response_queue_->MarkNextResponseComplete();
367-
state->complete_ = true;
368-
state->context_->PutTaskBackToQueue(state);
366+
InferHandler::State* writing_state = nullptr;
367+
std::lock_guard<std::recursive_mutex> lk1(state->context_->mu_);
368+
{
369+
std::lock_guard<std::recursive_mutex> lk2(state->step_mtx_);
370+
state->response_queue_->MarkNextResponseComplete();
371+
state->context_->ready_to_write_states_.push(state);
372+
if (!state->context_->ongoing_write_) {
373+
// Only one write is allowed per gRPC stream / context at any time.
374+
// If the stream is not currently writing, start writing the next
375+
// ready to write response from the next ready to write state from
376+
// 'ready_to_write_states_'. If there are other responses on the
377+
// state ready to be written after starting the write, the state
378+
// will be placed at the back of the 'ready_to_write_states_'. If
379+
// there are no other response, the state will be marked as 'ISSUED'
380+
// if complete final flag is not received yet from the backend or
381+
// completed if complete final flag is received.
382+
// The 'ongoing_write_' will reset once the completion queue returns
383+
// a written state and no additional response on the stream is ready
384+
// to be written.
385+
state->context_->ongoing_write_ = true;
386+
writing_state = state->context_->ready_to_write_states_.front();
387+
state->context_->ready_to_write_states_.pop();
388+
}
389+
state->complete_ = true;
390+
}
391+
if (writing_state != nullptr) {
392+
StateWriteResponse(writing_state);
393+
}
369394
}
370395
}
371396

@@ -451,7 +476,6 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
451476
// Decoupled state transitions
452477
//
453478
if (state->step_ == Steps::WRITTEN) {
454-
state->context_->ongoing_write_ = false;
455479
#ifdef TRITON_ENABLE_TRACING
456480
state->trace_timestamps_.emplace_back(
457481
std::make_pair("GRPC_SEND_END", TraceManager::CaptureTimestamp()));
@@ -469,61 +493,76 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
469493
state->context_->finish_ok_ = false;
470494
}
471495

472-
// Finish the state if all the transactions associated with
473-
// the state have completed.
474-
if (state->IsComplete()) {
475-
state->context_->DecrementRequestCounter();
476-
finished = Finish(state);
477-
} else {
478-
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);
479-
480-
// If there is an available response to be written
481-
// to the stream, then transition directly to WRITEREADY
482-
// state and enqueue itself to the completion queue to be
483-
// taken up later. Otherwise, go to ISSUED state and wait
484-
// for the callback to make a response available.
485-
if (state->response_queue_->HasReadyResponse()) {
486-
state->step_ = Steps::WRITEREADY;
487-
state->context_->PutTaskBackToQueue(state);
488-
} else {
489-
state->step_ = Steps::ISSUED;
496+
{
497+
InferHandler::State* writing_state = nullptr;
498+
std::lock_guard<std::recursive_mutex> lk1(state->context_->mu_);
499+
{
500+
std::lock_guard<std::recursive_mutex> lk2(state->step_mtx_);
501+
if (!state->context_->ready_to_write_states_.empty()) {
502+
writing_state = state->context_->ready_to_write_states_.front();
503+
state->context_->ready_to_write_states_.pop();
504+
} else {
505+
state->context_->ongoing_write_ = false;
506+
}
507+
// Finish the state if all the transactions associated with
508+
// the state have completed.
509+
if (state != writing_state) {
510+
if (state->IsComplete()) {
511+
state->context_->DecrementRequestCounter();
512+
finished = Finish(state);
513+
} else {
514+
state->step_ = Steps::ISSUED;
515+
}
516+
}
517+
}
518+
if (writing_state != nullptr) {
519+
StateWriteResponse(writing_state);
490520
}
491521
}
492522
} else if (state->step_ == Steps::WRITEREADY) {
493-
if (state->delay_response_ms_ != 0) {
494-
// Will delay the write of the response by the specified time.
495-
// This can be used to test the flow where there are other
496-
// responses available to be written.
497-
LOG_INFO << "Delaying the write of the response by "
498-
<< state->delay_response_ms_ << " ms...";
499-
std::this_thread::sleep_for(
500-
std::chrono::milliseconds(state->delay_response_ms_));
501-
}
502-
503523
// Finish the state if all the transactions associated with
504524
// the state have completed.
505525
if (state->IsComplete()) {
506526
state->context_->DecrementRequestCounter();
507527
finished = Finish(state);
508528
} else {
509-
// GRPC doesn't allow to issue another write till
510-
// the notification from previous write has been
511-
// delivered. If there is an ongoing write then
512-
// defer writing and place the task at the back
513-
// of the completion queue to be taken up later.
514-
if (!state->context_->ongoing_write_) {
515-
state->context_->ongoing_write_ = true;
516-
state->context_->DecoupledWriteResponse(state);
517-
} else {
518-
state->context_->PutTaskBackToQueue(state);
519-
}
529+
LOG_ERROR << "Should not print this! Decoupled should NOT write via "
530+
"WRITEREADY!";
531+
// Remove the state from the completion queue
532+
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);
533+
state->step_ = Steps::ISSUED;
520534
}
521535
}
522536
}
523537

524538
return !finished;
525539
}
526540

541+
// For decoupled only. Caller must ensure exclusive write.
542+
void
543+
ModelStreamInferHandler::StateWriteResponse(InferHandler::State* state)
544+
{
545+
if (state->delay_response_ms_ != 0) {
546+
// Will delay the write of the response by the specified time.
547+
// This can be used to test the flow where there are other
548+
// responses available to be written.
549+
LOG_INFO << "Delaying the write of the response by "
550+
<< state->delay_response_ms_ << " ms...";
551+
std::this_thread::sleep_for(
552+
std::chrono::milliseconds(state->delay_response_ms_));
553+
}
554+
{
555+
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);
556+
state->step_ = Steps::WRITTEN;
557+
// gRPC doesn't allow to issue another write till the notification from
558+
// previous write has been delivered.
559+
state->context_->DecoupledWriteResponse(state);
560+
if (state->response_queue_->HasReadyResponse()) {
561+
state->context_->ready_to_write_states_.push(state);
562+
}
563+
}
564+
}
565+
527566
bool
528567
ModelStreamInferHandler::Finish(InferHandler::State* state)
529568
{
@@ -701,45 +740,64 @@ ModelStreamInferHandler::StreamInferResponseComplete(
701740
}
702741
}
703742

704-
// Update states to signal that response/error is ready to write to stream
705-
{
743+
if (state->IsGrpcContextCancelled()) {
706744
// Need to hold lock because the handler thread processing context
707745
// cancellation might have cancelled or marked the state for cancellation.
708746
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);
709747

710-
if (state->IsGrpcContextCancelled()) {
711-
LOG_VERBOSE(1)
712-
<< "ModelStreamInferHandler::StreamInferResponseComplete, "
713-
<< state->unique_id_
714-
<< ", skipping writing response because of transaction was cancelled";
715-
716-
// If this was the final callback for the state
717-
// then cycle through the completion queue so
718-
// that state object can be released.
719-
if (is_complete) {
720-
state->step_ = Steps::CANCELLED;
721-
state->context_->PutTaskBackToQueue(state);
722-
}
748+
LOG_VERBOSE(1)
749+
<< "ModelStreamInferHandler::StreamInferResponseComplete, "
750+
<< state->unique_id_
751+
<< ", skipping writing response because of transaction was cancelled";
723752

724-
state->complete_ = is_complete;
725-
return;
753+
// If this was the final callback for the state
754+
// then cycle through the completion queue so
755+
// that state object can be released.
756+
if (is_complete) {
757+
state->step_ = Steps::CANCELLED;
758+
state->context_->PutTaskBackToQueue(state);
726759
}
727760

728-
if (state->is_decoupled_) {
761+
state->complete_ = is_complete;
762+
return;
763+
}
764+
765+
if (state->is_decoupled_) {
766+
InferHandler::State* writing_state = nullptr;
767+
std::lock_guard<std::recursive_mutex> lk1(state->context_->mu_);
768+
{
769+
std::lock_guard<std::recursive_mutex> lk2(state->step_mtx_);
770+
bool has_prev_ready_response = state->response_queue_->HasReadyResponse();
729771
if (response) {
730772
state->response_queue_->MarkNextResponseComplete();
731773
}
732-
if (state->step_ == Steps::ISSUED) {
774+
if (!has_prev_ready_response && response) {
775+
state->context_->ready_to_write_states_.push(state);
776+
}
777+
if (!state->context_->ongoing_write_ &&
778+
!state->context_->ready_to_write_states_.empty()) {
779+
state->context_->ongoing_write_ = true;
780+
writing_state = state->context_->ready_to_write_states_.front();
781+
state->context_->ready_to_write_states_.pop();
782+
}
783+
if (is_complete && state->response_queue_->IsEmpty() &&
784+
state->step_ == Steps::ISSUED) {
785+
// The response queue is empty and complete final flag is received, so
786+
// mark the state as 'WRITEREADY' so it can be cleaned up later.
733787
state->step_ = Steps::WRITEREADY;
734788
state->context_->PutTaskBackToQueue(state);
735789
}
736-
} else {
737-
state->step_ = Steps::WRITEREADY;
738-
if (is_complete) {
739-
state->context_->WriteResponseIfReady(state);
740-
}
790+
state->complete_ = is_complete;
791+
}
792+
if (writing_state != nullptr) {
793+
StateWriteResponse(writing_state);
794+
}
795+
} else { // non-decoupled
796+
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);
797+
state->step_ = Steps::WRITEREADY;
798+
if (is_complete) {
799+
state->context_->WriteResponseIfReady(state);
741800
}
742-
743801
state->complete_ = is_complete;
744802
}
745803
}

src/grpc/stream_infer_handler.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -112,6 +112,7 @@ class ModelStreamInferHandler
112112
static void StreamInferResponseComplete(
113113
TRITONSERVER_InferenceResponse* response, const uint32_t flags,
114114
void* userp);
115+
static void StateWriteResponse(InferHandler::State* state);
115116
bool Finish(State* state);
116117

117118
TraceManager* trace_manager_;

0 commit comments

Comments
 (0)