@@ -10,6 +10,7 @@ use std::{
10
10
} ;
11
11
12
12
use crate :: runtime:: { AsyncTimer , AsyncUdpSocket , Runtime } ;
13
+ use atomic_waker:: AtomicWaker ;
13
14
use bytes:: { Bytes , BytesMut } ;
14
15
use pin_project_lite:: pin_project;
15
16
use proto:: { ConnectionError , ConnectionHandle , ConnectionStats , Dir , StreamEvent , StreamId } ;
@@ -40,6 +41,7 @@ impl Connecting {
40
41
handle : ConnectionHandle ,
41
42
conn : proto:: Connection ,
42
43
endpoint_events : mpsc:: UnboundedSender < ( ConnectionHandle , EndpointEvent ) > ,
44
+ endpoint_driver : Arc < AtomicWaker > ,
43
45
conn_events : mpsc:: UnboundedReceiver < ConnectionEvent > ,
44
46
socket : Arc < dyn AsyncUdpSocket > ,
45
47
runtime : Arc < dyn Runtime > ,
@@ -50,6 +52,7 @@ impl Connecting {
50
52
handle,
51
53
conn,
52
54
endpoint_events,
55
+ endpoint_driver,
53
56
conn_events,
54
57
on_handshake_data_send,
55
58
on_connected_send,
@@ -233,7 +236,7 @@ impl Future for ConnectionDriver {
233
236
// If a timer expires, there might be more to transmit. When we transmit something, we
234
237
// might need to reset a timer. Hence, we must loop until neither happens.
235
238
keep_going |= conn. drive_timer ( cx) ;
236
- conn. forward_endpoint_events ( ) ;
239
+ conn. forward_endpoint_events ( & self . 0 . shared ) ;
237
240
conn. forward_app_events ( & self . 0 . shared ) ;
238
241
239
242
if !conn. inner . is_drained ( ) {
@@ -759,6 +762,7 @@ impl ConnectionRef {
759
762
handle : ConnectionHandle ,
760
763
conn : proto:: Connection ,
761
764
endpoint_events : mpsc:: UnboundedSender < ( ConnectionHandle , EndpointEvent ) > ,
765
+ endpoint_driver : Arc < AtomicWaker > ,
762
766
conn_events : mpsc:: UnboundedReceiver < ConnectionEvent > ,
763
767
on_handshake_data : oneshot:: Sender < ( ) > ,
764
768
on_connected : oneshot:: Sender < bool > ,
@@ -786,7 +790,13 @@ impl ConnectionRef {
786
790
socket,
787
791
runtime,
788
792
} ) ,
789
- shared : Shared :: default ( ) ,
793
+ shared : Shared {
794
+ endpoint_driver,
795
+ stream_budget_available : Default :: default ( ) ,
796
+ stream_incoming : Default :: default ( ) ,
797
+ datagrams : Default :: default ( ) ,
798
+ closed : Default :: default ( ) ,
799
+ } ,
790
800
} ) )
791
801
}
792
802
@@ -831,7 +841,7 @@ pub(crate) struct ConnectionInner {
831
841
pub ( crate ) shared : Shared ,
832
842
}
833
843
834
- #[ derive( Debug , Default ) ]
844
+ #[ derive( Debug ) ]
835
845
pub ( crate ) struct Shared {
836
846
/// Notified when new streams may be locally initiated due to an increase in stream ID flow
837
847
/// control budget
@@ -840,6 +850,7 @@ pub(crate) struct Shared {
840
850
stream_incoming : [ Notify ; 2 ] ,
841
851
datagrams : Notify ,
842
852
closed : Notify ,
853
+ endpoint_driver : Arc < AtomicWaker > ,
843
854
}
844
855
845
856
pub ( crate ) struct State {
@@ -898,18 +909,17 @@ impl State {
898
909
false
899
910
}
900
911
901
- fn forward_endpoint_events ( & mut self ) {
912
+ fn forward_endpoint_events ( & mut self , shared : & Shared ) {
902
913
if !self . inner . poll_endpoint_events ( ) {
903
914
return ;
904
915
}
905
- // If the endpoint driver is gone, noop.
906
- let _ = self . endpoint_events . send ( (
907
- self . handle ,
908
- match self . inner . is_drained ( ) {
909
- false => EndpointEvent :: Proto ,
910
- true => EndpointEvent :: Drained ,
911
- } ,
912
- ) ) ;
916
+ shared. endpoint_driver . wake ( ) ;
917
+ if self . inner . is_drained ( ) {
918
+ // If the endpoint driver is gone, noop.
919
+ let _ = self
920
+ . endpoint_events
921
+ . send ( ( self . handle , EndpointEvent :: Drained ) ) ;
922
+ }
913
923
}
914
924
915
925
/// If this returns `Err`, the endpoint is dead, so the driver should exit immediately.
0 commit comments