Skip to content
50 changes: 46 additions & 4 deletions src/pb_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
ScopedDefer _(
[this, &execute_response] { SendIPCMessage(execute_response); });
py::object execute_return;
py::object coroutine_return;
try {
if (!py::hasattr(model_instance_, "execute")) {
std::string message = "Python model " + model_context_.PythonModelPath() +
Expand All @@ -685,7 +686,7 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
// Do not wait for async decoupled execute to return.
RunCoroutine(execute_return, true /* in_background */);
} else {
py::object coroutine_return =
coroutine_return =
RunCoroutine(execute_return, false /* in_background */);
ProcessReturnedResponses(
py_request_list, coroutine_return, response_batch);
Expand Down Expand Up @@ -718,6 +719,24 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(
response_batch.value().data_.get() + sizeof(IPCMessageShm));

// The backend will clean up the response factory if there is an error in
// the response batch. It is necessary to handle cases where the response
// sender should have already cleaned up, ensuring the backend does not
// delete the response factory again during error handling.
if (err_message.find("Response sender has been closed") !=
std::string::npos) {
response_batch_shm_ptr->is_response_factory_deleted = true;
} else if (
err_message.find("is using the decoupled mode and the execute function "
"must return None") != std::string::npos) {
for (py::handle py_request : py_request_list) {
InferRequest* request = py_request.cast<InferRequest*>();
if (request->GetResponseSender()->IsClosed()) {
response_batch_shm_ptr->is_response_factory_deleted = true;
}
}
}

response_batch_shm_ptr->has_error = true;
error_string_shm = PbString::Create(shm_pool_, err_message);
response_batch_shm_ptr->error = error_string_shm->ShmHandle();
Expand Down Expand Up @@ -821,8 +840,32 @@ Stub::ProcessReturnedResponses(
}

InferResponse* response = py_responses[i].cast<InferResponse*>();
request->GetResponseSender()->UpdateStateAndCounters(
response, TRITONSERVER_RESPONSE_COMPLETE_FINAL);
try {
request->GetResponseSender()->UpdateStateAndCounters(
response, TRITONSERVER_RESPONSE_COMPLETE_FINAL);
}
catch (const PythonBackendException& pb_exception) {
// Handle the exception here to catch the error when there's a response
// returned from `execute()`, and the below error message is thrown.
// In default (non-decoupled) mode, the response factory should already
// have been cleaned up when the previous response was sent by the
// response sender. However, if the model attempts to return another
// response from the `execute()` function, notify the backend NOT to
// delete the response factory again during error handling.
std::string err_message = pb_exception.what();
if (err_message.find(
"Non-decoupled model cannot send more than one response") !=
std::string::npos) {
response_batch = std::move(shm_pool_->Construct<char>(
sizeof(ResponseBatch) + sizeof(IPCMessageShm)));
ResponseBatch* response_batch_shm_ptr =
reinterpret_cast<ResponseBatch*>(
response_batch.value().data_.get() + sizeof(IPCMessageShm));
response_batch_shm_ptr->batch_size = 0;
response_batch_shm_ptr->is_response_factory_deleted = true;
}
throw pb_exception;
}
}
}
// Return all the created responses using response_batch. The reason
Expand All @@ -839,7 +882,6 @@ Stub::ProcessReturnedResponses(
reinterpret_cast<bi::managed_external_buffer::handle_t*>(
response_batch.value().data_.get() + sizeof(ResponseBatch) +
sizeof(IPCMessageShm));

for (size_t i = 0; i < responses_size; i++) {
// Check the return type of execute function.
InferRequest* infer_request = py_requests[i].cast<InferRequest*>();
Expand Down
3 changes: 3 additions & 0 deletions src/pb_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ struct ResponseBatch : SendMessageBase {
bool is_error_set;

uint32_t response_size;

// Indicates whether the response factory has been deleted or not.
bool is_response_factory_deleted = false;
};

enum LogLevel { kInfo = 0, kWarning, kError, kVerbose };
Expand Down
132 changes: 103 additions & 29 deletions src/python_be.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,23 @@ ModelInstanceState::SetErrorForResponseSendMessage(
}
}

bool
ModelInstanceState::IsStubProcessAlive()
{
boost::posix_time::ptime timeout =
boost::get_system_time() + boost::posix_time::seconds(1);
bi::scoped_lock<bi::interprocess_mutex> lock(*Stub()->HealthMutex(), timeout);

// Check if lock has been acquired.
if (lock) {
return Stub()->IpcControl()->stub_health;
} else {
// If It failed to obtain the lock, it means that the stub has been
// stuck or exited while holding the health mutex lock.
return false;
}
}

TRITONSERVER_Error*
ModelInstanceState::SaveRequestsToSharedMemory(
TRITONBACKEND_Request** requests, const uint32_t request_count,
Expand Down Expand Up @@ -1009,11 +1026,43 @@ ModelInstanceState::ProcessModelControlRequest(
});
}

void
TRITONSERVER_Error*
ModelInstanceState::SendMessageToStub(
bi::managed_external_buffer::handle_t message)
{
Stub()->StubMessageQueue()->Push(message);
// Stub()->StubMessageQueue()->Push(message);
bool success = false;
while (!success) {
uint64_t timeout_miliseconds = 1000;
{
boost::posix_time::ptime timeout =
boost::get_system_time() +
boost::posix_time::milliseconds(timeout_miliseconds);

bi::scoped_lock<bi::interprocess_mutex> lock(
*(Stub()->HealthMutex()), timeout);

// Check if lock has been acquired.
if (lock) {
Stub()->IpcControl()->stub_health = false;
} else {
// If it failed to obtain the lock, it means that the stub has been
// stuck or exited while holding the health mutex lock.
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL, "Failed to obtain the health mutex.");
}
}

Stub()->StubMessageQueue()->Push(
message, timeout_miliseconds /* duration ms */, success);

if (!success && !IsStubProcessAlive()) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL, "Stub process is not healthy.");
}
}

return nullptr; // success
}

void
Expand All @@ -1023,10 +1072,29 @@ ModelInstanceState::SendMessageAndReceiveResponse(
std::shared_ptr<std::vector<TRITONBACKEND_Response*>>& responses,
TRITONBACKEND_Request** requests, const uint32_t request_count)
{
SendMessageToStub(message);
// SendMessageToStub(message);

// bi::managed_external_buffer::handle_t response_message;
// Stub()->ReceiveMessageFromStub(response_message);

// response = response_message;

auto error = SendMessageToStub(message);
if (error != nullptr) {
RespondErrorToAllRequests(
TRITONSERVER_ErrorMessage(error), responses, requests, request_count);

return;
}

bi::managed_external_buffer::handle_t response_message;
Stub()->ReceiveMessageFromStub(response_message);
error = Stub()->ReceiveMessageFromStub(response_message);
if (error != nullptr) {
RespondErrorToAllRequests(
TRITONSERVER_ErrorMessage(error), responses, requests, request_count);

return;
}

response = response_message;
}
Expand Down Expand Up @@ -1059,6 +1127,7 @@ ModelInstanceState::RespondErrorToAllRequests(
}
}


void
ModelInstanceState::StartMonitor()
{
Expand Down Expand Up @@ -1278,13 +1347,12 @@ ModelInstanceState::ProcessRequests(
{
Stub()->StubMessageQueue()->Push(ipc_message->ShmHandle());
bi::managed_external_buffer::handle_t response_message;
Stub()->ReceiveMessageFromStub(response_message);
RETURN_IF_ERROR(Stub()->ReceiveMessageFromStub(response_message));
response =
IPCMessage::LoadFromSharedMemory(Stub()->ShmPool(), response_message);
}
char* ipc_message_shm =
reinterpret_cast<char*>(response->GetAllocatedSharedMemory().data_.get());
;
ResponseBatch* response_batch_shm_ptr =
reinterpret_cast<ResponseBatch*>(ipc_message_shm + sizeof(IPCMessageShm));

Expand All @@ -1294,7 +1362,10 @@ ModelInstanceState::ProcessRequests(
reporter.SetBatchStatistics(total_batch_size);

if (response_batch_shm_ptr->has_error) {
if (response_batch_shm_ptr->is_error_set) {
// Clean up the response factory if an error occurred. The
// `is_response_factory_deleted` flag indicates whether the response factory
// has been deleted for some corner cases.
if (!response_batch_shm_ptr->is_response_factory_deleted) {
for (uint32_t r = 0; r < request_count; r++) {
TRITONBACKEND_ResponseFactory* response_factory =
reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
Expand All @@ -1304,6 +1375,8 @@ ModelInstanceState::ProcessRequests(
lresponse_factory(reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
response_factory));
}
}
if (response_batch_shm_ptr->is_error_set) {
auto error = PbString::LoadFromSharedMemory(
Stub()->ShmPool(), response_batch_shm_ptr->error);
return TRITONSERVER_ErrorNew(
Expand All @@ -1315,26 +1388,33 @@ ModelInstanceState::ProcessRequests(
}

if (response_batch_shm_ptr->batch_size > 0) {
bi::managed_external_buffer::handle_t* response_shm_handle =
reinterpret_cast<bi::managed_external_buffer::handle_t*>(
ipc_message_shm + sizeof(ResponseBatch) + sizeof(IPCMessageShm));

std::shared_ptr<std::vector<TRITONBACKEND_Response*>> responses(
new std::vector<TRITONBACKEND_Response*>());
responses->reserve(request_count);
for (size_t i = 0; i < request_count; i++) {
TRITONBACKEND_Response* response;
auto err = TRITONBACKEND_ResponseNew(&response, requests[i]);
if (err == nullptr) {
responses->emplace_back(response);
} else {
// It is possible to have multiple responses batched together in a single
// response batch shm, where some of the responses are None due to the
// usage of response sender, so only create a TRITONBACKEND_Response
// object for the valid responses.
if (response_shm_handle[i] == 0) {
responses->emplace_back(nullptr);
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Fail to create response");
TRITONSERVER_ErrorDelete(err);
} else {
TRITONBACKEND_Response* response;
auto err = TRITONBACKEND_ResponseNew(&response, requests[i]);
if (err == nullptr) {
responses->emplace_back(response);
} else {
responses->emplace_back(nullptr);
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Fail to create response");
TRITONSERVER_ErrorDelete(err);
}
}
}
bi::managed_external_buffer::handle_t* response_shm_handle =
reinterpret_cast<bi::managed_external_buffer::handle_t*>(
ipc_message_shm + sizeof(ResponseBatch) + sizeof(IPCMessageShm));

// If the output provided by the model is in GPU, we will pass the list of
// buffers provided by Triton to the stub process.
std::vector<bool> requires_deferred_callback;

bool has_gpu_output = false;
Expand All @@ -1345,10 +1425,13 @@ ModelInstanceState::ProcessRequests(

for (uint32_t r = 0; r < request_count; ++r) {
NVTX_RANGE(nvtx_, "LoadingResponse " + Name());
requires_deferred_callback.push_back(false);
if (response_shm_handle[r] == 0) {
continue;
}
TRITONBACKEND_Response* response = (*responses)[r];
TRITONBACKEND_Request* request = requests[r];
uint32_t requested_output_count = 0;
requires_deferred_callback.push_back(false);

shm_responses.emplace_back(nullptr);
std::unique_ptr<InferResponse>& infer_response = shm_responses.back();
Expand All @@ -1362,14 +1445,6 @@ ModelInstanceState::ProcessRequests(
(*responses)[r] = nullptr;
continue;
}

if (response_shm_handle[r] == 0) {
LOG_IF_ERROR(
TRITONBACKEND_ResponseDelete((*responses)[r]),
"failed to delete response");
(*responses)[r] = nullptr;
continue;
}
{
TRITONBACKEND_ResponseFactory* response_factory =
reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
Expand Down Expand Up @@ -1422,7 +1497,6 @@ ModelInstanceState::ProcessRequests(
GUARDED_RESPOND_IF_ERROR(
responses, r,
TRITONBACKEND_RequestOutputCount(request, &requested_output_count));

std::set<std::string> requested_output_names;
for (size_t j = 0; j < requested_output_count; ++j) {
const char* output_name;
Expand Down
7 changes: 6 additions & 1 deletion src/python_be.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,12 @@ class ModelInstanceState : public BackendModelInstance {
std::shared_ptr<std::vector<TRITONBACKEND_Response*>>& responses,
TRITONBACKEND_Request** requests, const uint32_t request_count);

void SendMessageToStub(bi::managed_external_buffer::handle_t message);
// void SendMessageToStub(bi::managed_external_buffer::handle_t message);
TRITONSERVER_Error* SendMessageToStub(
bi::managed_external_buffer::handle_t message);

// Checks whether the stub process is live
bool IsStubProcessAlive();

// Model instance stub
std::unique_ptr<StubLauncher>& Stub() { return model_instance_stub_; }
Expand Down
8 changes: 8 additions & 0 deletions src/response_sender.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,19 @@ ResponseSender::IsCancelled()
return pb_cancel_->IsCancelled();
}

bool
ResponseSender::IsClosed()
{
std::lock_guard<std::mutex> lk(mu_);
return closed_;
}

void
ResponseSender::Close()
{
std::lock_guard<std::mutex> lk(mu_);
closed_ = true;
response_factory_deleted_.exchange(true);
}

void
Expand Down
1 change: 1 addition & 0 deletions src/response_sender.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class ResponseSender {

// Can be useful at stopping the model from sending any more responses.
void Close();
bool IsClosed();

private:
void DeleteResponseFactory();
Expand Down
Loading
Loading