Skip to content

Commit b6908b5

Browse files
committed
split setting task name during function declaration stage
1 parent 2082cb8 commit b6908b5

File tree

4 files changed

+60
-25
lines changed

4 files changed

+60
-25
lines changed

ydb/_topic_common/common.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@ def wrapper(rpc_state, response_pb, driver=None):
2929

3030
return wrapper
3131

32-
33-
def wrap_create_asyncio_task(func: typing.Callable, task_name: str, *args, **kwargs):
34-
if sys.hexversion < 0x03080000:
35-
return asyncio.create_task(func(*args, **kwargs))
36-
return asyncio.create_task(func(*args, **kwargs), name=task_name)
32+
if sys.hexversion < 0x03080000:
33+
def wrap_set_name_for_asyncio_task(task: asyncio.Task, task_name: str) -> asyncio.Task:
34+
task.set_name(task_name)
35+
return task
36+
else:
37+
def wrap_set_name_for_asyncio_task(task: asyncio.Task, task_name: str) -> asyncio.Task:
38+
return task
3739

3840

3941
_shared_event_loop_lock = threading.Lock()

ydb/_topic_common/common_test.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import asyncio
2+
import sys
23
import threading
34
import time
45
import typing
56

67
import grpc
78
import pytest
89

9-
from .common import CallFromSyncToAsync
10+
from .common import CallFromSyncToAsync, wrap_set_name_for_asyncio_task
1011
from .._grpc.grpcwrapper.common_utils import (
1112
GrpcWrapperAsyncIO,
1213
ServerStatus,
@@ -75,6 +76,21 @@ async def async_failed():
7576
with pytest.raises(TestError):
7677
await callback_from_asyncio(async_failed)
7778

79+
async def test_task_name_on_asyncio_task(self):
80+
task_name = "asyncio task"
81+
loop = asyncio.get_running_loop()
82+
83+
async def some_async_task():
84+
await asyncio.sleep(0)
85+
return 1
86+
87+
asyncio_task = loop.create_task(some_async_task())
88+
wrap_set_name_for_asyncio_task(asyncio_task, task_name=task_name)
89+
90+
if sys.hexversion >= 0x03080000:
91+
assert asyncio_task.get_name() == task_name
92+
93+
7894

7995
@pytest.mark.asyncio
8096
class TestGrpcWrapperAsyncIO:

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import asyncio
44
import concurrent.futures
55
import gzip
6-
import sys
76
import typing
87
from asyncio import Task
98
from collections import deque
@@ -89,10 +88,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
8988

9089
def __del__(self):
9190
if not self._closed:
92-
if sys.hexversion < 0x03080000:
93-
self._loop.create_task(self.close(flush=False))
94-
else:
95-
self._loop.create_task(self.close(flush=False), name="close reader")
91+
task = self._loop.create_task(self.close(flush=False))
92+
topic_common.wrap_set_name_for_asyncio_task(task, task_name="close reader")
9693

9794
async def wait_message(self):
9895
"""
@@ -343,17 +340,29 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMess
343340
self._update_token_event.set()
344341

345342
self._background_tasks.add(
346-
topic_common.wrap_create_asyncio_task(self._read_messages_loop, "read_messages_loop"),
343+
topic_common.wrap_set_name_for_asyncio_task(
344+
asyncio.create_task(self._read_messages_loop()),
345+
task_name="read_messages_loop",
346+
),
347347
)
348348
self._background_tasks.add(
349-
topic_common.wrap_create_asyncio_task(self._decode_batches_loop, "decode_batches"),
349+
topic_common.wrap_set_name_for_asyncio_task(
350+
asyncio.create_task(self._decode_batches_loop()),
351+
task_name="decode_batches",
352+
),
350353
)
351354
if self._get_token_function:
352355
self._background_tasks.add(
353-
topic_common.wrap_create_asyncio_task(self._update_token_loop, "update_token_loop"),
356+
topic_common.wrap_set_name_for_asyncio_task(
357+
asyncio.create_task(self._update_token_loop()),
358+
task_name="update_token_loop",
359+
),
354360
)
355361
self._background_tasks.add(
356-
topic_common.wrap_create_asyncio_task(self._handle_background_errors, "handle_background_errors"),
362+
topic_common.wrap_set_name_for_asyncio_task(
363+
asyncio.create_task(self._handle_background_errors()),
364+
task_name="handle_background_errors",
365+
),
357366
)
358367

359368
async def wait_error(self):

ydb/_topic_writer/topic_writer_asyncio.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,14 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
232232
self._new_messages = asyncio.Queue()
233233
self._stop_reason = self._loop.create_future()
234234
self._background_tasks = [
235-
topic_common.wrap_create_asyncio_task(self._connection_loop, "connection_loop"),
236-
topic_common.wrap_create_asyncio_task(self._encode_loop, "encode_loop"),
235+
topic_common.wrap_set_name_for_asyncio_task(
236+
asyncio.create_task(self._connection_loop()),
237+
task_name="connection_loop",
238+
),
239+
topic_common.wrap_set_name_for_asyncio_task(
240+
asyncio.create_task(self._encode_loop()),
241+
task_name="encode_loop",
242+
),
237243
]
238244

239245
self._state_changed = asyncio.Event()
@@ -367,11 +373,13 @@ async def _connection_loop(self):
367373

368374
self._stream_connected.set()
369375

370-
send_loop = topic_common.wrap_create_asyncio_task(self._send_loop, "writer send loop", stream_writer)
371-
receive_loop = topic_common.wrap_create_asyncio_task(
372-
self._read_loop,
373-
"writer receive loop",
374-
stream_writer,
376+
send_loop = topic_common.wrap_set_name_for_asyncio_task(
377+
asyncio.create_task(self._send_loop(stream_writer)),
378+
task_name="writer send loop",
379+
)
380+
receive_loop = topic_common.wrap_set_name_for_asyncio_task(
381+
asyncio.create_task(self._read_loop(stream_writer)),
382+
task_name="writer receive loop",
375383
)
376384

377385
tasks = [send_loop, receive_loop]
@@ -658,9 +666,9 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMes
658666

659667
if self._update_token_interval is not None:
660668
self._update_token_event.set()
661-
self._update_token_task = topic_common.wrap_create_asyncio_task(
662-
self._update_token_loop,
663-
"update_token_loop",
669+
self._update_token_task = topic_common.wrap_set_name_for_asyncio_task(
670+
asyncio.create_task(self._update_token_loop()),
671+
task_name="update_token_loop",
664672
)
665673

666674
@staticmethod

0 commit comments

Comments
 (0)