Skip to content

Commit 27cece9

Browse files
authored
fix: Improve cancellation handling for gRPC non-decoupled inference (#8220)
1 parent 550f64b commit 27cece9

File tree

3 files changed

+221
-58
lines changed

3 files changed

+221
-58
lines changed

qa/L0_decoupled/decoupled_test.py

Lines changed: 200 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22

3-
# Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
44
#
55
# Redistribution and use in source and binary forms, with or without
66
# modification, are permitted provided that the following conditions
@@ -32,6 +32,7 @@
3232

3333
import os
3434
import queue
35+
import threading
3536
import time
3637
import unittest
3738
from functools import partial
@@ -606,53 +607,212 @@ def test_wrong_shape(self):
606607
class NonDecoupledTest(tu.TestResultCollector):
607608
def setUp(self):
608609
self.model_name_ = "repeat_int32"
609-
self.input_data = {
610-
"IN": np.array([1], dtype=np.int32),
611-
"DELAY": np.array([0], dtype=np.uint32),
612-
"WAIT": np.array([0], dtype=np.uint32),
610+
self.data_matrix = [
611+
# ("IN", "DELAY", "WAIT")
612+
([1], [0], [0]),
613+
([1], [4000], [2000]),
614+
([1], [2000], [4000]),
615+
]
616+
617+
# For grpc async infer test
618+
self.callback_error = None
619+
self.callback_result = None
620+
self.callback_invoked_event = threading.Event()
621+
622+
def _input_data(self, in_value, delay_value, wait_value):
623+
return {
624+
"IN": np.array(in_value, dtype=np.int32),
625+
"DELAY": np.array(delay_value, dtype=np.uint32),
626+
"WAIT": np.array(wait_value, dtype=np.uint32),
613627
}
614628

629+
def _async_callback(self, result, error):
630+
"""Callback for async_infer."""
631+
self.callback_error = error
632+
self.callback_result = result
633+
self.callback_invoked_event.set()
634+
615635
def test_grpc(self):
616-
inputs = [
617-
grpcclient.InferInput("IN", [1], "INT32").set_data_from_numpy(
618-
self.input_data["IN"]
619-
),
620-
grpcclient.InferInput("DELAY", [1], "UINT32").set_data_from_numpy(
621-
self.input_data["DELAY"]
622-
),
623-
grpcclient.InferInput("WAIT", [1], "UINT32").set_data_from_numpy(
624-
self.input_data["WAIT"]
625-
),
626-
]
636+
for in_value, delay_value, wait_value in self.data_matrix:
637+
with self.subTest(IN=in_value, DELAY=delay_value, WAIT=wait_value):
638+
input_data = self._input_data(in_value, delay_value, wait_value)
639+
inputs = [
640+
grpcclient.InferInput("IN", [1], "INT32").set_data_from_numpy(
641+
input_data["IN"]
642+
),
643+
grpcclient.InferInput("DELAY", [1], "UINT32").set_data_from_numpy(
644+
input_data["DELAY"]
645+
),
646+
grpcclient.InferInput("WAIT", [1], "UINT32").set_data_from_numpy(
647+
input_data["WAIT"]
648+
),
649+
]
650+
651+
triton_client = grpcclient.InferenceServerClient(
652+
url="localhost:8001", verbose=True
653+
)
627654

628-
triton_client = grpcclient.InferenceServerClient(
629-
url="localhost:8001", verbose=True
630-
)
631-
# Expect the inference is successful
632-
res = triton_client.infer(model_name=self.model_name_, inputs=inputs)
633-
self.assertEqual(1, res.as_numpy("OUT")[0])
634-
self.assertEqual(0, res.as_numpy("IDX")[0])
655+
# Expect the inference is successful
656+
res = triton_client.infer(model_name=self.model_name_, inputs=inputs)
657+
self.assertEqual(1, res.as_numpy("OUT")[0])
658+
self.assertEqual(0, res.as_numpy("IDX")[0])
635659

636660
def test_http(self):
637-
inputs = [
638-
httpclient.InferInput("IN", [1], "INT32").set_data_from_numpy(
639-
self.input_data["IN"]
640-
),
641-
httpclient.InferInput("DELAY", [1], "UINT32").set_data_from_numpy(
642-
self.input_data["DELAY"]
643-
),
644-
httpclient.InferInput("WAIT", [1], "UINT32").set_data_from_numpy(
645-
self.input_data["WAIT"]
646-
),
661+
for in_value, delay_value, wait_value in self.data_matrix:
662+
with self.subTest(IN=in_value, DELAY=delay_value, WAIT=wait_value):
663+
input_data = self._input_data(in_value, delay_value, wait_value)
664+
inputs = [
665+
httpclient.InferInput("IN", [1], "INT32").set_data_from_numpy(
666+
input_data["IN"]
667+
),
668+
httpclient.InferInput("DELAY", [1], "UINT32").set_data_from_numpy(
669+
input_data["DELAY"]
670+
),
671+
httpclient.InferInput("WAIT", [1], "UINT32").set_data_from_numpy(
672+
input_data["WAIT"]
673+
),
674+
]
675+
676+
triton_client = httpclient.InferenceServerClient(
677+
url="localhost:8000", verbose=True
678+
)
679+
680+
# Expect the inference is successful
681+
res = triton_client.infer(model_name=self.model_name_, inputs=inputs)
682+
self.assertEqual(1, res.as_numpy("OUT")[0])
683+
self.assertEqual(0, res.as_numpy("IDX")[0])
684+
685+
def test_grpc_async(self):
686+
for in_value, delay_value, wait_value in self.data_matrix:
687+
with self.subTest(IN=in_value, DELAY=delay_value, WAIT=wait_value):
688+
input_data = self._input_data(in_value, delay_value, wait_value)
689+
inputs = [
690+
grpcclient.InferInput("IN", [1], "INT32").set_data_from_numpy(
691+
input_data["IN"]
692+
),
693+
grpcclient.InferInput("DELAY", [1], "UINT32").set_data_from_numpy(
694+
input_data["DELAY"]
695+
),
696+
grpcclient.InferInput("WAIT", [1], "UINT32").set_data_from_numpy(
697+
input_data["WAIT"]
698+
),
699+
]
700+
701+
triton_client = grpcclient.InferenceServerClient(
702+
url="localhost:8001",
703+
verbose=True,
704+
)
705+
706+
# Clear previous results
707+
self.callback_error = None
708+
self.callback_result = None
709+
self.callback_invoked_event.clear()
710+
711+
try:
712+
triton_client.async_infer(
713+
model_name=self.model_name_,
714+
inputs=inputs,
715+
callback=self._async_callback,
716+
)
717+
except Exception as e:
718+
self.fail(f"Failed to initiate async_infer: {e}")
719+
continue
720+
721+
# Wait for the callback to be invoked, with a timeout
722+
self.assertTrue(
723+
self.callback_invoked_event.wait(timeout=10),
724+
"Callback not invoked within timeout.",
725+
)
726+
727+
# Expect the inference is successful
728+
self.assertIsNone(
729+
self.callback_error, f"Inference failed: {self.callback_error}"
730+
)
731+
self.assertIsNotNone(self.callback_result, "Inference result is None.")
732+
self.assertEqual(1, self.callback_result.as_numpy("OUT")[0])
733+
self.assertEqual(0, self.callback_result.as_numpy("IDX")[0])
734+
735+
# Wait and check server/model health
736+
time.sleep(5)
737+
self.assertTrue(triton_client.is_model_ready(self.model_name_))
738+
739+
def test_grpc_async_cancel(self):
740+
data_matrix = [
741+
# ("IN", "DELAY", "WAIT")
742+
([1], [4000], [2000]),
743+
([1], [2000], [4000]),
647744
]
648745

649-
triton_client = httpclient.InferenceServerClient(
650-
url="localhost:8000", verbose=True
651-
)
652-
# Expect the inference is successful
653-
res = triton_client.infer(model_name=self.model_name_, inputs=inputs)
654-
self.assertEqual(1, res.as_numpy("OUT")[0])
655-
self.assertEqual(0, res.as_numpy("IDX")[0])
746+
for in_value, delay_value, wait_value in data_matrix:
747+
with self.subTest(IN=in_value, DELAY=delay_value, WAIT=wait_value):
748+
input_data = self._input_data(in_value, delay_value, wait_value)
749+
inputs = [
750+
grpcclient.InferInput("IN", [1], "INT32").set_data_from_numpy(
751+
input_data["IN"]
752+
),
753+
grpcclient.InferInput("DELAY", [1], "UINT32").set_data_from_numpy(
754+
input_data["DELAY"]
755+
),
756+
grpcclient.InferInput("WAIT", [1], "UINT32").set_data_from_numpy(
757+
input_data["WAIT"]
758+
),
759+
]
760+
761+
triton_client = grpcclient.InferenceServerClient(
762+
url="localhost:8001",
763+
verbose=True,
764+
)
765+
766+
# Clear previous results
767+
self.callback_error = None
768+
self.callback_result = None
769+
self.callback_invoked_event.clear()
770+
771+
request_handle = None
772+
try:
773+
request_handle = triton_client.async_infer(
774+
model_name=self.model_name_,
775+
inputs=inputs,
776+
callback=self._async_callback,
777+
)
778+
except Exception as e:
779+
self.fail(f"Failed to initiate async_infer: {e}")
780+
continue
781+
782+
# Allow request to be fully initiated
783+
time.sleep(0.5)
784+
785+
# Attempt to cancel the request
786+
if request_handle:
787+
try:
788+
request_handle.cancel()
789+
except Exception as e:
790+
self.fail(f"Error calling request_handle.cancel(): {e}")
791+
continue
792+
else:
793+
self.fail("Invalid request_handle, cannot cancel.")
794+
continue
795+
796+
# Wait for the callback to be invoked
797+
self.assertTrue(
798+
self.callback_invoked_event.wait(timeout=10),
799+
"Callback not invoked within timeout after cancellation.",
800+
)
801+
802+
# Expect the inference is failed
803+
self.assertIsInstance(
804+
self.callback_error,
805+
InferenceServerException,
806+
f"Unexpected error type: {type(self.callback_error)}",
807+
)
808+
self.assertIn(
809+
"StatusCode.CANCELLED",
810+
self.callback_error.status(),
811+
)
812+
813+
# Wait and check server/model health
814+
time.sleep(5)
815+
self.assertTrue(triton_client.is_model_ready(self.model_name_))
656816

657817

658818
if __name__ == "__main__":

qa/L0_decoupled/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ if [ $? -ne 0 ]; then
196196
echo -e "\n***\n*** Test NonDecoupledTest Failed\n***"
197197
RET=1
198198
else
199-
check_test_results $TEST_RESULT_FILE 2
199+
check_test_results $TEST_RESULT_FILE 4
200200
if [ $? -ne 0 ]; then
201201
cat $CLIENT_LOG
202202
echo -e "\n***\n*** Test Result Verification Failed\n***"

src/grpc/infer_handler.cc

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,12 +1068,14 @@ ModelInferHandler::InferResponseComplete(
10681068
state->cb_count_++;
10691069
}
10701070

1071+
bool is_final_response = (flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) != 0;
1072+
10711073
LOG_VERBOSE(1) << "ModelInferHandler::InferResponseComplete, "
10721074
<< state->unique_id_ << " step " << state->step_;
10731075

10741076
// Allow sending 1 response and final flag separately, only mark
10751077
// non-inflight when seeing final flag
1076-
if (flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
1078+
if (is_final_response) {
10771079
state->context_->EraseInflightState(state);
10781080
}
10791081

@@ -1093,22 +1095,23 @@ ModelInferHandler::InferResponseComplete(
10931095
<< ", skipping response generation as grpc transaction was "
10941096
"cancelled... ";
10951097

1096-
if (state->delay_enqueue_ms_ != 0) {
1097-
// Will delay PutTaskBackToQueue by the specified time.
1098-
// This can be used to test the flow when cancellation request
1099-
// issued for the request during InferResponseComplete
1100-
// callback right before Process in the notification thread.
1101-
LOG_INFO << "Delaying PutTaskBackToQueue by " << state->delay_enqueue_ms_
1102-
<< " ms...";
1103-
std::this_thread::sleep_for(
1104-
std::chrono::milliseconds(state->delay_enqueue_ms_));
1105-
}
1106-
1107-
// Send state back to the queue so that state can be released
1108-
// in the next cycle.
1109-
state->context_->PutTaskBackToQueue(state);
1098+
if (is_final_response) {
1099+
if (state->delay_enqueue_ms_ != 0) {
1100+
// Will delay PutTaskBackToQueue by the specified time.
1101+
// This can be used to test the flow when cancellation request
1102+
// issued for the request during InferResponseComplete
1103+
// callback right before Process in the notification thread.
1104+
LOG_INFO << "Delaying PutTaskBackToQueue by "
1105+
<< state->delay_enqueue_ms_ << " ms...";
1106+
std::this_thread::sleep_for(
1107+
std::chrono::milliseconds(state->delay_enqueue_ms_));
1108+
}
11101109

1111-
delete response_release_payload;
1110+
// Send state back to the queue so that state can be released
1111+
// in the next cycle.
1112+
state->context_->PutTaskBackToQueue(state);
1113+
delete response_release_payload;
1114+
}
11121115
return;
11131116
}
11141117

@@ -1156,7 +1159,7 @@ ModelInferHandler::InferResponseComplete(
11561159

11571160
// Defer sending the response until FINAL flag is seen or
11581161
// there is error
1159-
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
1162+
if (!is_final_response) {
11601163
return;
11611164
}
11621165

0 commit comments

Comments
 (0)