Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 39 additions & 25 deletions taskiq_nats/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@
from nats.errors import TimeoutError as NatsTimeoutError
from nats.js import JetStreamContext
from nats.js.api import ConsumerConfig, StreamConfig
from nats.js.errors import NotFoundError
from taskiq import AckableMessage, AsyncBroker, AsyncResultBackend, BrokerMessage

_T = typing.TypeVar("_T") # (Too short)


JetStreamConsumerType = typing.TypeVar(
"JetStreamConsumerType",
)


logger = getLogger("taskiq_nats")


Expand All @@ -36,13 +35,13 @@ class NatsBroker(AsyncBroker):
"""

def __init__(
self,
servers: typing.Union[str, typing.List[str]],
subject: str = "taskiq_tasks",
queue: typing.Optional[str] = None,
result_backend: "typing.Optional[AsyncResultBackend[_T]]" = None,
task_id_generator: typing.Optional[typing.Callable[[], str]] = None,
**connection_kwargs: typing.Any,
self,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this unnessesary chanes with additional spaces

(I wonder why pre-commit hooks or CI checks didn't catch this...)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I've corrected it.

servers: typing.Union[str, typing.List[str]],
subject: str = "taskiq_tasks",
queue: typing.Optional[str] = None,
result_backend: "typing.Optional[AsyncResultBackend[_T]]" = None,
task_id_generator: typing.Optional[typing.Callable[[], str]] = None,
**connection_kwargs: typing.Any,
) -> None:
super().__init__(result_backend, task_id_generator)
self.servers = servers
Expand Down Expand Up @@ -106,17 +105,17 @@ class BaseJetStreamBroker(
"""

def __init__(
self,
servers: typing.Union[str, typing.List[str]],
subject: str = "taskiq_tasks",
stream_name: str = "taskiq_jetstream",
queue: typing.Optional[str] = None,
durable: str = "taskiq_durable",
stream_config: typing.Optional[StreamConfig] = None,
consumer_config: typing.Optional[ConsumerConfig] = None,
pull_consume_batch: int = 1,
pull_consume_timeout: typing.Optional[float] = None,
**connection_kwargs: typing.Any,
self,
servers: typing.Union[str, typing.List[str]],
subject: str = "taskiq_tasks",
stream_name: str = "taskiq_jetstream",
queue: typing.Optional[str] = None,
durable: str = "taskiq_durable",
stream_config: typing.Optional[StreamConfig] = None,
consumer_config: typing.Optional[ConsumerConfig] = None,
pull_consume_batch: int = 1,
pull_consume_timeout: typing.Optional[float] = None,
**connection_kwargs: typing.Any,
) -> None:
super().__init__()
self.servers: typing.Final = servers
Expand All @@ -138,6 +137,23 @@ def __init__(

self.consumer: JetStreamConsumerType

async def _ensure_stream_exists(self) -> None:
"""Ensure stream exists, create if it doesn't."""
if self.stream_config.name is None:
self.stream_config.name = self.stream_name
if not self.stream_config.subjects:
self.stream_config.subjects = [self.subject]

try:
# Check if stream already exists
await self.js.stream_info(self.stream_config.name)
logger.debug("Stream %s already exists", self.stream_config.name)
except NotFoundError:
logger.debug("stream %s does not exist", self.stream_config.name)
# Stream doesn't exist, create it
await self.js.add_stream(config=self.stream_config)
logger.info("Created stream %s", self.stream_config.name)

async def startup(self) -> None:
"""
Startup event handler.
Expand All @@ -148,11 +164,9 @@ async def startup(self) -> None:
await super().startup()
await self.client.connect(self.servers, **self.connection_kwargs)
self.js = self.client.jetstream()
if self.stream_config.name is None:
self.stream_config.name = self.stream_name
if not self.stream_config.subjects:
self.stream_config.subjects = [self.subject]
await self.js.add_stream(config=self.stream_config)

# Ensure stream exists (won't recreate if it exists)
await self._ensure_stream_exists()
await self._startup_consumer()

async def shutdown(self) -> None:
Expand Down