Skip to content

Commit f1fe183

Browse files
Frandogretchenfrage
authored andcommitted
feat: Make the future returned from SendStream::stopped 'static
Changes the implementation of `SendStream::stopped` such that the returned future is static and no longer lifetime-bound onto a mutable reference to the send stream. This allows to use the stopped future with combinators or in a separate task while still sending on the stream concurrently. Internally, this is done changing the implementation of the stopped notification to use a cloneable tokio::sync::Notify instead of storing a single waker.
1 parent 9f008ad commit f1fe183

File tree

3 files changed

+148
-40
lines changed

3 files changed

+148
-40
lines changed

quinn/src/connection.rs

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,7 @@ pub(crate) struct State {
963963
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
964964
pub(crate) blocked_writers: FxHashMap<StreamId, Waker>,
965965
pub(crate) blocked_readers: FxHashMap<StreamId, Waker>,
966-
pub(crate) stopped: FxHashMap<StreamId, Waker>,
966+
pub(crate) stopped: FxHashMap<StreamId, Arc<Notify>>,
967967
/// Always set to Some before the connection becomes drained
968968
pub(crate) error: Option<ConnectionError>,
969969
/// Number of live handles that can be used to initiate or handle I/O; excludes the driver
@@ -1105,7 +1105,7 @@ impl State {
11051105
// `ZeroRttRejected` errors.
11061106
wake_all(&mut self.blocked_writers);
11071107
wake_all(&mut self.blocked_readers);
1108-
wake_all(&mut self.stopped);
1108+
wake_all_notify(&mut self.stopped);
11091109
}
11101110
}
11111111
ConnectionLost { reason } => {
@@ -1129,9 +1129,9 @@ impl State {
11291129
// Might mean any number of streams are ready, so we wake up everyone
11301130
shared.stream_budget_available[dir as usize].notify_waiters();
11311131
}
1132-
Stream(StreamEvent::Finished { id }) => wake_stream(id, &mut self.stopped),
1132+
Stream(StreamEvent::Finished { id }) => wake_stream_notify(id, &mut self.stopped),
11331133
Stream(StreamEvent::Stopped { id, .. }) => {
1134-
wake_stream(id, &mut self.stopped);
1134+
wake_stream_notify(id, &mut self.stopped);
11351135
wake_stream(id, &mut self.blocked_writers);
11361136
}
11371137
}
@@ -1212,7 +1212,7 @@ impl State {
12121212
if let Some(x) = self.on_connected.take() {
12131213
let _ = x.send(false);
12141214
}
1215-
wake_all(&mut self.stopped);
1215+
wake_all_notify(&mut self.stopped);
12161216
shared.closed.notify_waiters();
12171217
}
12181218

@@ -1266,6 +1266,18 @@ fn wake_all(wakers: &mut FxHashMap<StreamId, Waker>) {
12661266
wakers.drain().for_each(|(_, waker)| waker.wake())
12671267
}
12681268

1269+
fn wake_stream_notify(stream_id: StreamId, wakers: &mut FxHashMap<StreamId, Arc<Notify>>) {
1270+
if let Some(notify) = wakers.remove(&stream_id) {
1271+
notify.notify_waiters()
1272+
}
1273+
}
1274+
1275+
fn wake_all_notify(wakers: &mut FxHashMap<StreamId, Arc<Notify>>) {
1276+
wakers
1277+
.drain()
1278+
.for_each(|(_, notify)| notify.notify_waiters())
1279+
}
1280+
12691281
/// Errors that can arise when sending a datagram
12701282
#[derive(Debug, Error, Clone, Eq, PartialEq)]
12711283
pub enum SendDatagramError {

quinn/src/send_stream.rs

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ use bytes::Bytes;
99
use proto::{ClosedStream, ConnectionError, FinishError, StreamId, Written};
1010
use thiserror::Error;
1111

12-
use crate::{VarInt, connection::ConnectionRef};
12+
use crate::{
13+
VarInt,
14+
connection::{ConnectionRef, State},
15+
};
1316

1417
/// A stream that can only be used to send data
1518
///
@@ -199,27 +202,31 @@ impl SendStream {
199202
/// For a variety of reasons, the peer may not send acknowledgements immediately upon receiving
200203
/// data. As such, relying on `stopped` to know when the peer has read a stream to completion
201204
/// may introduce more latency than using an application-level response of some sort.
202-
pub async fn stopped(&mut self) -> Result<Option<VarInt>, StoppedError> {
203-
Stopped { stream: self }.await
204-
}
205-
206-
fn poll_stopped(&mut self, cx: &mut Context) -> Poll<Result<Option<VarInt>, StoppedError>> {
207-
let mut conn = self.conn.state.lock("SendStream::poll_stopped");
208-
209-
if self.is_0rtt {
210-
conn.check_0rtt()
211-
.map_err(|()| StoppedError::ZeroRttRejected)?;
212-
}
213-
214-
match conn.inner.send_stream(self.stream).stopped() {
215-
Err(_) => Poll::Ready(Ok(None)),
216-
Ok(Some(error_code)) => Poll::Ready(Ok(Some(error_code))),
217-
Ok(None) => {
218-
if let Some(e) = &conn.error {
219-
return Poll::Ready(Err(e.clone().into()));
205+
pub fn stopped(
206+
&self,
207+
) -> impl Future<Output = Result<Option<VarInt>, StoppedError>> + Send + Sync + 'static {
208+
let conn = self.conn.clone();
209+
let stream = self.stream;
210+
let is_0rtt = self.is_0rtt;
211+
async move {
212+
loop {
213+
// The `Notify::notified` future needs to be created while the lock is being held,
214+
// otherwise a wakeup could be missed if triggered inbetween releasing the lock
215+
// and creating the future.
216+
// The lock may only be held in a block without `await`s, otherwise the future
217+
// becomes `!Send`. `Notify::notified` is lifetime-bound to `Notify`, therefore
218+
// we need to declare `notify` outside of the block, and initialize it inside.
219+
let notify;
220+
{
221+
let mut conn = conn.state.lock("SendStream::stopped");
222+
if let Some(output) = send_stream_stopped(&mut conn, stream, is_0rtt) {
223+
return output;
224+
}
225+
226+
notify = conn.stopped.entry(stream).or_default().clone();
227+
notify.notified()
220228
}
221-
conn.stopped.insert(self.stream, cx.waker().clone());
222-
Poll::Pending
229+
.await
223230
}
224231
}
225232
}
@@ -245,6 +252,25 @@ impl SendStream {
245252
}
246253
}
247254

255+
/// Check if a send stream is stopped.
256+
///
257+
/// Returns `Some` if the stream is stopped or the connection is closed.
258+
/// Returns `None` if the stream is not stopped.
259+
fn send_stream_stopped(
260+
conn: &mut State,
261+
stream: StreamId,
262+
is_0rtt: bool,
263+
) -> Option<Result<Option<VarInt>, StoppedError>> {
264+
if is_0rtt && conn.check_0rtt().is_err() {
265+
return Some(Err(StoppedError::ZeroRttRejected));
266+
}
267+
match conn.inner.send_stream(stream).stopped() {
268+
Err(ClosedStream { .. }) => Some(Ok(None)),
269+
Ok(Some(error_code)) => Some(Ok(Some(error_code))),
270+
Ok(None) => conn.error.clone().map(|error| Err(error.into())),
271+
}
272+
}
273+
248274
#[cfg(feature = "futures-io")]
249275
impl futures_io::AsyncWrite for SendStream {
250276
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
@@ -283,7 +309,6 @@ impl Drop for SendStream {
283309
let mut conn = self.conn.state.lock("SendStream::drop");
284310

285311
// clean up any previously registered wakers
286-
conn.stopped.remove(&self.stream);
287312
conn.blocked_writers.remove(&self.stream);
288313

289314
if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
@@ -302,19 +327,6 @@ impl Drop for SendStream {
302327
}
303328
}
304329

305-
/// Future produced by `SendStream::stopped`
306-
struct Stopped<'a> {
307-
stream: &'a mut SendStream,
308-
}
309-
310-
impl Future for Stopped<'_> {
311-
type Output = Result<Option<VarInt>, StoppedError>;
312-
313-
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
314-
self.get_mut().stream.poll_stopped(cx)
315-
}
316-
}
317-
318330
/// Errors that arise from writing to a stream
319331
#[derive(Debug, Error, Clone, PartialEq, Eq)]
320332
pub enum WriteError {

quinn/src/tests.rs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,3 +861,87 @@ async fn multiple_conns_with_zero_length_cids() {
861861
.instrument(error_span!("server"));
862862
tokio::join!(client1, client2, server);
863863
}
864+
865+
#[tokio::test]
866+
async fn stream_stopped() {
867+
let _guard = subscribe();
868+
let factory = EndpointFactory::new();
869+
let server = {
870+
let _guard = error_span!("server").entered();
871+
factory.endpoint()
872+
};
873+
let server_addr = server.local_addr().unwrap();
874+
875+
let client = {
876+
let _guard = error_span!("client1").entered();
877+
factory.endpoint()
878+
};
879+
880+
let client = async move {
881+
let conn = client
882+
.connect(server_addr, "localhost")
883+
.unwrap()
884+
.await
885+
.unwrap();
886+
let mut stream = conn.open_uni().await.unwrap();
887+
let stopped1 = stream.stopped();
888+
let stopped2 = stream.stopped();
889+
let stopped3 = stream.stopped();
890+
891+
stream.write_all(b"hi").await.unwrap();
892+
// spawn one of the futures into a task
893+
let stopped1 = tokio::task::spawn(stopped1);
894+
// verify that both futures resolved
895+
let (stopped1, stopped2) = tokio::join!(stopped1, stopped2);
896+
assert!(matches!(stopped1, Ok(Ok(Some(val))) if val == 42u32.into()));
897+
assert!(matches!(stopped2, Ok(Some(val)) if val == 42u32.into()));
898+
// drop the stream
899+
drop(stream);
900+
// verify that a future also resolves after dropping the stream
901+
let stopped3 = stopped3.await;
902+
assert_eq!(stopped3, Ok(Some(42u32.into())));
903+
};
904+
let client =
905+
tokio::time::timeout(Duration::from_millis(100), client).instrument(error_span!("client"));
906+
let server = async move {
907+
let conn = server.accept().await.unwrap().await.unwrap();
908+
let mut stream = conn.accept_uni().await.unwrap();
909+
let mut buf = [0u8; 2];
910+
stream.read_exact(&mut buf).await.unwrap();
911+
stream.stop(42u32.into()).unwrap();
912+
conn
913+
}
914+
.instrument(error_span!("server"));
915+
let (client, conn) = tokio::join!(client, server);
916+
client.expect("timeout");
917+
drop(conn);
918+
}
919+
920+
#[tokio::test]
921+
async fn stream_stopped_2() {
922+
let _guard = subscribe();
923+
let endpoint = endpoint();
924+
925+
let (conn, _server_conn) = tokio::try_join!(
926+
endpoint
927+
.connect(endpoint.local_addr().unwrap(), "localhost")
928+
.unwrap(),
929+
async { endpoint.accept().await.unwrap().await }
930+
)
931+
.unwrap();
932+
let send_stream = conn.open_uni().await.unwrap();
933+
let stopped = tokio::time::timeout(Duration::from_millis(100), send_stream.stopped())
934+
.instrument(error_span!("stopped"));
935+
tokio::pin!(stopped);
936+
// poll the future once so that the waker is registered.
937+
tokio::select! {
938+
biased;
939+
_x = &mut stopped => {},
940+
_x = std::future::ready(()) => {}
941+
}
942+
// drop the send stream
943+
drop(send_stream);
944+
// make sure the stopped future still resolves
945+
let res = stopped.await;
946+
assert_eq!(res, Ok(Ok(None)));
947+
}

0 commit comments

Comments
 (0)