@@ -654,7 +654,7 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
654654 py::list py_request_list =
655655 LoadRequestsFromSharedMemory (request_batch_shm_ptr);
656656 std::unique_ptr<IPCMessage> execute_response;
657- // IPCMessage::Create(shm_pool_, false /* Inline response */);
657+ // IPCMessage::Create(shm_pool_, false /* Inline response */);
658658
659659 std::optional<AllocatedSharedMemory<char >> response_batch;
660660 bool has_exception = false ;
@@ -675,8 +675,7 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
675675 {
676676 NVTX_RANGE (nvtx_, " PyExecute " + name_);
677677
678- execute_return =
679- model_instance_.attr (" execute" )(py_request_list);
678+ execute_return = model_instance_.attr (" execute" )(py_request_list);
680679
681680 bool is_coroutine = py::module::import (" asyncio" )
682681 .attr (" iscoroutine" )(execute_return)
@@ -688,10 +687,12 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
688687 } else {
689688 py::object coroutine_return =
690689 RunCoroutine (execute_return, false /* in_background */ );
691- ProcessReturnedResponses (py_request_list, coroutine_return, response_batch);
690+ ProcessReturnedResponses (
691+ py_request_list, coroutine_return, response_batch);
692692 }
693693 } else {
694- ProcessReturnedResponses (py_request_list, execute_return, response_batch);
694+ ProcessReturnedResponses (
695+ py_request_list, execute_return, response_batch);
695696 }
696697 }
697698 }
@@ -712,11 +713,14 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
712713 error_string;
713714 LOG_ERROR << err_message.c_str ();
714715 if (!response_batch) {
715- response_batch = shm_pool_->Construct <char >(sizeof (ResponseBatch) + sizeof (IPCMessageShm));
716- }
717- ResponseBatch* response_batch_shm_ptr = reinterpret_cast <ResponseBatch*>(response_batch.value ().data_ .get () + sizeof (IPCMessageShm));
716+ response_batch = shm_pool_->Construct <char >(
717+ sizeof (ResponseBatch) + sizeof (IPCMessageShm));
718+ }
719+ ResponseBatch* response_batch_shm_ptr = reinterpret_cast <ResponseBatch*>(
720+ response_batch.value ().data_ .get () + sizeof (IPCMessageShm));
718721
719- response_batch_shm_ptr = reinterpret_cast <ResponseBatch*>(response_batch.value ().data_ .get ());
722+ response_batch_shm_ptr =
723+ reinterpret_cast <ResponseBatch*>(response_batch.value ().data_ .get ());
720724 response_batch_shm_ptr->has_error = true ;
721725 error_string_shm = PbString::Create (shm_pool_, err_message);
722726 response_batch_shm_ptr->error = error_string_shm->ShmHandle ();
@@ -732,14 +736,19 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
732736 }
733737
734738 if (!response_batch) {
735- response_batch = shm_pool_->Construct <char >(sizeof (ResponseBatch) + sizeof (IPCMessageShm));
736- ResponseBatch* response_batch_shm_ptr =reinterpret_cast <ResponseBatch*>(response_batch.value ().data_ .get () + sizeof (IPCMessageShm));
737- response_batch_shm_ptr->batch_size = 0 ;
738- }
739- ResponseBatch* response_batch_shm_ptr = reinterpret_cast <ResponseBatch*>(response_batch.value ().data_ .get () + sizeof (IPCMessageShm));
739+ response_batch = shm_pool_->Construct <char >(
740+ sizeof (ResponseBatch) + sizeof (IPCMessageShm));
741+ ResponseBatch* response_batch_shm_ptr = reinterpret_cast <ResponseBatch*>(
742+ response_batch.value ().data_ .get () + sizeof (IPCMessageShm));
743+ response_batch_shm_ptr->batch_size = 0 ;
744+ }
745+ ResponseBatch* response_batch_shm_ptr = reinterpret_cast <ResponseBatch*>(
746+ response_batch.value ().data_ .get () + sizeof (IPCMessageShm));
740747 response_batch_shm_ptr->has_error = false ;
741748 response_batch_shm_ptr->is_error_set = false ;
742- execute_response = IPCMessage::Create (reinterpret_cast <IPCMessageShm*>(response_batch.value ().data_ .get ()), response_batch.value ().handle_ );
749+ execute_response = IPCMessage::Create (
750+ reinterpret_cast <IPCMessageShm*>(response_batch.value ().data_ .get ()),
751+ response_batch.value ().handle_ );
743752 execute_response->Args () = response_batch.value ().handle_ ;
744753 execute_response->InlineResponse () = false ;
745754 execute_response->Command () = PYTHONSTUB_ExecuteResponse;
@@ -761,7 +770,8 @@ Stub::ProcessResponse(InferResponse* response)
761770
762771void
763772Stub::ProcessReturnedResponses (
764- py::list py_requests, py::object py_responses_obj, std::optional<AllocatedSharedMemory<char >>& response_batch)
773+ py::list py_requests, py::object py_responses_obj,
774+ std::optional<AllocatedSharedMemory<char >>& response_batch)
765775{
766776 // Return if there is nothing to process.
767777 if (py::isinstance<py::none>(py_responses_obj)) {
@@ -812,29 +822,34 @@ Stub::ProcessReturnedResponses(
812822
813823 std::shared_ptr<InferResponse> response =
814824 py_responses[i].cast <std::shared_ptr<InferResponse>>();
815- request->GetResponseSender ()->UpdateStateAndCounters (response, TRITONSERVER_RESPONSE_COMPLETE_FINAL);
825+ request->GetResponseSender ()->UpdateStateAndCounters (
826+ response, TRITONSERVER_RESPONSE_COMPLETE_FINAL);
816827 }
817828 }
818- response_batch = std::move (shm_pool_->Construct <char >(sizeof (IPCMessageShm) +
829+ // Return all the created responses using response_batch. The reason
830+ // that both of the paths are available is that sending the responses
831+ // using response_batch is faster than using `response_sender`.
832+ response_batch = std::move (shm_pool_->Construct <char >(
833+ sizeof (IPCMessageShm) +
819834 requests_size * sizeof (bi::managed_external_buffer::handle_t ) +
820835 sizeof (ResponseBatch)));
821- ResponseBatch* response_batch_shm_ptr =
822- reinterpret_cast <ResponseBatch*>( response_batch.value ().data_ .get () + sizeof (IPCMessageShm));
836+ ResponseBatch* response_batch_shm_ptr = reinterpret_cast <ResponseBatch*>(
837+ response_batch.value ().data_ .get () + sizeof (IPCMessageShm));
823838
824839 bi::managed_external_buffer::handle_t * responses_shm_handle =
825840 reinterpret_cast <bi::managed_external_buffer::handle_t *>(
826- response_batch.value ().data_ .get () + sizeof (ResponseBatch) + sizeof (IPCMessageShm));
827-
828- for ( size_t i = 0 ; i < responses_size; i++) {
829- // Check the return type of execute function.
830- InferRequest* infer_request = py_requests[i]. cast <InferRequest*>();
831- InferResponse* infer_response = py_responses [i].cast <InferResponse *>();
832- infer_response-> PruneOutputTensors (
833- infer_request->RequestedOutputNames ());
834- ProcessResponse (infer_response);
835- responses_shm_handle[i] = infer_response->ShmHandle ();
836- }
837- response_batch_shm_ptr->batch_size = requests_size;
841+ response_batch.value ().data_ .get () + sizeof (ResponseBatch) +
842+ sizeof (IPCMessageShm));
843+
844+ for ( size_t i = 0 ; i < responses_size; i++) {
845+ // Check the return type of execute function.
846+ InferRequest* infer_request = py_requests [i].cast <InferRequest *>();
847+ InferResponse* infer_response = py_responses[i]. cast <InferResponse*>();
848+ infer_response-> PruneOutputTensors ( infer_request->RequestedOutputNames ());
849+ ProcessResponse (infer_response);
850+ responses_shm_handle[i] = infer_response->ShmHandle ();
851+ }
852+ response_batch_shm_ptr->batch_size = requests_size;
838853}
839854
840855py::object
0 commit comments