diff --git a/Cargo.lock b/Cargo.lock index d62c790c63..99025b5a8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -883,6 +883,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + [[package]] name = "fuzz" version = "0.1.0" @@ -1480,6 +1486,7 @@ dependencies = [ "serde_json", "socket2 0.6.0", "tokio", + "tokio-util", "tracing", "tracing-subscriber", ] @@ -2310,6 +2317,19 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-util" +version = "0.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "tracing" version = "0.1.41" diff --git a/perf/Cargo.toml b/perf/Cargo.toml index ae2f5753ee..d40135fab6 100644 --- a/perf/Cargo.toml +++ b/perf/Cargo.toml @@ -32,5 +32,6 @@ serde = { workspace = true, optional = true } serde_json = { workspace = true, optional = true } socket2 = { workspace = true } tokio = { workspace = true, features = ["rt", "macros", "signal", "net"] } +tokio-util = { version = "0.7.15" } tracing = { workspace = true } tracing-subscriber = { workspace = true } diff --git a/perf/src/client.rs b/perf/src/client.rs index aebf722e0f..85fff93a7f 100644 --- a/perf/src/client.rs +++ b/perf/src/client.rs @@ -12,6 +12,7 @@ use clap::Parser; use quinn::{TokioRuntime, crypto::rustls::QuicClientConfig}; use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; use tokio::sync::Semaphore; +use tokio_util::sync::CancellationToken; use tracing::{debug, error, info}; use crate::{ @@ -142,69 +143,148 @@ pub async fn run(opt: Opt) -> Result<()> { info!("established"); - let drive_fut = async { - tokio::try_join!( - drive_uni( - connection.clone(), - stream_stats.clone(), - opt.uni_requests, - opt.upload_size, - opt.download_size - ), - drive_bi( - connection.clone(), - stream_stats.clone(), - opt.bi_requests, - opt.upload_size, - opt.download_size - ) + // This will be used to cancel drive futures + let shutdown_drive = CancellationToken::new(); + + // This will be used to cancel stat future once the drive futures are finished + let shutdown_stats = CancellationToken::new(); + + let shutdown2 = shutdown_drive.clone(); + let connection2 = connection.clone(); + let stream_stats2 = stream_stats.clone(); + let mut drive_uni_fut = tokio::spawn(async move { + drive_uni( + shutdown2, + connection2, + stream_stats2, + opt.uni_requests, + opt.upload_size, + opt.download_size, ) - }; + .await + }); + + let shutdown2 = shutdown_drive.clone(); + let connection2 = connection.clone(); + let stream_stats2 = stream_stats.clone(); + let mut drive_bi_fut = tokio::spawn(async move { + drive_bi( + shutdown2, + connection2, + stream_stats2, + opt.bi_requests, + opt.upload_size, + opt.download_size, + ) + .await + }); let mut stats = Stats::default(); - let stats_fut = async { + let shutdown2 = shutdown_stats.clone(); + let connection2 = connection.clone(); + let mut stats_fut = tokio::spawn(async move { let interval_duration = Duration::from_secs(opt.interval); - loop { - let start = Instant::now(); - tokio::time::sleep(interval_duration).await; - { - stats.on_interval(start, &stream_stats); + let start = Instant::now(); - stats.print(); - if opt.common.conn_stats { - println!("{:?}\n", connection.stats()); + loop { + tokio::select! { + biased; + _ = shutdown2.cancelled() => { + debug!("stats_fut: leaving"); + + stats.on_interval(start, &stream_stats); + + stats.print(); + if opt.common.conn_stats { + println!("{:?}\n", connection2.stats()); + } + + #[cfg(feature = "json-output")] + if let Some(path) = opt.json { + stats.print_json(path.as_path()).unwrap(); // FIXME handle ? + } + + break; + }, + _ = tokio::time::sleep(interval_duration) => { + stats.on_interval(start, &stream_stats); + + stats.print(); + if opt.common.conn_stats { + println!("{:?}\n", connection2.stats()); + } } } } - }; + }); - tokio::select! { - _ = drive_fut => {} - _ = stats_fut => {} - _ = tokio::signal::ctrl_c() => { - info!("shutting down"); - connection.close(0u32.into(), b"interrupted"); - } - // Add a small duration so the final interval can be reported - _ = tokio::time::sleep(Duration::from_secs(opt.duration) + Duration::from_millis(200)) => { - info!("shutting down"); - connection.close(0u32.into(), b"done"); + let mut drive_uni_fut_exited = false; + let mut drive_bi_fut_exited = false; + let mut ctrlc_fut_exited = false; + let mut duration_fut_exited = false; + let mut remaining_drive_tasks = 2; + let mut reason = String::new(); + loop { + tokio::select! { + res = &mut drive_uni_fut, if !drive_uni_fut_exited => { + if let Err(err) = res { + error!("drive_uni left with error {err}"); + } + + drive_uni_fut_exited = true; + remaining_drive_tasks -= 1; + + if remaining_drive_tasks == 0 { + // we can cancel stats future as all drive futures have finished + shutdown_stats.cancel(); + } + } + res = &mut drive_bi_fut, if !drive_bi_fut_exited => { + if let Err(err) = res { + error!("drive_bi left with error {err}"); + } + + drive_bi_fut_exited = true; + remaining_drive_tasks -= 1; + + if remaining_drive_tasks == 0 { + // we can cancel stats future as all drive futures have finished + shutdown_stats.cancel(); + } + } + _ = &mut stats_fut => { + break; + } + _ = tokio::signal::ctrl_c(), if !ctrlc_fut_exited => { + info!("shutting down (ctrl-c)"); + ctrlc_fut_exited = true; + + shutdown_drive.cancel(); + + reason = "interrupted".to_owned(); + } + _ = tokio::time::sleep(Duration::from_secs(opt.duration)), if !duration_fut_exited => { + duration_fut_exited = true; + info!("shutting down (timeout)"); + + shutdown_drive.cancel(); + + reason = "done".to_owned(); + } } } - endpoint.wait_idle().await; + connection.close(0u32.into(), &reason.into_bytes()); - #[cfg(feature = "json-output")] - if let Some(path) = opt.json { - stats.print_json(path.as_path())?; - } + endpoint.wait_idle().await; Ok(()) } async fn drain_stream( + shutdown: CancellationToken, mut stream: quinn::RecvStream, download: u64, stream_stats: OpenStreamStats, @@ -224,30 +304,44 @@ async fn drain_stream( Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), ]; - let download_start = Instant::now(); let recv_stream_stats = stream_stats.new_receiver(&stream, download); let mut first_byte = true; - - while let Some(size) = stream.read_chunks(&mut bufs[..]).await? { - if first_byte { - recv_stream_stats.on_first_byte(download_start.elapsed()); - first_byte = false; + let mut total_bytes_received = 0; + let download_start = Instant::now(); + loop { + tokio::select! { + biased; + _ = shutdown.cancelled() => { + break; + }, + res = stream.read_chunks(&mut bufs[..]) => { + if let Some(size) = res? { + if first_byte { + recv_stream_stats.on_first_byte(download_start.elapsed()); + first_byte = false; + } + let bytes_received = bufs[..size].iter().map(|b| b.len()).sum(); + recv_stream_stats.on_bytes(bytes_received); + total_bytes_received += bytes_received as u64; + } else { + break; + } + } } - let bytes_received = bufs[..size].iter().map(|b| b.len()).sum(); - recv_stream_stats.on_bytes(bytes_received); } if first_byte { recv_stream_stats.on_first_byte(download_start.elapsed()); } - recv_stream_stats.finish(download_start.elapsed()); + recv_stream_stats.finish(download_start.elapsed(), total_bytes_received); debug!("response finished on {}", stream.id()); Ok(()) } async fn drive_uni( + shutdown: CancellationToken, connection: quinn::Connection, stream_stats: OpenStreamStats, concurrency: u64, @@ -261,14 +355,22 @@ async fn drive_uni( let sem = Arc::new(Semaphore::new(concurrency as usize)); loop { + if shutdown.is_cancelled() { + debug!("drive_uni: leaving"); + return Ok(()); + } + let permit = sem.clone().acquire_owned().await.unwrap(); let send = connection.open_uni().await?; let stream_stats = stream_stats.clone(); debug!("sending request on {}", send.id()); let connection = connection.clone(); + let shutdown2 = shutdown.clone(); tokio::spawn(async move { - if let Err(e) = request_uni(send, connection, upload, download, stream_stats).await { + if let Err(e) = + request_uni(shutdown2, send, connection, upload, download, stream_stats).await + { error!("sending request failed: {:#}", e); } @@ -278,25 +380,34 @@ async fn drive_uni( } async fn request_uni( + shutdown: CancellationToken, send: quinn::SendStream, conn: quinn::Connection, upload: u64, download: u64, stream_stats: OpenStreamStats, ) -> Result<()> { - request(send, upload, download, stream_stats.clone()).await?; - let recv = conn.accept_uni().await?; - drain_stream(recv, download, stream_stats).await?; + request( + shutdown.clone(), + send, + upload, + download, + stream_stats.clone(), + ) + .await?; + let recv = conn.accept_uni().await?; // FIXME select ? + drain_stream(shutdown, recv, download, stream_stats).await?; Ok(()) } async fn request( + shutdown: CancellationToken, mut send: quinn::SendStream, - mut upload: u64, + upload: u64, download: u64, stream_stats: OpenStreamStats, ) -> Result<()> { - let upload_start = Instant::now(); + // FIXME select ? send.write_all(&download.to_be_bytes()).await?; if upload == 0 { send.finish().unwrap(); @@ -305,25 +416,121 @@ async fn request( let send_stream_stats = stream_stats.new_sender(&send, upload); - static DATA: [u8; 1024 * 1024] = [42; 1024 * 1024]; - while upload > 0 { - let chunk_len = upload.min(DATA.len() as u64); - send.write_chunk(Bytes::from_static(&DATA[..chunk_len as usize])) - .await - .context("sending response")?; - send_stream_stats.on_bytes(chunk_len as usize); - upload -= chunk_len; + static DATA: [u8; 1024 * 1024] = [42; 1024 * 1024]; // 1MB of data + let unrolled = true; + + let mut remaining = upload; + let upload_start = Instant::now(); + while remaining > 0 { + if unrolled { + #[rustfmt::skip] + let mut data_chunks = [ // 32 MB of data + Bytes::from_static(&DATA), Bytes::from_static(&DATA), + Bytes::from_static(&DATA), Bytes::from_static(&DATA), + Bytes::from_static(&DATA), Bytes::from_static(&DATA), + Bytes::from_static(&DATA), Bytes::from_static(&DATA), + Bytes::from_static(&DATA), Bytes::from_static(&DATA), + Bytes::from_static(&DATA), Bytes::from_static(&DATA), + Bytes::from_static(&DATA), Bytes::from_static(&DATA), + Bytes::from_static(&DATA), Bytes::from_static(&DATA), + Bytes::from_static(&DATA), Bytes::from_static(&DATA), + Bytes::from_static(&DATA), Bytes::from_static(&DATA), + Bytes::from_static(&DATA), Bytes::from_static(&DATA), + Bytes::from_static(&DATA), Bytes::from_static(&DATA), + Bytes::from_static(&DATA), Bytes::from_static(&DATA), + Bytes::from_static(&DATA), Bytes::from_static(&DATA), + Bytes::from_static(&DATA), Bytes::from_static(&DATA), + Bytes::from_static(&DATA), Bytes::from_static(&DATA), + ]; + let one_data_chunks_len = data_chunks[0].len() as u64; + let all_data_chunks_len = data_chunks[..data_chunks.len()] + .iter() + .map(|b| b.len() as u64) + .sum::(); + + if remaining > all_data_chunks_len { + // send all chunks at the same time + tokio::select! { + biased; + _ = shutdown.cancelled() => { + break; + }, + res = send.write_chunks(&mut data_chunks) => { + let res = res.context("sending all chunks")?; + + info!("sent {} chunks for {} bytes remaining {remaining}", res.chunks, res.bytes); + + send_stream_stats.on_bytes(res.bytes); + remaining -= res.bytes as u64; + } + } + } else if remaining <= one_data_chunks_len { + // manually send remaining data + let chunk_len = remaining.min(DATA.len() as u64); + + tokio::select! { + biased; + _ = shutdown.cancelled() => { + break; + }, + res = send.write_chunk(Bytes::from_static(&DATA[..chunk_len as usize])) => { + res.context("sending response")?; + + info!("sent {chunk_len} bytes remaining {remaining}"); + + send_stream_stats.on_bytes(chunk_len as usize); + remaining -= chunk_len; + } + } + } else { + // send a bunch of chunks but not all + let chunk_count = remaining / one_data_chunks_len; + tokio::select! { + biased; + _ = shutdown.cancelled() => { + break; + }, + res = send.write_chunks(&mut data_chunks[..chunk_count as usize]) => { + let res = res.context("sending some chunks")?; + + info!("sent {} chunks for {} bytes remaining {remaining}", res.chunks, res.bytes); + + send_stream_stats.on_bytes(res.bytes); + remaining -= res.bytes as u64; + } + } + } + } else { + let chunk_len = remaining.min(DATA.len() as u64); + + tokio::select! { + biased; + _ = shutdown.cancelled() => { + break; + }, + res = send.write_chunk(Bytes::from_static(&DATA[..chunk_len as usize])) => { + res.context("sending response")?; + + send_stream_stats.on_bytes(chunk_len as usize); + remaining -= chunk_len; + } + } + } } + send.finish().unwrap(); // Wait for stream to close - _ = send.stopped().await; - send_stream_stats.finish(upload_start.elapsed()); + let _ = send.stopped().await; + + let elapsed = upload_start.elapsed(); + send_stream_stats.finish(elapsed, upload - remaining); debug!("upload finished on {}", send.id()); Ok(()) } async fn drive_bi( + shutdown: CancellationToken, connection: quinn::Connection, stream_stats: OpenStreamStats, concurrency: u64, @@ -337,13 +544,21 @@ async fn drive_bi( let sem = Arc::new(Semaphore::new(concurrency as usize)); loop { + if shutdown.is_cancelled() { + debug!("drive_bi: leaving"); + return Ok(()); + } + let permit = sem.clone().acquire_owned().await.unwrap(); let (send, recv) = connection.open_bi().await?; let stream_stats = stream_stats.clone(); debug!("sending request on {}", send.id()); + let shutdown2 = shutdown.clone(); + // FIXME store handle and wait for everyone to get cancelled before leaving the function tokio::spawn(async move { - if let Err(e) = request_bi(send, recv, upload, download, stream_stats).await { + if let Err(e) = request_bi(shutdown2, send, recv, upload, download, stream_stats).await + { error!("request failed: {:#}", e); } @@ -353,14 +568,22 @@ async fn drive_bi( } async fn request_bi( + shutdown: CancellationToken, send: quinn::SendStream, recv: quinn::RecvStream, upload: u64, download: u64, stream_stats: OpenStreamStats, ) -> Result<()> { - request(send, upload, download, stream_stats.clone()).await?; - drain_stream(recv, download, stream_stats).await?; + request( + shutdown.clone(), + send, + upload, + download, + stream_stats.clone(), + ) + .await?; + drain_stream(shutdown, recv, download, stream_stats).await?; Ok(()) } diff --git a/perf/src/stats.rs b/perf/src/stats.rs index e042cbfb15..041be4a69a 100644 --- a/perf/src/stats.rs +++ b/perf/src/stats.rs @@ -61,7 +61,10 @@ impl Stats { fn record(&mut self, stream_stats: Arc) { if stream_stats.finished.load(Ordering::SeqCst) { let duration = stream_stats.duration.load(Ordering::SeqCst); - let bps = throughput_bytes_per_second(duration, stream_stats.request_size); + let bps = throughput_bytes_per_second( + duration, + stream_stats.request_size.load(Ordering::SeqCst), + ); if stream_stats.sender { self.upload_throughput.record(bps as u64).unwrap(); @@ -139,7 +142,7 @@ impl OpenStreamStats { pub fn new_sender(&self, stream: &quinn::SendStream, upload_size: u64) -> Arc { let send_stream_stats = StreamStats { id: stream.id(), - request_size: upload_size, + request_size: AtomicU64::new(upload_size), bytes: Default::default(), sender: true, finished: Default::default(), @@ -154,7 +157,7 @@ impl OpenStreamStats { pub fn new_receiver(&self, stream: &quinn::RecvStream, download_size: u64) -> Arc { let recv_stream_stats = StreamStats { id: stream.id(), - request_size: download_size, + request_size: AtomicU64::new(download_size), bytes: Default::default(), sender: false, finished: Default::default(), @@ -173,7 +176,7 @@ impl OpenStreamStats { pub struct StreamStats { id: StreamId, - request_size: u64, + request_size: AtomicU64, bytes: AtomicUsize, sender: bool, finished: AtomicBool, @@ -191,7 +194,9 @@ impl StreamStats { self.bytes.fetch_add(bytes, Ordering::SeqCst); } - pub fn finish(&self, duration: Duration) { + pub fn finish(&self, duration: Duration, real_size: u64) { + // correct request size with what was really uploaded/downloaded + self.request_size.store(real_size, Ordering::SeqCst); self.duration .store(duration.as_micros() as u64, Ordering::SeqCst); self.finished.store(true, Ordering::SeqCst);