Skip to content

Commit 09a0d00

Browse files
authored
Merge pull request #3338 from jarhodes314/bug/flaky-unit-tests
refactor: fix flaky unit tests
2 parents 4487507 + 2483809 commit 09a0d00

File tree

12 files changed

+224
-204
lines changed

12 files changed

+224
-204
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/common/axum_tls/src/files.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ mod tests {
114114
use assert_matches::assert_matches;
115115
use axum::routing::get;
116116
use axum::Router;
117+
use camino::Utf8PathBuf;
117118
use std::io::Cursor;
118119

119120
mod read_trust_store {
@@ -199,27 +200,28 @@ mod tests {
199200
}
200201

201202
fn copy_test_file_to(test_file: &str, path: impl AsRef<Path>) -> io::Result<u64> {
202-
std::fs::copy(format!("./test_data/{test_file}"), path)
203+
let dir = env!("CARGO_MANIFEST_DIR");
204+
std::fs::copy(format!("{dir}/test_data/{test_file}"), path)
203205
}
204206
}
205207

206208
#[test]
207209
fn load_pkey_fails_when_given_x509_certificate() {
210+
let dir = env!("CARGO_MANIFEST_DIR");
211+
let path = Utf8PathBuf::from(format!("{dir}/test_data/ec.crt"));
208212
assert_eq!(
209-
load_pkey(Utf8Path::new("./test_data/ec.crt"))
210-
.unwrap_err()
211-
.to_string(),
212-
"expected private key in \"./test_data/ec.crt\", found an X509 certificate"
213+
load_pkey(&path).unwrap_err().to_string(),
214+
format!("expected private key in {path:?}, found an X509 certificate")
213215
);
214216
}
215217

216218
#[test]
217219
fn load_pkey_fails_when_given_certificate_revocation_list() {
220+
let dir = env!("CARGO_MANIFEST_DIR");
221+
let path = Utf8PathBuf::from(format!("{dir}/test_data/demo.crl"));
218222
assert_eq!(
219-
load_pkey(Utf8Path::new("./test_data/demo.crl"))
220-
.unwrap_err()
221-
.to_string(),
222-
"expected private key in \"./test_data/demo.crl\", found a CRL"
223+
load_pkey(&path).unwrap_err().to_string(),
224+
format!("expected private key in {path:?}, found a CRL")
223225
);
224226
}
225227

@@ -288,7 +290,8 @@ mod tests {
288290
}
289291

290292
fn test_data(file_name: &str) -> String {
291-
std::fs::read_to_string(format!("./test_data/{file_name}"))
293+
let dir = env!("CARGO_MANIFEST_DIR");
294+
std::fs::read_to_string(format!("{dir}/test_data/{file_name}"))
292295
.with_context(|| format!("opening file {file_name} from test_data"))
293296
.unwrap()
294297
}

crates/common/mqtt_channel/Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@ log = { workspace = true }
1717
rumqttc = { workspace = true }
1818
serde = { workspace = true }
1919
thiserror = { workspace = true }
20-
tokio = { workspace = true, features = ["rt", "time"] }
20+
tokio = { workspace = true, features = ["rt", "time", "rt-multi-thread"] }
2121
zeroize = { workspace = true }
2222

2323
[dev-dependencies]
2424
anyhow = { workspace = true }
2525
mqtt_tests = { workspace = true }
2626
serde_json = { workspace = true }
27-
serial_test = { workspace = true }
2827

2928
[lints]
3029
workspace = true

crates/common/mqtt_channel/src/connection.rs

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ use rumqttc::EventLoop;
1616
use rumqttc::Incoming;
1717
use rumqttc::Outgoing;
1818
use rumqttc::Packet;
19+
use std::sync::Arc;
1920
use std::time::Duration;
21+
use tokio::sync::OwnedSemaphorePermit;
22+
use tokio::sync::Semaphore;
2023
use tokio::time::sleep;
2124

2225
/// A connection to some MQTT server
@@ -88,19 +91,23 @@ impl Connection {
8891

8992
let (mqtt_client, event_loop) =
9093
Connection::open(config, received_sender.clone(), error_sender.clone()).await?;
94+
let permits = Arc::new(Semaphore::new(1));
95+
let permit = permits.clone().acquire_owned().await.unwrap();
9196
tokio::spawn(Connection::receiver_loop(
9297
mqtt_client.clone(),
9398
config.clone(),
9499
event_loop,
95100
received_sender,
96101
error_sender.clone(),
102+
pub_done_sender,
103+
permits,
97104
));
98105
tokio::spawn(Connection::sender_loop(
99106
mqtt_client,
100107
published_receiver,
101108
error_sender,
102109
config.last_will_message.clone(),
103-
pub_done_sender,
110+
permit,
104111
));
105112

106113
Ok(Connection {
@@ -200,9 +207,41 @@ impl Connection {
200207
mut event_loop: EventLoop,
201208
mut message_sender: mpsc::UnboundedSender<MqttMessage>,
202209
mut error_sender: mpsc::UnboundedSender<MqttError>,
210+
done: oneshot::Sender<()>,
211+
permits: Arc<Semaphore>,
203212
) -> Result<(), MqttError> {
213+
let mut triggered_disconnect = false;
214+
let mut disconnect_permit = None;
215+
204216
loop {
205-
match event_loop.poll().await {
217+
// Check if we are ready to disconnect. Due to ownership of the
218+
// event loop, this needs to be done before we call
219+
// `event_loop.poll()`
220+
let remaining_events_empty = event_loop.state.inflight() == 0;
221+
if disconnect_permit.is_some() && !triggered_disconnect && remaining_events_empty {
222+
// `sender_loop` is not running and we have no remaining
223+
// publishes to process
224+
let client = mqtt_client.clone();
225+
tokio::spawn(async move { client.disconnect().await });
226+
triggered_disconnect = true;
227+
}
228+
229+
let event = tokio::select! {
230+
// If there is an event, we need to process that first
231+
// Otherwise we risk shutting down early
232+
// e.g. a `Publish` request from the sender is not "inflight"
233+
// but will immediately be returned by `event_loop.poll()`
234+
biased;
235+
236+
event = event_loop.poll() => event,
237+
permit = permits.clone().acquire_owned() => {
238+
// The `sender_loop` has now concluded
239+
disconnect_permit = Some(permit.unwrap());
240+
continue;
241+
}
242+
};
243+
244+
match event {
206245
Ok(Event::Incoming(Packet::Publish(msg))) => {
207246
if msg.payload.len() > config.max_packet_size {
208247
error!("Dropping message received on topic {} with payload size {} that exceeds the maximum packet size of {}",
@@ -266,6 +305,7 @@ impl Connection {
266305
// No more messages will be forwarded to the client
267306
let _ = message_sender.close().await;
268307
let _ = error_sender.close().await;
308+
let _ = done.send(());
269309
Ok(())
270310
}
271311

@@ -274,24 +314,15 @@ impl Connection {
274314
mut messages_receiver: mpsc::UnboundedReceiver<MqttMessage>,
275315
mut error_sender: mpsc::UnboundedSender<MqttError>,
276316
last_will: Option<MqttMessage>,
277-
done: oneshot::Sender<()>,
317+
_disconnect_permit: OwnedSemaphorePermit,
278318
) {
279-
loop {
280-
match messages_receiver.next().await {
281-
None => {
282-
// The sender channel has been closed by the client
283-
// No more messages will be published by the client
284-
break;
285-
}
286-
Some(message) => {
287-
let payload = Vec::from(message.payload_bytes());
288-
if let Err(err) = mqtt_client
289-
.publish(message.topic, message.qos, message.retain, payload)
290-
.await
291-
{
292-
let _ = error_sender.send(err.into()).await;
293-
}
294-
}
319+
while let Some(message) = messages_receiver.next().await {
320+
let payload = Vec::from(message.payload_bytes());
321+
if let Err(err) = mqtt_client
322+
.publish(message.topic, message.qos, message.retain, payload)
323+
.await
324+
{
325+
let _ = error_sender.send(err.into()).await;
295326
}
296327
}
297328

@@ -303,8 +334,9 @@ impl Connection {
303334
.publish(last_will.topic, last_will.qos, last_will.retain, payload)
304335
.await;
305336
}
306-
let _ = mqtt_client.disconnect().await;
307-
let _ = done.send(());
337+
338+
// At this point, `_disconnect_permit` is dropped
339+
// This allows `receiver_loop` acquire a permit and commence the shutdown process
308340
}
309341

310342
pub(crate) async fn do_pause() {

0 commit comments

Comments
 (0)