diff --git a/pyproject.toml b/pyproject.toml index 1bfe704..a09a813 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "taskiq-faststream" -version = "0.2.2" +version = "0.2.3" description = "FastStream - taskiq integration to schedule FastStream tasks" readme = "README.md" authors = [ @@ -78,8 +78,8 @@ test = [ dev = [ "taskiq-faststream[test]", - "mypy==1.15.0", - "ruff==0.11.10", + "mypy==1.16.0", + "ruff==0.11.12", "pre-commit >=3.6.0,<5.0.0", ] diff --git a/taskiq_faststream/broker.py b/taskiq_faststream/broker.py index 34b6b2a..36b55cc 100644 --- a/taskiq_faststream/broker.py +++ b/taskiq_faststream/broker.py @@ -10,7 +10,7 @@ from taskiq.decor import AsyncTaskiqDecoratedTask from typing_extensions import TypeAlias -from taskiq_faststream.formatter import PatchedFormatter, PathcedMessage +from taskiq_faststream.formatter import PatchedFormatter, PatchedMessage from taskiq_faststream.types import ScheduledTask from taskiq_faststream.utils import resolve_msg @@ -46,7 +46,7 @@ async def shutdown(self) -> None: await self.broker.close() await super().shutdown() - async def kick(self, message: PathcedMessage) -> None: # type: ignore[override] + async def kick(self, message: PatchedMessage) -> None: # type: ignore[override] """Call wrapped FastStream broker `publish` method.""" await _broker_publish(self.broker, message) @@ -123,7 +123,7 @@ async def shutdown(self) -> None: await self.app._shutdown() # noqa: SLF001 await super(BrokerWrapper, self).shutdown() - async def kick(self, message: PathcedMessage) -> None: # type: ignore[override] + async def kick(self, message: PatchedMessage) -> None: # type: ignore[override] """Call wrapped FastStream broker `publish` method.""" assert ( # noqa: S101 self.app.broker @@ -133,7 +133,7 @@ async def kick(self, message: PathcedMessage) -> None: # type: ignore[override] async def _broker_publish( broker: Any, - message: PathcedMessage, + message: PatchedMessage, ) -> None: async for msg in resolve_msg(message.body): await broker.publish(msg, **message.labels) diff --git a/taskiq_faststream/formatter.py b/taskiq_faststream/formatter.py index 5ed1cc2..425f56b 100644 --- a/taskiq_faststream/formatter.py +++ b/taskiq_faststream/formatter.py @@ -6,7 +6,7 @@ @dataclass -class PathcedMessage: +class PatchedMessage: """DTO to transfer data to `broker.kick`.""" body: Any @@ -19,18 +19,18 @@ class PatchedFormatter(TaskiqFormatter): def dumps( # type: ignore[override] self, message: TaskiqMessage, - ) -> PathcedMessage: + ) -> PatchedMessage: """ Dumps taskiq message to some broker message format. :param message: message to send. :return: Dumped message. """ - labels = message.labels + labels = message.labels.copy() labels.pop("schedule", None) labels.pop("schedule_id", None) - return PathcedMessage( + return PatchedMessage( body=labels.pop("message", None), labels=labels, ) diff --git a/taskiq_faststream/kicker.py b/taskiq_faststream/kicker.py index 2226e5d..f2a96eb 100644 --- a/taskiq_faststream/kicker.py +++ b/taskiq_faststream/kicker.py @@ -1,5 +1,13 @@ +from typing import Any + from taskiq.kicker import AsyncKicker, _FuncParams, _ReturnType +from taskiq.message import TaskiqMessage class LabelRespectKicker(AsyncKicker[_FuncParams, _ReturnType]): """Patched kicker doesn't cast labels to str.""" + + def _prepare_message(self, *args: Any, **kwargs: Any) -> TaskiqMessage: + msg = super()._prepare_message(*args, **kwargs) + msg.labels = self.labels + return msg diff --git a/tests/messages.py b/tests/messages.py new file mode 100644 index 0000000..c5d69fe --- /dev/null +++ b/tests/messages.py @@ -0,0 +1,33 @@ +from collections.abc import AsyncIterator, Iterator + +message = "Hi!" + + +def sync_callable_msg() -> str: + return message + + +async def async_callable_msg() -> str: + return message + + +async def async_generator_msg() -> AsyncIterator[str]: + yield message + + +def sync_generator_msg() -> Iterator[str]: + yield message + + +class _C: + def __call__(self) -> str: + return message + + +class _AC: + async def __call__(self) -> str: + return message + + +sync_callable_class_message = _C() +async_callable_class_message = _AC() diff --git a/tests/test_resolve_message.py b/tests/test_resolve_message.py index 3bcdfef..78d24e2 100644 --- a/tests/test_resolve_message.py +++ b/tests/test_resolve_message.py @@ -1,70 +1,34 @@ -from collections.abc import AsyncIterator, Iterator +import typing import pytest +from faststream.types import SendableMessage from taskiq_faststream.utils import resolve_msg - - -@pytest.mark.anyio -async def test_regular() -> None: - async for m in resolve_msg("msg"): - assert m == "msg" - - -@pytest.mark.anyio -async def test_sync_callable() -> None: - async for m in resolve_msg(lambda: "msg"): - assert m == "msg" - - +from tests import messages + + +@pytest.mark.parametrize( + "msg", + [ + messages.message, # regular msg + messages.sync_callable_msg, # sync callable + messages.async_callable_msg, # async callable + messages.sync_generator_msg, # sync generator + messages.async_generator_msg, # async generator + messages.sync_callable_class_message, # sync callable class + messages.async_callable_class_message, # async callable class + ], +) @pytest.mark.anyio -async def test_async_callable() -> None: - async def gen_msg() -> str: - return "msg" - - async for m in resolve_msg(gen_msg): - assert m == "msg" - - -@pytest.mark.anyio -async def test_sync_callable_class() -> None: - class C: - def __init__(self) -> None: - pass - - def __call__(self) -> str: - return "msg" - - async for m in resolve_msg(C()): - assert m == "msg" - - -@pytest.mark.anyio -async def test_async_callable_class() -> None: - class C: - def __init__(self) -> None: - pass - - async def __call__(self) -> str: - return "msg" - - async for m in resolve_msg(C()): - assert m == "msg" - - -@pytest.mark.anyio -async def test_async_generator() -> None: - async def get_msg() -> AsyncIterator[str]: - yield "msg" - - async for m in resolve_msg(get_msg): - assert m == "msg" - - -@pytest.mark.anyio -async def test_sync_generator() -> None: - def get_msg() -> Iterator[str]: - yield "msg" - - async for m in resolve_msg(get_msg): - assert m == "msg" +async def test_resolve_msg( + msg: typing.Union[ + None, + SendableMessage, + typing.Callable[[], SendableMessage], + typing.Callable[[], typing.Awaitable[SendableMessage]], + typing.Callable[[], typing.Generator[SendableMessage, None, None]], + typing.Callable[[], typing.AsyncGenerator[SendableMessage, None]], + ], +) -> None: + async for m in resolve_msg(msg): + assert m == messages.message diff --git a/tests/testcase.py b/tests/testcase.py index d1e62a0..738a9fa 100644 --- a/tests/testcase.py +++ b/tests/testcase.py @@ -1,17 +1,20 @@ import asyncio +import typing from datetime import datetime, timedelta, timezone from typing import Any from unittest.mock import MagicMock import pytest +from faststream.types import SendableMessage from faststream.utils.functions import timeout_scope from freezegun import freeze_time -from taskiq import AsyncBroker, TaskiqScheduler +from taskiq import AsyncBroker from taskiq.cli.scheduler.args import SchedulerArgs from taskiq.cli.scheduler.run import run_scheduler from taskiq.schedule_sources import LabelScheduleSource from taskiq_faststream import BrokerWrapper, StreamScheduler +from tests import messages @pytest.mark.anyio @@ -54,7 +57,7 @@ async def handler(msg: str) -> None: task = asyncio.create_task( run_scheduler( SchedulerArgs( - scheduler=TaskiqScheduler( + scheduler=StreamScheduler( broker=taskiq_broker, sources=[LabelScheduleSource(taskiq_broker)], ), @@ -69,24 +72,44 @@ async def handler(msg: str) -> None: mock.assert_called_once_with("Hi!") task.cancel() + @pytest.mark.parametrize( + "msg", + [ + messages.message, # regular msg + messages.sync_callable_msg, # sync callable + messages.async_callable_msg, # async callable + messages.sync_generator_msg, # sync generator + messages.async_generator_msg, # async generator + messages.sync_callable_class_message, # sync callable class + messages.async_callable_class_message, # async callable class + ], + ) async def test_task_multiple_schedules_by_cron( self, subject: str, broker: Any, event: asyncio.Event, + msg: typing.Union[ + None, + SendableMessage, + typing.Callable[[], SendableMessage], + typing.Callable[[], typing.Awaitable[SendableMessage]], + typing.Callable[[], typing.Generator[SendableMessage, None, None]], + typing.Callable[[], typing.AsyncGenerator[SendableMessage, None]], + ], ) -> None: """Test cron runs twice via StreamScheduler.""" received_message = [] @broker.subscriber(subject) - async def handler(msg: str) -> None: - received_message.append(msg) + async def handler(message: str) -> None: + received_message.append(message) event.set() taskiq_broker = self.build_taskiq_broker(broker) taskiq_broker.task( - "Hi!", + msg, **{self.subj_name: subject}, schedule=[ { @@ -116,4 +139,6 @@ async def handler(msg: str) -> None: task.cancel() - assert received_message == ["Hi!", "Hi!"], received_message + assert received_message == [messages.message, messages.message], ( + received_message + )