@@ -70,38 +70,62 @@ def __init__(
7070 server_url = self ._get_base_url (server_url )
7171
7272 self .server_url = server_url
73+ self .server_headers = server_headers
74+ self .http_timeout = http_timeout
75+ self .max_retries = max_retries
76+ self .retry_backoff_factor = retry_backoff_factor
7377
74- self ._client = httpx .Client (
75- headers = server_headers ,
76- timeout = httpx .Timeout (connect = 10.0 , read = http_timeout , write = http_timeout , pool = None ),
78+ self ._client = self ._new_client ()
79+ self ._aio_client_sem = asyncio .Semaphore (1 )
80+ self ._aio_client_cache : dict [asyncio .AbstractEventLoop , httpx .AsyncClient ] = {}
81+
82+ if model_name :
83+ self ._check_model_name (self .server_url , model_name )
84+ self .model_name = model_name
85+ else :
86+ self .model_name = self ._get_model_name (self .server_url )
87+
88+ @property
89+ def chat_url (self ) -> str :
90+ return f"{ self .server_url } /v1/chat/completions"
91+
92+ def _new_client (self ) -> httpx .Client :
93+ return httpx .Client (
94+ headers = self .server_headers ,
95+ timeout = httpx .Timeout (connect = 10.0 , read = self .http_timeout , write = self .http_timeout , pool = None ),
7796 transport = RetryTransport (
78- retry = Retry (total = max_retries , backoff_factor = retry_backoff_factor ),
97+ retry = Retry (total = self . max_retries , backoff_factor = self . retry_backoff_factor ),
7998 transport = httpx .HTTPTransport (
8099 limits = httpx .Limits (max_connections = None , max_keepalive_connections = 20 ),
81100 ),
82101 ),
83102 )
84103
85- self ._aio_client = httpx .AsyncClient (
86- headers = server_headers ,
87- timeout = httpx .Timeout (connect = 10.0 , read = http_timeout , write = http_timeout , pool = None ),
104+ async def _new_aio_client (self ) -> httpx .AsyncClient :
105+ return httpx .AsyncClient (
106+ headers = self .server_headers ,
107+ timeout = httpx .Timeout (connect = 10.0 , read = self .http_timeout , write = self .http_timeout , pool = None ),
88108 transport = RetryTransport (
89- retry = Retry (total = max_retries , backoff_factor = retry_backoff_factor ),
109+ retry = Retry (total = self . max_retries , backoff_factor = self . retry_backoff_factor ),
90110 transport = httpx .AsyncHTTPTransport (
91111 limits = httpx .Limits (max_connections = None , max_keepalive_connections = 20 ),
92112 ),
93113 ),
94114 )
95115
96- if model_name :
97- self ._check_model_name (self .server_url , model_name )
98- self .model_name = model_name
99- else :
100- self .model_name = self ._get_model_name (self .server_url )
101-
102- @property
103- def chat_url (self ) -> str :
104- return f"{ self .server_url } /v1/chat/completions"
116+ async def _aio_client (self ) -> httpx .AsyncClient :
117+ loop = asyncio .get_running_loop ()
118+ aio_client = self ._aio_client_cache .get (loop )
119+ if aio_client is not None :
120+ return aio_client
121+ async with self ._aio_client_sem :
122+ aio_client = self ._aio_client_cache .get (loop )
123+ if aio_client is not None :
124+ return aio_client
125+ aio_client = await self ._new_aio_client ()
126+ self ._aio_client_cache .clear ()
127+ self ._aio_client_cache [loop ] = aio_client
128+ return aio_client
105129
106130 def _get_base_url (self , server_url : str ) -> str :
107131 matched = re .match (r"^(https?://[^/]+)" , server_url )
@@ -409,7 +433,8 @@ async def aio_predict(
409433 request_text = request_text [:2048 ] + "...(truncated)..." + request_text [- 2048 :]
410434 print (f"Request body: { request_text } " )
411435
412- response = await self ._aio_client .post (self .chat_url , json = request_body )
436+ client = await self ._aio_client ()
437+ response = await client .post (self .chat_url , json = request_body )
413438 response_data = self .get_response_data (response )
414439
415440 if self .debug :
0 commit comments