33import asyncio
44import functools
55import operator
6- from collections .abc import Iterable , Mapping
6+ from collections .abc import AsyncIterable , Iterable , Mapping
77from enum import Enum
88from itertools import starmap
99from typing import (
@@ -50,10 +50,15 @@ def product(tup: ChunkCoords) -> int:
5050
5151
5252async def concurrent_map (
53- items : Iterable [T ], func : Callable [..., Awaitable [V ]], limit : int | None = None
53+ items : Iterable [T ] | AsyncIterable [T ],
54+ func : Callable [..., Awaitable [V ]],
55+ limit : int | None = None ,
5456) -> list [V ]:
5557 if limit is None :
56- return await asyncio .gather (* list (starmap (func , items )))
58+ if isinstance (items , AsyncIterable ):
59+ return await asyncio .gather (* list (starmap (func , [x async for x in items ])))
60+ else :
61+ return await asyncio .gather (* list (starmap (func , items )))
5762
5863 else :
5964 sem = asyncio .Semaphore (limit )
@@ -62,7 +67,10 @@ async def run(item: tuple[Any]) -> V:
6267 async with sem :
6368 return await func (* item )
6469
65- return await asyncio .gather (* [asyncio .ensure_future (run (item )) for item in items ])
70+ if isinstance (items , AsyncIterable ):
71+ return await asyncio .gather (* [asyncio .ensure_future (run (item )) async for item in items ])
72+ else :
73+ return await asyncio .gather (* [asyncio .ensure_future (run (item )) for item in items ])
6674
6775
6876E = TypeVar ("E" , bound = Enum )
0 commit comments