Skip to content
Merged
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "taskiq-faststream"
version = "0.2.2"
version = "0.2.3"
description = "FastStream - taskiq integration to schedule FastStream tasks"
readme = "README.md"
authors = [
Expand Down Expand Up @@ -78,8 +78,8 @@ test = [

dev = [
"taskiq-faststream[test]",
"mypy==1.15.0",
"ruff==0.11.10",
"mypy==1.16.0",
"ruff==0.11.12",
"pre-commit >=3.6.0,<5.0.0",
]

Expand Down
8 changes: 4 additions & 4 deletions taskiq_faststream/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from taskiq.decor import AsyncTaskiqDecoratedTask
from typing_extensions import TypeAlias

from taskiq_faststream.formatter import PatchedFormatter, PathcedMessage
from taskiq_faststream.formatter import PatchedFormatter, PatchedMessage
from taskiq_faststream.types import ScheduledTask
from taskiq_faststream.utils import resolve_msg

Expand Down Expand Up @@ -46,7 +46,7 @@ async def shutdown(self) -> None:
await self.broker.close()
await super().shutdown()

async def kick(self, message: PathcedMessage) -> None: # type: ignore[override]
async def kick(self, message: PatchedMessage) -> None: # type: ignore[override]
"""Call wrapped FastStream broker `publish` method."""
await _broker_publish(self.broker, message)

Expand Down Expand Up @@ -123,7 +123,7 @@ async def shutdown(self) -> None:
await self.app._shutdown() # noqa: SLF001
await super(BrokerWrapper, self).shutdown()

async def kick(self, message: PathcedMessage) -> None: # type: ignore[override]
async def kick(self, message: PatchedMessage) -> None: # type: ignore[override]
"""Call wrapped FastStream broker `publish` method."""
assert ( # noqa: S101
self.app.broker
Expand All @@ -133,7 +133,7 @@ async def kick(self, message: PathcedMessage) -> None: # type: ignore[override]

async def _broker_publish(
broker: Any,
message: PathcedMessage,
message: PatchedMessage,
) -> None:
async for msg in resolve_msg(message.body):
await broker.publish(msg, **message.labels)
8 changes: 4 additions & 4 deletions taskiq_faststream/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


@dataclass
class PathcedMessage:
class PatchedMessage:
"""DTO to transfer data to `broker.kick`."""

body: Any
Expand All @@ -19,18 +19,18 @@ class PatchedFormatter(TaskiqFormatter):
def dumps( # type: ignore[override]
self,
message: TaskiqMessage,
) -> PathcedMessage:
) -> PatchedMessage:
"""
Dumps taskiq message to some broker message format.

:param message: message to send.
:return: Dumped message.
"""
labels = message.labels
labels = message.labels.copy()
labels.pop("schedule", None)
labels.pop("schedule_id", None)

return PathcedMessage(
return PatchedMessage(
body=labels.pop("message", None),
labels=labels,
)
Expand Down
8 changes: 8 additions & 0 deletions taskiq_faststream/kicker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from typing import Any

from taskiq.kicker import AsyncKicker, _FuncParams, _ReturnType
from taskiq.message import TaskiqMessage


class LabelRespectKicker(AsyncKicker[_FuncParams, _ReturnType]):
"""Patched kicker doesn't cast labels to str."""

def _prepare_message(self, *args: Any, **kwargs: Any) -> TaskiqMessage:
msg = super()._prepare_message(*args, **kwargs)
msg.labels = self.labels
return msg
33 changes: 33 additions & 0 deletions tests/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from collections.abc import AsyncIterator, Iterator

message = "Hi!"


def sync_callable_msg() -> str:
return message


async def async_callable_msg() -> str:
return message


async def async_generator_msg() -> AsyncIterator[str]:
yield message


def sync_generator_msg() -> Iterator[str]:
yield message


class _C:
def __call__(self) -> str:
return message


class _AC:
async def __call__(self) -> str:
return message


sync_callable_class_message = _C()
async_callable_class_message = _AC()
94 changes: 29 additions & 65 deletions tests/test_resolve_message.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,34 @@
from collections.abc import AsyncIterator, Iterator
import typing

import pytest
from faststream.types import SendableMessage

from taskiq_faststream.utils import resolve_msg


@pytest.mark.anyio
async def test_regular() -> None:
async for m in resolve_msg("msg"):
assert m == "msg"


@pytest.mark.anyio
async def test_sync_callable() -> None:
async for m in resolve_msg(lambda: "msg"):
assert m == "msg"


from tests import messages


@pytest.mark.parametrize(
"msg",
[
messages.message, # regular msg
messages.sync_callable_msg, # sync callable
messages.async_callable_msg, # async callable
messages.sync_generator_msg, # sync generator
messages.async_generator_msg, # async generator
messages.sync_callable_class_message, # sync callable class
messages.async_callable_class_message, # async callable class
],
)
@pytest.mark.anyio
async def test_async_callable() -> None:
async def gen_msg() -> str:
return "msg"

async for m in resolve_msg(gen_msg):
assert m == "msg"


@pytest.mark.anyio
async def test_sync_callable_class() -> None:
class C:
def __init__(self) -> None:
pass

def __call__(self) -> str:
return "msg"

async for m in resolve_msg(C()):
assert m == "msg"


@pytest.mark.anyio
async def test_async_callable_class() -> None:
class C:
def __init__(self) -> None:
pass

async def __call__(self) -> str:
return "msg"

async for m in resolve_msg(C()):
assert m == "msg"


@pytest.mark.anyio
async def test_async_generator() -> None:
async def get_msg() -> AsyncIterator[str]:
yield "msg"

async for m in resolve_msg(get_msg):
assert m == "msg"


@pytest.mark.anyio
async def test_sync_generator() -> None:
def get_msg() -> Iterator[str]:
yield "msg"

async for m in resolve_msg(get_msg):
assert m == "msg"
async def test_resolve_msg(
msg: typing.Union[
None,
SendableMessage,
typing.Callable[[], SendableMessage],
typing.Callable[[], typing.Awaitable[SendableMessage]],
typing.Callable[[], typing.Generator[SendableMessage, None, None]],
typing.Callable[[], typing.AsyncGenerator[SendableMessage, None]],
],
) -> None:
async for m in resolve_msg(msg):
assert m == messages.message
37 changes: 31 additions & 6 deletions tests/testcase.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import asyncio
import typing
from datetime import datetime, timedelta, timezone
from typing import Any
from unittest.mock import MagicMock

import pytest
from faststream.types import SendableMessage
from faststream.utils.functions import timeout_scope
from freezegun import freeze_time
from taskiq import AsyncBroker, TaskiqScheduler
from taskiq import AsyncBroker
from taskiq.cli.scheduler.args import SchedulerArgs
from taskiq.cli.scheduler.run import run_scheduler
from taskiq.schedule_sources import LabelScheduleSource

from taskiq_faststream import BrokerWrapper, StreamScheduler
from tests import messages


@pytest.mark.anyio
Expand Down Expand Up @@ -54,7 +57,7 @@ async def handler(msg: str) -> None:
task = asyncio.create_task(
run_scheduler(
SchedulerArgs(
scheduler=TaskiqScheduler(
scheduler=StreamScheduler(
broker=taskiq_broker,
sources=[LabelScheduleSource(taskiq_broker)],
),
Expand All @@ -69,24 +72,44 @@ async def handler(msg: str) -> None:
mock.assert_called_once_with("Hi!")
task.cancel()

@pytest.mark.parametrize(
"msg",
[
messages.message, # regular msg
messages.sync_callable_msg, # sync callable
messages.async_callable_msg, # async callable
messages.sync_generator_msg, # sync generator
messages.async_generator_msg, # async generator
messages.sync_callable_class_message, # sync callable class
messages.async_callable_class_message, # async callable class
],
)
async def test_task_multiple_schedules_by_cron(
self,
subject: str,
broker: Any,
event: asyncio.Event,
msg: typing.Union[
None,
SendableMessage,
typing.Callable[[], SendableMessage],
typing.Callable[[], typing.Awaitable[SendableMessage]],
typing.Callable[[], typing.Generator[SendableMessage, None, None]],
typing.Callable[[], typing.AsyncGenerator[SendableMessage, None]],
],
) -> None:
"""Test cron runs twice via StreamScheduler."""
received_message = []

@broker.subscriber(subject)
async def handler(msg: str) -> None:
received_message.append(msg)
async def handler(message: str) -> None:
received_message.append(message)
event.set()

taskiq_broker = self.build_taskiq_broker(broker)

taskiq_broker.task(
"Hi!",
msg,
**{self.subj_name: subject},
schedule=[
{
Expand Down Expand Up @@ -116,4 +139,6 @@ async def handler(msg: str) -> None:

task.cancel()

assert received_message == ["Hi!", "Hi!"], received_message
assert received_message == [messages.message, messages.message], (
received_message
)
Loading