@@ -16,7 +16,10 @@ use rumqttc::EventLoop;
1616use rumqttc:: Incoming ;
1717use rumqttc:: Outgoing ;
1818use rumqttc:: Packet ;
19+ use std:: sync:: Arc ;
1920use std:: time:: Duration ;
21+ use tokio:: sync:: OwnedSemaphorePermit ;
22+ use tokio:: sync:: Semaphore ;
2023use tokio:: time:: sleep;
2124
2225/// A connection to some MQTT server
@@ -88,19 +91,23 @@ impl Connection {
8891
8992 let ( mqtt_client, event_loop) =
9093 Connection :: open ( config, received_sender. clone ( ) , error_sender. clone ( ) ) . await ?;
94+ let permits = Arc :: new ( Semaphore :: new ( 1 ) ) ;
95+ let permit = permits. clone ( ) . acquire_owned ( ) . await . unwrap ( ) ;
9196 tokio:: spawn ( Connection :: receiver_loop (
9297 mqtt_client. clone ( ) ,
9398 config. clone ( ) ,
9499 event_loop,
95100 received_sender,
96101 error_sender. clone ( ) ,
102+ pub_done_sender,
103+ permits,
97104 ) ) ;
98105 tokio:: spawn ( Connection :: sender_loop (
99106 mqtt_client,
100107 published_receiver,
101108 error_sender,
102109 config. last_will_message . clone ( ) ,
103- pub_done_sender ,
110+ permit ,
104111 ) ) ;
105112
106113 Ok ( Connection {
@@ -200,9 +207,41 @@ impl Connection {
200207 mut event_loop : EventLoop ,
201208 mut message_sender : mpsc:: UnboundedSender < MqttMessage > ,
202209 mut error_sender : mpsc:: UnboundedSender < MqttError > ,
210+ done : oneshot:: Sender < ( ) > ,
211+ permits : Arc < Semaphore > ,
203212 ) -> Result < ( ) , MqttError > {
213+ let mut triggered_disconnect = false ;
214+ let mut disconnect_permit = None ;
215+
204216 loop {
205- match event_loop. poll ( ) . await {
217+ // Check if we are ready to disconnect. Due to ownership of the
218+ // event loop, this needs to be done before we call
219+ // `event_loop.poll()`
220+ let remaining_events_empty = event_loop. state . inflight ( ) == 0 ;
221+ if disconnect_permit. is_some ( ) && !triggered_disconnect && remaining_events_empty {
222+ // `sender_loop` is not running and we have no remaining
223+ // publishes to process
224+ let client = mqtt_client. clone ( ) ;
225+ tokio:: spawn ( async move { client. disconnect ( ) . await } ) ;
226+ triggered_disconnect = true ;
227+ }
228+
229+ let event = tokio:: select! {
230+ // If there is an event, we need to process that first
231+ // Otherwise we risk shutting down early
232+ // e.g. a `Publish` request from the sender is not "inflight"
233+ // but will immediately be returned by `event_loop.poll()`
234+ biased;
235+
236+ event = event_loop. poll( ) => event,
237+ permit = permits. clone( ) . acquire_owned( ) => {
238+ // The `sender_loop` has now concluded
239+ disconnect_permit = Some ( permit. unwrap( ) ) ;
240+ continue ;
241+ }
242+ } ;
243+
244+ match event {
206245 Ok ( Event :: Incoming ( Packet :: Publish ( msg) ) ) => {
207246 if msg. payload . len ( ) > config. max_packet_size {
208247 error ! ( "Dropping message received on topic {} with payload size {} that exceeds the maximum packet size of {}" ,
@@ -266,6 +305,7 @@ impl Connection {
266305 // No more messages will be forwarded to the client
267306 let _ = message_sender. close ( ) . await ;
268307 let _ = error_sender. close ( ) . await ;
308+ let _ = done. send ( ( ) ) ;
269309 Ok ( ( ) )
270310 }
271311
@@ -274,24 +314,15 @@ impl Connection {
274314 mut messages_receiver : mpsc:: UnboundedReceiver < MqttMessage > ,
275315 mut error_sender : mpsc:: UnboundedSender < MqttError > ,
276316 last_will : Option < MqttMessage > ,
277- done : oneshot :: Sender < ( ) > ,
317+ _disconnect_permit : OwnedSemaphorePermit ,
278318 ) {
279- loop {
280- match messages_receiver. next ( ) . await {
281- None => {
282- // The sender channel has been closed by the client
283- // No more messages will be published by the client
284- break ;
285- }
286- Some ( message) => {
287- let payload = Vec :: from ( message. payload_bytes ( ) ) ;
288- if let Err ( err) = mqtt_client
289- . publish ( message. topic , message. qos , message. retain , payload)
290- . await
291- {
292- let _ = error_sender. send ( err. into ( ) ) . await ;
293- }
294- }
319+ while let Some ( message) = messages_receiver. next ( ) . await {
320+ let payload = Vec :: from ( message. payload_bytes ( ) ) ;
321+ if let Err ( err) = mqtt_client
322+ . publish ( message. topic , message. qos , message. retain , payload)
323+ . await
324+ {
325+ let _ = error_sender. send ( err. into ( ) ) . await ;
295326 }
296327 }
297328
@@ -303,8 +334,9 @@ impl Connection {
303334 . publish ( last_will. topic , last_will. qos , last_will. retain , payload)
304335 . await ;
305336 }
306- let _ = mqtt_client. disconnect ( ) . await ;
307- let _ = done. send ( ( ) ) ;
337+
338+ // At this point, `_disconnect_permit` is dropped
339+ // This allows `receiver_loop` acquire a permit and commence the shutdown process
308340 }
309341
310342 pub ( crate ) async fn do_pause ( ) {
0 commit comments