Skip to content

Commit c4e4c9c

Browse files
committed
feat: support multiple push
1 parent 2c4f136 commit c4e4c9c

File tree

3 files changed

+61
-14
lines changed

3 files changed

+61
-14
lines changed

taskiq_faststream/broker.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,5 +133,7 @@ async def _broker_publish(
133133
) -> None:
134134
labels = message.labels
135135
labels.pop("schedule", None)
136-
msg = await resolve_msg(labels.pop("message", message.message))
137-
await broker.publish(msg, **labels)
136+
async for msg in resolve_msg(
137+
msg=labels.pop("message", message.message),
138+
):
139+
await broker.publish(msg, **labels)

taskiq_faststream/utils.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typing
22

3+
from fast_depends.utils import is_async_gen_callable, is_gen_callable
34
from faststream.types import SendableMessage
45
from faststream.utils.functions import to_async
56

@@ -10,8 +11,10 @@ async def resolve_msg(
1011
SendableMessage,
1112
typing.Callable[[], SendableMessage],
1213
typing.Callable[[], typing.Awaitable[SendableMessage]],
14+
typing.Callable[[], typing.Iterator[SendableMessage]],
15+
typing.Callable[[], typing.AsyncIterator[SendableMessage]],
1316
],
14-
) -> SendableMessage:
17+
) -> typing.AsyncIterator[SendableMessage]:
1518
"""Resolve message generation callback.
1619
1720
Args:
@@ -21,9 +24,26 @@ async def resolve_msg(
2124
The message to send
2225
"""
2326
if callable(msg):
24-
get_msg = typing.cast(
25-
typing.Callable[[], typing.Awaitable[SendableMessage]],
26-
to_async(msg),
27-
)
28-
msg = await get_msg()
29-
return msg
27+
if is_async_gen_callable(msg):
28+
async for i in typing.cast(
29+
typing.Callable[[], typing.AsyncIterator[SendableMessage]],
30+
msg,
31+
)():
32+
yield i
33+
34+
elif is_gen_callable(msg):
35+
for i in typing.cast(
36+
typing.Callable[[], typing.Iterator[SendableMessage]],
37+
msg,
38+
)():
39+
yield i
40+
41+
else:
42+
get_msg = typing.cast(
43+
typing.Callable[[], typing.Awaitable[SendableMessage]],
44+
to_async(msg),
45+
)
46+
yield await get_msg()
47+
48+
else:
49+
yield msg

tests/test_resolve_message.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,29 @@
1+
from typing import AsyncIterator, Iterator
2+
13
import pytest
24

35
from taskiq_faststream.utils import resolve_msg
46

57

68
@pytest.mark.anyio
79
async def test_regular() -> None:
8-
assert await resolve_msg("msg") == "msg"
10+
async for m in resolve_msg("msg"):
11+
assert m == "msg"
912

1013

1114
@pytest.mark.anyio
1215
async def test_sync_callable() -> None:
13-
assert await resolve_msg(lambda: "msg") == "msg"
16+
async for m in resolve_msg(lambda: "msg"):
17+
assert m == "msg"
1418

1519

1620
@pytest.mark.anyio
1721
async def test_async_callable() -> None:
1822
async def gen_msg() -> str:
1923
return "msg"
2024

21-
assert await resolve_msg(gen_msg) == "msg"
25+
async for m in resolve_msg(gen_msg):
26+
assert m == "msg"
2227

2328

2429
@pytest.mark.anyio
@@ -30,7 +35,8 @@ def __init__(self) -> None:
3035
def __call__(self) -> str:
3136
return "msg"
3237

33-
assert await resolve_msg(C()) == "msg"
38+
async for m in resolve_msg(C()):
39+
assert m == "msg"
3440

3541

3642
@pytest.mark.anyio
@@ -42,4 +48,23 @@ def __init__(self) -> None:
4248
async def __call__(self) -> str:
4349
return "msg"
4450

45-
assert await resolve_msg(C()) == "msg"
51+
async for m in resolve_msg(C()):
52+
assert m == "msg"
53+
54+
55+
@pytest.mark.anyio
56+
async def test_async_generator() -> None:
57+
async def get_msg() -> AsyncIterator[str]:
58+
yield "msg"
59+
60+
async for m in resolve_msg(get_msg):
61+
assert m == "msg"
62+
63+
64+
@pytest.mark.anyio
65+
async def test_sync_generator() -> None:
66+
def get_msg() -> Iterator[str]:
67+
yield "msg"
68+
69+
async for m in resolve_msg(get_msg):
70+
assert m == "msg"

0 commit comments

Comments
 (0)