@@ -653,27 +653,20 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
653653{
654654 py::list py_request_list =
655655 LoadRequestsFromSharedMemory (request_batch_shm_ptr);
656- std::unique_ptr<IPCMessage> execute_response =
657- IPCMessage::Create (shm_pool_, false /* Inline response */ );
658- execute_response->Command () = PYTHONSTUB_ExecuteResponse;
656+ std::unique_ptr<IPCMessage> execute_response;
659657
660- AllocatedSharedMemory<ResponseBatch> response_batch =
661- shm_pool_->Construct <ResponseBatch>();
662- ResponseBatch* response_batch_shm_ptr =
663- reinterpret_cast <ResponseBatch*>(response_batch.data_ .get ());
664- execute_response->Args () = response_batch.handle_ ;
658+ std::optional<AllocatedSharedMemory<char >> response_batch;
665659 bool has_exception = false ;
666660 std::string error_string;
667661 std::unique_ptr<PbString> error_string_shm;
662+ std::string err_message;
668663
669664 ScopedDefer execute_finalize ([this ] { stub_message_queue_->Pop (); });
670665 ScopedDefer _ (
671666 [this , &execute_response] { SendIPCMessage (execute_response); });
672-
667+ py::object execute_return;
668+ py::object coroutine_return;
673669 try {
674- response_batch_shm_ptr->has_error = false ;
675- response_batch_shm_ptr->is_error_set = false ;
676-
677670 if (!py::hasattr (model_instance_, " execute" )) {
678671 std::string message = " Python model " + model_context_.PythonModelPath () +
679672 " does not implement `execute` method." ;
@@ -683,8 +676,7 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
683676 {
684677 NVTX_RANGE (nvtx_, " PyExecute " + name_);
685678
686- py::object execute_return =
687- model_instance_.attr (" execute" )(py_request_list);
679+ execute_return = model_instance_.attr (" execute" )(py_request_list);
688680
689681 bool is_coroutine = py::module::import (" asyncio" )
690682 .attr (" iscoroutine" )(execute_return)
@@ -694,12 +686,14 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
694686 // Do not wait for async decoupled execute to return.
695687 RunCoroutine (execute_return, true /* in_background */ );
696688 } else {
697- py::object coroutine_return =
689+ coroutine_return =
698690 RunCoroutine (execute_return, false /* in_background */ );
699- ProcessReturnedResponses (py_request_list, coroutine_return);
691+ ProcessReturnedResponses (
692+ py_request_list, coroutine_return, response_batch);
700693 }
701694 } else {
702- ProcessReturnedResponses (py_request_list, execute_return);
695+ ProcessReturnedResponses (
696+ py_request_list, execute_return, response_batch);
703697 }
704698 }
705699 }
@@ -713,16 +707,36 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
713707 }
714708
715709 if (has_exception) {
716- std::string err_message =
717- std::string (
718- " Failed to process the request(s) for model '" + name_ +
719- " ', message: " ) +
720- error_string;
710+ err_message = std::string (
711+ " Failed to process the request(s) for model '" + name_ +
712+ " ', message: " ) +
713+ error_string;
721714 LOG_ERROR << err_message.c_str ();
715+ if (!response_batch) {
716+ response_batch = shm_pool_->Construct <char >(
717+ sizeof (ResponseBatch) + sizeof (IPCMessageShm));
718+ }
719+ ResponseBatch* response_batch_shm_ptr = reinterpret_cast <ResponseBatch*>(
720+ response_batch.value ().data_ .get () + sizeof (IPCMessageShm));
721+
722+ // The backend will clean up the response factory if there is an error in
723+ // the response batch. For decoupled mode, it is necessary to handle cases
724+ // where the response sender should have already cleaned up, ensuring the
725+ // backend does not delete the response factory again during error handling.
726+ if (IsDecoupled ()) {
727+ for (py::handle py_request : py_request_list) {
728+ InferRequest* request = py_request.cast <InferRequest*>();
729+ if (request->GetResponseSender ()->IsClosed ()) {
730+ response_batch_shm_ptr->is_response_factory_deleted = true ;
731+ }
732+ }
733+ }
734+
722735 response_batch_shm_ptr->has_error = true ;
723736 error_string_shm = PbString::Create (shm_pool_, err_message);
724737 response_batch_shm_ptr->error = error_string_shm->ShmHandle ();
725738 response_batch_shm_ptr->is_error_set = true ;
739+ response_batch_shm_ptr->batch_size = 0 ;
726740 // Once the error is sent to the backend, the backend is supposed to close
727741 // all response factories if not already closed, so closing all response
728742 // senders if not already closed to prevent the model from sending more
@@ -731,12 +745,47 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
731745 InferRequest* request = py_request.cast <InferRequest*>();
732746 request->GetResponseSender ()->Close ();
733747 }
748+ } else {
749+ if (!response_batch) {
750+ response_batch = shm_pool_->Construct <char >(
751+ sizeof (ResponseBatch) + sizeof (IPCMessageShm));
752+ ResponseBatch* response_batch_shm_ptr = reinterpret_cast <ResponseBatch*>(
753+ response_batch.value ().data_ .get () + sizeof (IPCMessageShm));
754+ response_batch_shm_ptr->batch_size = 0 ;
755+ }
756+ ResponseBatch* response_batch_shm_ptr = reinterpret_cast <ResponseBatch*>(
757+ response_batch.value ().data_ .get () + sizeof (IPCMessageShm));
758+ response_batch_shm_ptr->has_error = false ;
759+ response_batch_shm_ptr->is_error_set = false ;
760+ }
761+
762+ execute_response = IPCMessage::Create (
763+ reinterpret_cast <IPCMessageShm*>(response_batch.value ().data_ .get ()),
764+ response_batch.value ().handle_ );
765+ execute_response->Args () =
766+ response_batch.value ().handle_ + sizeof (IPCMessageShm);
767+ execute_response->InlineResponse () = false ;
768+ execute_response->Command () = PYTHONSTUB_ExecuteResponse;
769+ _.Complete ();
770+ execute_finalize.Complete ();
771+ }
772+
773+ void
774+ Stub::ProcessResponse (InferResponse* response)
775+ {
776+ response->SaveToSharedMemory (shm_pool_, false /* copy_gpu */ );
777+
778+ for (auto & output_tensor : response->OutputTensors ()) {
779+ if (!output_tensor->IsCPU ()) {
780+ gpu_tensors_.push_back (output_tensor);
781+ }
734782 }
735783}
736784
737785void
738786Stub::ProcessReturnedResponses (
739- py::list py_requests, py::object py_responses_obj)
787+ py::list py_requests, py::object py_responses_obj,
788+ std::optional<AllocatedSharedMemory<char >>& response_batch)
740789{
741790 // Return if there is nothing to process.
742791 if (py::isinstance<py::none>(py_responses_obj)) {
@@ -784,12 +833,55 @@ Stub::ProcessReturnedResponses(
784833 " return list, found type '" +
785834 std::string (py::str (py_responses[i].get_type ())) + " '." );
786835 }
787- std::shared_ptr<InferResponse> response =
788- py_responses[i].cast <std::shared_ptr<InferResponse>>();
789- request->GetResponseSender ()->Send (
790- response, TRITONSERVER_RESPONSE_COMPLETE_FINAL);
836+
837+ InferResponse* response = py_responses[i].cast <InferResponse*>();
838+ try {
839+ request->GetResponseSender ()->UpdateStateAndCounters (
840+ response, TRITONSERVER_RESPONSE_COMPLETE_FINAL);
841+ }
842+ catch (const PythonBackendException& pb_exception) {
843+ // Handle the exception here to catch the error when there's a response
844+ // returned from `execute()`.
845+ if (request->GetResponseSender ()->IsClosed ()) {
846+ response_batch = std::move (shm_pool_->Construct <char >(
847+ sizeof (ResponseBatch) + sizeof (IPCMessageShm)));
848+ ResponseBatch* response_batch_shm_ptr =
849+ reinterpret_cast <ResponseBatch*>(
850+ response_batch.value ().data_ .get () + sizeof (IPCMessageShm));
851+ response_batch_shm_ptr->batch_size = 0 ;
852+ response_batch_shm_ptr->is_response_factory_deleted = true ;
853+ }
854+ throw pb_exception;
855+ }
856+ }
857+ }
858+ // Return all the created responses using response_batch. The reason
859+ // that both of the paths are available is that sending the responses
860+ // using response_batch is faster than using `response_sender`.
861+ response_batch = std::move (shm_pool_->Construct <char >(
862+ sizeof (IPCMessageShm) +
863+ requests_size * sizeof (bi::managed_external_buffer::handle_t ) +
864+ sizeof (ResponseBatch)));
865+ ResponseBatch* response_batch_shm_ptr = reinterpret_cast <ResponseBatch*>(
866+ response_batch.value ().data_ .get () + sizeof (IPCMessageShm));
867+
868+ bi::managed_external_buffer::handle_t * responses_shm_handle =
869+ reinterpret_cast <bi::managed_external_buffer::handle_t *>(
870+ response_batch.value ().data_ .get () + sizeof (ResponseBatch) +
871+ sizeof (IPCMessageShm));
872+ for (size_t i = 0 ; i < responses_size; i++) {
873+ // Check the return type of execute function.
874+ InferRequest* infer_request = py_requests[i].cast <InferRequest*>();
875+ InferResponse* infer_response = py_responses[i].cast <InferResponse*>();
876+ if (!py::isinstance<py::none>(py_responses[i])) {
877+ infer_response->PruneOutputTensors (infer_request->RequestedOutputNames ());
878+ ProcessResponse (infer_response);
879+ responses_shm_handle[i] = infer_response->ShmHandle ();
880+ } else {
881+ responses_shm_handle[i] = 0 ;
791882 }
792883 }
884+ response_batch_shm_ptr->batch_size = requests_size;
793885}
794886
795887py::object
0 commit comments