Skip to content

Commit ceee5af

Browse files
authored
Merge pull request #43 from opendatalab/dev
feat: http_client recreate aio_client after event loop changed.
2 parents 2c41db4 + a963aa8 commit ceee5af

File tree

1 file changed

+43
-18
lines changed

1 file changed

+43
-18
lines changed

mineru_vl_utils/vlm_client/http_client.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -70,38 +70,62 @@ def __init__(
7070
server_url = self._get_base_url(server_url)
7171

7272
self.server_url = server_url
73+
self.server_headers = server_headers
74+
self.http_timeout = http_timeout
75+
self.max_retries = max_retries
76+
self.retry_backoff_factor = retry_backoff_factor
7377

74-
self._client = httpx.Client(
75-
headers=server_headers,
76-
timeout=httpx.Timeout(connect=10.0, read=http_timeout, write=http_timeout, pool=None),
78+
self._client = self._new_client()
79+
self._aio_client_sem = asyncio.Semaphore(1)
80+
self._aio_client_cache: dict[asyncio.AbstractEventLoop, httpx.AsyncClient] = {}
81+
82+
if model_name:
83+
self._check_model_name(self.server_url, model_name)
84+
self.model_name = model_name
85+
else:
86+
self.model_name = self._get_model_name(self.server_url)
87+
88+
@property
89+
def chat_url(self) -> str:
90+
return f"{self.server_url}/v1/chat/completions"
91+
92+
def _new_client(self) -> httpx.Client:
93+
return httpx.Client(
94+
headers=self.server_headers,
95+
timeout=httpx.Timeout(connect=10.0, read=self.http_timeout, write=self.http_timeout, pool=None),
7796
transport=RetryTransport(
78-
retry=Retry(total=max_retries, backoff_factor=retry_backoff_factor),
97+
retry=Retry(total=self.max_retries, backoff_factor=self.retry_backoff_factor),
7998
transport=httpx.HTTPTransport(
8099
limits=httpx.Limits(max_connections=None, max_keepalive_connections=20),
81100
),
82101
),
83102
)
84103

85-
self._aio_client = httpx.AsyncClient(
86-
headers=server_headers,
87-
timeout=httpx.Timeout(connect=10.0, read=http_timeout, write=http_timeout, pool=None),
104+
async def _new_aio_client(self) -> httpx.AsyncClient:
105+
return httpx.AsyncClient(
106+
headers=self.server_headers,
107+
timeout=httpx.Timeout(connect=10.0, read=self.http_timeout, write=self.http_timeout, pool=None),
88108
transport=RetryTransport(
89-
retry=Retry(total=max_retries, backoff_factor=retry_backoff_factor),
109+
retry=Retry(total=self.max_retries, backoff_factor=self.retry_backoff_factor),
90110
transport=httpx.AsyncHTTPTransport(
91111
limits=httpx.Limits(max_connections=None, max_keepalive_connections=20),
92112
),
93113
),
94114
)
95115

96-
if model_name:
97-
self._check_model_name(self.server_url, model_name)
98-
self.model_name = model_name
99-
else:
100-
self.model_name = self._get_model_name(self.server_url)
101-
102-
@property
103-
def chat_url(self) -> str:
104-
return f"{self.server_url}/v1/chat/completions"
116+
async def _aio_client(self) -> httpx.AsyncClient:
117+
loop = asyncio.get_running_loop()
118+
aio_client = self._aio_client_cache.get(loop)
119+
if aio_client is not None:
120+
return aio_client
121+
async with self._aio_client_sem:
122+
aio_client = self._aio_client_cache.get(loop)
123+
if aio_client is not None:
124+
return aio_client
125+
aio_client = await self._new_aio_client()
126+
self._aio_client_cache.clear()
127+
self._aio_client_cache[loop] = aio_client
128+
return aio_client
105129

106130
def _get_base_url(self, server_url: str) -> str:
107131
matched = re.match(r"^(https?://[^/]+)", server_url)
@@ -409,7 +433,8 @@ async def aio_predict(
409433
request_text = request_text[:2048] + "...(truncated)..." + request_text[-2048:]
410434
print(f"Request body: {request_text}")
411435

412-
response = await self._aio_client.post(self.chat_url, json=request_body)
436+
client = await self._aio_client()
437+
response = await client.post(self.chat_url, json=request_body)
413438
response_data = self.get_response_data(response)
414439

415440
if self.debug:

0 commit comments

Comments
 (0)