Skip to content

Commit d940ed2

Browse files
authored
Fix test flakiness from TCP port contention (payjoin#388)
Eliminate the lock contention on test service TCP sockets leading to random, frequent test failure.
2 parents e54e51f + 3c20429 commit d940ed2

File tree

7 files changed

+178
-97
lines changed

7 files changed

+178
-97
lines changed

Cargo-minimal.lock

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,9 +1450,9 @@ dependencies = [
14501450

14511451
[[package]]
14521452
name = "ohttp-relay"
1453-
version = "0.0.8"
1453+
version = "0.0.9"
14541454
source = "registry+https://github.com/rust-lang/crates.io-index"
1455-
checksum = "7850c40a0aebcba289d3252c0a45f93cba6ad4b0c46b88a5fc51dba6ddce8632"
1455+
checksum = "4f8e8aef13b8327b680aaaca807aa11ba5979fc5858203e7b77c68128ede61a2"
14561456
dependencies = [
14571457
"futures",
14581458
"http",

Cargo-recent.lock

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,9 +1450,9 @@ dependencies = [
14501450

14511451
[[package]]
14521452
name = "ohttp-relay"
1453-
version = "0.0.8"
1453+
version = "0.0.9"
14541454
source = "registry+https://github.com/rust-lang/crates.io-index"
1455-
checksum = "7850c40a0aebcba289d3252c0a45f93cba6ad4b0c46b88a5fc51dba6ddce8632"
1455+
checksum = "4f8e8aef13b8327b680aaaca807aa11ba5979fc5858203e7b77c68128ede61a2"
14561456
dependencies = [
14571457
"futures",
14581458
"http",

payjoin-cli/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ url = { version = "2.3.1", features = ["serde"] }
5050
[dev-dependencies]
5151
bitcoind = { version = "0.36.0", features = ["0_21_2"] }
5252
http = "1"
53-
ohttp-relay = "0.0.8"
53+
ohttp-relay = { version = "0.0.9", features = ["_test-util"] }
5454
once_cell = "1"
5555
payjoin-directory = { path = "../payjoin-directory", features = ["_danger-local-https"] }
5656
testcontainers = "0.15.0"

payjoin-cli/tests/e2e.rs

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,11 @@ mod e2e {
151151
payjoin_sent.unwrap().unwrap_or(Some(false)).unwrap(),
152152
"Payjoin send was not detected"
153153
);
154+
155+
fn find_free_port() -> u16 {
156+
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
157+
listener.local_addr().unwrap().port()
158+
}
154159
}
155160

156161
#[cfg(feature = "v2")]
@@ -170,6 +175,7 @@ mod e2e {
170175
use url::Url;
171176

172177
type Error = Box<dyn std::error::Error + 'static>;
178+
type BoxSendSyncError = Box<dyn std::error::Error + Send + Sync>;
173179
type Result<T> = std::result::Result<T, Error>;
174180

175181
static INIT_TRACING: OnceCell<()> = OnceCell::new();
@@ -178,18 +184,26 @@ mod e2e {
178184

179185
init_tracing();
180186
let (cert, key) = local_cert_key();
181-
let ohttp_relay_port = find_free_port();
182-
let ohttp_relay = Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
183-
let directory_port = find_free_port();
184-
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
187+
let docker: Cli = Cli::default();
188+
let db = docker.run(Redis);
189+
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));
190+
let (port, directory_handle) =
191+
init_directory(db_host, (cert.clone(), key)).await.expect("Failed to init directory");
192+
let directory = Url::parse(&format!("https://localhost:{}", port)).unwrap();
193+
185194
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
195+
let (ohttp_relay_port, ohttp_relay_handle) =
196+
ohttp_relay::listen_tcp_on_free_port(gateway_origin)
197+
.await
198+
.expect("Failed to init ohttp relay");
199+
let ohttp_relay = Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
186200

187201
let temp_dir = env::temp_dir();
188202
let receiver_db_path = temp_dir.join("receiver_db");
189203
let sender_db_path = temp_dir.join("sender_db");
190204
let result: Result<()> = tokio::select! {
191-
res = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => Err(format!("Ohttp relay is long running: {:?}", res).into()),
192-
res = init_directory(directory_port, (cert.clone(), key)) => Err(format!("Directory server is long running: {:?}", res).into()),
205+
res = ohttp_relay_handle => Err(format!("Ohttp relay is long running: {:?}", res).into()),
206+
res = directory_handle => Err(format!("Directory server is long running: {:?}", res).into()),
193207
res = send_receive_cli_async(ohttp_relay, directory, cert, receiver_db_path.clone(), sender_db_path.clone()) => res.map_err(|e| format!("send_receive failed: {:?}", e).into()),
194208
};
195209

@@ -479,13 +493,17 @@ mod e2e {
479493
Err("Timeout waiting for service to be ready".into())
480494
}
481495

482-
async fn init_directory(port: u16, local_cert_key: (Vec<u8>, Vec<u8>)) -> Result<()> {
483-
let docker: Cli = Cli::default();
496+
async fn init_directory(
497+
db_host: String,
498+
local_cert_key: (Vec<u8>, Vec<u8>),
499+
) -> std::result::Result<
500+
(u16, tokio::task::JoinHandle<std::result::Result<(), BoxSendSyncError>>),
501+
BoxSendSyncError,
502+
> {
503+
println!("Database running on {}", db_host);
484504
let timeout = Duration::from_secs(2);
485-
let db = docker.run(Redis);
486-
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));
487-
println!("Database running on {}", db.get_host_port_ipv4(6379));
488-
payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await
505+
payjoin_directory::listen_tcp_with_tls_on_free_port(db_host, timeout, local_cert_key)
506+
.await
489507
}
490508

491509
// generates or gets a DER encoded localhost cert and key.
@@ -524,11 +542,6 @@ mod e2e {
524542
}
525543
}
526544

527-
fn find_free_port() -> u16 {
528-
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
529-
listener.local_addr().unwrap().port()
530-
}
531-
532545
async fn cleanup_temp_file(path: &std::path::Path) {
533546
if let Err(e) = fs::remove_dir_all(path).await {
534547
eprintln!("Failed to remove {:?}: {}", path, e);

payjoin-directory/src/lib.rs

Lines changed: 66 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,67 @@ const ID_LENGTH: usize = 13;
3636
mod db;
3737
use crate::db::DbPool;
3838

39+
#[cfg(feature = "_danger-local-https")]
40+
type BoxError = Box<dyn std::error::Error + Send + Sync>;
41+
42+
#[cfg(feature = "_danger-local-https")]
43+
pub async fn listen_tcp_with_tls_on_free_port(
44+
db_host: String,
45+
timeout: Duration,
46+
cert_key: (Vec<u8>, Vec<u8>),
47+
) -> Result<(u16, tokio::task::JoinHandle<Result<(), BoxError>>), BoxError> {
48+
let listener = tokio::net::TcpListener::bind("[::]:0").await?;
49+
let port = listener.local_addr()?.port();
50+
println!("Directory server binding to port {}", listener.local_addr()?);
51+
let handle = listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key).await?;
52+
Ok((port, handle))
53+
}
54+
55+
// Helper function to avoid code duplication
56+
#[cfg(feature = "_danger-local-https")]
57+
async fn listen_tcp_with_tls_on_listener(
58+
listener: tokio::net::TcpListener,
59+
db_host: String,
60+
timeout: Duration,
61+
tls_config: (Vec<u8>, Vec<u8>),
62+
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
63+
let pool = DbPool::new(timeout, db_host).await?;
64+
let ohttp = Arc::new(Mutex::new(init_ohttp()?));
65+
let tls_acceptor = init_tls_acceptor(tls_config)?;
66+
// Spawn the connection handling loop in a separate task
67+
let handle = tokio::spawn(async move {
68+
while let Ok((stream, _)) = listener.accept().await {
69+
let pool = pool.clone();
70+
let ohttp = ohttp.clone();
71+
let tls_acceptor = tls_acceptor.clone();
72+
tokio::spawn(async move {
73+
let tls_stream = match tls_acceptor.accept(stream).await {
74+
Ok(tls_stream) => tls_stream,
75+
Err(e) => {
76+
error!("TLS accept error: {}", e);
77+
return;
78+
}
79+
};
80+
if let Err(err) = http1::Builder::new()
81+
.serve_connection(
82+
TokioIo::new(tls_stream),
83+
service_fn(move |req| {
84+
serve_payjoin_directory(req, pool.clone(), ohttp.clone())
85+
}),
86+
)
87+
.with_upgrades()
88+
.await
89+
{
90+
error!("Error serving connection: {:?}", err);
91+
}
92+
});
93+
}
94+
Ok(())
95+
});
96+
Ok(handle)
97+
}
98+
99+
// Modify existing listen_tcp_with_tls to use the new helper
39100
pub async fn listen_tcp(
40101
port: u16,
41102
db_host: String,
@@ -73,41 +134,11 @@ pub async fn listen_tcp_with_tls(
73134
port: u16,
74135
db_host: String,
75136
timeout: Duration,
76-
tls_config: (Vec<u8>, Vec<u8>),
77-
) -> Result<(), Box<dyn std::error::Error>> {
78-
let pool = DbPool::new(timeout, db_host).await?;
79-
let ohttp = Arc::new(Mutex::new(init_ohttp()?));
80-
let bind_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port);
81-
let tls_acceptor = init_tls_acceptor(tls_config)?;
82-
let listener = TcpListener::bind(bind_addr).await?;
83-
while let Ok((stream, _)) = listener.accept().await {
84-
let pool = pool.clone();
85-
let ohttp = ohttp.clone();
86-
let tls_acceptor = tls_acceptor.clone();
87-
tokio::spawn(async move {
88-
let tls_stream = match tls_acceptor.accept(stream).await {
89-
Ok(tls_stream) => tls_stream,
90-
Err(e) => {
91-
error!("TLS accept error: {}", e);
92-
return;
93-
}
94-
};
95-
if let Err(err) = http1::Builder::new()
96-
.serve_connection(
97-
TokioIo::new(tls_stream),
98-
service_fn(move |req| {
99-
serve_payjoin_directory(req, pool.clone(), ohttp.clone())
100-
}),
101-
)
102-
.with_upgrades()
103-
.await
104-
{
105-
error!("Error serving connection: {:?}", err);
106-
}
107-
});
108-
}
109-
110-
Ok(())
137+
cert_key: (Vec<u8>, Vec<u8>),
138+
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
139+
let addr = format!("0.0.0.0:{}", port);
140+
let listener = tokio::net::TcpListener::bind(&addr).await?;
141+
listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key).await
111142
}
112143

113144
#[cfg(feature = "_danger-local-https")]

payjoin/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ serde_json = "1.0.108"
4242
bitcoind = { version = "0.36.0", features = ["0_21_2"] }
4343
http = "1"
4444
payjoin-directory = { path = "../payjoin-directory", features = ["_danger-local-https"] }
45-
ohttp-relay = "0.0.8"
45+
ohttp-relay = { version = "0.0.9", features = ["_test-util"] }
4646
once_cell = "1"
4747
rcgen = { version = "0.11" }
4848
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }

0 commit comments

Comments
 (0)