Skip to content

Commit b0e264c

Browse files
committed
refactor consumer option
add an abstract class to define the consumer options. Each Queue type can define the consumer options type: Signed-off-by: Gabriele Santomaggio <[email protected]>
1 parent 42048b0 commit b0e264c

File tree

10 files changed

+136
-43
lines changed

10 files changed

+136
-43
lines changed

examples/streams/example_with_streams.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def main() -> None:
104104
message_handler=MyMessageHandler(),
105105
# can be first, last, next or an offset long
106106
# you can also specify stream filters with methods: apply_filters and filter_match_unfiltered
107-
stream_consumer_options=StreamConsumerOptions(
107+
consumer_options=StreamConsumerOptions(
108108
offset_specification=OffsetSpecification.first
109109
),
110110
)

examples/streams_with_filters/example_streams_with_filters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def main() -> None:
9191
message_handler=MyMessageHandler(),
9292
# the consumer will only receive messages with filter value banana and subject yellow
9393
# and application property from = italy
94-
stream_consumer_options=StreamConsumerOptions(
94+
consumer_options=StreamConsumerOptions(
9595
offset_specification=OffsetSpecification.first,
9696
filter_options=StreamFilterOptions(
9797
values=["banana"],

examples/streams_with_sql_filters/example_streams_with_sql_filters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def main() -> None:
8888
consumer = consumer_connection.consumer(
8989
addr_queue,
9090
message_handler=MyMessageHandler(),
91-
stream_consumer_options=StreamConsumerOptions(
91+
consumer_options=StreamConsumerOptions(
9292
offset_specification=OffsetSpecification.first,
9393
filter_options=StreamFilterOptions(sql=sql),
9494
),

rabbitmq_amqp_python_client/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .connection import Connection
77
from .consumer import Consumer
88
from .entities import (
9+
ConsumerOptions,
910
ExchangeCustomSpecification,
1011
ExchangeSpecification,
1112
ExchangeToExchangeBindingSpecification,
@@ -89,6 +90,7 @@
8990
"ConnectionClosed",
9091
"StreamConsumerOptions",
9192
"StreamFilterOptions",
93+
"ConsumerOptions",
9294
"MessageProperties",
9395
"OffsetSpecification",
9496
"OutcomeState",

rabbitmq_amqp_python_client/connection.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from .address_helper import validate_address
1717
from .consumer import Consumer
1818
from .entities import (
19+
ConsumerOptions,
1920
OAuth2Options,
2021
RecoveryConfiguration,
21-
StreamConsumerOptions,
2222
)
2323
from .exceptions import (
2424
ArgumentOutOfRangeException,
@@ -211,15 +211,15 @@ def _validate_server_properties(self) -> None:
211211

212212
logger.debug(f"Connected to RabbitMQ server version {server_version}")
213213

214-
def _is_server_version_gte_4_2_0(self) -> bool:
214+
def _is_server_version_gte(self, target_version: str) -> bool:
215215
"""
216-
Check if the server version is greater than or equal to 4.2.0.
216+
Check if the server version is greater than or equal to version.
217217
218218
This is an internal method that can be used to conditionally enable
219-
features that require RabbitMQ 4.2.0 or higher.
219+
features that require RabbitMQ version or higher.
220220
221221
Returns:
222-
bool: True if server version >= 4.2.0, False otherwise
222+
bool: True if server version >= version, False otherwise
223223
224224
Raises:
225225
ValidationCodeException: If connection is not established or
@@ -237,7 +237,12 @@ def _is_server_version_gte_4_2_0(self) -> bool:
237237
raise ValidationCodeException("Server version not provided")
238238

239239
try:
240-
return version.parse(str(server_version)) >= version.parse("4.2.0")
240+
srv = version.parse(str(server_version))
241+
trg = version.parse(target_version)
242+
# compare the version even if it contains pre-release or build metadata
243+
return (
244+
version.parse("{}.{}.{}".format(srv.major, srv.minor, srv.micro)) >= trg
245+
)
241246
except Exception as e:
242247
raise ValidationCodeException(
243248
f"Failed to parse server version '{server_version}': {e}"
@@ -376,7 +381,7 @@ def consumer(
376381
self,
377382
destination: str,
378383
message_handler: Optional[MessagingHandler] = None,
379-
stream_consumer_options: Optional[StreamConsumerOptions] = None,
384+
consumer_options: Optional[ConsumerOptions] = None,
380385
credit: Optional[int] = None,
381386
) -> Consumer:
382387
"""
@@ -385,7 +390,7 @@ def consumer(
385390
Args:
386391
destination: The address to consume from
387392
message_handler: Optional handler for processing messages
388-
stream_consumer_options: Optional configuration for stream consumption
393+
consumer_options: Optional configuration for queue consumption. Each queue has its own consumer options.co
389394
credit: Optional credit value for flow control
390395
391396
Returns:
@@ -398,8 +403,16 @@ def consumer(
398403
raise ArgumentOutOfRangeException(
399404
"destination address must start with /queues or /exchanges"
400405
)
406+
if consumer_options is not None:
407+
consumer_options.validate(
408+
{
409+
"4.0.0": self._is_server_version_gte("4.0.0"),
410+
"4.1.0": self._is_server_version_gte("4.1.0"),
411+
"4.2.0": self._is_server_version_gte("4.2.0"),
412+
}
413+
)
401414
consumer = Consumer(
402-
self._conn, destination, message_handler, stream_consumer_options, credit
415+
self._conn, destination, message_handler, consumer_options, credit
403416
)
404417
self._consumers.append(consumer)
405418
return consumer

rabbitmq_amqp_python_client/consumer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Literal, Optional, Union, cast
33

44
from .amqp_consumer_handler import AMQPMessagingHandler
5-
from .entities import StreamConsumerOptions
5+
from .entities import ConsumerOptions
66
from .options import (
77
ReceiverOptionUnsettled,
88
ReceiverOptionUnsettledWithFilters,
@@ -38,7 +38,7 @@ def __init__(
3838
conn: BlockingConnection,
3939
addr: str,
4040
handler: Optional[AMQPMessagingHandler] = None,
41-
stream_options: Optional[StreamConsumerOptions] = None,
41+
stream_options: Optional[ConsumerOptions] = None,
4242
credit: Optional[int] = None,
4343
):
4444
"""

rabbitmq_amqp_python_client/entities.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,14 @@ class ExchangeToExchangeBindingSpecification:
153153
binding_key: Optional[str] = None
154154

155155

156+
class ConsumerOptions:
157+
def validate(self, versions: Dict[str, bool]) -> None:
158+
raise NotImplementedError("Subclasses should implement this method")
159+
160+
def filter_set(self) -> Dict[symbol, Described]:
161+
raise NotImplementedError("Subclasses should implement this method")
162+
163+
156164
@dataclass
157165
class MessageProperties:
158166
"""
@@ -215,7 +223,7 @@ def __init__(
215223
self.sql = sql
216224

217225

218-
class StreamConsumerOptions:
226+
class StreamConsumerOptions(ConsumerOptions):
219227
"""
220228
Configuration options for stream queues.
221229
@@ -237,6 +245,7 @@ def __init__(
237245
):
238246

239247
self._filter_set: Dict[symbol, Described] = {}
248+
self._filter_option = filter_options
240249

241250
if offset_specification is None and filter_options is None:
242251
raise ValidationCodeException(
@@ -329,7 +338,6 @@ def _filter_message_properties(
329338
def _filter_application_properties(
330339
self, application_properties: Optional[dict[str, Any]]
331340
) -> None:
332-
app_prop = {}
333341
if application_properties is not None:
334342
app_prop = application_properties.copy()
335343

@@ -356,6 +364,41 @@ def filter_set(self) -> Dict[symbol, Described]:
356364
"""
357365
return self._filter_set
358366

367+
def validate(self, versions: Dict[str, bool]) -> None:
368+
"""
369+
Validates stream filter options against supported RabbitMQ server versions.
370+
371+
Args:
372+
versions: Dictionary mapping version strings to boolean indicating support.
373+
374+
Raises:
375+
ValidationCodeException: If a filter option requires a higher RabbitMQ version.
376+
"""
377+
if self._filter_option is None:
378+
return
379+
if self._filter_option.values and not versions.get("4.1.0", False):
380+
raise ValidationCodeException(
381+
"Stream filter by values requires RabbitMQ 4.1.0 or higher"
382+
)
383+
if self._filter_option.match_unfiltered and not versions.get("4.1.0", False):
384+
raise ValidationCodeException(
385+
"Stream filter by match_unfiltered requires RabbitMQ 4.1.0 or higher"
386+
)
387+
if self._filter_option.sql and not versions.get("4.2.0", False):
388+
raise ValidationCodeException(
389+
"Stream filter by SQL requires RabbitMQ 4.2.0 or higher"
390+
)
391+
if self._filter_option.message_properties and not versions.get("4.1.0", False):
392+
raise ValidationCodeException(
393+
"Stream filter by SQL requires RabbitMQ 4.1.0 or higher"
394+
)
395+
if self._filter_option.application_properties and not versions.get(
396+
"4.1.0", False
397+
):
398+
raise ValidationCodeException(
399+
"Stream filter by SQL requires RabbitMQ 4.1.0 or higher"
400+
)
401+
359402

360403
@dataclass
361404
class RecoveryConfiguration:

rabbitmq_amqp_python_client/options.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .entities import StreamConsumerOptions
1+
from .entities import ConsumerOptions
22
from .qpid.proton._data import ( # noqa: E402
33
PropertyDict,
44
symbol,
@@ -68,8 +68,8 @@ def test(self, link: Link) -> bool:
6868

6969

7070
class ReceiverOptionUnsettledWithFilters(Filter): # type: ignore
71-
def __init__(self, addr: str, filter_options: StreamConsumerOptions):
72-
super().__init__(filter_options.filter_set())
71+
def __init__(self, addr: str, consumer_options: ConsumerOptions):
72+
super().__init__(consumer_options.filter_set())
7373
self._addr = addr
7474

7575
def apply(self, link: Link) -> None:

tests/test_server_validation.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def test_is_server_version_gte_4_2_0_exact_version(self):
306306
mock_blocking_conn.conn = mock_proton_conn
307307
self.connection._conn = mock_blocking_conn
308308

309-
result = self.connection._is_server_version_gte_4_2_0()
309+
result = self.connection._is_server_version_gte("4.2.0")
310310
assert result is True
311311

312312
def test_is_server_version_gte_4_2_0_higher_versions(self):
@@ -322,7 +322,7 @@ def test_is_server_version_gte_4_2_0_higher_versions(self):
322322
mock_blocking_conn.conn = mock_proton_conn
323323
self.connection._conn = mock_blocking_conn
324324

325-
result = self.connection._is_server_version_gte_4_2_0()
325+
result = self.connection._is_server_version_gte("4.2.0")
326326
assert result is True, f"Version {version_str} should return True"
327327

328328
def test_is_server_version_gte_4_2_0_lower_versions(self):
@@ -338,15 +338,15 @@ def test_is_server_version_gte_4_2_0_lower_versions(self):
338338
mock_blocking_conn.conn = mock_proton_conn
339339
self.connection._conn = mock_blocking_conn
340340

341-
result = self.connection._is_server_version_gte_4_2_0()
341+
result = self.connection._is_server_version_gte("4.2.0")
342342
assert result is False, f"Version {version_str} should return False"
343343

344344
def test_is_server_version_gte_4_2_0_no_connection(self):
345345
"""Test when connection is None."""
346346
self.connection._conn = None
347347

348348
with pytest.raises(ValidationCodeException) as exc_info:
349-
self.connection._is_server_version_gte_4_2_0()
349+
self.connection._is_server_version_gte("4.2.0")
350350

351351
assert "Connection not established" in str(exc_info.value)
352352

@@ -357,7 +357,7 @@ def test_is_server_version_gte_4_2_0_no_proton_connection(self):
357357
self.connection._conn = mock_blocking_conn
358358

359359
with pytest.raises(ValidationCodeException) as exc_info:
360-
self.connection._is_server_version_gte_4_2_0()
360+
self.connection._is_server_version_gte("4.2.0")
361361

362362
assert "Connection not established" in str(exc_info.value)
363363

@@ -370,7 +370,7 @@ def test_is_server_version_gte_4_2_0_no_remote_properties(self):
370370
self.connection._conn = mock_blocking_conn
371371

372372
with pytest.raises(ValidationCodeException) as exc_info:
373-
self.connection._is_server_version_gte_4_2_0()
373+
self.connection._is_server_version_gte("4.2.0")
374374

375375
assert "No remote properties received from server" in str(exc_info.value)
376376

@@ -388,7 +388,7 @@ def test_is_server_version_gte_4_2_0_missing_version(self):
388388
self.connection._conn = mock_blocking_conn
389389

390390
with pytest.raises(ValidationCodeException) as exc_info:
391-
self.connection._is_server_version_gte_4_2_0()
391+
self.connection._is_server_version_gte("4.2.0")
392392

393393
assert "Server version not provided" in str(exc_info.value)
394394

@@ -406,7 +406,7 @@ def test_is_server_version_gte_4_2_0_invalid_version_format(self):
406406
self.connection._conn = mock_blocking_conn
407407

408408
with pytest.raises(ValidationCodeException) as exc_info:
409-
self.connection._is_server_version_gte_4_2_0()
409+
self.connection._is_server_version_gte("4.2.0")
410410

411411
error_msg = str(exc_info.value)
412412
assert "Failed to parse server version" in error_msg
@@ -419,7 +419,10 @@ def test_is_server_version_gte_4_2_0_edge_cases(self):
419419
("4.2.0", True), # Exact match
420420
("4.2.0.0", True), # With extra zeroes
421421
("v4.2.0", True), # With v prefix
422-
("4.2.0-rc1", False), # Pre-release should be less than 4.2.0
422+
(
423+
"4.2.0-rc1",
424+
True,
425+
), # Pre-release should be less than 4.2.0 but accepted it equal
423426
]
424427

425428
for version_str, expected in test_cases:
@@ -433,12 +436,12 @@ def test_is_server_version_gte_4_2_0_edge_cases(self):
433436

434437
if version_str == "4.2.0-rc1":
435438
# Pre-release versions should be handled correctly
436-
result = self.connection._is_server_version_gte_4_2_0()
439+
result = self.connection._is_server_version_gte("4.2.0")
437440
assert (
438441
result == expected
439442
), f"Version {version_str} should return {expected}"
440443
else:
441-
result = self.connection._is_server_version_gte_4_2_0()
444+
result = self.connection._is_server_version_gte("4.2.0")
442445
assert (
443446
result == expected
444447
), f"Version {version_str} should return {expected}"

0 commit comments

Comments
 (0)