@@ -69,12 +69,15 @@ InferRequestComplete(
6969 TRITONSERVER_InferenceRequest* request, const uint32_t flags, void * userp)
7070{
7171 if (request != nullptr ) {
72- auto request_executor = reinterpret_cast <RequestExecutor*>(userp);
73- request_executor->EraseRequestAddress (reinterpret_cast <intptr_t >(request));
72+ RequestCompletionUserp* completion_userp =
73+ reinterpret_cast <RequestCompletionUserp*>(userp);
74+ completion_userp->infer_payload ->SetRequestDeleted ();
7475
7576 LOG_IF_ERROR (
7677 TRITONSERVER_InferenceRequestDelete (request),
7778 " Failed to delete inference request." );
79+
80+ delete completion_userp;
7881 }
7982}
8083
@@ -322,6 +325,18 @@ ResponseAlloc(
322325 return nullptr ; // Success
323326}
324327
328+ void
329+ InferRequestCancel (intptr_t request_address)
330+ {
331+ if (request_address == 0L ) {
332+ return ;
333+ }
334+
335+ TRITONSERVER_InferenceRequest* irequest =
336+ reinterpret_cast <TRITONSERVER_InferenceRequest*>(request_address);
337+ THROW_IF_TRITON_ERROR (TRITONSERVER_InferenceRequestCancel (irequest));
338+ }
339+
325340TRITONSERVER_Error*
326341OutputBufferQuery (
327342 TRITONSERVER_ResponseAllocator* allocator, void * userp,
@@ -364,6 +379,7 @@ RequestExecutor::Infer(
364379 bool is_ready = false ;
365380 const char * model_name = infer_request->ModelName ().c_str ();
366381 TRITONSERVER_InferenceRequest* irequest = nullptr ;
382+ RequestCompletionUserp* completion_userp = nullptr ;
367383
368384 try {
369385 int64_t model_version = infer_request->ModelVersion ();
@@ -415,8 +431,10 @@ RequestExecutor::Infer(
415431 THROW_IF_TRITON_ERROR (TRITONSERVER_InferenceRequestSetTimeoutMicroseconds (
416432 irequest, infer_request->Timeout ()));
417433
434+ completion_userp = new RequestCompletionUserp (infer_payload);
418435 THROW_IF_TRITON_ERROR (TRITONSERVER_InferenceRequestSetReleaseCallback (
419- irequest, InferRequestComplete, reinterpret_cast <void *>(this )));
436+ irequest, InferRequestComplete,
437+ reinterpret_cast <void *>(completion_userp)));
420438
421439 TRITONSERVER_InferenceTrace* trace = nullptr ;
422440 if (infer_request->GetTrace ().TritonTrace () != nullptr ) {
@@ -485,22 +503,20 @@ RequestExecutor::Infer(
485503 reinterpret_cast <void *>(infer_payload->ResponseAllocUserp ().get ()),
486504 InferResponseComplete, reinterpret_cast <void *>(infer_payload.get ())));
487505
488- {
489- std::lock_guard<std::mutex> lk (on_going_request_addresses_mu_);
490- on_going_request_addresses_.insert (
491- reinterpret_cast <intptr_t >(irequest));
492- }
493506 // Store the inference request address submitted to the Triton server for
494507 // retrieval
495508 infer_payload->SetRequestAddress (reinterpret_cast <intptr_t >(irequest));
509+ infer_payload->SetRequestCancellationFunc (InferRequestCancel);
496510
497511 THROW_IF_TRITON_ERROR (
498512 TRITONSERVER_ServerInferAsync (server_, irequest, trace));
499513 }
500514 }
501515 catch (const PythonBackendException& pb_exception) {
502- EraseRequestAddress (reinterpret_cast <intptr_t >(irequest));
503516 infer_payload->SetRequestAddress (0L );
517+ if (completion_userp != nullptr ) {
518+ delete completion_userp;
519+ }
504520
505521 LOG_IF_ERROR (
506522 TRITONSERVER_InferenceRequestDelete (irequest),
@@ -514,34 +530,6 @@ RequestExecutor::Infer(
514530 return response_future;
515531}
516532
517- void
518- RequestExecutor::Cancel (std::shared_ptr<InferPayload>& infer_payload)
519- {
520- intptr_t request_address = infer_payload->GetRequestAddress ();
521- if (request_address == 0L ) {
522- return ;
523- }
524-
525- {
526- std::lock_guard<std::mutex> lk (on_going_request_addresses_mu_);
527- if (on_going_request_addresses_.find (request_address) !=
528- on_going_request_addresses_.end ()) {
529- TRITONSERVER_InferenceRequest* irequest =
530- reinterpret_cast <TRITONSERVER_InferenceRequest*>(request_address);
531- THROW_IF_TRITON_ERROR (TRITONSERVER_InferenceRequestCancel (irequest));
532- }
533- }
534- }
535-
536- void
537- RequestExecutor::EraseRequestAddress (intptr_t request_address)
538- {
539- if (request_address != 0L ) {
540- std::unique_lock<std::mutex> lk (on_going_request_addresses_mu_);
541- on_going_request_addresses_.erase (request_address);
542- }
543- }
544-
545533RequestExecutor::~RequestExecutor ()
546534{
547535 if (response_allocator_ != nullptr ) {
0 commit comments