diff --git a/awaitlet/base.py b/awaitlet/base.py index 1accf9c..70eaebd 100644 --- a/awaitlet/base.py +++ b/awaitlet/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -import asyncio +import inspect import sys from typing import Any from typing import Awaitable @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING from typing import TypeVar +from anyio import get_cancelled_exc_class from greenlet import greenlet from .util.typing import TypeGuard @@ -20,7 +21,7 @@ def is_exit_exception(e: BaseException) -> bool: # note asyncio.CancelledError is already BaseException # so was an exit exception in any case return not isinstance(e, Exception) or isinstance( - e, (asyncio.TimeoutError, asyncio.CancelledError) + e, (TimeoutError, get_cancelled_exc_class()) ) @@ -50,7 +51,7 @@ def iscoroutine( ) -> TypeGuard[Coroutine[Any, Any, _T_co]]: ... else: - iscoroutine = asyncio.iscoroutine + iscoroutine = inspect.isawaitable def _safe_cancel_awaitable(awaitable: Awaitable[Any]) -> None: diff --git a/pyproject.toml b/pyproject.toml index 79c7ac9..4338bb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ ] dependencies = [ + "anyio >= 4, < 5", "greenlet >= 1", "typing-extensions >= 4.6.0", ] diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index ab3862a..b792127 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -1,7 +1,7 @@ -import asyncio import contextvars import random +import anyio from awaitlet import async_def from awaitlet import awaitlet from awaitlet import NoAwaitletContext @@ -42,7 +42,7 @@ async def test_propagate_cancelled(self): cleanup = [] async def async_meth_raise(): - raise asyncio.CancelledError() + raise anyio.get_cancelled_exc_class() def sync_meth(): try: @@ -54,7 +54,7 @@ def sync_meth(): async def run_w_cancel(): await async_def(sync_meth) - with expect_raises(asyncio.CancelledError, check_context=False): + with expect_raises(anyio.get_cancelled_exc_class(), check_context=False): await run_w_cancel() assert cleanup @@ -112,12 +112,12 @@ async def test_contextvars(self): # NOTE: sleep here is not necessary. It's used to simulate IO # ensuring that task are not run sequentially async def async_inner(val): - await asyncio.sleep(random.uniform(0.005, 0.015)) + await anyio.sleep(random.uniform(0.005, 0.015)) eq_(val, var.get()) return var.get() async def async_set(val): - await asyncio.sleep(random.uniform(0.005, 0.015)) + await anyio.sleep(random.uniform(0.005, 0.015)) var.set(val) def inner(val): @@ -140,18 +140,17 @@ def inner(val): return retval + values = set() + async def task(val): - await asyncio.sleep(random.uniform(0.005, 0.015)) + await anyio.sleep(random.uniform(0.005, 0.015)) var.set(val) - await asyncio.sleep(random.uniform(0.005, 0.015)) - return await async_def(inner, val) - - values = { - await coro - for coro in asyncio.as_completed( - [task(i) for i in range(concurrency)] - ) - } + await anyio.sleep(random.uniform(0.005, 0.015)) + values.add(await async_def(inner, val)) + + async with anyio.create_task_group() as tg: + for i in range(concurrency): + tg.start_soon(task, i) eq_(values, set(range(concurrency * 2, concurrency * 3))) @async_test