Skip to content

Commit a0ee953

Browse files
committed
better handling of asyncio.Lock and event loop, add docstrings
1 parent e3a397e commit a0ee953

File tree

5 files changed

+631
-97
lines changed

5 files changed

+631
-97
lines changed

rabbitmq_amqp_python_client/asyncio/connection.py

Lines changed: 106 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,25 @@
2828

2929

3030
class AsyncConnection:
31-
"""Asyncio-compatible facade around Connection."""
31+
"""
32+
Asyncio-compatible facade around Connection.
33+
34+
This class manages the connection to RabbitMQ and provides factory methods for
35+
creating publishers, consumers, and management interfaces. It supports both
36+
single-node and multi-node configurations, as well as SSL/TLS connections.
37+
38+
Note:
39+
The underlying Proton BlockingConnection is NOT thread-safe. A lock is used
40+
to serialize all operations.
41+
42+
Attributes:
43+
_connection (Connection): The underlying synchronous Connection
44+
_connection_lock (asyncio.Lock): Lock for coordinating access to the shared connection
45+
_async_publishers (list[AsyncPublisher]): List of active async publishers
46+
_async_consumers (list[AsyncConsumer]): List of active async consumers
47+
_async_managements (list[AsyncManagement]): List of active async management interfaces
48+
_remove_callback (Optional[Callable[[AsyncConnection], None]]): Callback on close
49+
"""
3250

3351
def __init__(
3452
self,
@@ -39,45 +57,70 @@ def __init__(
3957
] = None,
4058
oauth2_options: Optional[OAuth2Options] = None,
4159
recovery_configuration: RecoveryConfiguration = RecoveryConfiguration(),
42-
*,
43-
loop: Optional[asyncio.AbstractEventLoop] = None,
4460
):
61+
"""
62+
Initialize AsyncConnection.
63+
64+
Args:
65+
uri: Optional single-node connection URI
66+
uris: Optional multi-node connection URIs
67+
ssl_context: Optional SSL/TLS configuration
68+
oauth2_options: Optional OAuth2 configuration
69+
recovery_configuration: Configuration for automatic recovery
70+
71+
Raises:
72+
ValidationCodeException: If recovery configuration is invalid
73+
"""
4574
self._connection = Connection(
4675
uri=uri,
4776
uris=uris,
4877
ssl_context=ssl_context,
4978
oauth2_options=oauth2_options,
5079
recovery_configuration=recovery_configuration,
5180
)
52-
self._loop = loop
5381
self._connection_lock = asyncio.Lock()
5482
self._async_publishers: list[AsyncPublisher] = []
5583
self._async_consumers: list[AsyncConsumer] = []
5684
self._async_managements: list[AsyncManagement] = []
5785
self._remove_callback: Optional[Callable[[AsyncConnection], None]] = None
5886

5987
async def dial(self) -> None:
88+
"""
89+
Establish a connection to the AMQP server.
90+
91+
Configures SSL if specified and establishes the connection using the
92+
provided URI(s). Also initializes the management interface.
93+
"""
6094
async with self._connection_lock:
61-
await self._event_loop.run_in_executor(None, self._connection.dial)
95+
await asyncio.to_thread(self._connection.dial)
6296

6397
def _set_remove_callback(
6498
self, callback: Optional[Callable[[AsyncConnection], None]]
6599
) -> None:
100+
"""Set callback to be called when connection is closed."""
66101
self._remove_callback = callback
67102

68103
def _remove_publisher(self, publisher: AsyncPublisher) -> None:
104+
"""Remove a publisher from the active list."""
69105
if publisher in self._async_publishers:
70106
self._async_publishers.remove(publisher)
71107

72108
def _remove_consumer(self, consumer: AsyncConsumer) -> None:
109+
"""Remove a consumer from the active list."""
73110
if consumer in self._async_consumers:
74111
self._async_consumers.remove(consumer)
75112

76113
def _remove_management(self, management: AsyncManagement) -> None:
114+
"""Remove a management interface from the active list."""
77115
if management in self._async_managements:
78116
self._async_managements.remove(management)
79117

80118
async def close(self) -> None:
119+
"""
120+
Close the connection to the AMQP 1.0 server.
121+
122+
Closes the underlying connection and removes it from the connection list.
123+
"""
81124
logger.debug("Closing async connection")
82125
try:
83126
for async_publisher in self._async_publishers[:]:
@@ -88,7 +131,7 @@ async def close(self) -> None:
88131
await async_management.close()
89132

90133
async with self._connection_lock:
91-
await self._event_loop.run_in_executor(None, self._connection.close)
134+
await asyncio.to_thread(self._connection.close)
92135
except Exception as e:
93136
logger.error(f"Error closing async connections: {e}")
94137
raise e
@@ -103,6 +146,12 @@ def _set_connection_managements(self, management: Management) -> None:
103146
async def management(
104147
self,
105148
) -> AsyncManagement:
149+
"""
150+
Get the management interface for this connection.
151+
152+
Returns:
153+
AsyncManagement: The management interface for performing administrative tasks
154+
"""
106155
if len(self._async_managements) > 0:
107156
return self._async_managements[0]
108157

@@ -111,7 +160,6 @@ async def management(
111160
if len(self._async_managements) == 0:
112161
async_management = AsyncManagement(
113162
self._connection._conn,
114-
loop=self._event_loop,
115163
connection_lock=self._connection_lock,
116164
)
117165

@@ -128,12 +176,26 @@ async def management(
128176
return self._async_managements[0]
129177

130178
def _set_connection_publishers(self, publisher: Publisher) -> None:
179+
"""Set the list of publishers in the underlying connection."""
131180
publisher._set_publishers_list(
132181
[async_publisher._publisher for async_publisher in self._async_publishers]
133182
)
134183
self._connection._publishers.append(publisher)
135184

136185
async def publisher(self, destination: str = "") -> AsyncPublisher:
186+
"""
187+
Create a new publisher instance.
188+
189+
Args:
190+
destination: Optional default destination for published messages
191+
192+
Returns:
193+
AsyncPublisher: A new publisher instance
194+
195+
Raises:
196+
RuntimeError: If publisher creation fails
197+
ArgumentOutOfRangeException: If destination address format is invalid
198+
"""
137199
if destination != "":
138200
if not validate_address(destination):
139201
raise ArgumentOutOfRangeException(
@@ -144,9 +206,11 @@ async def publisher(self, destination: str = "") -> AsyncPublisher:
144206
async_publisher = AsyncPublisher(
145207
self._connection._conn,
146208
destination,
147-
loop=self._event_loop,
148209
connection_lock=self._connection_lock,
149210
)
211+
await async_publisher.open()
212+
if async_publisher._publisher is None:
213+
raise RuntimeError("Failed to create publisher")
150214
self._set_connection_publishers(
151215
async_publisher._publisher
152216
) # TODO: check this
@@ -157,6 +221,7 @@ async def publisher(self, destination: str = "") -> AsyncPublisher:
157221
return async_publisher
158222

159223
def _set_connection_consumers(self, consumer: Consumer) -> None:
224+
"""Set the list of consumers in the underlying connection."""
160225
self._connection._consumers.append(consumer)
161226

162227
async def consumer(
@@ -166,6 +231,22 @@ async def consumer(
166231
consumer_options: Optional[ConsumerOptions] = None,
167232
credit: Optional[int] = None,
168233
) -> AsyncConsumer:
234+
"""
235+
Create a new consumer instance.
236+
237+
Args:
238+
destination: The address to consume from
239+
message_handler: Optional handler for processing messages
240+
consumer_options: Optional configuration for queue consumption. Each queue has its own consumer options.
241+
credit: Optional credit value for flow control
242+
243+
Returns:
244+
AsyncConsumer: A new consumer instance
245+
246+
Raises:
247+
RuntimeError: If consumer creation fails
248+
ArgumentOutOfRangeException: If destination address format is invalid
249+
"""
169250
if not validate_address(destination):
170251
raise ArgumentOutOfRangeException(
171252
"destination address must start with /queues or /exchanges"
@@ -186,9 +267,11 @@ async def consumer(
186267
message_handler, # pyright: ignore[reportArgumentType]
187268
consumer_options,
188269
credit,
189-
loop=self._event_loop,
190270
connection_lock=self._connection_lock,
191271
)
272+
await async_consumer.open()
273+
if async_consumer._consumer is None:
274+
raise RuntimeError("Failed to create consumer")
192275
self._set_connection_consumers(async_consumer._consumer) # TODO: check this
193276
self._async_consumers.append(async_consumer)
194277

@@ -197,14 +280,20 @@ async def consumer(
197280
return async_consumer
198281

199282
async def refresh_token(self, token: str) -> None:
283+
"""
284+
Refresh the oauth token
285+
286+
Args:
287+
token: the oauth token to refresh
288+
289+
Raises:
290+
ValidationCodeException: If oauth is not enabled
291+
"""
200292
async with self._connection_lock:
201-
await self._event_loop.run_in_executor(
202-
None,
203-
self._connection.refresh_token,
204-
token,
205-
)
293+
await asyncio.to_thread(self._connection.refresh_token, token)
206294

207295
async def __aenter__(self) -> AsyncConnection:
296+
""" "Async context manager entry."""
208297
await self.dial()
209298
return self
210299

@@ -214,16 +303,15 @@ async def __aexit__(
214303
exc: Optional[BaseException],
215304
tb: Optional[object],
216305
) -> None:
306+
"""Async context manager exit."""
217307
await self.close()
218308

219-
@property
220-
def _event_loop(self) -> asyncio.AbstractEventLoop:
221-
return self._loop or asyncio.get_running_loop()
222-
223309
@property
224310
def active_producers(self) -> int:
311+
"""Get the number of active producers of the connection."""
225312
return len(self._async_publishers)
226313

227314
@property
228315
def active_consumers(self) -> int:
316+
"""Get the number of active consumers of the connection."""
229317
return len(self._async_consumers)

0 commit comments

Comments
 (0)