2323import com .rabbitmq .client .Address ;
2424import com .rabbitmq .client .ConnectionFactory ;
2525import com .rabbitmq .client .MalformedFrameException ;
26+ import com .rabbitmq .client .ShutdownSignalException ;
2627import com .rabbitmq .client .SocketConfigurator ;
2728import io .netty .bootstrap .Bootstrap ;
2829import io .netty .buffer .ByteBuf ;
5152import java .net .InetSocketAddress ;
5253import java .net .SocketAddress ;
5354import java .time .Duration ;
54- import java .util .Queue ;
55- import java .util .concurrent .ConcurrentLinkedQueue ;
5655import java .util .concurrent .CountDownLatch ;
5756import java .util .concurrent .ExecutionException ;
5857import java .util .concurrent .TimeUnit ;
5958import java .util .concurrent .TimeoutException ;
6059import java .util .concurrent .atomic .AtomicBoolean ;
60+ import java .util .concurrent .atomic .AtomicInteger ;
6161import java .util .concurrent .atomic .AtomicReference ;
6262import java .util .function .Consumer ;
6363import java .util .function .Function ;
64+ import java .util .function .Predicate ;
6465import javax .net .ssl .SSLHandshakeException ;
6566import org .slf4j .Logger ;
6667import org .slf4j .LoggerFactory ;
@@ -73,6 +74,7 @@ public final class NettyFrameHandlerFactory extends AbstractFrameHandlerFactory
7374 private final Consumer <Channel > channelCustomizer ;
7475 private final Consumer <Bootstrap > bootstrapCustomizer ;
7576 private final Duration enqueuingTimeout ;
77+ private final Predicate <ShutdownSignalException > willRecover ;
7678
7779 public NettyFrameHandlerFactory (
7880 EventLoopGroup eventLoopGroup ,
@@ -82,14 +84,30 @@ public NettyFrameHandlerFactory(
8284 Duration enqueuingTimeout ,
8385 int connectionTimeout ,
8486 SocketConfigurator configurator ,
85- int maxInboundMessageBodySize ) {
87+ int maxInboundMessageBodySize ,
88+ boolean automaticRecovery ,
89+ Predicate <ShutdownSignalException > recoveryCondition ) {
8690 super (connectionTimeout , configurator , sslContextFactory != null , maxInboundMessageBodySize );
8791 this .eventLoopGroup = eventLoopGroup ;
8892 this .sslContextFactory = sslContextFactory == null ? connName -> null : sslContextFactory ;
8993 this .channelCustomizer = channelCustomizer == null ? Utils .noOpConsumer () : channelCustomizer ;
9094 this .bootstrapCustomizer =
9195 bootstrapCustomizer == null ? Utils .noOpConsumer () : bootstrapCustomizer ;
9296 this .enqueuingTimeout = enqueuingTimeout ;
97+ this .willRecover =
98+ sse -> {
99+ if (!automaticRecovery ) {
100+ return false ;
101+ } else {
102+ try {
103+ return recoveryCondition .test (sse );
104+ } catch (Exception e ) {
105+ // we assume it will recover, so we take the safe path to dispatch the closing
106+ // it avoids the risk of deadlock
107+ return true ;
108+ }
109+ }
110+ };
93111 }
94112
95113 private static void closeNettyState (Channel channel , EventLoopGroup eventLoopGroup ) {
@@ -133,6 +151,7 @@ public FrameHandler create(Address addr, String connectionName) throws IOExcepti
133151 sslContext ,
134152 this .eventLoopGroup ,
135153 this .enqueuingTimeout ,
154+ this .willRecover ,
136155 this .channelCustomizer ,
137156 this .bootstrapCustomizer );
138157 }
@@ -163,6 +182,7 @@ private NettyFrameHandler(
163182 SslContext sslContext ,
164183 EventLoopGroup elg ,
165184 Duration enqueuingTimeout ,
185+ Predicate <ShutdownSignalException > willRecover ,
166186 Consumer <Channel > channelCustomizer ,
167187 Consumer <Bootstrap > bootstrapCustomizer )
168188 throws IOException {
@@ -180,6 +200,14 @@ private NettyFrameHandler(
180200 } else {
181201 this .eventLoopGroup = null ;
182202 }
203+
204+ if (b .config ().group () == null ) {
205+ throw new IllegalStateException ("The event loop group is not set" );
206+ } else if (b .config ().group ().isShuttingDown ()) {
207+ LOGGER .warn ("The Netty loop group was shut down, it is not possible to connect or recover" );
208+ throw new IllegalStateException ("The event loop group was shut down" );
209+ }
210+
183211 if (b .config ().channelFactory () == null ) {
184212 b .channel (NioSocketChannel .class );
185213 }
@@ -195,7 +223,8 @@ private NettyFrameHandler(
195223 int lengthFieldOffset = 3 ;
196224 int lengthFieldLength = 4 ;
197225 int lengthAdjustement = 1 ;
198- AmqpHandler amqpHandler = new AmqpHandler (maxInboundMessageBodySize , this ::close );
226+ AmqpHandler amqpHandler =
227+ new AmqpHandler (maxInboundMessageBodySize , this ::close , willRecover );
199228 int port = ConnectionFactory .portOrDefault (addr .getPort (), sslContext != null );
200229 b .handler (
201230 new ChannelInitializer <SocketChannel >() {
@@ -296,6 +325,10 @@ public void sendHeader() {
296325
297326 @ Override
298327 public void initialize (AMQConnection connection ) {
328+ LOGGER .debug (
329+ "Setting connection {} to AMQP handler {}" ,
330+ connection .getClientProvidedName (),
331+ this .handler .id );
299332 this .handler .connection = connection ;
300333 }
301334
@@ -333,7 +366,6 @@ public void writeFrame(Frame frame) throws IOException {
333366 if (canWriteNow ) {
334367 this .doWriteFrame (frame );
335368 } else {
336- this .handler .logEvents ();
337369 throw new IOException ("Frame enqueuing failed" );
338370 }
339371 } catch (InterruptedException e ) {
@@ -404,14 +436,30 @@ private static class AmqpHandler extends ChannelInboundHandlerAdapter {
404436
405437 private final int maxPayloadSize ;
406438 private final Runnable closeSequence ;
439+ private final Predicate <ShutdownSignalException > willRecover ;
407440 private volatile AMQConnection connection ;
441+ private volatile Channel ch ;
408442 private final AtomicBoolean writable = new AtomicBoolean (true );
409443 private final AtomicReference <CountDownLatch > writableLatch =
410444 new AtomicReference <>(new CountDownLatch (1 ));
411-
412- private AmqpHandler (int maxPayloadSize , Runnable closeSequence ) {
445+ private final AtomicBoolean shutdownDispatched = new AtomicBoolean (false );
446+ private static final AtomicInteger SEQUENCE = new AtomicInteger (0 );
447+ private final String id ;
448+
449+ private AmqpHandler (
450+ int maxPayloadSize ,
451+ Runnable closeSequence ,
452+ Predicate <ShutdownSignalException > willRecover ) {
413453 this .maxPayloadSize = maxPayloadSize ;
414454 this .closeSequence = closeSequence ;
455+ this .willRecover = willRecover ;
456+ this .id = "amqp-handler-" + SEQUENCE .getAndIncrement ();
457+ }
458+
459+ @ Override
460+ public void channelActive (ChannelHandlerContext ctx ) throws Exception {
461+ this .ch = ctx .channel ();
462+ super .channelActive (ctx );
415463 }
416464
417465 @ Override
@@ -444,49 +492,16 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
444492 if (noProblem
445493 && (!this .connection .isRunning () || this .connection .hasBrokerInitiatedShutdown ())) {
446494 // looks like the frame was Close-Ok or Close
447- ctx . executor (). submit (() -> this .connection .doFinalShutdown ());
495+ this . dispatchShutdownToConnection (() -> this .connection .doFinalShutdown ());
448496 }
449497 } finally {
450498 m .release ();
451499 }
452500 }
453501
454- private static class Event {
455- private final long time ;
456- private final String label ;
457-
458- public Event (long time , String label ) {
459- this .time = time ;
460- this .label = label ;
461- }
462-
463- @ Override
464- public String toString () {
465- return this .label + " " + this .time ;
466- }
467- }
468-
469- private static final int MAX_EVENTS = 100 ;
470- private final Queue <Event > events = new ConcurrentLinkedQueue <>();
471-
472- private void logEvents () {
473- if (this .events .size () > 0 ) {
474- long start = this .events .peek ().time ;
475- LOGGER .info ("channel writability history:" );
476- events .forEach (e -> LOGGER .info ("{}: {}" , (e .time - start ) / 1_000_000 , e .label ));
477- }
478- }
479-
480502 @ Override
481503 public void channelWritabilityChanged (ChannelHandlerContext ctx ) throws Exception {
482504 boolean canWrite = ctx .channel ().isWritable ();
483- Event event = new Event (System .nanoTime (), Boolean .toString (canWrite ));
484- if (this .events .size () >= MAX_EVENTS ) {
485- this .events .poll ();
486- this .events .offer (event );
487- }
488- this .events .add (event );
489-
490505 if (this .writable .compareAndSet (!canWrite , canWrite )) {
491506 if (canWrite ) {
492507 CountDownLatch latch = writableLatch .getAndSet (new CountDownLatch (1 ));
@@ -502,12 +517,13 @@ public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exceptio
502517 public void channelInactive (ChannelHandlerContext ctx ) {
503518 if (needToDispatchIoError ()) {
504519 AMQConnection c = this .connection ;
520+ LOGGER .debug ("Dispatching shutdown when channel became inactive ({})" , this .id );
505521 if (c .isOpen ()) {
506522 // it is likely to be an IO exception
507- c .handleIoError (null );
523+ this . dispatchShutdownToConnection (() -> c .handleIoError (null ) );
508524 } else {
509525 // just in case, the call is idempotent anyway
510- c . doFinalShutdown ( );
526+ this . dispatchShutdownToConnection ( c :: doFinalShutdown );
511527 }
512528 }
513529 }
@@ -533,7 +549,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc
533549 this .connection .getAddress ().getHostName (),
534550 this .connection .getPort ());
535551 if (needToDispatchIoError ()) {
536- this .connection .handleHeartbeatFailure ();
552+ this .dispatchShutdownToConnection (() -> this . connection .handleHeartbeatFailure () );
537553 }
538554 } else if (e .state () == IdleState .WRITER_IDLE ) {
539555 this .connection .writeFrame (new Frame (AMQP .FRAME_HEARTBEAT , 0 ));
@@ -545,7 +561,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc
545561
546562 private void handleIoError (Throwable cause ) {
547563 if (needToDispatchIoError ()) {
548- this .connection .handleIoError (cause );
564+ this .dispatchShutdownToConnection (() -> this . connection .handleIoError (cause ) );
549565 } else {
550566 this .closeSequence .run ();
551567 }
@@ -563,6 +579,32 @@ private boolean isWritable() {
563579 private CountDownLatch writableLatch () {
564580 return this .writableLatch .get ();
565581 }
582+
583+ protected void dispatchShutdownToConnection (Runnable connectionShutdownRunnable ) {
584+ if (this .shutdownDispatched .compareAndSet (false , true )) {
585+ String name = "rabbitmq-connection-shutdown-" + this .id ;
586+ AMQConnection c = this .connection ;
587+ if (c == null || ch == null ) {
588+ // not enough information, we dispatch in separate thread
589+ Environment .newThread (connectionShutdownRunnable , name ).start ();
590+ } else {
591+ if (ch .eventLoop ().inEventLoop ()) {
592+ if (this .willRecover .test (c .getCloseReason ()) || ch .eventLoop ().isShuttingDown ()) {
593+ // the connection will recover, we don't want this to happen in the event loop,
594+ // it could cause a deadlock, so using a separate thread
595+ // name = name + "-" + c;
596+ Environment .newThread (connectionShutdownRunnable , name ).start ();
597+ } else {
598+ // no recovery, it is safe to dispatch in the event loop
599+ ch .eventLoop ().submit (connectionShutdownRunnable );
600+ }
601+ } else {
602+ // not in the event loop, we can run it in the same thread
603+ connectionShutdownRunnable .run ();
604+ }
605+ }
606+ }
607+ }
566608 }
567609
568610 private static final class ProtocolVersionMismatchHandler extends ChannelInboundHandlerAdapter {
0 commit comments