Skip to content

Commit 0ef6067

Browse files
authored
auto-reload model_spec when Exception (#16)
1 parent 164ac37 commit 0ef6067

File tree

3 files changed

+54
-9
lines changed

3 files changed

+54
-9
lines changed

tests/test_model_call.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@ def test_with_optional(protocol_and_port):
7171
assert np.isclose(result[0], sample[0] - OPTIONAL_SUB_VALUE, rtol=EPSILON).all()
7272

7373

74+
def test_reload_model_spec(protocol_and_port):
75+
client = get_client(*protocol_and_port, model_name="sample_autobatching")
76+
# force to change max_batch_size
77+
client.default_model_spec.max_batch_size = 4
78+
79+
sample = np.random.rand(8, 100).astype(np.float32)
80+
result = client(sample)
81+
assert np.isclose(result, sample).all()
82+
83+
7484
if __name__ == "__main__":
7585
test_with_parameters(("grpc", "8101"))
7686
test_with_optional(("grpc", "8101"))

tritony/tools.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737

3838
logger = logging.getLogger(__name__)
3939

40-
TRITON_LOAD_DELAY = float(os.environ.get("TRITON_LOAD_DELAY", 3))
41-
TRITON_BACKOFF_COEFF = float(os.environ.get("TRITON_BACKOFF_COEFF", 0.2))
40+
TRITON_LOAD_DELAY = float(os.environ.get("TRITON_LOAD_DELAY", 2))
41+
TRITON_BACKOFF_COEFF = float(os.environ.get("TRITON_BACKOFF_COEFF", 2))
4242
TRITON_RETRIES = int(os.environ.get("TRITON_RETRIES", 5))
4343
TRITON_CLIENT_TIMEOUT = int(os.environ.get("TRITON_CLIENT_TIMEOUT", 30))
4444

@@ -93,21 +93,38 @@ async def send_request_async(
9393
compression=inference_client.flag.compression_algorithm,
9494
)
9595
except InferenceServerException as triton_error:
96-
handel_triton_error(triton_error)
96+
handle_triton_error(triton_error)
97+
except grpc.RpcError as grpc_error:
98+
handle_grpc_error(grpc_error)
9799
ret.append((idx, a_pred))
98100
data_queue.task_done()
99101

100102

101-
def handel_triton_error(triton_error: InferenceServerException):
103+
def handle_triton_error(triton_error: InferenceServerException):
102104
"""
103105
https://github.com/triton-inference-server/core/blob/0141e0651c4355bf8a9d1118aac45abda6569997/src/scheduler_utils.cc#L133
104106
("Max_queue_size exceeded", "Server not ready", "failed to connect to all addresses")
105107
"""
108+
if triton_error.status() == "400" and "batch-size must be <=" in triton_error.message():
109+
raise triton_error
106110
runtime_msg = f"{triton_error.status()} with {triton_error.message()}"
107111
raise RuntimeError(runtime_msg) from triton_error
108112

109113

110-
@retry((InferenceServerException, grpc.RpcError), tries=TRITON_RETRIES, delay=TRITON_BACKOFF_COEFF, backoff=2)
114+
def handle_grpc_error(grpc_error: grpc.RpcError):
115+
if grpc_error.code() == grpc.StatusCode.INVALID_ARGUMENT:
116+
raise grpc_error
117+
else:
118+
runtime_msg = f"{grpc_error.code()} with {grpc_error.details()}"
119+
raise RuntimeError(grpc_error.details()) from grpc_error
120+
121+
122+
@retry(
123+
(InferenceServerException, grpc.RpcError),
124+
tries=TRITON_RETRIES,
125+
delay=TRITON_LOAD_DELAY,
126+
backoff=TRITON_BACKOFF_COEFF,
127+
)
111128
async def request_async(protocol: TritonProtocol, model_input: Dict, triton_client, timeout: int, compression: str):
112129
st = time.time()
113130

@@ -300,10 +317,28 @@ def _call_async(
300317
model_spec: TritonModelSpec,
301318
parameters: dict | None = None,
302319
) -> Optional[np.ndarray]:
303-
async_result = asyncio.run(self._call_async_item(data=data, model_spec=model_spec, parameters=parameters))
320+
for retry_idx in range(max(2, TRITON_RETRIES)):
321+
async_result = asyncio.run(self._call_async_item(data=data, model_spec=model_spec, parameters=parameters))
322+
323+
is_invalid_argument_grpc = (
324+
self.flag.protocol is TritonProtocol.grpc
325+
and isinstance(async_result, grpc.RpcError)
326+
and async_result.code() == grpc.StatusCode.INVALID_ARGUMENT
327+
)
328+
is_invalid_argument_http = (
329+
self.flag.protocol is TritonProtocol.http
330+
and isinstance(async_result, InferenceServerException)
331+
and async_result.status() == "400"
332+
)
304333

305-
if isinstance(async_result, Exception):
306-
raise async_result
334+
if is_invalid_argument_grpc or is_invalid_argument_http:
335+
time.sleep(TRITON_LOAD_DELAY * TRITON_BACKOFF_COEFF**retry_idx)
336+
self._renew_triton_client(self._triton_client)
337+
model_spec = self.model_specs[(model_spec.name, model_spec.model_version)]
338+
continue
339+
elif isinstance(async_result, Exception):
340+
raise async_result
341+
break
307342

308343
return async_result
309344

tritony/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.14"
1+
__version__ = "0.0.15"

0 commit comments

Comments
 (0)