@@ -359,13 +359,38 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
359359 response->mutable_infer_response ()->Clear ();
360360 // repopulate the id so that client knows which request failed.
361361 response->mutable_infer_response ()->set_id (request.id ());
362- state->step_ = Steps::WRITEREADY;
363362 if (!state->is_decoupled_ ) {
363+ state->step_ = Steps::WRITEREADY;
364364 state->context_ ->WriteResponseIfReady (state);
365365 } else {
366- state->response_queue_ ->MarkNextResponseComplete ();
367- state->complete_ = true ;
368- state->context_ ->PutTaskBackToQueue (state);
366+ InferHandler::State* writing_state = nullptr ;
367+ std::lock_guard<std::recursive_mutex> lk1 (state->context_ ->mu_ );
368+ {
369+ std::lock_guard<std::recursive_mutex> lk2 (state->step_mtx_ );
370+ state->response_queue_ ->MarkNextResponseComplete ();
371+ state->context_ ->ready_to_write_states_ .push (state);
372+ if (!state->context_ ->ongoing_write_ ) {
373+ // Only one write is allowed per gRPC stream / context at any time.
374+ // If the stream is not currently writing, start writing the next
375+ // ready to write response from the next ready to write state from
376+ // 'ready_to_write_states_'. If there are other responses on the
377+ // state ready to be written after starting the write, the state
378+ // will be placed at the back of the 'ready_to_write_states_'. If
379+ // there are no other response, the state will be marked as 'ISSUED'
380+ // if complete final flag is not received yet from the backend or
381+ // completed if complete final flag is received.
382+ // The 'ongoing_write_' will reset once the completion queue returns
383+ // a written state and no additional response on the stream is ready
384+ // to be written.
385+ state->context_ ->ongoing_write_ = true ;
386+ writing_state = state->context_ ->ready_to_write_states_ .front ();
387+ state->context_ ->ready_to_write_states_ .pop ();
388+ }
389+ state->complete_ = true ;
390+ }
391+ if (writing_state != nullptr ) {
392+ StateWriteResponse (writing_state);
393+ }
369394 }
370395 }
371396
@@ -451,7 +476,6 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
451476 // Decoupled state transitions
452477 //
453478 if (state->step_ == Steps::WRITTEN) {
454- state->context_ ->ongoing_write_ = false ;
455479#ifdef TRITON_ENABLE_TRACING
456480 state->trace_timestamps_ .emplace_back (
457481 std::make_pair (" GRPC_SEND_END" , TraceManager::CaptureTimestamp ()));
@@ -469,61 +493,76 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
469493 state->context_ ->finish_ok_ = false ;
470494 }
471495
472- // Finish the state if all the transactions associated with
473- // the state have completed.
474- if (state->IsComplete ()) {
475- state->context_ ->DecrementRequestCounter ();
476- finished = Finish (state);
477- } else {
478- std::lock_guard<std::recursive_mutex> lock (state->step_mtx_ );
479-
480- // If there is an available response to be written
481- // to the stream, then transition directly to WRITEREADY
482- // state and enqueue itself to the completion queue to be
483- // taken up later. Otherwise, go to ISSUED state and wait
484- // for the callback to make a response available.
485- if (state->response_queue_ ->HasReadyResponse ()) {
486- state->step_ = Steps::WRITEREADY;
487- state->context_ ->PutTaskBackToQueue (state);
488- } else {
489- state->step_ = Steps::ISSUED;
496+ {
497+ InferHandler::State* writing_state = nullptr ;
498+ std::lock_guard<std::recursive_mutex> lk1 (state->context_ ->mu_ );
499+ {
500+ std::lock_guard<std::recursive_mutex> lk2 (state->step_mtx_ );
501+ if (!state->context_ ->ready_to_write_states_ .empty ()) {
502+ writing_state = state->context_ ->ready_to_write_states_ .front ();
503+ state->context_ ->ready_to_write_states_ .pop ();
504+ } else {
505+ state->context_ ->ongoing_write_ = false ;
506+ }
507+ // Finish the state if all the transactions associated with
508+ // the state have completed.
509+ if (state != writing_state) {
510+ if (state->IsComplete ()) {
511+ state->context_ ->DecrementRequestCounter ();
512+ finished = Finish (state);
513+ } else {
514+ state->step_ = Steps::ISSUED;
515+ }
516+ }
517+ }
518+ if (writing_state != nullptr ) {
519+ StateWriteResponse (writing_state);
490520 }
491521 }
492522 } else if (state->step_ == Steps::WRITEREADY) {
493- if (state->delay_response_ms_ != 0 ) {
494- // Will delay the write of the response by the specified time.
495- // This can be used to test the flow where there are other
496- // responses available to be written.
497- LOG_INFO << " Delaying the write of the response by "
498- << state->delay_response_ms_ << " ms..." ;
499- std::this_thread::sleep_for (
500- std::chrono::milliseconds (state->delay_response_ms_ ));
501- }
502-
503523 // Finish the state if all the transactions associated with
504524 // the state have completed.
505525 if (state->IsComplete ()) {
506526 state->context_ ->DecrementRequestCounter ();
507527 finished = Finish (state);
508528 } else {
509- // GRPC doesn't allow to issue another write till
510- // the notification from previous write has been
511- // delivered. If there is an ongoing write then
512- // defer writing and place the task at the back
513- // of the completion queue to be taken up later.
514- if (!state->context_ ->ongoing_write_ ) {
515- state->context_ ->ongoing_write_ = true ;
516- state->context_ ->DecoupledWriteResponse (state);
517- } else {
518- state->context_ ->PutTaskBackToQueue (state);
519- }
529+ LOG_ERROR << " Should not print this! Decoupled should NOT write via "
530+ " WRITEREADY!" ;
531+ // Remove the state from the completion queue
532+ std::lock_guard<std::recursive_mutex> lock (state->step_mtx_ );
533+ state->step_ = Steps::ISSUED;
520534 }
521535 }
522536 }
523537
524538 return !finished;
525539}
526540
541+ // For decoupled only. Caller must ensure exclusive write.
542+ void
543+ ModelStreamInferHandler::StateWriteResponse (InferHandler::State* state)
544+ {
545+ if (state->delay_response_ms_ != 0 ) {
546+ // Will delay the write of the response by the specified time.
547+ // This can be used to test the flow where there are other
548+ // responses available to be written.
549+ LOG_INFO << " Delaying the write of the response by "
550+ << state->delay_response_ms_ << " ms..." ;
551+ std::this_thread::sleep_for (
552+ std::chrono::milliseconds (state->delay_response_ms_ ));
553+ }
554+ {
555+ std::lock_guard<std::recursive_mutex> lock (state->step_mtx_ );
556+ state->step_ = Steps::WRITTEN;
557+ // gRPC doesn't allow to issue another write till the notification from
558+ // previous write has been delivered.
559+ state->context_ ->DecoupledWriteResponse (state);
560+ if (state->response_queue_ ->HasReadyResponse ()) {
561+ state->context_ ->ready_to_write_states_ .push (state);
562+ }
563+ }
564+ }
565+
527566bool
528567ModelStreamInferHandler::Finish (InferHandler::State* state)
529568{
@@ -701,45 +740,64 @@ ModelStreamInferHandler::StreamInferResponseComplete(
701740 }
702741 }
703742
704- // Update states to signal that response/error is ready to write to stream
705- {
743+ if (state->IsGrpcContextCancelled ()) {
706744 // Need to hold lock because the handler thread processing context
707745 // cancellation might have cancelled or marked the state for cancellation.
708746 std::lock_guard<std::recursive_mutex> lock (state->step_mtx_ );
709747
710- if (state->IsGrpcContextCancelled ()) {
711- LOG_VERBOSE (1 )
712- << " ModelStreamInferHandler::StreamInferResponseComplete, "
713- << state->unique_id_
714- << " , skipping writing response because of transaction was cancelled" ;
715-
716- // If this was the final callback for the state
717- // then cycle through the completion queue so
718- // that state object can be released.
719- if (is_complete) {
720- state->step_ = Steps::CANCELLED;
721- state->context_ ->PutTaskBackToQueue (state);
722- }
748+ LOG_VERBOSE (1 )
749+ << " ModelStreamInferHandler::StreamInferResponseComplete, "
750+ << state->unique_id_
751+ << " , skipping writing response because of transaction was cancelled" ;
723752
724- state->complete_ = is_complete;
725- return ;
753+ // If this was the final callback for the state
754+ // then cycle through the completion queue so
755+ // that state object can be released.
756+ if (is_complete) {
757+ state->step_ = Steps::CANCELLED;
758+ state->context_ ->PutTaskBackToQueue (state);
726759 }
727760
728- if (state->is_decoupled_ ) {
761+ state->complete_ = is_complete;
762+ return ;
763+ }
764+
765+ if (state->is_decoupled_ ) {
766+ InferHandler::State* writing_state = nullptr ;
767+ std::lock_guard<std::recursive_mutex> lk1 (state->context_ ->mu_ );
768+ {
769+ std::lock_guard<std::recursive_mutex> lk2 (state->step_mtx_ );
770+ bool has_prev_ready_response = state->response_queue_ ->HasReadyResponse ();
729771 if (response) {
730772 state->response_queue_ ->MarkNextResponseComplete ();
731773 }
732- if (state->step_ == Steps::ISSUED) {
774+ if (!has_prev_ready_response && response) {
775+ state->context_ ->ready_to_write_states_ .push (state);
776+ }
777+ if (!state->context_ ->ongoing_write_ &&
778+ !state->context_ ->ready_to_write_states_ .empty ()) {
779+ state->context_ ->ongoing_write_ = true ;
780+ writing_state = state->context_ ->ready_to_write_states_ .front ();
781+ state->context_ ->ready_to_write_states_ .pop ();
782+ }
783+ if (is_complete && state->response_queue_ ->IsEmpty () &&
784+ state->step_ == Steps::ISSUED) {
785+ // The response queue is empty and complete final flag is received, so
786+ // mark the state as 'WRITEREADY' so it can be cleaned up later.
733787 state->step_ = Steps::WRITEREADY;
734788 state->context_ ->PutTaskBackToQueue (state);
735789 }
736- } else {
737- state->step_ = Steps::WRITEREADY;
738- if (is_complete) {
739- state->context_ ->WriteResponseIfReady (state);
740- }
790+ state->complete_ = is_complete;
791+ }
792+ if (writing_state != nullptr ) {
793+ StateWriteResponse (writing_state);
794+ }
795+ } else { // non-decoupled
796+ std::lock_guard<std::recursive_mutex> lock (state->step_mtx_ );
797+ state->step_ = Steps::WRITEREADY;
798+ if (is_complete) {
799+ state->context_ ->WriteResponseIfReady (state);
741800 }
742-
743801 state->complete_ = is_complete;
744802 }
745803}
0 commit comments