|
| 1 | +import asyncio |
| 2 | +import logging |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import pytest |
| 6 | + |
| 7 | +from tritony import InferenceClient |
| 8 | + |
| 9 | +from .common_fixtures import MODEL_NAME, TRITON_HOST, async_config |
| 10 | + |
| 11 | +logging.basicConfig( |
| 12 | + level=logging.INFO, |
| 13 | + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 14 | + datefmt="%m/%d/%Y %H:%M:%S", |
| 15 | +) |
| 16 | + |
| 17 | +__all__ = ["async_config"] |
| 18 | +EPSILON = 1e-8 |
| 19 | + |
| 20 | + |
| 21 | +def get_client(protocol, port, model_name): |
| 22 | + print(f"Testing {protocol}", flush=True) |
| 23 | + return InferenceClient.create_with_asyncio(model_name, f"{TRITON_HOST}:{port}", protocol=protocol) |
| 24 | + |
| 25 | + |
| 26 | +@pytest.mark.asyncio |
| 27 | +async def test_basics(async_config): |
| 28 | + protocol, port = async_config |
| 29 | + |
| 30 | + client = get_client(*async_config, model_name=MODEL_NAME) |
| 31 | + sample = np.random.rand(1, 100).astype(np.float32) |
| 32 | + |
| 33 | + result = await client.aio_infer(sample) |
| 34 | + assert np.isclose(result, sample).all() |
| 35 | + |
| 36 | + result = await client.aio_infer({"model_in": sample}) |
| 37 | + assert np.isclose(result, sample).all() |
| 38 | + |
| 39 | + |
| 40 | +@pytest.mark.asyncio |
| 41 | +async def test_multiple_tasks(async_config): |
| 42 | + n_multiple_tasks = 10 |
| 43 | + protocol, port = async_config |
| 44 | + print(f"Testing {protocol}:{port}") |
| 45 | + |
| 46 | + client_list = [get_client(*async_config, model_name="sample_sleep_1sec") for _ in range(n_multiple_tasks)] |
| 47 | + |
| 48 | + sample = np.random.rand(1, 100).astype(np.float32) |
| 49 | + tasks = [client.aio_infer(sample) for client in client_list] |
| 50 | + |
| 51 | + start_time = asyncio.get_event_loop().time() |
| 52 | + results = await asyncio.gather(*tasks) |
| 53 | + end_time = asyncio.get_event_loop().time() |
| 54 | + |
| 55 | + for result in results: |
| 56 | + assert np.isclose(result, sample).all() |
| 57 | + |
| 58 | + assert (end_time - start_time) < 2, f"Time taken: {end_time - start_time}" |
0 commit comments