22
33import abc
44import asyncio
5+ import concurrent .futures
56import contextvars
67import datetime
78import functools
@@ -112,15 +113,16 @@ def __next__(self):
112113
113114
114115class SyncIteratorToAsyncIterator :
115- def __init__ (self , sync_iterator : Iterator ):
116+ def __init__ (self , sync_iterator : Iterator , executor : concurrent . futures . Executor ):
116117 self ._sync_iterator = sync_iterator
118+ self ._executor = executor
117119
118120 def __aiter__ (self ):
119121 return self
120122
121123 async def __anext__ (self ):
122124 try :
123- res = await to_thread (self ._sync_iterator .__next__ )
125+ res = await to_thread (self ._sync_iterator .__next__ , executor = self . _executor )
124126 return res
125127 except StopAsyncIteration :
126128 raise StopIteration ()
@@ -149,12 +151,17 @@ class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO):
149151 convert_server_grpc_to_wrapper : Callable [[Any ], Any ]
150152 _connection_state : str
151153 _stream_call : Optional [Union [grpc .aio .StreamStreamCall , "grpc._channel._MultiThreadedRendezvous" ]]
154+ _wait_executor : Optional [concurrent .futures .ThreadPoolExecutor ]
152155
153156 def __init__ (self , convert_server_grpc_to_wrapper ):
154157 self .from_client_grpc = asyncio .Queue ()
155158 self .convert_server_grpc_to_wrapper = convert_server_grpc_to_wrapper
156159 self ._connection_state = "new"
157160 self ._stream_call = None
161+ self ._wait_executor = None
162+
163+ def __del__ (self ):
164+ self ._clean_executor (wait = False )
158165
159166 async def start (self , driver : SupportedDriverType , stub , method ):
160167 if asyncio .iscoroutinefunction (driver .__call__ ):
@@ -168,6 +175,12 @@ def close(self):
168175 if self ._stream_call :
169176 self ._stream_call .cancel ()
170177
178+ self ._clean_executor (wait = True )
179+
180+ def _clean_executor (self , wait : bool ):
181+ if self ._wait_executor :
182+ self ._wait_executor .shutdown (wait )
183+
171184 async def _start_asyncio_driver (self , driver : ydb .aio .Driver , stub , method ):
172185 requests_iterator = QueueToIteratorAsyncIO (self .from_client_grpc )
173186 stream_call = await driver (
@@ -180,14 +193,11 @@ async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method):
180193
181194 async def _start_sync_driver (self , driver : ydb .Driver , stub , method ):
182195 requests_iterator = AsyncQueueToSyncIteratorAsyncIO (self .from_client_grpc )
183- stream_call = await to_thread (
184- driver ,
185- requests_iterator ,
186- stub ,
187- method ,
188- )
196+ self ._wait_executor = concurrent .futures .ThreadPoolExecutor (max_workers = 1 )
197+
198+ stream_call = await to_thread (driver , requests_iterator , stub , method , executor = self ._wait_executor )
189199 self ._stream_call = stream_call
190- self .from_server_grpc = SyncIteratorToAsyncIterator (stream_call .__iter__ ())
200+ self .from_server_grpc = SyncIteratorToAsyncIterator (stream_call .__iter__ (), self . _wait_executor )
191201
192202 async def receive (self ) -> Any :
193203 # todo handle grpc exceptions and convert it to internal exceptions
@@ -255,7 +265,7 @@ def callback_from_asyncio(callback: Union[Callable, Coroutine]) -> [asyncio.Futu
255265 return loop .run_in_executor (None , callback )
256266
257267
258- async def to_thread (func , / , * args , ** kwargs ):
268+ async def to_thread (func , * args , executor : Optional [ concurrent . futures . Executor ] , ** kwargs ):
259269 """Asynchronously run function *func* in a separate thread.
260270
261271 Any *args and **kwargs supplied for this function are directly passed
@@ -271,7 +281,7 @@ async def to_thread(func, /, *args, **kwargs):
271281 loop = asyncio .get_running_loop ()
272282 ctx = contextvars .copy_context ()
273283 func_call = functools .partial (ctx .run , func , * args , ** kwargs )
274- return await loop .run_in_executor (None , func_call )
284+ return await loop .run_in_executor (executor , func_call )
275285
276286
277287def proto_duration_from_timedelta (t : Optional [datetime .timedelta ]) -> Optional [ProtoDuration ]:
0 commit comments