diff --git a/README.md b/README.md index 038796b..37f84b8 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,21 @@ To run TLS you need to: Read more about the issue [here](https://stackoverflow.com/questions/44979947/python-qpid-proton-for-mac-using-amqps) +### SSL Problems in local enviroment + +If when running tests, this exceptions is raised by the proton library: `SSLUnavailable`: +``` bash +pip uninstall python-qpid-proton -y + +sudo apt-get update +sudo apt-get install -y swig cmake build-essential libssl-dev pkg-config + +export PKG_CONFIG_PATH=/usr/lib/x86_64-linux-gnu/pkgconfig +export CFLAGS="-I/usr/include/openssl" +export LDFLAGS="-L/usr/lib/x86_64-linux-gnu" + +pip install "python-qpid-proton>=0.39.0,<0.40.0" --no-binary python-qpid-proton --verbose --no-cache-dir +``` diff --git a/examples/getting_started/getting_started_async.py b/examples/getting_started/getting_started_async.py new file mode 100644 index 0000000..bf752b8 --- /dev/null +++ b/examples/getting_started/getting_started_async.py @@ -0,0 +1,123 @@ +# type: ignore + +import asyncio + +from rabbitmq_amqp_python_client import ( + AddressHelper, + AMQPMessagingHandler, + AsyncEnvironment, + Converter, + Event, + ExchangeSpecification, + ExchangeToQueueBindingSpecification, + Message, + OutcomeState, + QuorumQueueSpecification, +) + +MESSAGES_TO_PUBLISH = 100 + + +class StopConsumerException(Exception): + """Exception to signal consumer should stop""" + pass + + +class MyMessageHandler(AMQPMessagingHandler): + + def __init__(self): + super().__init__() + self._count = 0 + + def on_amqp_message(self, event: Event): + print( + "received message: {} ".format( + Converter.bytes_to_string(event.message.body) + ) + ) + + self.delivery_context.accept(event) + self._count = self._count + 1 + print("count " + str(self._count)) + + if self._count == MESSAGES_TO_PUBLISH: + print("received all messages") + # Stop the consumer by raising an exception + raise StopConsumerException("All messages consumed") + + def on_connection_closed(self, event: Event): + print("connection closed") + + def on_link_closed(self, event: Event) -> None: + print("link closed") + + +async def main(): + exchange_name = "test-exchange" + queue_name = "example-queue" + routing_key = "routing-key" + + print("connection to amqp server") + async with AsyncEnvironment( + uri="amqp://guest:guest@localhost:5672/" + ) as environment: + async with await environment.connection() as connection: + async with await connection.management() as management: + print("declaring exchange and queue") + await management.declare_exchange(ExchangeSpecification(name=exchange_name)) + await management.declare_queue( + QuorumQueueSpecification(name=queue_name) + ) + + print("binding queue to exchange") + bind_name = await management.bind( + ExchangeToQueueBindingSpecification( + source_exchange=exchange_name, + destination_queue=queue_name, + binding_key=routing_key, + ) + ) + + addr = AddressHelper.exchange_address(exchange_name, routing_key) + addr_queue = AddressHelper.queue_address(queue_name) + + print("create a publisher and publish a test message") + async with await connection.publisher(addr) as publisher: + print("purging the queue") + messages_purged = await management.purge_queue(queue_name) + print("messages purged: " + str(messages_purged)) + + # publish messages + for i in range(MESSAGES_TO_PUBLISH): + status = await publisher.publish( + Message(body=Converter.string_to_bytes("test message {} ".format(i))) + ) + if status.remote_state == OutcomeState.ACCEPTED: + print("message accepted") + + print("create a consumer and consume the test message - press control + c to terminate to consume") + handler = MyMessageHandler() + async with await connection.consumer(addr_queue, message_handler=handler) as consumer: + # Run the consumer in a background task + consumer_task = asyncio.create_task(consumer.run()) + + try: + # Wait for the consumer to finish (e.g., by raising the exception) + await consumer_task + except StopConsumerException as e: + print(f"Consumer stopped: {e}") + except KeyboardInterrupt: + print("consumption interrupted by user, stopping consumer...") + await consumer.stop() + + print("unbind") + await management.unbind(bind_name) + + print("delete queue") + await management.delete_queue(queue_name) + + print("delete exchange") + await management.delete_exchange(exchange_name) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/poetry.lock b/poetry.lock index f33cdda..b3b615c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,17 @@ -# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. + +[[package]] +name = "backports-asyncio-runner" +version = "1.2.0" +description = "Backport of asyncio.Runner, a context manager that controls event loop life cycle." +optional = false +python-versions = "<3.11,>=3.8" +groups = ["dev"] +markers = "python_version < \"3.11\"" +files = [ + {file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"}, + {file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"}, +] [[package]] name = "black" @@ -566,6 +579,27 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "1.2.0" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99"}, + {file = "pytest_asyncio-1.2.0.tar.gz", hash = "sha256:c609a64a2a8768462d0c99811ddb8bd2583c33fd33cf7f21af1c142e824ffb57"}, +] + +[package.dependencies] +backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} +pytest = ">=8.2,<9" +typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "python-qpid-proton" version = "0.39.0" @@ -695,4 +729,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.1" python-versions = "^3.9" -content-hash = "674d5be8f30ba05fcd618702e4b8c662604a4568181ef335fd22cebe483935ad" +content-hash = "6855640542dddf03775cf0ecc647aa2e277b471618471e31a382012117ea76ce" diff --git a/pyproject.toml b/pyproject.toml index 9ccea55..9f6d308 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,11 @@ pytest = "^8.3.4" black = "^24.3.0" python-qpid-proton = "^0.39.0" requests = "^2.31.0" +pytest-asyncio = "^1.2.0" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +asyncio_mode = "auto" \ No newline at end of file diff --git a/rabbitmq_amqp_python_client/__init__.py b/rabbitmq_amqp_python_client/__init__.py index e787f23..0ec9cb4 100644 --- a/rabbitmq_amqp_python_client/__init__.py +++ b/rabbitmq_amqp_python_client/__init__.py @@ -2,6 +2,13 @@ from .address_helper import AddressHelper from .amqp_consumer_handler import AMQPMessagingHandler +from .asyncio import ( + AsyncConnection, + AsyncConsumer, + AsyncEnvironment, + AsyncManagement, + AsyncPublisher, +) from .common import ExchangeType, QueueType from .connection import Connection from .consumer import Consumer @@ -99,4 +106,9 @@ "RecoveryConfiguration", "OAuth2Options", "Converter", + "AsyncConnection", + "AsyncConsumer", + "AsyncPublisher", + "AsyncManagement", + "AsyncEnvironment", ] diff --git a/rabbitmq_amqp_python_client/asyncio/__init__.py b/rabbitmq_amqp_python_client/asyncio/__init__.py new file mode 100644 index 0000000..8ff79c3 --- /dev/null +++ b/rabbitmq_amqp_python_client/asyncio/__init__.py @@ -0,0 +1,13 @@ +from .connection import AsyncConnection +from .consumer import AsyncConsumer +from .enviroment import AsyncEnvironment +from .management import AsyncManagement +from .publisher import AsyncPublisher + +__all__ = [ + "AsyncConnection", + "AsyncConsumer", + "AsyncManagement", + "AsyncPublisher", + "AsyncEnvironment", +] diff --git a/rabbitmq_amqp_python_client/asyncio/connection.py b/rabbitmq_amqp_python_client/asyncio/connection.py new file mode 100644 index 0000000..56b7395 --- /dev/null +++ b/rabbitmq_amqp_python_client/asyncio/connection.py @@ -0,0 +1,317 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import Callable, Optional, Union + +from ..address_helper import validate_address +from ..connection import Connection +from ..consumer import Consumer +from ..entities import ( + ConsumerOptions, + OAuth2Options, + RecoveryConfiguration, +) +from ..exceptions import ArgumentOutOfRangeException +from ..management import Management +from ..publisher import Publisher +from ..qpid.proton._handlers import MessagingHandler +from ..ssl_configuration import ( + PosixSslConfigurationContext, + WinSslConfigurationContext, +) +from .consumer import AsyncConsumer +from .management import AsyncManagement +from .publisher import AsyncPublisher + +logger = logging.getLogger(__name__) + + +class AsyncConnection: + """ + Asyncio-compatible facade around Connection. + + This class manages the connection to RabbitMQ and provides factory methods for + creating publishers, consumers, and management interfaces. It supports both + single-node and multi-node configurations, as well as SSL/TLS connections. + + Note: + The underlying Proton BlockingConnection is NOT thread-safe. A lock is used + to serialize all operations. + + Attributes: + _connection (Connection): The underlying synchronous Connection + _connection_lock (asyncio.Lock): Lock for coordinating access to the shared connection + _async_publishers (list[AsyncPublisher]): List of active async publishers + _async_consumers (list[AsyncConsumer]): List of active async consumers + _async_managements (list[AsyncManagement]): List of active async management interfaces + _remove_callback (Optional[Callable[[AsyncConnection], None]]): Callback on close + """ + + def __init__( + self, + uri: Optional[str] = None, + uris: Optional[list[str]] = None, + ssl_context: Union[ + PosixSslConfigurationContext, WinSslConfigurationContext, None + ] = None, + oauth2_options: Optional[OAuth2Options] = None, + recovery_configuration: RecoveryConfiguration = RecoveryConfiguration(), + ): + """ + Initialize AsyncConnection. + + Args: + uri: Optional single-node connection URI + uris: Optional multi-node connection URIs + ssl_context: Optional SSL/TLS configuration + oauth2_options: Optional OAuth2 configuration + recovery_configuration: Configuration for automatic recovery + + Raises: + ValidationCodeException: If recovery configuration is invalid + """ + self._connection = Connection( + uri=uri, + uris=uris, + ssl_context=ssl_context, + oauth2_options=oauth2_options, + recovery_configuration=recovery_configuration, + ) + self._connection_lock = asyncio.Lock() + self._async_publishers: list[AsyncPublisher] = [] + self._async_consumers: list[AsyncConsumer] = [] + self._async_managements: list[AsyncManagement] = [] + self._remove_callback: Optional[Callable[[AsyncConnection], None]] = None + + async def dial(self) -> None: + """ + Establish a connection to the AMQP server. + + Configures SSL if specified and establishes the connection using the + provided URI(s). Also initializes the management interface. + """ + async with self._connection_lock: + await asyncio.to_thread(self._connection.dial) + + def _set_remove_callback( + self, callback: Optional[Callable[[AsyncConnection], None]] + ) -> None: + """Set callback to be called when connection is closed.""" + self._remove_callback = callback + + def _remove_publisher(self, publisher: AsyncPublisher) -> None: + """Remove a publisher from the active list.""" + if publisher in self._async_publishers: + self._async_publishers.remove(publisher) + + def _remove_consumer(self, consumer: AsyncConsumer) -> None: + """Remove a consumer from the active list.""" + if consumer in self._async_consumers: + self._async_consumers.remove(consumer) + + def _remove_management(self, management: AsyncManagement) -> None: + """Remove a management interface from the active list.""" + if management in self._async_managements: + self._async_managements.remove(management) + + async def close(self) -> None: + """ + Close the connection to the AMQP 1.0 server. + + Closes the underlying connection and removes it from the connection list. + """ + logger.debug("Closing async connection") + try: + for async_publisher in self._async_publishers[:]: + await async_publisher.close() + for async_consumer in self._async_consumers[:]: + await async_consumer.close() + for async_management in self._async_managements[:]: + await async_management.close() + + async with self._connection_lock: + await asyncio.to_thread(self._connection.close) + except Exception as e: + logger.error(f"Error closing async connections: {e}") + raise e + finally: + if self._remove_callback is not None: + self._remove_callback(self) + + def _set_connection_managements(self, management: Management) -> None: + if len(self._connection._managements) == 0: + self._connection._managements = [management] + + async def management( + self, + ) -> AsyncManagement: + """ + Get the management interface for this connection. + + Returns: + AsyncManagement: The management interface for performing administrative tasks + """ + if len(self._async_managements) > 0: + return self._async_managements[0] + + async_management: Optional[AsyncManagement] = None + async with self._connection_lock: + if len(self._async_managements) == 0: + async_management = AsyncManagement( + self._connection._conn, + connection_lock=self._connection_lock, + ) + + self._set_connection_managements( + async_management._management + ) # TODO: check this + self._async_managements.append(async_management) + + async_management._set_remove_callback(self._remove_management) + + if async_management is not None: + await async_management.open() + + return self._async_managements[0] + + def _set_connection_publishers(self, publisher: Publisher) -> None: + """Set the list of publishers in the underlying connection.""" + publisher._set_publishers_list( + [async_publisher._publisher for async_publisher in self._async_publishers] + ) + self._connection._publishers.append(publisher) + + async def publisher(self, destination: str = "") -> AsyncPublisher: + """ + Create a new publisher instance. + + Args: + destination: Optional default destination for published messages + + Returns: + AsyncPublisher: A new publisher instance + + Raises: + RuntimeError: If publisher creation fails + ArgumentOutOfRangeException: If destination address format is invalid + """ + if destination != "": + if not validate_address(destination): + raise ArgumentOutOfRangeException( + "destination address must start with /queues or /exchanges" + ) + + async with self._connection_lock: + async_publisher = AsyncPublisher( + self._connection._conn, + destination, + connection_lock=self._connection_lock, + ) + await async_publisher.open() + if async_publisher._publisher is None: + raise RuntimeError("Failed to create publisher") + self._set_connection_publishers( + async_publisher._publisher + ) # TODO: check this + self._async_publishers.append(async_publisher) + + async_publisher._set_remove_callback(self._remove_publisher) + + return async_publisher + + def _set_connection_consumers(self, consumer: Consumer) -> None: + """Set the list of consumers in the underlying connection.""" + self._connection._consumers.append(consumer) + + async def consumer( + self, + destination: str, + message_handler: Optional[MessagingHandler] = None, + consumer_options: Optional[ConsumerOptions] = None, + credit: Optional[int] = None, + ) -> AsyncConsumer: + """ + Create a new consumer instance. + + Args: + destination: The address to consume from + message_handler: Optional handler for processing messages + consumer_options: Optional configuration for queue consumption. Each queue has its own consumer options. + credit: Optional credit value for flow control + + Returns: + AsyncConsumer: A new consumer instance + + Raises: + RuntimeError: If consumer creation fails + ArgumentOutOfRangeException: If destination address format is invalid + """ + if not validate_address(destination): + raise ArgumentOutOfRangeException( + "destination address must start with /queues or /exchanges" + ) + if consumer_options is not None: + consumer_options.validate( + { + "4.0.0": self._connection._is_server_version_gte("4.0.0"), + "4.1.0": self._connection._is_server_version_gte("4.1.0"), + "4.2.0": self._connection._is_server_version_gte("4.2.0"), + } + ) + + async with self._connection_lock: + async_consumer = AsyncConsumer( + self._connection._conn, + destination, + message_handler, # pyright: ignore[reportArgumentType] + consumer_options, + credit, + connection_lock=self._connection_lock, + ) + await async_consumer.open() + if async_consumer._consumer is None: + raise RuntimeError("Failed to create consumer") + self._set_connection_consumers(async_consumer._consumer) # TODO: check this + self._async_consumers.append(async_consumer) + + async_consumer._set_remove_callback(self._remove_consumer) + + return async_consumer + + async def refresh_token(self, token: str) -> None: + """ + Refresh the oauth token + + Args: + token: the oauth token to refresh + + Raises: + ValidationCodeException: If oauth is not enabled + """ + async with self._connection_lock: + await asyncio.to_thread(self._connection.refresh_token, token) + + async def __aenter__(self) -> AsyncConnection: + """ "Async context manager entry.""" + await self.dial() + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc: Optional[BaseException], + tb: Optional[object], + ) -> None: + """Async context manager exit.""" + await self.close() + + @property + def active_producers(self) -> int: + """Get the number of active producers of the connection.""" + return len(self._async_publishers) + + @property + def active_consumers(self) -> int: + """Get the number of active consumers of the connection.""" + return len(self._async_consumers) diff --git a/rabbitmq_amqp_python_client/asyncio/consumer.py b/rabbitmq_amqp_python_client/asyncio/consumer.py new file mode 100644 index 0000000..62cd257 --- /dev/null +++ b/rabbitmq_amqp_python_client/asyncio/consumer.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import asyncio +import logging +from types import TracebackType +from typing import Callable, Literal, Optional, Type, Union + +from ..amqp_consumer_handler import AMQPMessagingHandler +from ..consumer import Consumer +from ..entities import ConsumerOptions +from ..qpid.proton._message import Message +from ..qpid.proton.utils import BlockingConnection + +logger = logging.getLogger(__name__) + + +class AsyncConsumer: + """ + Asyncio-compatible facade around Consumer. + + This class wraps the synchronous Consumer to provide an async interface. + All blocking operations are executed in threads to avoid blocking the event loop. + + Note: + The underlying Proton BlockingConnection is NOT thread-safe. A lock must + be provided by the caller (AsyncConnection) to serialize all operations. + + Attributes: + _consumer (Optional[Consumer]): The underlying synchronous Consumer + _conn (BlockingConnection): The shared blocking connection + _addr (str): The address to consume from + _handler (Optional[AMQPMessagingHandler]): Optional message handling callback + _stream_options (Optional[ConsumerOptions]): Configuration for stream consumption + _credit (Optional[int]): Flow control credit value + _connection_lock (asyncio.Lock): Lock for coordinating access to the shared connection + _remove_callback (Optional[Callable[[AsyncConsumer], None]]): Callback on close + _opened (bool): Indicates if the consumer is opened + """ + + def __init__( + self, + conn: BlockingConnection, + addr: str, + handler: Optional[AMQPMessagingHandler] = None, + stream_options: Optional[ConsumerOptions] = None, + credit: Optional[int] = None, + *, + connection_lock: asyncio.Lock, + ): + """ + Initialize AsyncConsumer. + + Args: + conn: The blocking connection to use + addr: The address to consume from + handler: Optional message handler for processing received messages + stream_options: Optional configuration for stream-based consumption + credit: Optional credit value for flow control + connection_lock: Lock for coordinating access to the shared connection. + Must be created by the caller (AsyncConnection). + """ + self._conn = conn + self._addr = addr + self._handler = handler + self._stream_options = stream_options + self._credit = credit + self._consumer: Optional[Consumer] = None + self._connection_lock = connection_lock + self._remove_callback: Optional[Callable[[AsyncConsumer], None]] = None + self._opened = False + + def _set_remove_callback( + self, callback: Optional[Callable[["AsyncConsumer"], None]] + ) -> None: + """Set callback to be called when consumer is closed.""" + self._remove_callback = callback + + async def open(self) -> None: + """ + Open the consumer in an async context. + + Creates the underlying Consumer instance. This should be called + before using the consumer, either explicitly or via async context manager. + """ + if self._opened: + return + + # Create consumer in thread to avoid blocking event loop + self._consumer = await asyncio.to_thread( + Consumer, + self._conn, + self._addr, + self._handler, + self._stream_options, + self._credit, + ) + self._opened = True + logger.debug(f"AsyncConsumer opened for address: {self._addr}") + + async def consume( + self, timeout: Union[None, Literal[False], float] = False + ) -> Message: + """ + Consume a message from the queue. + + Args: + timeout: The time to wait for a message. + None: Defaults to 60s + float: Wait for specified number of seconds + + Returns: + Message: The received message + + Raises: + RuntimeError: If consumer is not opened + + Note: + The return type might be None if no message is available and timeout occurs, + but this is handled by the cast to Message. + """ + if not self._opened or self._consumer is None: + raise RuntimeError( + "Consumer is not opened. Call open() or use async context manager." + ) + + async with self._connection_lock: + return await asyncio.to_thread(self._consumer.consume, timeout) + + async def close(self) -> None: + """ + Close the consumer connection. + + Closes the receiver if it exists and cleans up resources. + """ + if not self._opened or self._consumer is None: + return + + try: + async with self._connection_lock: + await asyncio.to_thread(self._consumer.close) + + logger.debug(f"AsyncConsumer closed for address: {self._addr}") + except Exception as e: + logger.error(f"Error closing consumer: {e}", exc_info=True) + raise + finally: + self._opened = False + self._consumer = None + if self._remove_callback is not None: + callback = self._remove_callback + self._remove_callback = None # Prevent double-call + callback(self) + + async def run(self) -> None: + """ + Run the consumer in continuous mode. + + Starts the consumer's container to process messages continuously. + + Raises: + RuntimeError: If consumer is not opened + """ + if not self._opened or self._consumer is None: + raise RuntimeError( + "Consumer is not opened. Call open() or use async context manager." + ) + + async with self._connection_lock: + await asyncio.to_thread(self._consumer.run) + + async def stop(self) -> None: + """ + Stop the consumer's continuous processing. + + Stops the consumer's container, halting message processing. + This should be called to cleanly stop a consumer that was started with run(). + + Raises: + RuntimeError: If consumer is not opened + """ + if not self._opened or self._consumer is None: + raise RuntimeError( + "Consumer is not opened. Call open() or use async context manager." + ) + + async with self._connection_lock: + await asyncio.to_thread(self._consumer.stop) + + async def __aenter__(self) -> AsyncConsumer: + """Async context manager entry.""" + await self.open() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + """Async context manager exit.""" + await self.close() + + @property + def address(self) -> str: + """Get the current consumer address.""" + return self._addr + + @property + def handler(self) -> Optional[AMQPMessagingHandler]: + """Get the current message handler.""" + return self._handler diff --git a/rabbitmq_amqp_python_client/asyncio/enviroment.py b/rabbitmq_amqp_python_client/asyncio/enviroment.py new file mode 100644 index 0000000..51816a8 --- /dev/null +++ b/rabbitmq_amqp_python_client/asyncio/enviroment.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import Optional, Union + +from ..entities import OAuth2Options, RecoveryConfiguration +from ..ssl_configuration import ( + PosixSslConfigurationContext, + WinSslConfigurationContext, +) +from .connection import AsyncConnection + +logger = logging.getLogger(__name__) + + +class AsyncEnvironment: + """ + Asyncio-compatible facade around Environment. + + This class serves as a connection pooler to maintain compatibility with other clients. + It manages a collection of connections and provides methods for creating and managing + these connections. + + Attributes: + _connections (list[AsyncConnection]): List of active async connections managed by this environment + _connections_lock (asyncio.Lock): Lock for coordinating access to the connections list + _uri (Optional[str]): Single node connection URI + _uris (Optional[list[str]]): List of URIs for multi-node setup + _ssl_context (Union[PosixSslConfigurationContext, WinSslConfigurationContext, None]): + SSL configuration for secure connections + _oauth2_options (Optional[OAuth2Options]): OAuth2 options for authentication + _recovery_configuration (RecoveryConfiguration): Configuration for connection recovery + """ + + def __init__( + self, # single-node mode + uri: Optional[str] = None, + # multi-node mode + uris: Optional[list[str]] = None, + ssl_context: Union[ + PosixSslConfigurationContext, WinSslConfigurationContext, None + ] = None, + oauth2_options: Optional[OAuth2Options] = None, + recovery_configuration: RecoveryConfiguration = RecoveryConfiguration(), + ): + """ + Initialize AsyncEnvironment. + + Args: + uri: Single node connection URI + uris: List of URIs for multi-node setup + ssl_context: SSL configuration for secure connections + oauth2_options: OAuth2 options for authentication + recovery_configuration: Configuration for connection recovery + + Raises: + ValueError: If both 'uri' and 'uris' are specified or if neither is specified. + """ + if uri is not None and uris is not None: + raise ValueError( + "Cannot specify both 'uri' and 'uris'. Choose one connection mode." + ) + if uri is None and uris is None: + raise ValueError("Must specify either 'uri' or 'uris' for connection.") + + self._uri = uri + self._uris = uris + self._ssl_context = ssl_context + self._oauth2_options = oauth2_options + self._recovery_configuration = recovery_configuration + self._connections: list[AsyncConnection] = [] + self._connections_lock = asyncio.Lock() + + def _remove_connection(self, connection: AsyncConnection) -> None: + """Remove a connection from the environment's tracking list.""" + if connection in self._connections: + self._connections.remove(connection) + + async def connection(self) -> AsyncConnection: + """ + Create and return a new connection. + + This method supports both single-node and multi-node configurations, with optional + SSL/TLS security and disconnection handling. + + Returns: + AsyncConnection: A new connection instance + + Raises: + ValueError: If neither uri nor uris is provided + """ + async with self._connections_lock: + connection = AsyncConnection( + uri=self._uri, + uris=self._uris, + ssl_context=self._ssl_context, + oauth2_options=self._oauth2_options, + recovery_configuration=self._recovery_configuration, + ) + logger.debug("AsyncEnvironment: Creating new async connection") + self._connections.append(connection) + + connection._set_remove_callback(self._remove_connection) + + return connection + + async def close(self) -> None: + """ + Close all active connections. + + Iterates through all connections managed by this environment and closes them. + This method should be called when shutting down the application to ensure + proper cleanup of resources. + """ + errors = [] + + async with self._connections_lock: + connections_to_close = self._connections[:] + + for connection in connections_to_close: + try: + await connection.close() + except Exception as e: + errors.append(e) + logger.error(f"Exception closing async connection: {e}") + + if errors: + raise RuntimeError( + f"Errors closing async connections: {'; '.join([str(e) for e in errors])}" + ) + + async def connections(self) -> list[AsyncConnection]: + """ + Get the list of active connections. + + Returns: + list[AsyncConnection]: List of all active connections managed by this environment + """ + return self._connections + + async def __aenter__(self) -> AsyncEnvironment: + """Async context manager entry.""" + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc: Optional[BaseException], + tb: Optional[object], + ) -> None: + """Async context manager exit.""" + await self.close() + + @property + def active_connections(self) -> int: + """Returns the number of active connections""" + return len(self._connections) diff --git a/rabbitmq_amqp_python_client/asyncio/management.py b/rabbitmq_amqp_python_client/asyncio/management.py new file mode 100644 index 0000000..1a7d50e --- /dev/null +++ b/rabbitmq_amqp_python_client/asyncio/management.py @@ -0,0 +1,375 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Callable, Optional, Union + +from ..entities import ( + ExchangeCustomSpecification, + ExchangeSpecification, + ExchangeToExchangeBindingSpecification, + ExchangeToQueueBindingSpecification, + QueueInfo, +) +from ..management import Management +from ..qpid.proton._message import Message +from ..qpid.proton.utils import BlockingConnection +from ..queues import ( + ClassicQueueSpecification, + QuorumQueueSpecification, + StreamSpecification, +) + +logger = logging.getLogger(__name__) + + +class AsyncManagement: + """ + Asyncio-compatible facade around the Management class + + This class provides methods for declaring and managing exchanges, queues, + and bindings in RabbitMQ. It uses a blocking connection to communicate + with the RabbitMQ management interface. + + Note: + The underlying Proton BlockingConnection is NOT thread-safe. A lock must + be provided by the caller (AsyncConnection) to serialize all operations. + + Attributes: + _management (Management): The underlying synchronous Management instance + _conn (BlockingConnection): The shared blocking connection + _connection_lock (asyncio.Lock): Lock for coordinating access to the shared connection + _remove_callback (Optional[Callable[[AsyncManagement], None]]): Callback on close + _opened (bool): Indicates if the management interface is open + """ + + def __init__( + self, + conn: BlockingConnection, + *, + connection_lock: asyncio.Lock, + ): + """ + Initialize AsyncManagement. + + Args: + conn: The blocking connection to use + connection_lock: Lock for coordinating access to the shared connection. + Must be created by the caller (AsyncConnection). + """ + self._conn = conn + self._management = Management(conn) + self._connection_lock = connection_lock + self._remove_callback: Optional[Callable[[AsyncManagement], None]] = None + self._opened = False + + def _check_is_open(self) -> None: + """ + Check if the management interface is open. + + Raises: + RuntimeError: If the management interface is not open + """ + if not self._opened: + raise RuntimeError("Management interface is not open") + + def _set_remove_callback( + self, callback: Optional[Callable[[AsyncManagement], None]] + ) -> None: + """Set callback to be called when management is closed.""" + self._remove_callback = callback + + async def open(self) -> None: + """ + Open the management connection by creating sender and receiver. + + Creates sender and receiver if they don't exist, using the management + node address defined in CommonValues. + """ + if self._opened: + return + + async with self._connection_lock: + await asyncio.to_thread(self._management.open) + + self._opened = True + logger.debug("AsyncManagement opened") + + async def close(self) -> None: + """ + Close the management connection. + + Closes both sender and receiver if they exist. + """ + if not self._opened: + return + + try: + async with self._connection_lock: + await asyncio.to_thread(self._management.close) + + logger.debug("AsyncManagement closed") + except Exception as e: + logger.error(f"Error closing management: {e}", exc_info=True) + raise + finally: + self._opened = False + if self._remove_callback is not None: + callback = self._remove_callback + self._remove_callback = None # Prevent multiple calls + callback(self) + + async def request( + self, + body: Any, + path: str, + method: str, + expected_response_codes: list[int], + ) -> Message: + """ + Send a management request with a new UUID. + + Args: + body: The request body to send + path: The management API path + method: The HTTP method to use + expected_response_codes: List of acceptable response codes + + Returns: + Message: The response message from the server + + Raises: + RuntimeError: If management interface is not open + ValidationCodeException: If response code is not in expected_response_codes + """ + self._check_is_open() + async with self._connection_lock: + return await asyncio.to_thread( + self._management.request, + body, + path, + method, + expected_response_codes, + ) + + async def declare_exchange( + self, + exchange_specification: Union[ + ExchangeSpecification, ExchangeCustomSpecification + ], + ) -> Union[ExchangeSpecification, ExchangeCustomSpecification]: + """ + Declare a new exchange in RabbitMQ. + + Args: + exchange_specification: The specification for the exchange to create + + Returns: + The same specification object that was passed in + + Raises: + RuntimeError: If management interface is not open + ValidationCodeException: If exchange already exists or other validation fails + """ + self._check_is_open() + async with self._connection_lock: + return await asyncio.to_thread( + self._management.declare_exchange, + exchange_specification, + ) + + async def declare_queue( + self, + queue_specification: Union[ + ClassicQueueSpecification, QuorumQueueSpecification, StreamSpecification + ], + ) -> Union[ + ClassicQueueSpecification, QuorumQueueSpecification, StreamSpecification + ]: + """ + Declare a new queue in RabbitMQ. + + Supports declaration of classic queues, quorum queues, and streams. + + Args: + queue_specification: The specification for the queue to create + + Returns: + The same specification object that was passed in + + Raises: + RuntimeError: If management interface is not open + ValidationCodeException: If queue already exists or other validation fails + """ + self._check_is_open() + async with self._connection_lock: + return await asyncio.to_thread( + self._management.declare_queue, + queue_specification, + ) + + async def delete_exchange(self, name: str) -> None: + """ + Delete an exchange. + + Args: + name: The name of the exchange to delete + + Raises: + RuntimeError: If management interface is not open + ValidationCodeException: If exchange doesn't exist or deletion fails + """ + self._check_is_open() + async with self._connection_lock: + await asyncio.to_thread( + self._management.delete_exchange, + name, + ) + + async def delete_queue(self, name: str) -> None: + """ + Delete a queue. + + Args: + name: The name of the queue to delete + + Raises: + RuntimeError: If management interface is not open + ValidationCodeException: If queue doesn't exist or deletion fails + """ + self._check_is_open() + async with self._connection_lock: + await asyncio.to_thread( + self._management.delete_queue, + name, + ) + + async def bind( + self, + bind_specification: Union[ + ExchangeToQueueBindingSpecification, ExchangeToExchangeBindingSpecification + ], + ) -> str: + """ + Create a binding between exchanges or between an exchange and a queue. + + Args: + bind_specification: The specification for the binding to create + + Returns: + str: The binding path created + + Raises: + RuntimeError: If management interface is not open + ValidationCodeException: If binding creation fails + """ + self._check_is_open() + async with self._connection_lock: + return await asyncio.to_thread( + self._management.bind, + bind_specification, + ) + + async def unbind( + self, + bind_specification: Union[ + str, + ExchangeToQueueBindingSpecification, + ExchangeToExchangeBindingSpecification, + ], + ) -> None: + """ + Remove a binding between exchanges or between an exchange and a queue. + + Args: + bind_specification: Either a binding path string or a binding specification + + Raises: + RuntimeError: If management interface is not open + ValidationCodeException: If unbinding fails + """ + self._check_is_open() + async with self._connection_lock: + await asyncio.to_thread( + self._management.unbind, + bind_specification, + ) + + async def purge_queue(self, name: str) -> int: + """ + Purge all messages from a queue. + + Args: + name: The name of the queue to purge + + Returns: + int: The number of messages that were purged + + Raises: + RuntimeError: If management interface is not open + ValidationCodeException: If queue doesn't exist or purge fails + """ + self._check_is_open() + async with self._connection_lock: + return await asyncio.to_thread( + self._management.purge_queue, + name, + ) + + async def queue_info(self, name: str) -> QueueInfo: + """ + Get information about a queue. + + Args: + name: The name of the queue to get information about + + Returns: + QueueInfo: Object containing queue information + + Raises: + RuntimeError: If management interface is not open + ValidationCodeException: If queue doesn't exist or other errors occur + """ + self._check_is_open() + async with self._connection_lock: + return await asyncio.to_thread( + self._management.queue_info, + name, + ) + + async def refresh_token(self, token: str) -> None: + """ + Refresh the oauth token + + Args: + token: the oauth token to refresh + + Raises: + RuntimeError: If management interface is not open + ValidationCodeException: If oauth is not enabled + """ + self._check_is_open() + async with self._connection_lock: + await asyncio.to_thread( + self._management.refresh_token, + token, + ) + + async def __aenter__(self) -> AsyncManagement: + """Async context manager entry.""" + await self.open() + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc: Optional[BaseException], + tb: Optional[object], + ) -> None: + """Async context manager exit.""" + await self.close() + + @property + def is_open(self) -> bool: + """Check if the management interface is open.""" + return self._opened diff --git a/rabbitmq_amqp_python_client/asyncio/publisher.py b/rabbitmq_amqp_python_client/asyncio/publisher.py new file mode 100644 index 0000000..9008ed7 --- /dev/null +++ b/rabbitmq_amqp_python_client/asyncio/publisher.py @@ -0,0 +1,157 @@ +import asyncio +import logging +from types import TracebackType +from typing import Callable, Optional, Type + +from ..publisher import Publisher +from ..qpid.proton._delivery import Delivery +from ..qpid.proton._message import Message +from ..qpid.proton.utils import BlockingConnection + +logger = logging.getLogger(__name__) + + +class AsyncPublisher: + """ + Asyncio-compatible facade for the Publisher class. + + This class wraps the synchronous Publisher to provide an async interface. + All blocking operations are executed in threads to avoid blocking the event loop. + + Note: + The underlying Proton BlockingConnection is NOT thread-safe. A lock must + be provided by the caller (AsyncConnection) to serialize all operations. + + Attributes: + _publisher (Optional[Publisher]): The underlying synchronous Publisher + _conn (BlockingConnection): The shared blocking connection + _addr (str): The default address for publishing + _connection_lock (asyncio.Lock): Lock for coordinating access to the shared connection + _remove_callback (Optional[Callable[[AsyncPublisher], None]]): Callback on close + _opened (bool): Indicates if the publisher is opened + """ + + def __init__( + self, + conn: BlockingConnection, + addr: str = "", + *, + connection_lock: asyncio.Lock, + ) -> None: + """ + Initialize AsyncPublisher. + + Args: + conn: The blocking connection to use + addr: Optional default address for publishing + connection_lock: Lock for coordinating access to the shared connection. + Must be created by the caller (AsyncConnection). + + Note: + The underlying Publisher is NOT created here. Call open() explicitly + or use the async context manager. + """ + self._conn = conn + self._addr = addr + self._publisher: Optional[Publisher] = None + self._connection_lock = connection_lock + self._remove_callback: Optional[Callable[["AsyncPublisher"], None]] = None + self._opened = False + + def _set_remove_callback( + self, callback: Optional[Callable[["AsyncPublisher"], None]] + ) -> None: + """Set callback to be called when publisher is closed.""" + self._remove_callback = callback + + async def open(self) -> None: + """ + Open the publisher in an async context. + + Creates the underlying Publisher instance. This should be called + before using the publisher, either explicitly or via async context manager. + """ + if self._opened: + return + + # Create publisher in thread to avoid blocking event loop + # Note: We don't need the lock here because Publisher.__init__ doesn't + # send any network traffic, it just initializes local state + self._publisher = await asyncio.to_thread(Publisher, self._conn, self._addr) + self._opened = True + logger.debug(f"AsyncPublisher opened for address: {self._addr}") + + async def publish(self, message: Message) -> Delivery: + """ + Publish a message to RabbitMQ. + + The message can be sent to either the publisher's default address or + to an address specified in the message itself, but not both. + + Args: + message: The message to publish + + Returns: + Delivery: The delivery confirmation from RabbitMQ + + Raises: + RuntimeError: If publisher is not opened + ValidationCodeException: If address is specified in both message and publisher + ArgumentOutOfRangeException: If message address format is invalid + """ + if not self._opened or self._publisher is None: + raise RuntimeError( + "Publisher is not opened. Call open() or use async context manager." + ) + + async with self._connection_lock: + return await asyncio.to_thread(self._publisher.publish, message) + + async def close(self) -> None: + """ + Close the publisher connection. + + Closes the sender if it exists and cleans up resources. + """ + if not self._opened or self._publisher is None: + return + + try: + async with self._connection_lock: + await asyncio.to_thread(self._publisher.close) + + logger.debug(f"AsyncPublisher closed for address: {self._addr}") + except Exception as e: + logger.error(f"Error closing publisher: {e}", exc_info=True) + raise + finally: + self._opened = False + self._publisher = None + if self._remove_callback is not None: + callback = self._remove_callback + self._remove_callback = None # Prevent double-call + callback(self) + + async def __aenter__(self) -> "AsyncPublisher": + """Async context manager entry.""" + await self.open() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """Async context manager exit.""" + await self.close() + + @property + def is_open(self) -> bool: + """Check if publisher is open and ready to send messages.""" + return self._opened and self._publisher is not None and self._publisher.is_open + + @property + def address(self) -> str: + """Get the current publisher address.""" + return self._addr diff --git a/rabbitmq_amqp_python_client/management.py b/rabbitmq_amqp_python_client/management.py index 7a64f2f..68ba568 100644 --- a/rabbitmq_amqp_python_client/management.py +++ b/rabbitmq_amqp_python_client/management.py @@ -18,6 +18,7 @@ BlockingConnection, BlockingReceiver, BlockingSender, + LinkDetached, ) from .queues import ( ClassicQueueSpecification, @@ -92,10 +93,22 @@ def close(self) -> None: Closes both sender and receiver if they exist. """ logger.debug("Closing Sender and Receiver") - if self._sender is not None: - self._sender.close() - if self._receiver is not None: - self._receiver.close() + + if self._sender: + try: + self._sender.close() + except LinkDetached as e: + # avoid raising exception if the queue is deleted before closing the link + if e.condition and e.condition != "amqp:resource-deleted": + raise + + if self._receiver: + try: + self._receiver.close() + except LinkDetached as e: + if e.condition and e.condition != "amqp:resource-deleted": + raise + pass def request( self, diff --git a/rabbitmq_amqp_python_client/publisher.py b/rabbitmq_amqp_python_client/publisher.py index 824170c..0ae4bab 100644 --- a/rabbitmq_amqp_python_client/publisher.py +++ b/rabbitmq_amqp_python_client/publisher.py @@ -8,6 +8,7 @@ ) from .options import SenderOptionUnseattle from .qpid.proton._delivery import Delivery +from .qpid.proton._endpoints import Endpoint from .qpid.proton._message import Message from .qpid.proton.utils import ( BlockingConnection, @@ -110,16 +111,24 @@ def close(self) -> None: logger.debug("Closing Sender") if self.is_open: self._sender.close() # type: ignore + self._sender = None if self in self._publishers: self._publishers.remove(self) def _create_sender(self, addr: str) -> BlockingSender: return self._conn.create_sender(addr, options=SenderOptionUnseattle(addr)) + def _is_sender_closed(self) -> bool: + if self._sender is None: + return True + return bool( + self._sender.link.state & (Endpoint.LOCAL_CLOSED | Endpoint.REMOTE_CLOSED) + ) + @property def is_open(self) -> bool: """Check if publisher is open and ready to send messages.""" - return self._sender is not None + return self._sender is not None and not self._is_sender_closed() @property def address(self) -> str: diff --git a/tests/asyncio/__init__.py b/tests/asyncio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/asyncio/fixtures.py b/tests/asyncio/fixtures.py new file mode 100644 index 0000000..58597ee --- /dev/null +++ b/tests/asyncio/fixtures.py @@ -0,0 +1,81 @@ +from datetime import datetime, timedelta +from typing import AsyncGenerator, Union + +import pytest_asyncio + +from rabbitmq_amqp_python_client import ( + AsyncConnection, + AsyncEnvironment, + AsyncManagement, + OAuth2Options, + PosixSslConfigurationContext, + RecoveryConfiguration, + WinSslConfigurationContext, +) + +from ..utils import token + + +@pytest_asyncio.fixture +async def async_environment(): + environment = AsyncEnvironment(uri="amqp://guest:guest@localhost:5672/") + yield environment + await environment.close() + + +@pytest_asyncio.fixture +async def async_environment_auth() -> AsyncGenerator[AsyncEnvironment, None]: + token_string = token(datetime.now() + timedelta(milliseconds=2500)) + environment = AsyncEnvironment( + uri="amqp://localhost:5672", + oauth2_options=OAuth2Options(token=token_string), + ) + yield environment + await environment.close() + + +@pytest_asyncio.fixture +async def async_connection() -> AsyncGenerator[AsyncConnection, None]: + environment = AsyncEnvironment( + uri="amqp://guest:guest@localhost:5672/", + ) + connection = await environment.connection() + await connection.dial() + yield connection + await connection.close() + + +@pytest_asyncio.fixture +async def async_connection_with_reconnect() -> AsyncGenerator[AsyncConnection, None]: + environment = AsyncEnvironment( + uri="amqp://guest:guest@localhost:5672/", + recovery_configuration=RecoveryConfiguration(active_recovery=True), + ) + connection = await environment.connection() + await connection.dial() + yield connection + await connection.close() + + +@pytest_asyncio.fixture +async def async_connection_ssl( + ssl_context: Union[PosixSslConfigurationContext, WinSslConfigurationContext], +) -> AsyncGenerator[AsyncConnection, None]: + environment = AsyncEnvironment( + uri="amqps://guest:guest@localhost:5671/", + ssl_context=ssl_context, + ) + connection = await environment.connection() + await connection.dial() + yield connection + await connection.close() + + +@pytest_asyncio.fixture +async def async_management() -> AsyncGenerator[AsyncManagement, None]: + environment = AsyncEnvironment(uri="amqp://guest:guest@localhost:5672/") + connection = await environment.connection() + await connection.dial() + management = await connection.management() + yield management + await management.close() diff --git a/tests/asyncio/test_amqp_091.py b/tests/asyncio/test_amqp_091.py new file mode 100644 index 0000000..52089a8 --- /dev/null +++ b/tests/asyncio/test_amqp_091.py @@ -0,0 +1,72 @@ +import functools + +import pika +import pytest + +from rabbitmq_amqp_python_client import ( + AddressHelper, + AsyncConnection, + Converter, + OutcomeState, + QuorumQueueSpecification, +) +from rabbitmq_amqp_python_client.qpid.proton import Message + +from .fixtures import * # noqa: F401, F403 + + +@pytest.mark.asyncio +async def test_publish_queue(async_connection: AsyncConnection) -> None: + queue_name = "amqp091-queue" + management = await async_connection.management() + + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + raised = False + + publisher = None + accepted = False + + try: + publisher = await async_connection.publisher( + destination=AddressHelper.queue_address(queue_name) + ) + status = await publisher.publish( + Message(body=Converter.string_to_bytes("my_test_string_for_amqp")) + ) + if status.remote_state == OutcomeState.ACCEPTED: + accepted = True + except Exception: + raised = True + + if publisher is not None: + await publisher.close() + + assert accepted is True + assert raised is False + + credentials = pika.PlainCredentials("guest", "guest") + parameters = pika.ConnectionParameters("localhost", credentials=credentials) + connection = pika.BlockingConnection(parameters) + channel = connection.channel() + + def on_message(chan, method_frame, header_frame, body, userdata=None): + """Called when a message is received. Log message and ack it.""" + chan.basic_ack(delivery_tag=method_frame.delivery_tag) + assert body is not None + body_text = Converter.bytes_to_string(body) + assert body_text is not None + assert body_text == "my_test_string_for_amqp" + channel.stop_consuming() + + on_message_callback = functools.partial(on_message, userdata="on_message_userdata") + channel.basic_qos( + prefetch_count=1, + ) + channel.basic_consume(queue_name, on_message_callback) + + channel.start_consuming() + connection.close() + + await management.delete_queue(queue_name) + await management.close() diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py new file mode 100644 index 0000000..c55f2d5 --- /dev/null +++ b/tests/asyncio/test_connection.py @@ -0,0 +1,262 @@ +import time +from datetime import datetime, timedelta +from pathlib import Path + +import pytest + +from rabbitmq_amqp_python_client import ( + AsyncEnvironment, + ConnectionClosed, + PKCS12Store, + PosixSslConfigurationContext, + QuorumQueueSpecification, + RecoveryConfiguration, + StreamSpecification, + ValidationCodeException, + WinSslConfigurationContext, +) +from rabbitmq_amqp_python_client.qpid.proton import ( + ConnectionException, +) + +from ..http_requests import ( + create_vhost, + delete_all_connections, + delete_vhost, +) +from ..utils import token +from .fixtures import * # noqa: F401, F403 + + +def on_disconnected(): + global disconnected + disconnected = True + + +@pytest.mark.asyncio +async def test_async_connection() -> None: + environment = AsyncEnvironment(uri="amqp://guest:guest@localhost:5672/") + connection = await environment.connection() + await connection.dial() + await environment.close() + + +@pytest.mark.asyncio +async def test_async_environment_context_manager() -> None: + async with AsyncEnvironment( + uri="amqp://guest:guest@localhost:5672/" + ) as environment: + connection = await environment.connection() + await connection.dial() + + +@pytest.mark.asyncio +async def test_async_connection_ssl(ssl_context) -> None: + environment = AsyncEnvironment( + "amqps://guest:guest@localhost:5671/", + ssl_context=ssl_context, + ) + if isinstance(ssl_context, PosixSslConfigurationContext): + path = Path(ssl_context.ca_cert) + assert path.is_file() is True + assert path.exists() is True + + path = Path(ssl_context.client_cert.client_cert) # type: ignore + assert path.is_file() is True + assert path.exists() is True + elif isinstance(ssl_context, WinSslConfigurationContext): + assert isinstance(ssl_context.ca_store, PKCS12Store) + path = Path(ssl_context.ca_store.path) + assert path.is_file() is True + assert path.exists() is True + + assert isinstance(ssl_context.client_cert.store, PKCS12Store) # type: ignore + path = Path(ssl_context.client_cert.store.path) # type: ignore + assert path.is_file() is True + assert path.exists() is True + else: + pytest.fail("Unsupported ssl context") + + connection = await environment.connection() + await connection.dial() + + await environment.close() + + +@pytest.mark.asyncio +async def test_async_connection_oauth(async_environment_auth: AsyncEnvironment) -> None: + connection = await async_environment_auth.connection() + await connection.dial() + management = await connection.management() + await management.declare_queue(QuorumQueueSpecification(name="test-queue")) + await management.close() + await connection.close() + + +@pytest.mark.asyncio +async def test_async_connection_oauth_with_timeout( + async_environment_auth: AsyncEnvironment, +) -> None: + connection = await async_environment_auth.connection() + await connection.dial() + + # let the token expire + time.sleep(3) + # token expired + + with pytest.raises(Exception): + management = await connection.management() + await management.declare_queue(QuorumQueueSpecification(name="test-queue")) + await management.close() + + await connection.close() + + +@pytest.mark.asyncio +async def test_async_connection_oauth_refresh_token( + async_environment_auth: AsyncEnvironment, +) -> None: + connection = await async_environment_auth.connection() + await connection.dial() + + # let the token expire + time.sleep(1) + # # token expired, refresh + + await connection.refresh_token(token(datetime.now() + timedelta(milliseconds=5000))) + time.sleep(3) + + with pytest.raises(Exception): + management = await connection.management() + await management.declare_queue(QuorumQueueSpecification(name="test-queue")) + await management.close() + + await connection.close() + + +@pytest.mark.asyncio +async def test_async_connection_oauth_refresh_token_with_disconnection( + async_environment_auth: AsyncEnvironment, +) -> None: + connection = await async_environment_auth.connection() + await connection.dial() + + # let the token expire + time.sleep(1) + # # token expired, refresh + + await connection.refresh_token(token(datetime.now() + timedelta(milliseconds=5000))) + delete_all_connections() + time.sleep(3) + + with pytest.raises(Exception): + management = await connection.management() + await management.declare_queue(QuorumQueueSpecification(name="test-queue")) + await management.close() + + await connection.close() + + +@pytest.mark.asyncio +async def test_async_environment_connections_management() -> None: + enviroment = AsyncEnvironment(uri="amqp://guest:guest@localhost:5672/") + + connection1 = await enviroment.connection() + await connection1.dial() + connection2 = await enviroment.connection() + await connection2.dial() + connection3 = await enviroment.connection() + await connection3.dial() + + assert enviroment.active_connections == 3 + + # this shouldn't happen but we test it anyway + await connection1.close() + assert enviroment.active_connections == 2 + + await connection2.close() + assert enviroment.active_connections == 1 + + await connection3.close() + assert enviroment.active_connections == 0 + + await enviroment.close() + + +@pytest.mark.asyncio +async def test_async_connection_reconnection() -> None: + disconnected = False + enviroment = AsyncEnvironment( + uri="amqp://guest:guest@localhost:5672/", + recovery_configuration=RecoveryConfiguration(active_recovery=True), + ) + + connection = await enviroment.connection() + await connection.dial() + + # delay + time.sleep(5) + # simulate a disconnection + # raise a reconnection + management = await connection.management() + + delete_all_connections() + + stream_name = "test_stream_info_with_validation" + queue_specification = StreamSpecification( + name=stream_name, + ) + + try: + await management.declare_queue(queue_specification) + except ConnectionClosed: + disconnected = True + + # check that we reconnected + await management.declare_queue(queue_specification) + await management.delete_queue(stream_name) + await management.close() + await enviroment.close() + + assert disconnected is True + + +@pytest.mark.asyncio +async def test_async_reconnection_parameters() -> None: + enviroment = AsyncEnvironment( + uri="amqp://guest:guest@localhost:5672/", + recovery_configuration=RecoveryConfiguration( + active_recovery=True, + back_off_reconnect_interval=timedelta(milliseconds=100), + ), + ) + + with pytest.raises(ValidationCodeException): + await enviroment.connection() + + +@pytest.mark.asyncio +async def test_async_connection_vhost() -> None: + vhost = "tmpVhost" + str(time.time()) + create_vhost(vhost) + uri = "amqp://guest:guest@localhost:5672/{}".format(vhost) + environment = AsyncEnvironment(uri=uri) + connection = await environment.connection() + await connection.dial() + is_correct_vhost = connection._connection._conn.conn.hostname == "vhost:{}".format(vhost) # type: ignore + await environment.close() + delete_vhost(vhost) + + assert is_correct_vhost is True + + +@pytest.mark.asyncio +async def test_async_connection_vhost_not_exists() -> None: + vhost = "tmpVhost" + str(time.time()) + uri = "amqp://guest:guest@localhost:5672/{}".format(vhost) + + environment = AsyncEnvironment(uri=uri) + + with pytest.raises(ConnectionException): + connection = await environment.connection() + await connection.dial() diff --git a/tests/asyncio/test_consumer.py b/tests/asyncio/test_consumer.py new file mode 100644 index 0000000..ad7f315 --- /dev/null +++ b/tests/asyncio/test_consumer.py @@ -0,0 +1,383 @@ +import pytest + +from rabbitmq_amqp_python_client import ( + AddressHelper, + ArgumentOutOfRangeException, + AsyncConnection, + AsyncEnvironment, + QuorumQueueSpecification, +) +from rabbitmq_amqp_python_client.utils import Converter + +from ..conftest import ( + ConsumerTestException, + MyMessageHandlerAccept, + MyMessageHandlerDiscard, + MyMessageHandlerDiscardWithAnnotations, + MyMessageHandlerNoack, + MyMessageHandlerRequeue, + MyMessageHandlerRequeueWithAnnotations, + MyMessageHandlerRequeueWithInvalidAnnotations, +) +from .fixtures import * # noqa: F401, F403 +from .utils import ( + async_cleanup_dead_lettering, + async_publish_messages, + async_setup_dead_lettering, +) + + +@pytest.mark.asyncio +async def test_async_consumer_sync_queue_accept( + async_connection: AsyncConnection, +) -> None: + queue_name = "test-queue-sync-accept" + messages_to_send = 100 + management = await async_connection.management() + + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + consumer = await async_connection.consumer( + destination=AddressHelper.queue_address(queue_name) + ) + + consumed = 0 + + # publish messages_to_send messages + await async_publish_messages(async_connection, messages_to_send, queue_name) + + # consumer synchronously without handler + for i in range(messages_to_send): + message = await consumer.consume() + if Converter.bytes_to_string(message.body) == "test{}".format(i): # type: ignore + consumed += 1 + + await consumer.close() + await management.delete_queue(queue_name) + await management.close() + + assert consumed == messages_to_send + + +@pytest.mark.asyncio +async def test_async_consumer_invalid_destination( + async_connection: AsyncConnection, +) -> None: + queue_name = "test-queue-sync-invalid-accept" + consumer = None + + with pytest.raises(ArgumentOutOfRangeException): + consumer = await async_connection.consumer(destination="/invalid/" + queue_name) + + if consumer is not None: + await consumer.close() + + +@pytest.mark.asyncio +async def test_async_consumer_async_queue_accept( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + messages_to_send = 1000 + queue_name = "test-queue-async-accept" + + management = await async_connection.management() + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + await async_publish_messages(async_connection, messages_to_send, queue_name) + + # we closed the connection so we need to open a new one + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + destination=AddressHelper.queue_address(queue_name), + message_handler=MyMessageHandlerAccept(), + ) + + try: + await consumer.run() + # ack to terminate the consumer + except ConsumerTestException: + pass + + await consumer.close() + + message_count = await management.purge_queue(queue_name) + await management.delete_queue(queue_name) + await management.close() + + assert message_count == 0 + + +@pytest.mark.asyncio +async def test_async_consumer_async_queue_no_ack( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + messages_to_send = 1000 + queue_name = "test-queue-async-no-ack" + + management = await async_connection.management() + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + await async_publish_messages(async_connection, messages_to_send, queue_name) + + # we closed the connection so we need to open a new one + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + destination=AddressHelper.queue_address(queue_name), + message_handler=MyMessageHandlerNoack(), + ) + + try: + await consumer.run() + # ack to terminate the consumer + except ConsumerTestException: + pass + + await consumer.close() + message_count = await management.purge_queue(queue_name) + await management.delete_queue(queue_name) + await management.close() + + assert message_count == messages_to_send + + +@pytest.mark.asyncio +async def test_async_consumer_async_queue_with_discard( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + messages_to_send = 1000 + + queue_dead_lettering = "queue-dead-letter" + queue_name = "test-queue-async-discard" + exchange_dead_lettering = "exchange-dead-letter" + binding_key = "key-dead-letter" + + management = await async_connection.management() + + # configuring dead lettering + bind_path = await async_setup_dead_lettering(management) + addr_queue = AddressHelper.queue_address(queue_name) + + await management.declare_queue( + QuorumQueueSpecification( + name=queue_name, + dead_letter_exchange=exchange_dead_lettering, + dead_letter_routing_key=binding_key, + ) + ) + + await async_publish_messages(async_connection, messages_to_send, queue_name) + + # we closed the connection so we need to open a new one + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + destination=addr_queue, + message_handler=MyMessageHandlerDiscard(), + ) + + try: + await consumer.run() + # ack to terminate the consumer + except ConsumerTestException: + pass + + await consumer.close() + + message_count = await management.purge_queue(queue_name) + await management.delete_queue(queue_name) + + message_count_dead_lettering = await management.purge_queue(queue_dead_lettering) + await async_cleanup_dead_lettering(management, bind_path) + + await management.close() + + assert message_count == 0 + # check dead letter queue + assert message_count_dead_lettering == messages_to_send + + +@pytest.mark.asyncio +async def test__async_consumer_async_queue_with_discard_with_annotations( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + messages_to_send = 1000 + + queue_dead_lettering = "queue-dead-letter" + queue_name = "test-queue-async-discard-annotations" + exchange_dead_lettering = "exchange-dead-letter" + binding_key = "key-dead-letter" + + management = await async_connection.management() + + await management.declare_queue( + QuorumQueueSpecification( + name=queue_name, + dead_letter_exchange=exchange_dead_lettering, + dead_letter_routing_key=binding_key, + ) + ) + + await async_publish_messages(async_connection, messages_to_send, queue_name) + + bind_path = await async_setup_dead_lettering(management) + addr_queue = AddressHelper.queue_address(queue_name) + addr_queue_dl = AddressHelper.queue_address(queue_dead_lettering) + + # we closed the connection so we need to open a new one + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + destination=addr_queue, + message_handler=MyMessageHandlerDiscardWithAnnotations(), + ) + + try: + await consumer.run() + # ack to terminate the consumer + except ConsumerTestException: + pass + + await consumer.close() + + # check for added annotation + new_consumer = await async_connection.consumer(addr_queue_dl) + message = await new_consumer.consume() + await new_consumer.close() + + message_count = await management.purge_queue(queue_name) + await management.delete_queue(queue_name) + + message_count_dead_lettering = await management.purge_queue(queue_dead_lettering) + await async_cleanup_dead_lettering(management, bind_path) + + await management.close() + + assert "x-opt-string" in message.annotations # type: ignore + assert message_count == 0 + # check dead letter queue + assert message_count_dead_lettering == messages_to_send + + +@pytest.mark.asyncio +async def test_async_consumer_async_queue_with_requeue( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + messages_to_send = 1000 + queue_name = "test-queue-async-requeue" + + management = await async_connection.management() + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + addr_queue = AddressHelper.queue_address(queue_name) + await async_publish_messages(async_connection, messages_to_send, queue_name) + + # we closed the connection so we need to open a new one + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + + consumer = await connection_consumer.consumer( + destination=addr_queue, + message_handler=MyMessageHandlerRequeue(), + ) + + try: + await consumer.run() + # ack to terminate the consumer + except ConsumerTestException: + pass + + await consumer.close() + + message_count = await management.purge_queue(queue_name) + + await management.delete_queue(queue_name) + await management.close() + + assert message_count > 0 + + +@pytest.mark.asyncio +async def test_async_consumer_async_queue_with_requeue_with_annotations( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + messages_to_send = 1000 + queue_name = "test-queue-async-requeue" + + management = await async_connection.management() + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + addr_queue = AddressHelper.queue_address(queue_name) + await async_publish_messages(async_connection, messages_to_send, queue_name) + + # we closed the connection so we need to open a new one + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + + consumer = await connection_consumer.consumer( + destination=addr_queue, + message_handler=MyMessageHandlerRequeueWithAnnotations(), + ) + + try: + await consumer.run() + # ack to terminate the consumer + except ConsumerTestException: + pass + + await consumer.close() + + # check for added annotation + new_consumer = await async_connection.consumer(addr_queue) + message = await new_consumer.consume() + await new_consumer.close() + + message_count = await management.purge_queue(queue_name) + + await management.delete_queue(queue_name) + await management.close() + + assert "x-opt-string" in message.annotations # type: ignore + assert message_count > 0 + + +@pytest.mark.asyncio +async def test_async_consumer_async_queue_with_requeue_with_invalid_annotations( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + messages_to_send = 1000 + test_failure = True + queue_name = "test-queue-async-requeue" + + management = await async_connection.management() + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + addr_queue = AddressHelper.queue_address(queue_name) + await async_publish_messages(async_connection, messages_to_send, queue_name) + + # we closed the connection so we need to open a new one + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + + consumer = None + try: + consumer = await connection_consumer.consumer( + destination=addr_queue, + message_handler=MyMessageHandlerRequeueWithInvalidAnnotations(), + ) + + await consumer.run() + # ack to terminate the consumer + except ConsumerTestException: + pass + except ArgumentOutOfRangeException: + test_failure = False + + if consumer is not None: + await consumer.close() + + await management.delete_queue(queue_name) + await management.close() + + assert test_failure is False diff --git a/tests/asyncio/test_management.py b/tests/asyncio/test_management.py new file mode 100644 index 0000000..52668a6 --- /dev/null +++ b/tests/asyncio/test_management.py @@ -0,0 +1,570 @@ +from datetime import timedelta + +import pytest + +from rabbitmq_amqp_python_client import ( + AsyncManagement, + ClassicQueueSpecification, + ExchangeCustomSpecification, + ExchangeSpecification, + ExchangeToExchangeBindingSpecification, + ExchangeToQueueBindingSpecification, + ExchangeType, + QueueType, + QuorumQueueSpecification, + StreamSpecification, +) +from rabbitmq_amqp_python_client.exceptions import ( + ValidationCodeException, +) + +from .fixtures import * # noqa: F401, F403 + + +@pytest.mark.asyncio +async def test_async_declare_delete_exchange(async_management: AsyncManagement) -> None: + exchange_name = "test-exchange" + exchange_info = await async_management.declare_exchange( + ExchangeSpecification(name=exchange_name) + ) + + assert exchange_info.name == exchange_name + + await async_management.delete_exchange(name=exchange_name) + + +@pytest.mark.asyncio +async def test_async_declare_delete_exchange_headers( + async_management: AsyncManagement, +) -> None: + exchange_name = "test-exchange" + exchange_info = await async_management.declare_exchange( + ExchangeSpecification(name=exchange_name, exchange_type=ExchangeType.headers) + ) + + assert exchange_info.name == exchange_name + + await async_management.delete_exchange(name=exchange_name) + + +@pytest.mark.asyncio +async def test_async_declare_delete_exchange_custom( + async_management: AsyncManagement, +) -> None: + exchange_name = "test-exchange-custom" + exchange_arguments = {} + exchange_arguments["x-delayed-type"] = "direct" + + exchange_info = await async_management.declare_exchange( + ExchangeCustomSpecification( + name=exchange_name, + exchange_type="x-local-random", + arguments=exchange_arguments, + ) + ) + + assert exchange_info.name == exchange_name + + await async_management.delete_exchange(name=exchange_name) + + +@pytest.mark.asyncio +async def test_async_declare_delete_exchange_with_args( + async_management: AsyncManagement, +) -> None: + exchange_name = "test-exchange-with-args" + exchange_arguments = {} + exchange_arguments["test"] = "test" + + exchange_info = await async_management.declare_exchange( + ExchangeSpecification( + name=exchange_name, + exchange_type=ExchangeType.topic, + arguments=exchange_arguments, + ) + ) + + assert exchange_info.name == exchange_name + assert exchange_info.exchange_type == ExchangeType.topic + assert exchange_info.arguments == exchange_arguments + + await async_management.delete_exchange(name=exchange_name) + + +@pytest.mark.asyncio +async def test_async_declare_purge_delete_queue( + async_management: AsyncManagement, +) -> None: + queue_name = "my_queue" + + queue_info = await async_management.declare_queue( + ClassicQueueSpecification(name=queue_name) + ) + + assert queue_info.name == queue_name + + await async_management.purge_queue(name=queue_name) + await async_management.delete_queue(name=queue_name) + + +@pytest.mark.asyncio +async def test_async_bind_exchange_to_queue(async_management: AsyncManagement) -> None: + exchange_name = "test-bind-exchange-to-queue-exchange" + queue_name = "test-bind-exchange-to-queue-queue" + routing_key = "routing-key" + + await async_management.declare_exchange(ExchangeSpecification(name=exchange_name)) + await async_management.declare_queue(ClassicQueueSpecification(name=queue_name)) + binding_exchange_queue_path = await async_management.bind( + ExchangeToQueueBindingSpecification( + source_exchange=exchange_name, + destination_queue=queue_name, + binding_key=routing_key, + ) + ) + + assert ( + binding_exchange_queue_path + == "/bindings/src=" + + exchange_name + + ";dstq=" + + queue_name + + ";key=" + + routing_key + + ";args=" + ) + + await async_management.delete_exchange(name=exchange_name) + await async_management.delete_queue(name=queue_name) + await async_management.unbind(binding_exchange_queue_path) + + +@pytest.mark.asyncio +async def test_async_bind_exchange_to_queue_without_key( + async_management: AsyncManagement, +) -> None: + exchange_name = "test-bind-exchange-to-queue-no-key-exchange" + queue_name = "test-bind-exchange-to-queue-no-key-queue" + + await async_management.declare_exchange(ExchangeSpecification(name=exchange_name)) + await async_management.declare_queue(ClassicQueueSpecification(name=queue_name)) + binding_exchange_queue_path = await async_management.bind( + ExchangeToQueueBindingSpecification( + source_exchange=exchange_name, + destination_queue=queue_name, + ) + ) + + assert ( + binding_exchange_queue_path + == "/bindings/src=" + exchange_name + ";dstq=" + queue_name + ";key=" + ";args=" + ) + + await async_management.delete_exchange(name=exchange_name) + await async_management.delete_queue(name=queue_name) + await async_management.unbind(binding_exchange_queue_path) + + +@pytest.mark.asyncio +async def test_async_bind_exchange_to_exchange_without_key( + async_management: AsyncManagement, +) -> None: + source_exchange_name = "test-bind-exchange-to-queue-exchange" + destination_exchange_name = "test-bind-exchange-to-queue-queue" + + await async_management.declare_exchange( + ExchangeSpecification(name=source_exchange_name) + ) + await async_management.declare_exchange( + ExchangeSpecification(name=destination_exchange_name) + ) + binding_exchange_queue_path = await async_management.bind( + ExchangeToExchangeBindingSpecification( + source_exchange=source_exchange_name, + destination_exchange=destination_exchange_name, + ) + ) + + assert ( + binding_exchange_queue_path + == "/bindings/src=" + + source_exchange_name + + ";dstq=" + + destination_exchange_name + + ";key=" + + ";args=" + ) + + await async_management.unbind(binding_exchange_queue_path) + await async_management.delete_exchange(name=source_exchange_name) + await async_management.delete_exchange(name=destination_exchange_name) + + +@pytest.mark.asyncio +async def test_async_bind_unbind_by_binding_spec( + async_management: AsyncManagement, +) -> None: + exchange_name = "test-bind-exchange-to-queue-exchange" + queue_name = "test-bind-exchange-to-queue-queue" + + await async_management.declare_exchange(ExchangeSpecification(name=exchange_name)) + await async_management.declare_queue(ClassicQueueSpecification(name=queue_name)) + + await async_management.bind( + ExchangeToQueueBindingSpecification( + source_exchange=exchange_name, + destination_queue=queue_name, + ) + ) + await async_management.unbind( + ExchangeToQueueBindingSpecification( + source_exchange=exchange_name, + destination_queue=queue_name, + ) + ) + + await async_management.delete_exchange(name=exchange_name) + await async_management.delete_queue(name=queue_name) + + +@pytest.mark.asyncio +async def test_async_bind_unbind_exchange_by_exchange_spec( + async_management: AsyncManagement, +) -> None: + source_exchange_name = "test-bind-exchange-to-queue-exchange" + destination_exchange_name = "test-bind-exchange-to-queue-queue" + + await async_management.declare_exchange( + ExchangeSpecification(name=source_exchange_name) + ) + await async_management.declare_exchange( + ExchangeSpecification(name=destination_exchange_name) + ) + + binding_exchange_queue_path = await async_management.bind( + ExchangeToExchangeBindingSpecification( + source_exchange=source_exchange_name, + destination_exchange=destination_exchange_name, + ) + ) + + assert ( + binding_exchange_queue_path + == "/bindings/src=" + + source_exchange_name + + ";dstq=" + + destination_exchange_name + + ";key=" + + ";args=" + ) + + await async_management.unbind( + ExchangeToExchangeBindingSpecification( + source_exchange=source_exchange_name, + destination_exchange=destination_exchange_name, + ) + ) + + await async_management.delete_exchange(name=source_exchange_name) + await async_management.delete_exchange(name=destination_exchange_name) + + +@pytest.mark.asyncio +async def test_async_bind_exchange_to_exchange( + async_management: AsyncManagement, +) -> None: + source_exchange_name = "source_exchange" + destination_exchange_name = "destination_exchange" + routing_key = "routing-key" + + await async_management.declare_exchange( + ExchangeSpecification(name=source_exchange_name) + ) + await async_management.declare_exchange( + ExchangeSpecification(name=destination_exchange_name) + ) + + binding_exchange_exchange_path = await async_management.bind( + ExchangeToExchangeBindingSpecification( + source_exchange=source_exchange_name, + destination_exchange=destination_exchange_name, + binding_key=routing_key, + ) + ) + + assert ( + binding_exchange_exchange_path + == "/bindings/src=" + + source_exchange_name + + ";dstq=" + + destination_exchange_name + + ";key=" + + routing_key + + ";args=" + ) + + await async_management.unbind(binding_exchange_exchange_path) + await async_management.delete_exchange(name=source_exchange_name) + await async_management.delete_exchange(name=destination_exchange_name) + + +@pytest.mark.asyncio +async def test_queue_info_with_validations(async_management: AsyncManagement) -> None: + queue_name = "test_queue_info_with_validation" + + queue_specification = QuorumQueueSpecification(name=queue_name) + await async_management.declare_queue(queue_specification) + + queue_info = await async_management.queue_info(name=queue_name) + + await async_management.delete_queue(name=queue_name) + + assert queue_info.name == queue_name + assert queue_info.queue_type == QueueType.quorum + assert queue_info.is_durable is True + assert queue_info.message_count == 0 + + +@pytest.mark.asyncio +async def test_async_queue_info_for_stream_with_validations( + async_management: AsyncManagement, +) -> None: + stream_name = "test_stream_info_with_validation" + queue_specification = StreamSpecification( + name=stream_name, + ) + + await async_management.declare_queue(queue_specification) + + stream_info = await async_management.queue_info(name=stream_name) + + await async_management.delete_queue(name=stream_name) + + assert stream_info.name == stream_name + assert stream_info.queue_type == QueueType.stream + assert stream_info.message_count == 0 + + +@pytest.mark.asyncio +async def test_async_queue_precondition_failure( + async_management: AsyncManagement, +) -> None: + queue_name = "test-queue_precondition_fail" + + queue_specification = QuorumQueueSpecification(name=queue_name, max_len_bytes=100) + + await async_management.declare_queue(queue_specification) + + conflicting_queue_specification = QuorumQueueSpecification( + name=queue_name, max_len_bytes=200 + ) + + with pytest.raises(ValidationCodeException): + await async_management.declare_queue(conflicting_queue_specification) + + await async_management.delete_queue(name=queue_name) + + +@pytest.mark.asyncio +async def test_async_declare_classic_queue(async_management: AsyncManagement) -> None: + queue_name = "test-declare_classic_queue" + + queue_specification = ClassicQueueSpecification( + name=queue_name, is_auto_delete=True + ) + + queue_info = await async_management.declare_queue(queue_specification) + + assert queue_info.name == queue_specification.name + + await async_management.delete_queue(name=queue_name) + + +@pytest.mark.asyncio +async def test_async_declare_classic_queue_with_args( + async_management: AsyncManagement, +) -> None: + queue_name = "test-queue_with_args-2" + queue_specification = ClassicQueueSpecification( + name=queue_name, + is_auto_delete=False, + is_exclusive=False, + is_durable=True, + dead_letter_exchange="my_exchange", + dead_letter_routing_key="my_key", + max_len=500000, + max_len_bytes=1000000000, + message_ttl=timedelta(seconds=2), + overflow_behaviour="reject-publish", + auto_expires=timedelta(seconds=10), + single_active_consumer=True, + max_priority=100, + ) + + await async_management.declare_queue(queue_specification) + queue_info = await async_management.queue_info(name=queue_name) + + assert queue_specification.name == queue_info.name + assert queue_specification.is_auto_delete == queue_info.is_auto_delete + assert queue_specification.is_exclusive == queue_info.is_exclusive + assert queue_specification.is_durable == queue_info.is_durable + assert ( + queue_specification.message_ttl.total_seconds() * 1000 # type: ignore + ) == queue_info.arguments["x-message-ttl"] + assert queue_specification.overflow_behaviour == queue_info.arguments["x-overflow"] + assert ( + queue_specification.auto_expires.total_seconds() * 1000 # type: ignore + ) == queue_info.arguments["x-expires"] + assert queue_specification.max_priority == queue_info.arguments["x-max-priority"] + + assert ( + queue_specification.dead_letter_exchange + == queue_info.arguments["x-dead-letter-exchange"] + ) + assert ( + queue_specification.dead_letter_routing_key + == queue_info.arguments["x-dead-letter-routing-key"] + ) + assert queue_specification.max_len == queue_info.arguments["x-max-length"] + assert ( + queue_specification.max_len_bytes == queue_info.arguments["x-max-length-bytes"] + ) + + assert ( + queue_specification.single_active_consumer + == queue_info.arguments["x-single-active-consumer"] + ) + + await async_management.delete_queue(name=queue_name) + + +@pytest.mark.asyncio +async def test_async_declare_quorum_queue_with_args( + async_management: AsyncManagement, +) -> None: + queue_name = "test-queue_with_args" + queue_specification = QuorumQueueSpecification( + name=queue_name, + dead_letter_exchange="my_exchange", + dead_letter_routing_key="my_key", + max_len=500000, + max_len_bytes=1000000000, + message_ttl=timedelta(seconds=2), + overflow_behaviour="reject-publish", + auto_expires=timedelta(seconds=2), + single_active_consumer=True, + deliver_limit=10, + dead_letter_strategy="at-least-once", + quorum_initial_group_size=5, + cluster_target_group_size=5, + ) + + await async_management.declare_queue(queue_specification) + queue_info = await async_management.queue_info(name=queue_name) + + assert queue_specification.name == queue_info.name + assert queue_info.is_auto_delete is False + assert queue_info.is_exclusive is False + assert queue_info.is_durable is True + assert ( + queue_specification.message_ttl.total_seconds() * 1000 # type: ignore + ) == queue_info.arguments["x-message-ttl"] + assert queue_specification.overflow_behaviour == queue_info.arguments["x-overflow"] + assert ( + queue_specification.auto_expires.total_seconds() * 1000 # type: ignore + ) == queue_info.arguments["x-expires"] + + assert ( + queue_specification.dead_letter_exchange + == queue_info.arguments["x-dead-letter-exchange"] + ) + assert ( + queue_specification.dead_letter_routing_key + == queue_info.arguments["x-dead-letter-routing-key"] + ) + assert queue_specification.max_len == queue_info.arguments["x-max-length"] + assert ( + queue_specification.max_len_bytes == queue_info.arguments["x-max-length-bytes"] + ) + + assert ( + queue_specification.single_active_consumer + == queue_info.arguments["x-single-active-consumer"] + ) + + assert queue_specification.deliver_limit == queue_info.arguments["x-deliver-limit"] + assert ( + queue_specification.dead_letter_strategy + == queue_info.arguments["x-dead-letter-strategy"] + ) + assert ( + queue_specification.quorum_initial_group_size + == queue_info.arguments["x-quorum-initial-group-size"] + ) + assert ( + queue_specification.cluster_target_group_size + == queue_info.arguments["x-quorum-target-group-size"] + ) + + await async_management.delete_queue(name=queue_name) + + +@pytest.mark.asyncio +async def test_async_declare_stream_with_args( + async_management: AsyncManagement, +) -> None: + stream_name = "test-stream_with_args" + stream_specification = StreamSpecification( + name=stream_name, + max_len_bytes=1000, + max_age=timedelta(seconds=200000), + stream_max_segment_size_bytes=200, + stream_filter_size_bytes=100, + initial_group_size=5, + ) + + await async_management.declare_queue(stream_specification) + stream_info = await async_management.queue_info(name=stream_name) + + assert stream_specification.name == stream_info.name + assert stream_info.is_auto_delete is False + assert stream_info.is_exclusive is False + assert stream_info.is_durable is True + assert ( + stream_specification.max_len_bytes + == stream_info.arguments["x-max-length-bytes"] + ) + assert ( + str(int(stream_specification.max_age.total_seconds())) + "s" # type: ignore + == stream_info.arguments["x-max-age"] + ) + assert ( + stream_specification.stream_max_segment_size_bytes + == stream_info.arguments["x-stream-max-segment-size-bytes"] + ) + assert ( + stream_specification.stream_filter_size_bytes + == stream_info.arguments["x-stream-filter-size-bytes"] + ) + assert ( + stream_specification.initial_group_size + == stream_info.arguments["x-initial-group-size"] + ) + + await async_management.delete_queue(name=stream_name) + + +@pytest.mark.asyncio +async def test_async_declare_classic_queue_with_invalid_args( + async_management: AsyncManagement, +) -> None: + queue_name = "test-queue_with_args" + queue_specification = ClassicQueueSpecification( + name=queue_name, + max_len=-5, + ) + + with pytest.raises(ValidationCodeException): + await async_management.declare_queue(queue_specification) + + await async_management.delete_queue(name=queue_name) diff --git a/tests/asyncio/test_publisher.py b/tests/asyncio/test_publisher.py new file mode 100644 index 0000000..9212f79 --- /dev/null +++ b/tests/asyncio/test_publisher.py @@ -0,0 +1,707 @@ +import asyncio +import time + +import pytest + +from rabbitmq_amqp_python_client import ( + AddressHelper, + ArgumentOutOfRangeException, + AsyncConnection, + AsyncEnvironment, + ConnectionClosed, + Message, + OutcomeState, + QuorumQueueSpecification, + RecoveryConfiguration, + StreamSpecification, + ValidationCodeException, +) +from rabbitmq_amqp_python_client.asyncio.publisher import ( + AsyncPublisher, +) +from rabbitmq_amqp_python_client.utils import Converter + +from ..http_requests import delete_all_connections +from ..utils import create_binding +from .fixtures import * # noqa: F401, F403 +from .utils import async_publish_per_message + + +@pytest.mark.asyncio +async def test_validate_message_for_publishing_async( + async_connection: AsyncConnection, +) -> None: + queue_name = "validate-publishing-async" + management = await async_connection.management() + + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + publisher = await async_connection.publisher( + destination=AddressHelper.queue_address(queue_name) + ) + + with pytest.raises( + ArgumentOutOfRangeException, match="Message inferred must be True" + ): + await publisher.publish( + Message(body=Converter.string_to_bytes("test"), inferred=False) + ) + + with pytest.raises( + ArgumentOutOfRangeException, match="Message body must be of type bytes or None" + ): + await publisher.publish(Message(body="test")) # type: ignore + + with pytest.raises( + ArgumentOutOfRangeException, match="Message body must be of type bytes or None" + ): + await publisher.publish(Message(body={"key": "value"})) # type: ignore + + await publisher.close() + await management.delete_queue(queue_name) + await management.close() + + +@pytest.mark.asyncio +async def test_publish_queue_async(async_connection: AsyncConnection) -> None: + queue_name = "test-queue-async" + management = await async_connection.management() + + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + raised = False + publisher = None + accepted = False + + try: + publisher = await async_connection.publisher( + destination=AddressHelper.queue_address(queue_name) + ) + status = await publisher.publish( + Message(body=Converter.string_to_bytes("test")) + ) + if status.remote_state == OutcomeState.ACCEPTED: + accepted = True + except Exception: + raised = True + + if publisher is not None: + await publisher.close() + + await management.delete_queue(queue_name) + await management.close() + + assert accepted is True + assert raised is False + + +@pytest.mark.asyncio +async def test_publish_per_message_async(async_connection: AsyncConnection) -> None: + queue_name = "test-queue-1-async" + queue_name_2 = "test-queue-2-async" + management = await async_connection.management() + + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + await management.declare_queue(QuorumQueueSpecification(name=queue_name_2)) + + raised = False + publisher = None + accepted = False + accepted_2 = False + + try: + publisher = await async_connection.publisher() + status = await async_publish_per_message( + publisher, addr=AddressHelper.queue_address(queue_name) + ) + if status.remote_state == OutcomeState.ACCEPTED: + accepted = True + status = await async_publish_per_message( + publisher, addr=AddressHelper.queue_address(queue_name_2) + ) + if status.remote_state == OutcomeState.ACCEPTED: + accepted_2 = True + except Exception: + raised = True + + if publisher is not None: + await publisher.close() + + purged_messages_queue_1 = await management.purge_queue(queue_name) + purged_messages_queue_2 = await management.purge_queue(queue_name_2) + await management.delete_queue(queue_name) + await management.delete_queue(queue_name_2) + await management.close() + + assert accepted is True + assert accepted_2 is True + assert purged_messages_queue_1 == 1 + assert purged_messages_queue_2 == 1 + assert raised is False + + +@pytest.mark.asyncio +async def test_publish_ssl(async_connection_ssl: AsyncConnection) -> None: + queue_name = "test-queue" + management = await async_connection_ssl.management() + + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + raised = False + publisher = None + + try: + publisher = await async_connection_ssl.publisher( + destination=AddressHelper.queue_address(queue_name) + ) + await publisher.publish(Message(body=Converter.string_to_bytes("test"))) + except Exception: + raised = True + + if publisher is not None: + await publisher.close() + + await management.delete_queue(queue_name) + await management.close() + + assert raised is False + + +@pytest.mark.asyncio +async def test_publish_to_invalid_destination_async( + async_connection: AsyncConnection, +) -> None: + queue_name = "test-queue-async" + raised = False + publisher = None + + try: + publisher = await async_connection.publisher( + "/invalid-destination/" + queue_name + ) + await publisher.publish(Message(body=Converter.string_to_bytes("test"))) + except ArgumentOutOfRangeException: + raised = True + except Exception: + raised = False + + if publisher is not None: + await publisher.close() + + assert raised is True + + +@pytest.mark.asyncio +async def test_publish_per_message_to_invalid_destination_async( + async_connection: AsyncConnection, +) -> None: + queue_name = "test-queue-1-async" + raised = False + + message = Message(body=Converter.string_to_bytes("test")) + message = AddressHelper.message_to_address_helper( + message, "/invalid_destination/" + queue_name + ) + publisher = await async_connection.publisher() + + try: + await publisher.publish(message) + except ArgumentOutOfRangeException: + raised = True + except Exception: + raised = False + + if publisher is not None: + await publisher.close() + + assert raised is True + + +@pytest.mark.asyncio +async def test_publish_per_message_both_address_async( + async_connection: AsyncConnection, +) -> None: + queue_name = "test-queue-1-async" + raised = False + + management = await async_connection.management() + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + publisher = await async_connection.publisher( + destination=AddressHelper.queue_address(queue_name) + ) + + try: + message = Message(body=Converter.string_to_bytes("test")) + message = AddressHelper.message_to_address_helper( + message, AddressHelper.queue_address(queue_name) + ) + await publisher.publish(message) + except ValidationCodeException: + raised = True + + if publisher is not None: + await publisher.close() + + await management.delete_queue(queue_name) + await management.close() + + assert raised is True + + +@pytest.mark.asyncio +async def test_publish_exchange_async(async_connection: AsyncConnection) -> None: + exchange_name = "test-exchange-async" + queue_name = "test-queue-async" + management = await async_connection.management() + routing_key = "routing-key" + + bind_name = create_binding( + management._management, exchange_name, queue_name, routing_key + ) + + addr = AddressHelper.exchange_address(exchange_name, routing_key) + + raised = False + accepted = False + publisher = None + + try: + publisher = await async_connection.publisher(addr) + status = await publisher.publish( + Message(body=Converter.string_to_bytes("test")) + ) + if status.remote_state == OutcomeState.ACCEPTED: + accepted = True + except Exception: + raised = True + + if publisher is not None: + await publisher.close() + + await management.unbind(bind_name) + await management.delete_exchange(exchange_name) + await management.delete_queue(queue_name) + await management.close() + + assert accepted is True + assert raised is False + + +@pytest.mark.asyncio +async def test_publish_purge_async(async_connection: AsyncConnection) -> None: + messages_to_publish = 20 + + queue_name = "test-queue-async" + management = await async_connection.management() + + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + raised = False + publisher = None + + try: + publisher = await async_connection.publisher( + destination=AddressHelper.queue_address(queue_name) + ) + for _ in range(messages_to_publish): + await publisher.publish(Message(body=Converter.string_to_bytes("test"))) + except Exception: + raised = True + + if publisher is not None: + await publisher.close() + + message_purged = await management.purge_queue(queue_name) + + await management.delete_queue(queue_name) + await management.close() + + assert raised is False + assert message_purged == 20 + + +@pytest.mark.asyncio +async def test_disconnection_reconnection_async( + async_connection: AsyncConnection, +) -> None: + # disconnected = False + generic_exception_raised = False + + environment = AsyncEnvironment( + uri="amqp://guest:guest@localhost:5672/", + recovery_configuration=RecoveryConfiguration(active_recovery=True), + ) + + connection_test = await environment.connection() + + await connection_test.dial() + # delay + time.sleep(5) + messages_to_publish = 10000 + queue_name = "test-queue-reconnection" + management = await connection_test.management() + + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + publisher = await connection_test.publisher( + destination=AddressHelper.queue_address(queue_name) + ) + while True: + for i in range(messages_to_publish): + if i == 5: + # simulate a disconnection + delete_all_connections() + try: + await publisher.publish(Message(body=Converter.string_to_bytes("test"))) + except ConnectionClosed: + # disconnected = True + + # TODO: check if this behavior is correct + # The underlying sync Connection handles all recovery automatically, + # hence the async wrapper transparently benefits from it. + # so the exception should is not raised + continue + except Exception: + generic_exception_raised = True + + break + + await publisher.close() + + # purge the queue and check number of published messages + message_purged = await management.purge_queue(queue_name) + + await management.delete_queue(queue_name) + await management.close() + + assert generic_exception_raised is False + # assert disconnected is True + assert message_purged == messages_to_publish - 1 + + +@pytest.mark.asyncio +async def test_queue_info_for_stream_with_validations_async( + async_connection: AsyncConnection, +) -> None: + stream_name = "test_stream_info_async" + messages_to_send = 200 + + queue_specification = StreamSpecification(name=stream_name) + management = await async_connection.management() + await management.declare_queue(queue_specification) + + publisher = await async_connection.publisher( + destination=AddressHelper.queue_address(stream_name) + ) + + for _ in range(messages_to_send): + await publisher.publish(Message(body=Converter.string_to_bytes("test"))) + + +@pytest.mark.asyncio +async def test_publish_per_message_exchange_async( + async_connection: AsyncConnection, +) -> None: + exchange_name = "test-exchange-per-message-async" + queue_name = "test-queue-per-message-async" + management = await async_connection.management() + routing_key = "routing-key-per-message" + + bind_name = create_binding( + management._management, exchange_name, queue_name, routing_key + ) + + raised = False + publisher = None + accepted = False + accepted_2 = False + + try: + publisher = await async_connection.publisher() + status = await async_publish_per_message( + publisher, addr=AddressHelper.exchange_address(exchange_name, routing_key) + ) + if status.remote_state == OutcomeState.ACCEPTED: + accepted = True + status = await async_publish_per_message( + publisher, addr=AddressHelper.queue_address(queue_name) + ) + if status.remote_state == OutcomeState.ACCEPTED: + accepted_2 = True + except Exception: + raised = True + + if publisher is not None: + await publisher.close() + + purged_messages_queue = await management.purge_queue(queue_name) + await management.unbind(bind_name) + await management.delete_exchange(exchange_name) + await management.delete_queue(queue_name) + await management.close() + + assert accepted is True + assert accepted_2 is True + assert purged_messages_queue == 2 + assert raised is False + + +@pytest.mark.asyncio +async def test_multiple_publishers_async(async_environment: AsyncEnvironment) -> None: + stream_name = "test_multiple_publisher_1_async" + stream_name_2 = "test_multiple_publisher_2_async" + connection = await async_environment.connection() + await connection.dial() + + stream_specification = StreamSpecification(name=stream_name) + management = await connection.management() + await management.declare_queue(stream_specification) + + stream_specification_2 = StreamSpecification(name=stream_name_2) + await management.declare_queue(stream_specification_2) + + destination = AddressHelper.queue_address(stream_name) + destination_2 = AddressHelper.queue_address(stream_name_2) + + await connection.publisher(destination) + assert connection.active_producers == 1 + + publisher_2 = await connection.publisher(destination_2) + assert connection.active_producers == 2 + + await publisher_2.close() + assert connection.active_producers == 1 + + await connection.publisher(destination_2) + assert connection.active_producers == 2 + + await connection.close() + assert connection.active_producers == 0 + + # cleanup + connection = await async_environment.connection() + await connection.dial() + management = await connection.management() + + await management.delete_queue(stream_name) + await management.delete_queue(stream_name_2) + await management.close() + + +@pytest.mark.asyncio +async def test_durable_message_async(async_connection: AsyncConnection) -> None: + queue_name = "test_durable_message_async" + + management = await async_connection.management() + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + destination = AddressHelper.queue_address(queue_name) + publisher = await async_connection.publisher(destination) + + # Message should be durable by default + status = await publisher.publish(Message(body=Converter.string_to_bytes("durable"))) + assert status.remote_state == OutcomeState.ACCEPTED + + consumer = await async_connection.consumer(destination) + should_be_durable = await consumer.consume() + assert should_be_durable.durable is True + + await consumer.close() + await publisher.close() + await management.purge_queue(queue_name) + await management.delete_queue(queue_name) + await management.close() + + +@pytest.mark.asyncio +async def test_concurrent_publishing_async(async_connection: AsyncConnection) -> None: + queue_name = "test-concurrent-async" + messages_to_publish = 100 + + management = await async_connection.management() + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + publisher = await async_connection.publisher( + destination=AddressHelper.queue_address(queue_name) + ) + + async def publish_message(i: int): + message = Message(body=Converter.string_to_bytes(f"message-{i}")) + status = await publisher.publish(message) + return status.remote_state == OutcomeState.ACCEPTED + + # Run concurrent publishes + results = await asyncio.gather( + *[publish_message(i) for i in range(messages_to_publish)] + ) + + assert all(results) + + await publisher.close() + + message_count = await management.purge_queue(queue_name) + assert message_count == messages_to_publish + + await management.delete_queue(queue_name) + await management.close() + + +@pytest.mark.asyncio +async def test_concurrent_publishing_stream_async( + async_connection: AsyncConnection, +) -> None: + stream_name = "test-concurrent-stream-async" + messages_to_publish = 100 + + management = await async_connection.management() + await management.declare_queue(StreamSpecification(name=stream_name)) + + publisher = await async_connection.publisher( + destination=AddressHelper.queue_address(stream_name) + ) + + async def publish_message(i: int): + message = Message(body=Converter.string_to_bytes(f"message-{i}")) + status = await publisher.publish(message) + return status.remote_state == OutcomeState.ACCEPTED + + # Run concurrent publishes + results = await asyncio.gather( + *[publish_message(i) for i in range(messages_to_publish)] + ) + + assert all(results) + + await publisher.close() + await management.delete_queue(stream_name) + await management.close() + + +@pytest.mark.asyncio +async def test_publisher_context_manager_async( + async_connection: AsyncConnection, +) -> None: + queue_name = "test-context-manager-async" + + management = await async_connection.management() + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + async with await async_connection.publisher( + destination=AddressHelper.queue_address(queue_name) + ) as publisher: + status = await publisher.publish( + Message(body=Converter.string_to_bytes("test")) + ) + assert status.remote_state == OutcomeState.ACCEPTED + assert publisher.is_open is True + + # Publisher should be closed after context manager exits + assert publisher.is_open is False + + message_count = await management.purge_queue(queue_name) + assert message_count == 1 + + await management.delete_queue(queue_name) + await management.close() + + +@pytest.mark.asyncio +async def test_connection_context_manager_async( + async_environment: AsyncEnvironment, +) -> None: + queue_name = "test-connection-context-async" + + async with await async_environment.connection() as connection: + management = await connection.management() + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + publisher = await connection.publisher( + destination=AddressHelper.queue_address(queue_name) + ) + status = await publisher.publish( + Message(body=Converter.string_to_bytes("test")) + ) + assert status.remote_state == OutcomeState.ACCEPTED + + await management.delete_queue(queue_name) + + +@pytest.mark.asyncio +async def test_nested_context_manager_async( + async_environment: AsyncEnvironment, +) -> None: + queue_name = "test-nested-context-async" + + async with await async_environment.connection() as connection: + management = await connection.management() + await management.declare_queue(QuorumQueueSpecification(name=queue_name)) + + async with await connection.publisher( + destination=AddressHelper.queue_address(queue_name) + ) as publisher: + status = await publisher.publish( + Message(body=Converter.string_to_bytes("test")) + ) + assert status.remote_state == OutcomeState.ACCEPTED + assert publisher.is_open is True + + assert publisher.is_open is False + + message_count = await management.purge_queue(queue_name) + assert message_count == 1 + + await management.delete_queue(queue_name) + + +@pytest.mark.asyncio +async def test_multiple_publishers_concurrent_async( + async_connection: AsyncConnection, +) -> None: + queue_name_1 = "test-multi-pub-1-async" + queue_name_2 = "test-multi-pub-2-async" + messages_per_publisher = 100 + + management = await async_connection.management() + await management.declare_queue(QuorumQueueSpecification(name=queue_name_1)) + await management.declare_queue(QuorumQueueSpecification(name=queue_name_2)) + + publisher1 = await async_connection.publisher( + destination=AddressHelper.queue_address(queue_name_1) + ) + publisher2 = await async_connection.publisher( + destination=AddressHelper.queue_address(queue_name_2) + ) + + async def publish_message(publisher: AsyncPublisher, i: int) -> bool: + message = Message(body=Converter.string_to_bytes(f"message-{i}")) + status = await publisher.publish(message) + return status.remote_state == OutcomeState.ACCEPTED + + async def publish_to_queue(publisher: AsyncPublisher, count: int) -> list[bool]: + return await asyncio.gather( + *[publish_message(publisher, i) for i in range(count)] + ) + + # Publish concurrently to both queues + results1, results2 = await asyncio.gather( + publish_to_queue(publisher1, messages_per_publisher), + publish_to_queue(publisher2, messages_per_publisher), + ) + + assert all(results1) + assert all(results2) + + await publisher1.close() + await publisher2.close() + + # Verify message counts + count1 = await management.purge_queue(queue_name_1) + count2 = await management.purge_queue(queue_name_2) + + assert count1 == messages_per_publisher + assert count2 == messages_per_publisher + + await management.delete_queue(queue_name_1) + await management.delete_queue(queue_name_2) + await management.close() diff --git a/tests/asyncio/test_streams.py b/tests/asyncio/test_streams.py new file mode 100644 index 0000000..0e30c35 --- /dev/null +++ b/tests/asyncio/test_streams.py @@ -0,0 +1,725 @@ +import pytest + +from rabbitmq_amqp_python_client import ( + AddressHelper, + AMQPMessagingHandler, + AsyncConnection, + AsyncEnvironment, + Converter, + Message, + OffsetSpecification, + StreamConsumerOptions, + StreamSpecification, + ValidationCodeException, +) +from rabbitmq_amqp_python_client.entities import ( + MessageProperties, + StreamFilterOptions, +) +from rabbitmq_amqp_python_client.qpid.proton import Event + +from ..conftest import ( + ConsumerTestException, + MyMessageHandlerAcceptStreamOffset, + MyMessageHandlerAcceptStreamOffsetReconnect, +) +from .fixtures import * # noqa: F401, F403 +from .utils import async_publish_messages + + +@pytest.mark.asyncio +async def test_stream_read_from_last_default_async( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + consumer = None + stream_name = "test_stream_info_with_validation_async" + messages_to_send = 10 + + queue_specification = StreamSpecification(name=stream_name) + management = await async_connection.management() + await management.declare_queue(queue_specification) + + addr_queue = AddressHelper.queue_address(stream_name) + + # consume and then publish + try: + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + addr_queue, message_handler=MyMessageHandlerAcceptStreamOffset() + ) + await async_publish_messages(async_connection, messages_to_send, stream_name) + await consumer.run() + # ack to terminate the consumer + except ConsumerTestException: + pass + finally: + if consumer is not None: + await consumer.close() + await management.delete_queue(stream_name) + await management.close() + + +@pytest.mark.asyncio +async def test_stream_read_from_last_async( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + consumer = None + stream_name = "test_stream_read_from_last_async" + messages_to_send = 10 + + queue_specification = StreamSpecification(name=stream_name) + management = await async_connection.management() + await management.declare_queue(queue_specification) + + addr_queue = AddressHelper.queue_address(stream_name) + + # consume and then publish + try: + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + addr_queue, + message_handler=MyMessageHandlerAcceptStreamOffset(), + consumer_options=StreamConsumerOptions( + offset_specification=OffsetSpecification.last + ), + ) + await async_publish_messages(async_connection, messages_to_send, stream_name) + await consumer.run() + # ack to terminate the consumer + except ConsumerTestException: + pass + finally: + if consumer is not None: + await consumer.close() + await management.delete_queue(stream_name) + await management.close() + + +@pytest.mark.asyncio +async def test_stream_read_from_offset_zero_async( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + consumer = None + stream_name = "test_stream_read_from_offset_zero_async" + messages_to_send = 10 + + queue_specification = StreamSpecification(name=stream_name) + management = await async_connection.management() + await management.declare_queue(queue_specification) + + addr_queue = AddressHelper.queue_address(stream_name) + + # publish and then consume + await async_publish_messages(async_connection, messages_to_send, stream_name) + + try: + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + addr_queue, + message_handler=MyMessageHandlerAcceptStreamOffset(0), + consumer_options=StreamConsumerOptions(offset_specification=0), + ) + await consumer.run() + # ack to terminate the consumer + except ConsumerTestException: + pass + finally: + if consumer is not None: + await consumer.close() + await management.delete_queue(stream_name) + await management.close() + + +@pytest.mark.asyncio +async def test_stream_read_from_offset_first_async( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + consumer = None + stream_name = "test_stream_read_from_offset_first_async" + messages_to_send = 10 + + queue_specification = StreamSpecification(name=stream_name) + management = await async_connection.management() + await management.declare_queue(queue_specification) + + addr_queue = AddressHelper.queue_address(stream_name) + + # publish and then consume + await async_publish_messages(async_connection, messages_to_send, stream_name) + + try: + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + addr_queue, + message_handler=MyMessageHandlerAcceptStreamOffset(0), + consumer_options=StreamConsumerOptions(OffsetSpecification.first), + ) + await consumer.run() + # ack to terminate the consumer + except ConsumerTestException: + pass + finally: + if consumer is not None: + await consumer.close() + await management.delete_queue(stream_name) + await management.close() + + +@pytest.mark.asyncio +async def test_stream_read_from_offset_ten_async( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + consumer = None + stream_name = "test_stream_read_from_offset_ten_async" + messages_to_send = 20 + + queue_specification = StreamSpecification(name=stream_name) + management = await async_connection.management() + await management.declare_queue(queue_specification) + + addr_queue = AddressHelper.queue_address(stream_name) + + # publish and then consume + await async_publish_messages(async_connection, messages_to_send, stream_name) + + try: + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + addr_queue, + message_handler=MyMessageHandlerAcceptStreamOffset(10), + consumer_options=StreamConsumerOptions(offset_specification=10), + ) + await consumer.run() + # ack to terminate the consumer + # this will finish after 10 messages read + except ConsumerTestException: + pass + finally: + if consumer is not None: + await consumer.close() + await management.delete_queue(stream_name) + await management.close() + + +@pytest.mark.asyncio +async def test_stream_filtering_async( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + consumer = None + stream_name = "test_stream_info_with_filtering_async" + messages_to_send = 10 + + queue_specification = StreamSpecification(name=stream_name) + management = await async_connection.management() + await management.declare_queue(queue_specification) + + addr_queue = AddressHelper.queue_address(stream_name) + + # consume and then publish + try: + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + addr_queue, + message_handler=MyMessageHandlerAcceptStreamOffset(), + consumer_options=StreamConsumerOptions( + filter_options=StreamFilterOptions(values=["banana"]) + ), + ) + + # send with annotations filter banana + await async_publish_messages( + async_connection, + messages_to_send, + stream_name, + ["banana"], + ) + + await consumer.run() + # ack to terminate the consumer + except ConsumerTestException: + pass + finally: + if consumer is not None: + await consumer.close() + await management.delete_queue(stream_name) + await management.close() + + +@pytest.mark.asyncio +async def test_stream_filtering_mixed_async( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + consumer = None + stream_name = "test_stream_info_with_filtering_mixed_async" + messages_to_send = 10 + + queue_specification = StreamSpecification(name=stream_name) + management = await async_connection.management() + await management.declare_queue(queue_specification) + + addr_queue = AddressHelper.queue_address(stream_name) + + # consume and then publish + try: + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + addr_queue, + # check we are reading just from offset 10 as just banana filtering applies + message_handler=MyMessageHandlerAcceptStreamOffset(10), + consumer_options=StreamConsumerOptions( + filter_options=StreamFilterOptions(values=["banana"]) + ), + ) + + # consume and then publish + await async_publish_messages( + async_connection, + messages_to_send, + stream_name, + ["apple"], + ) + await async_publish_messages( + async_connection, + messages_to_send, + stream_name, + ["banana"], + ) + + await consumer.run() + # ack to terminate the consumer + except ConsumerTestException: + pass + finally: + if consumer is not None: + await consumer.close() + await management.delete_queue(stream_name) + await management.close() + + +@pytest.mark.asyncio +async def test_stream_filtering_not_present_async( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + stream_name = "test_stream_filtering_not_present_async" + messages_to_send = 10 + + queue_specification = StreamSpecification(name=stream_name) + management = await async_connection.management() + await management.declare_queue(queue_specification) + + addr_queue = AddressHelper.queue_address(stream_name) + + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + + consumer = await connection_consumer.consumer( + addr_queue, + consumer_options=StreamConsumerOptions( + filter_options=StreamFilterOptions(values=["apple"]) + ), + ) + + # send with annotations filter banana + await async_publish_messages( + async_connection, + messages_to_send, + stream_name, + ["banana"], + ) + + with pytest.raises(Exception): + await consumer.consume(timeout=1) + + await consumer.close() + await management.delete_queue(stream_name) + await management.close() + + +@pytest.mark.asyncio +async def test_stream_match_unfiltered_async( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + consumer = None + stream_name = "test_stream_match_unfiltered_async" + messages_to_send = 10 + + queue_specification = StreamSpecification(name=stream_name) + management = await async_connection.management() + await management.declare_queue(queue_specification) + + addr_queue = AddressHelper.queue_address(stream_name) + + # consume and then publish + try: + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + addr_queue, + message_handler=MyMessageHandlerAcceptStreamOffset(), + consumer_options=StreamConsumerOptions( + filter_options=StreamFilterOptions( + values=["banana"], match_unfiltered=True + ) + ), + ) + + # unfiltered messages + await async_publish_messages(async_connection, messages_to_send, stream_name) + await consumer.run() + # ack to terminate the consumer + except ConsumerTestException: + pass + finally: + if consumer is not None: + await consumer.close() + await management.delete_queue(stream_name) + await management.close() + + +@pytest.mark.asyncio +async def test_stream_reconnection_async( + async_connection_with_reconnect: AsyncConnection, + async_environment: AsyncEnvironment, +) -> None: + consumer = None + stream_name = "test_stream_reconnection_async" + messages_to_send = 10 + + queue_specification = StreamSpecification(name=stream_name) + management = await async_connection_with_reconnect.management() + await management.declare_queue(queue_specification) + + addr_queue = AddressHelper.queue_address(stream_name) + + # consume and then publish + try: + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + addr_queue, + # disconnection and check happens here + message_handler=MyMessageHandlerAcceptStreamOffsetReconnect(), + consumer_options=StreamConsumerOptions( + filter_options=StreamFilterOptions( + values=["banana"], match_unfiltered=True + ) + ), + ) + + await async_publish_messages( + async_connection_with_reconnect, messages_to_send, stream_name + ) + + await consumer.run() + # ack to terminate the consumer + except ConsumerTestException: + pass + finally: + if consumer is not None: + await consumer.close() + await management.delete_queue(stream_name) + await management.close() + + +class MyMessageHandlerMessagePropertiesFilter(AMQPMessagingHandler): + def __init__(self): + super().__init__() + + def on_message(self, event: Event): + self.delivery_context.accept(event) + assert event.message.subject == "important_15" + assert event.message.group_id == "group_15" + assert event.message.body == Converter.string_to_bytes("hello_15") + raise ConsumerTestException("consumed") + + +@pytest.mark.asyncio +async def test_stream_filter_message_properties_async( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + consumer = None + stream_name = "test_stream_filter_message_properties_async" + messages_to_send = 30 + + queue_specification = StreamSpecification(name=stream_name) + management = await async_connection.management() + await management.declare_queue(queue_specification) + + addr_queue = AddressHelper.queue_address(stream_name) + + try: + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + addr_queue, + message_handler=MyMessageHandlerMessagePropertiesFilter(), + consumer_options=StreamConsumerOptions( + filter_options=StreamFilterOptions( + message_properties=MessageProperties( + subject="important_15", group_id="group_15" + ) + ) + ), + ) + + publisher = await async_connection.publisher(addr_queue) + for i in range(messages_to_send): + msg = Message( + body=Converter.string_to_bytes(f"hello_{i}"), + subject=f"important_{i}", + group_id=f"group_{i}", + ) + await publisher.publish(msg) + await publisher.close() + + await consumer.run() + except ConsumerTestException: + pass + finally: + if consumer is not None: + await consumer.close() + await management.delete_queue(stream_name) + await management.close() + + +class MyMessageHandlerApplicationPropertiesFilter(AMQPMessagingHandler): + def __init__(self): + super().__init__() + + def on_message(self, event: Event): + self.delivery_context.accept(event) + assert event.message.application_properties == {"key": "value_17"} + raise ConsumerTestException("consumed") + + +@pytest.mark.asyncio +async def test_stream_filter_application_properties_async( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + consumer = None + stream_name = "test_stream_application_message_properties_async" + messages_to_send = 30 + + queue_specification = StreamSpecification(name=stream_name) + management = await async_connection.management() + await management.declare_queue(queue_specification) + + addr_queue = AddressHelper.queue_address(stream_name) + + try: + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + addr_queue, + message_handler=MyMessageHandlerApplicationPropertiesFilter(), + consumer_options=StreamConsumerOptions( + filter_options=StreamFilterOptions( + application_properties={"key": "value_17"}, + ) + ), + ) + + publisher = await async_connection.publisher(addr_queue) + for i in range(messages_to_send): + msg = Message( + body=Converter.string_to_bytes(f"hello_{i}"), + application_properties={"key": f"value_{i}"}, + ) + await publisher.publish(msg) + await publisher.close() + + await consumer.run() + except ConsumerTestException: + pass + finally: + if consumer is not None: + await consumer.close() + await management.delete_queue(stream_name) + await management.close() + + +class MyMessageHandlerSQLFilter(AMQPMessagingHandler): + def __init__(self): + super().__init__() + + def on_message(self, event: Event): + self.delivery_context.accept(event) + assert event.message.body == Converter.string_to_bytes("the_right_one_sql") + assert event.message.subject == "something_in_the_filter" + assert event.message.reply_to == "the_reply_to" + assert ( + event.message.application_properties["a_in_the_filter_key"] + == "a_in_the_filter_value" + ) + raise ConsumerTestException("consumed") + + +@pytest.mark.asyncio +async def test_stream_filter_sql_async( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + consumer = None + stream_name = "test_stream_filter_sql_async" + messages_to_send = 30 + + queue_specification = StreamSpecification(name=stream_name) + management = await async_connection.management() + await management.delete_queue(stream_name) + await management.declare_queue(queue_specification) + + addr_queue = AddressHelper.queue_address(stream_name) + sql = ( + "properties.subject LIKE '%in_the_filter%' AND properties.reply_to = 'the_reply_to' " + "AND a_in_the_filter_key = 'a_in_the_filter_value'" + ) + + try: + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + addr_queue, + message_handler=MyMessageHandlerSQLFilter(), + consumer_options=StreamConsumerOptions( + filter_options=StreamFilterOptions(sql=sql) + ), + ) + + publisher = await async_connection.publisher(addr_queue) + + # Won't match + for i in range(messages_to_send): + msg = Message(body=Converter.string_to_bytes(f"hello_{i}")) + await publisher.publish(msg) + + # The only one that will match + msg_match = Message( + body=Converter.string_to_bytes("the_right_one_sql"), + subject="something_in_the_filter", + reply_to="the_reply_to", + application_properties={"a_in_the_filter_key": "a_in_the_filter_value"}, + ) + await publisher.publish(msg_match) + await publisher.close() + + await consumer.run() + except ConsumerTestException: + pass + finally: + if consumer is not None: + await consumer.close() + await management.delete_queue(stream_name) + await management.close() + + +class MyMessageHandlerMixingDifferentFilters(AMQPMessagingHandler): + def __init__(self): + super().__init__() + + def on_message(self, event: Event): + self.delivery_context.accept(event) + assert event.message.annotations["x-stream-filter-value"] == "the_value_filter" + assert event.message.application_properties == {"key": "app_value_9999"} + assert event.message.subject == "important_9999" + assert event.message.body == Converter.string_to_bytes("the_right_one_9999") + raise ConsumerTestException("consumed") + + +@pytest.mark.asyncio +async def test_stream_filter_mixing_different_async( + async_connection: AsyncConnection, async_environment: AsyncEnvironment +) -> None: + consumer = None + stream_name = "test_stream_filter_mixing_different_async" + messages_to_send = 30 + + queue_specification = StreamSpecification(name=stream_name) + management = await async_connection.management() + await management.delete_queue(stream_name) + await management.declare_queue(queue_specification) + + addr_queue = AddressHelper.queue_address(stream_name) + + try: + connection_consumer = await async_environment.connection() + await connection_consumer.dial() + consumer = await connection_consumer.consumer( + addr_queue, + message_handler=MyMessageHandlerMixingDifferentFilters(), + consumer_options=StreamConsumerOptions( + filter_options=StreamFilterOptions( + values=["the_value_filter"], + application_properties={"key": "app_value_9999"}, + message_properties=MessageProperties(subject="important_9999"), + ) + ), + ) + + publisher = await async_connection.publisher(addr_queue) + + # All these messages will be filtered out + for i in range(messages_to_send): + msg = Message(body=Converter.string_to_bytes(f"hello_{i}")) + await publisher.publish(msg) + + import asyncio + + await asyncio.sleep( + 1 + ) # Wait to ensure messages are published in different chunks + + msg = Message( + body=Converter.string_to_bytes("the_right_one_9999"), + annotations={"x-stream-filter-value": "the_value_filter"}, + application_properties={"key": "app_value_9999"}, + subject="important_9999", + ) + await publisher.publish(msg) + await publisher.close() + + await consumer.run() + except ConsumerTestException: + pass + finally: + if consumer is not None: + await consumer.close() + await management.delete_queue(stream_name) + await management.close() + + +@pytest.mark.asyncio +async def test_consumer_options_validation_async() -> None: + try: + x = StreamConsumerOptions(filter_options=StreamFilterOptions(sql="test")) + x.validate({"4.0.0": True, "4.1.0": False, "4.2.0": False}) + assert False + except ValidationCodeException: + assert True + + try: + x = StreamConsumerOptions( + filter_options=StreamFilterOptions( + message_properties=MessageProperties(subject="important_9999") + ) + ) + x.validate({"4.0.0": True, "4.1.0": True, "4.2.0": False}) + assert True + except ValidationCodeException: + assert False + + try: + x = StreamConsumerOptions( + filter_options=StreamFilterOptions( + application_properties={"key": "app_value_9999"} + ) + ) + x.validate({"4.0.0": True, "4.1.0": True, "4.2.0": False}) + assert True + except ValidationCodeException: + assert False diff --git a/tests/asyncio/utils.py b/tests/asyncio/utils.py new file mode 100644 index 0000000..467685f --- /dev/null +++ b/tests/asyncio/utils.py @@ -0,0 +1,81 @@ +from typing import Optional + +from rabbitmq_amqp_python_client import ( + AddressHelper, + AsyncConnection, + AsyncManagement, + AsyncPublisher, + Delivery, + ExchangeSpecification, + ExchangeToQueueBindingSpecification, + ExchangeType, + Message, + QuorumQueueSpecification, +) +from rabbitmq_amqp_python_client.utils import Converter + + +async def async_publish_per_message(publisher: AsyncPublisher, addr: str) -> Delivery: + message = Message(body=Converter.string_to_bytes("test")) + message = AddressHelper.message_to_address_helper(message, addr) + status = await publisher.publish(message) + return status + + +async def async_publish_messages( + connection: AsyncConnection, + messages_to_send: int, + queue_name: str, + filters: Optional[list[str]] = None, +) -> None: + annotations = {} + if filters is not None: + for filterItem in filters: + annotations = {"x-stream-filter-value": filterItem} + + publisher = await connection.publisher("/queues/" + queue_name) + # publish messages_to_send messages + for i in range(messages_to_send): + await publisher.publish( + Message( + body=Converter.string_to_bytes("test{}".format(i)), + annotations=annotations, + ) + ) + await publisher.close() + + +async def async_setup_dead_lettering(management: AsyncManagement) -> str: + exchange_dead_lettering = "exchange-dead-letter" + queue_dead_lettering = "queue-dead-letter" + binding_key = "key_dead_letter" + + # configuring dead lettering + await management.declare_exchange( + ExchangeSpecification( + name=exchange_dead_lettering, + exchange_type=ExchangeType.fanout, + arguments={}, + ) + ) + await management.declare_queue(QuorumQueueSpecification(name=queue_dead_lettering)) + bind_path = await management.bind( + ExchangeToQueueBindingSpecification( + source_exchange=exchange_dead_lettering, + destination_queue=queue_dead_lettering, + binding_key=binding_key, + ) + ) + + return bind_path + + +async def async_cleanup_dead_lettering( + management: AsyncManagement, bind_path: str +) -> None: + exchange_dead_lettering = "exchange-dead-letter" + queue_dead_lettering = "queue-dead-letter" + + await management.unbind(bind_path) + await management.delete_exchange(exchange_dead_lettering) + await management.delete_queue(queue_dead_lettering)