Skip to content

Commit 190f651

Browse files
committed
Add scenario tests for connection interruptions
1 parent fd0b0d3 commit 190f651

File tree

10 files changed

+569
-0
lines changed

10 files changed

+569
-0
lines changed

dev_requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@ uvloop
1717
vulture>=2.3.0
1818
wheel>=0.30.0
1919
numpy>=1.24.0
20+
requests>=2.23.0
21+
aiohttp>=3.0.0

tests/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,13 @@ def pytest_addoption(parser):
134134
help="Name of the Redis master service that the sentinels are monitoring",
135135
)
136136

137+
parser.addoption(
138+
"--endpoints-config",
139+
action="store",
140+
default="endpoints.json",
141+
help="Path to the Redis endpoints configuration file",
142+
)
143+
137144

138145
def _get_info(redis_url):
139146
client = redis.Redis.from_url(redis_url)

tests/scenario/__init__.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import json
2+
import os.path
3+
import dataclasses
4+
5+
import pytest
6+
from urllib.parse import urlparse
7+
8+
9+
@dataclasses.dataclass
10+
class Endpoint:
11+
bdb_id: int
12+
username: str
13+
password: str
14+
tls: bool
15+
endpoints: list[str]
16+
17+
@property
18+
def url(self):
19+
parsed_url = urlparse(self.endpoints[0])
20+
21+
if self.tls:
22+
parsed_url = parsed_url._replace(scheme="rediss")
23+
24+
domain = parsed_url.netloc.split("@")[-1]
25+
domain = f"{self.username}:{self.password}@{domain}"
26+
27+
parsed_url = parsed_url._replace(netloc=domain)
28+
29+
return parsed_url.geturl()
30+
31+
@classmethod
32+
def from_dict(cls, data: dict):
33+
field_names = set(f.name for f in dataclasses.fields(cls))
34+
return cls(**{k: v for k, v in data.items() if k in field_names})
35+
36+
37+
def get_endpoint(request: pytest.FixtureRequest, endpoint_name: str) -> Endpoint:
38+
endpoints_config_path = request.config.getoption("--endpoints-config")
39+
40+
if not (endpoints_config_path and os.path.exists(endpoints_config_path)):
41+
raise ValueError(f"Endpoints config file not found: {endpoints_config_path}")
42+
43+
try:
44+
with open(endpoints_config_path, "r") as f:
45+
endpoints_config = json.load(f)
46+
except Exception as e:
47+
raise ValueError(
48+
f"Failed to load endpoints config file: {endpoints_config_path}"
49+
) from e
50+
51+
if not (isinstance(endpoints_config, dict) and endpoint_name in endpoints_config):
52+
raise ValueError(f"Endpoint not found in config: {endpoint_name}")
53+
54+
return Endpoint.from_dict(endpoints_config.get(endpoint_name))

tests/scenario/fake_app.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import multiprocessing
2+
import typing
3+
4+
from threading import Thread, Event
5+
from multiprocessing import Process, Event as PEvent
6+
from unittest.mock import patch
7+
8+
from redis import Redis
9+
10+
11+
class FakeApp:
12+
13+
def __init__(self, client: Redis, logic: typing.Callable[[Redis], None]):
14+
self.client = client
15+
self.logic = logic
16+
self.disconnects = 0
17+
18+
def run(self) -> (Event, Thread):
19+
e = Event()
20+
t = Thread(target=self._run_logic, args=(e,))
21+
t.start()
22+
return e, t
23+
24+
def _run_logic(self, e: Event):
25+
with patch.object(
26+
self.client, "_disconnect_raise", wraps=self.client._disconnect_raise
27+
) as spy:
28+
while not e.is_set():
29+
self.logic(self.client)
30+
31+
self.disconnects = spy.call_count
32+
33+
34+
class FakeSubscriber:
35+
36+
def __init__(self, client: Redis, logic: typing.Callable[[dict], None]):
37+
self.client = client
38+
self.logic = logic
39+
self.disconnects = multiprocessing.Value("i", 0)
40+
41+
def run(self, channel: str) -> (PEvent, Process):
42+
e, started = PEvent(), PEvent()
43+
p = Process(target=self._run_logic, args=(e, started, channel))
44+
p.start()
45+
return e, started, p
46+
47+
def _run_logic(self, should_stop: PEvent, started: PEvent, channel: str):
48+
pubsub = self.client.pubsub()
49+
50+
with patch.object(
51+
pubsub, "_disconnect_raise_connect", wraps=pubsub._disconnect_raise_connect
52+
) as spy_pubsub:
53+
pubsub.subscribe(channel)
54+
55+
started.set()
56+
57+
while not should_stop.is_set():
58+
message = pubsub.get_message(ignore_subscribe_messages=True, timeout=1)
59+
60+
if message:
61+
self.logic(message)
62+
63+
self.disconnects.value = spy_pubsub.call_count
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import requests
2+
3+
4+
class TriggeredAction:
5+
6+
def __init__(self, client: "FaultInjectionClient", data: dict):
7+
self.client = client
8+
self.action_id = data["action_id"]
9+
self.data = data
10+
11+
def refresh(self):
12+
self.data = self.client.get_action(self.action_id)
13+
14+
@property
15+
def status(self):
16+
if "status" not in self.data:
17+
return "pending"
18+
return self.data["status"]
19+
20+
def wait_until_complete(self):
21+
while self.status not in ("success", "failed"):
22+
self.refresh()
23+
return self.status
24+
25+
26+
class FaultInjectionClient:
27+
def __init__(self, base_url: str = "http://127.0.0.1:20324"):
28+
self.base_url = base_url
29+
30+
def trigger_action(self, action_type: str, parameters: dict):
31+
response = requests.post(
32+
f"{self.base_url}/action",
33+
json={"type": action_type, "parameters": parameters},
34+
)
35+
return TriggeredAction(self, response.json())
36+
37+
def get_action(self, action_id: str):
38+
response = requests.get(f"{self.base_url}/action/{action_id}")
39+
return response.json()
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import multiprocessing
2+
3+
import socket
4+
import time
5+
6+
import pytest
7+
8+
9+
from redis import Redis, BusyLoadingError
10+
from redis.backoff import ExponentialBackoff
11+
from redis.retry import Retry
12+
from redis.exceptions import (
13+
ConnectionError as RedisConnectionError,
14+
TimeoutError as RedisTimeoutError,
15+
)
16+
17+
from ..conftest import _get_client
18+
from . import get_endpoint, Endpoint
19+
from .fake_app import FakeApp, FakeSubscriber
20+
from .fault_injection_client import FaultInjectionClient, TriggeredAction
21+
22+
23+
@pytest.fixture
24+
def endpoint_name():
25+
return "re-standalone"
26+
27+
28+
@pytest.fixture
29+
def endpoint(request: pytest.FixtureRequest, endpoint_name: str):
30+
return get_endpoint(request, endpoint_name)
31+
32+
33+
@pytest.fixture
34+
def clients(request: pytest.FixtureRequest, endpoint: Endpoint):
35+
# Use Recommended settings
36+
retry = Retry(ExponentialBackoff(base=1), 3)
37+
38+
clients = []
39+
40+
for _ in range(2):
41+
r = _get_client(
42+
Redis,
43+
request,
44+
decode_responses=True,
45+
from_url=endpoint.url,
46+
retry=retry,
47+
retry_on_error=[
48+
BusyLoadingError,
49+
RedisConnectionError,
50+
RedisTimeoutError,
51+
# FIXME: This is a workaround for a bug in redis-py
52+
# https://github.com/redis/redis-py/issues/3203
53+
ConnectionError,
54+
TimeoutError,
55+
],
56+
)
57+
r.flushdb()
58+
clients.append(r)
59+
return clients
60+
61+
62+
@pytest.fixture
63+
def fault_injection_client(request: pytest.FixtureRequest):
64+
return FaultInjectionClient()
65+
66+
67+
@pytest.mark.parametrize("action", ("dmc_restart", "network_failure"))
68+
def test_connection_interruptions(
69+
clients: list[Redis],
70+
endpoint: Endpoint,
71+
fault_injection_client: FaultInjectionClient,
72+
action: str,
73+
):
74+
client = clients.pop()
75+
app = FakeApp(client, lambda c: c.set("foo", "bar"))
76+
77+
stop_app, thread = app.run()
78+
79+
triggered_action = fault_injection_client.trigger_action(
80+
action, {"bdb_id": endpoint.bdb_id}
81+
)
82+
83+
triggered_action.wait_until_complete()
84+
85+
stop_app.set()
86+
thread.join()
87+
88+
if triggered_action.status == "failed":
89+
pytest.fail(f"Action failed: {triggered_action.data['error']}")
90+
91+
assert app.disconnects > 0, "Client did not disconnect"
92+
93+
94+
@pytest.mark.parametrize("action", ("dmc_restart", "network_failure"))
95+
def test_pubsub_with_connection_interruptions(
96+
clients: list[Redis],
97+
endpoint: Endpoint,
98+
fault_injection_client: FaultInjectionClient,
99+
action: str,
100+
):
101+
channel = "test"
102+
103+
# Subscriber is executed in a separate process to ensure it reacts
104+
# to the disconnection at the same time as the publisher
105+
with multiprocessing.Manager() as manager:
106+
received_messages = manager.list()
107+
108+
def read_message(message):
109+
nonlocal received_messages
110+
if message and message["type"] == "message":
111+
received_messages.append(message["data"])
112+
113+
subscriber_client = clients.pop()
114+
subscriber = FakeSubscriber(subscriber_client, read_message)
115+
stop_subscriber, subscriber_started, subscriber_t = subscriber.run(channel)
116+
117+
# Allow subscriber subscribe to the channel
118+
subscriber_started.wait(timeout=5)
119+
120+
messages_sent = 0
121+
122+
def publish_message(c):
123+
nonlocal messages_sent, channel
124+
messages_sent += 1
125+
c.publish(channel, messages_sent)
126+
127+
publisher_client = clients.pop()
128+
publisher = FakeApp(publisher_client, publish_message)
129+
stop_publisher, publisher_t = publisher.run()
130+
131+
triggered_action = fault_injection_client.trigger_action(
132+
action, {"bdb_id": endpoint.bdb_id}
133+
)
134+
135+
triggered_action.wait_until_complete()
136+
last_message_sent_after_trigger = messages_sent
137+
138+
time.sleep(3) # Wait for the publisher to send more messages
139+
140+
stop_publisher.set()
141+
publisher_t.join()
142+
143+
stop_subscriber.set()
144+
subscriber_t.join()
145+
146+
assert publisher.disconnects > 0
147+
assert subscriber.disconnects.value > 0
148+
149+
if triggered_action.status == "failed":
150+
pytest.fail(f"Action failed: {triggered_action.data['error']}")
151+
152+
assert (
153+
last_message_sent_after_trigger < messages_sent
154+
), "No messages were sent after the failure"
155+
assert (
156+
int(received_messages[-1]) == messages_sent
157+
), "Not all messages were received"

tests/test_asyncio/scenario/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)