|
37 | 37 |
|
38 | 38 | logger = logging.getLogger(__name__)
|
39 | 39 |
|
40 |
| -TRITON_LOAD_DELAY = float(os.environ.get("TRITON_LOAD_DELAY", 3)) |
41 |
| -TRITON_BACKOFF_COEFF = float(os.environ.get("TRITON_BACKOFF_COEFF", 0.2)) |
| 40 | +TRITON_LOAD_DELAY = float(os.environ.get("TRITON_LOAD_DELAY", 2)) |
| 41 | +TRITON_BACKOFF_COEFF = float(os.environ.get("TRITON_BACKOFF_COEFF", 2)) |
42 | 42 | TRITON_RETRIES = int(os.environ.get("TRITON_RETRIES", 5))
|
43 | 43 | TRITON_CLIENT_TIMEOUT = int(os.environ.get("TRITON_CLIENT_TIMEOUT", 30))
|
44 | 44 |
|
@@ -93,21 +93,38 @@ async def send_request_async(
|
93 | 93 | compression=inference_client.flag.compression_algorithm,
|
94 | 94 | )
|
95 | 95 | except InferenceServerException as triton_error:
|
96 |
| - handel_triton_error(triton_error) |
| 96 | + handle_triton_error(triton_error) |
| 97 | + except grpc.RpcError as grpc_error: |
| 98 | + handle_grpc_error(grpc_error) |
97 | 99 | ret.append((idx, a_pred))
|
98 | 100 | data_queue.task_done()
|
99 | 101 |
|
100 | 102 |
|
101 |
| -def handel_triton_error(triton_error: InferenceServerException): |
| 103 | +def handle_triton_error(triton_error: InferenceServerException): |
102 | 104 | """
|
103 | 105 | https://github.com/triton-inference-server/core/blob/0141e0651c4355bf8a9d1118aac45abda6569997/src/scheduler_utils.cc#L133
|
104 | 106 | ("Max_queue_size exceeded", "Server not ready", "failed to connect to all addresses")
|
105 | 107 | """
|
| 108 | + if triton_error.status() == "400" and "batch-size must be <=" in triton_error.message(): |
| 109 | + raise triton_error |
106 | 110 | runtime_msg = f"{triton_error.status()} with {triton_error.message()}"
|
107 | 111 | raise RuntimeError(runtime_msg) from triton_error
|
108 | 112 |
|
109 | 113 |
|
110 |
| -@retry((InferenceServerException, grpc.RpcError), tries=TRITON_RETRIES, delay=TRITON_BACKOFF_COEFF, backoff=2) |
| 114 | +def handle_grpc_error(grpc_error: grpc.RpcError): |
| 115 | + if grpc_error.code() == grpc.StatusCode.INVALID_ARGUMENT: |
| 116 | + raise grpc_error |
| 117 | + else: |
| 118 | + runtime_msg = f"{grpc_error.code()} with {grpc_error.details()}" |
| 119 | + raise RuntimeError(grpc_error.details()) from grpc_error |
| 120 | + |
| 121 | + |
| 122 | +@retry( |
| 123 | + (InferenceServerException, grpc.RpcError), |
| 124 | + tries=TRITON_RETRIES, |
| 125 | + delay=TRITON_LOAD_DELAY, |
| 126 | + backoff=TRITON_BACKOFF_COEFF, |
| 127 | +) |
111 | 128 | async def request_async(protocol: TritonProtocol, model_input: Dict, triton_client, timeout: int, compression: str):
|
112 | 129 | st = time.time()
|
113 | 130 |
|
@@ -300,10 +317,28 @@ def _call_async(
|
300 | 317 | model_spec: TritonModelSpec,
|
301 | 318 | parameters: dict | None = None,
|
302 | 319 | ) -> Optional[np.ndarray]:
|
303 |
| - async_result = asyncio.run(self._call_async_item(data=data, model_spec=model_spec, parameters=parameters)) |
| 320 | + for retry_idx in range(max(2, TRITON_RETRIES)): |
| 321 | + async_result = asyncio.run(self._call_async_item(data=data, model_spec=model_spec, parameters=parameters)) |
| 322 | + |
| 323 | + is_invalid_argument_grpc = ( |
| 324 | + self.flag.protocol is TritonProtocol.grpc |
| 325 | + and isinstance(async_result, grpc.RpcError) |
| 326 | + and async_result.code() == grpc.StatusCode.INVALID_ARGUMENT |
| 327 | + ) |
| 328 | + is_invalid_argument_http = ( |
| 329 | + self.flag.protocol is TritonProtocol.http |
| 330 | + and isinstance(async_result, InferenceServerException) |
| 331 | + and async_result.status() == "400" |
| 332 | + ) |
304 | 333 |
|
305 |
| - if isinstance(async_result, Exception): |
306 |
| - raise async_result |
| 334 | + if is_invalid_argument_grpc or is_invalid_argument_http: |
| 335 | + time.sleep(TRITON_LOAD_DELAY * TRITON_BACKOFF_COEFF**retry_idx) |
| 336 | + self._renew_triton_client(self._triton_client) |
| 337 | + model_spec = self.model_specs[(model_spec.name, model_spec.model_version)] |
| 338 | + continue |
| 339 | + elif isinstance(async_result, Exception): |
| 340 | + raise async_result |
| 341 | + break |
307 | 342 |
|
308 | 343 | return async_result
|
309 | 344 |
|
|
0 commit comments