Skip to content

Commit 253eba9

Browse files
committed
Replace dummy message with direct wake to trigger endpoint events
1 parent 91bc3a4 commit 253eba9

File tree

3 files changed

+47
-20
lines changed

3 files changed

+47
-20
lines changed

quinn/src/connection.rs

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::{
1010
};
1111

1212
use crate::runtime::{AsyncTimer, AsyncUdpSocket, Runtime};
13+
use atomic_waker::AtomicWaker;
1314
use bytes::{Bytes, BytesMut};
1415
use pin_project_lite::pin_project;
1516
use proto::{ConnectionError, ConnectionHandle, ConnectionStats, Dir, StreamEvent, StreamId};
@@ -40,6 +41,7 @@ impl Connecting {
4041
handle: ConnectionHandle,
4142
conn: proto::Connection,
4243
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
44+
endpoint_driver: Arc<AtomicWaker>,
4345
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
4446
socket: Arc<dyn AsyncUdpSocket>,
4547
runtime: Arc<dyn Runtime>,
@@ -50,6 +52,7 @@ impl Connecting {
5052
handle,
5153
conn,
5254
endpoint_events,
55+
endpoint_driver,
5356
conn_events,
5457
on_handshake_data_send,
5558
on_connected_send,
@@ -233,7 +236,7 @@ impl Future for ConnectionDriver {
233236
// If a timer expires, there might be more to transmit. When we transmit something, we
234237
// might need to reset a timer. Hence, we must loop until neither happens.
235238
keep_going |= conn.drive_timer(cx);
236-
conn.forward_endpoint_events();
239+
conn.forward_endpoint_events(&self.0.shared);
237240
conn.forward_app_events(&self.0.shared);
238241

239242
if !conn.inner.is_drained() {
@@ -759,6 +762,7 @@ impl ConnectionRef {
759762
handle: ConnectionHandle,
760763
conn: proto::Connection,
761764
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
765+
endpoint_driver: Arc<AtomicWaker>,
762766
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
763767
on_handshake_data: oneshot::Sender<()>,
764768
on_connected: oneshot::Sender<bool>,
@@ -786,7 +790,13 @@ impl ConnectionRef {
786790
socket,
787791
runtime,
788792
}),
789-
shared: Shared::default(),
793+
shared: Shared {
794+
endpoint_driver,
795+
stream_budget_available: Default::default(),
796+
stream_incoming: Default::default(),
797+
datagrams: Default::default(),
798+
closed: Default::default(),
799+
},
790800
}))
791801
}
792802

@@ -831,7 +841,7 @@ pub(crate) struct ConnectionInner {
831841
pub(crate) shared: Shared,
832842
}
833843

834-
#[derive(Debug, Default)]
844+
#[derive(Debug)]
835845
pub(crate) struct Shared {
836846
/// Notified when new streams may be locally initiated due to an increase in stream ID flow
837847
/// control budget
@@ -840,6 +850,7 @@ pub(crate) struct Shared {
840850
stream_incoming: [Notify; 2],
841851
datagrams: Notify,
842852
closed: Notify,
853+
endpoint_driver: Arc<AtomicWaker>,
843854
}
844855

845856
pub(crate) struct State {
@@ -898,18 +909,17 @@ impl State {
898909
false
899910
}
900911

901-
fn forward_endpoint_events(&mut self) {
912+
fn forward_endpoint_events(&mut self, shared: &Shared) {
902913
if !self.inner.poll_endpoint_events() {
903914
return;
904915
}
905-
// If the endpoint driver is gone, noop.
906-
let _ = self.endpoint_events.send((
907-
self.handle,
908-
match self.inner.is_drained() {
909-
false => EndpointEvent::Proto,
910-
true => EndpointEvent::Drained,
911-
},
912-
));
916+
shared.endpoint_driver.wake();
917+
if self.inner.is_drained() {
918+
// If the endpoint driver is gone, noop.
919+
let _ = self
920+
.endpoint_events
921+
.send((self.handle, EndpointEvent::Drained));
922+
}
913923
}
914924

915925
/// If this returns `Err`, the endpoint is dead, so the driver should exit immediately.

quinn/src/endpoint.rs

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,13 @@ impl Endpoint {
198198
.connect(Instant::now(), config, addr, server_name)?;
199199

200200
let socket = endpoint.socket.clone();
201-
Ok(endpoint
202-
.connections
203-
.insert(ch, conn, socket, self.runtime.clone()))
201+
Ok(endpoint.connections.insert(
202+
ch,
203+
conn,
204+
socket,
205+
self.runtime.clone(),
206+
self.inner.shared.driver.clone(),
207+
))
204208
}
205209

206210
/// Switch to a new UDP socket
@@ -325,7 +329,7 @@ impl Future for EndpointDriver {
325329

326330
let now = Instant::now();
327331
let mut keep_going = false;
328-
keep_going |= endpoint.drive_recv(cx, now)?;
332+
keep_going |= endpoint.drive_recv(cx, now, &self.0.shared)?;
329333
keep_going |= endpoint.handle_events(cx, &self.0.shared);
330334
keep_going |= endpoint.drive_send(cx)?;
331335

@@ -393,7 +397,12 @@ pub(crate) struct Shared {
393397
}
394398

395399
impl State {
396-
fn drive_recv<'a>(&'a mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
400+
fn drive_recv<'a>(
401+
&'a mut self,
402+
cx: &mut Context,
403+
now: Instant,
404+
shared: &Shared,
405+
) -> Result<bool, io::Error> {
397406
self.recv_limiter.start_cycle();
398407
let mut metas = [RecvMeta::default(); BATCH_SIZE];
399408
let mut iovs = MaybeUninit::<[IoSliceMut<'a>; BATCH_SIZE]>::uninit();
@@ -431,6 +440,7 @@ impl State {
431440
conn,
432441
self.socket.clone(),
433442
self.runtime.clone(),
443+
shared.driver.clone(),
434444
);
435445
self.incoming.push_back(conn);
436446
}
@@ -530,7 +540,6 @@ impl State {
530540
for _ in 0..IO_LOOP_BOUND {
531541
match self.events.poll_recv(cx) {
532542
Poll::Ready(Some((ch, event))) => match event {
533-
Proto => {}
534543
Drained => {
535544
self.connections.senders.remove(&ch);
536545
if self.connections.is_empty() {
@@ -617,6 +626,7 @@ impl ConnectionSet {
617626
conn: proto::Connection,
618627
socket: Arc<dyn AsyncUdpSocket>,
619628
runtime: Arc<dyn Runtime>,
629+
driver: Arc<AtomicWaker>,
620630
) -> Connecting {
621631
let (send, recv) = mpsc::unbounded_channel();
622632
if let Some((error_code, ref reason)) = self.close {
@@ -627,7 +637,15 @@ impl ConnectionSet {
627637
.unwrap();
628638
}
629639
self.senders.insert(handle, send);
630-
Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime)
640+
Connecting::new(
641+
handle,
642+
conn,
643+
self.sender.clone(),
644+
driver,
645+
recv,
646+
socket,
647+
runtime,
648+
)
631649
}
632650

633651
fn is_empty(&self) -> bool {

quinn/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ enum ConnectionEvent {
9898

9999
#[derive(Debug)]
100100
enum EndpointEvent {
101-
Proto,
102101
Drained,
103102
Transmit(proto::Transmit, Bytes),
104103
}

0 commit comments

Comments
 (0)