Skip to content

Commit 8ff7ad7

Browse files
artembilangaryrussell
authored andcommitted
GH-3627: Fix race condition NPE in MqttPahoMDCA
Fixes #3627 The `destroy()`, and therefore `stop()` could be called from the `MqttConnectionFailedEvent` handling in the same thread resetting the `client` property to `null`. * Check for `this.client != null` in the next block of the `connectAndSubscribe()` to avoid NPE * Check for `isActive()` in the `scheduleReconnect()` to be sure do not reconnect if channel adapter has been stopped already **Cherry-pick to `5.4.x`**
1 parent 9f42650 commit 8ff7ad7

File tree

2 files changed

+80
-33
lines changed

2 files changed

+80
-33
lines changed

spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/MqttPahoMessageDrivenChannelAdapter.java

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ public class MqttPahoMessageDrivenChannelAdapter extends AbstractMqttMessageDriv
8484

8585
private boolean manualAcks;
8686

87+
private ApplicationEventPublisher applicationEventPublisher;
88+
8789
private volatile IMqttClient client;
8890

8991
private volatile ScheduledFuture<?> reconnectFuture;
@@ -94,8 +96,6 @@ public class MqttPahoMessageDrivenChannelAdapter extends AbstractMqttMessageDriv
9496

9597
private volatile ConsumerStopAction consumerStopAction;
9698

97-
private ApplicationEventPublisher applicationEventPublisher;
98-
9999
/**
100100
* Use this constructor for a single url (although it may be overridden if the server
101101
* URI(s) are provided by the {@link MqttConnectOptions#getServerURIs()} provided by
@@ -311,15 +311,17 @@ private synchronized void connectAndSubscribe() throws MqttException {
311311
this.applicationEventPublisher.publishEvent(new MqttConnectionFailedEvent(this, ex));
312312
}
313313
logger.error(ex, () -> "Error connecting or subscribing to " + Arrays.toString(topics));
314-
this.client.disconnectForcibly(this.disconnectCompletionTimeout);
315-
try {
316-
this.client.setCallback(null);
317-
this.client.close();
318-
}
319-
catch (MqttException e1) {
320-
// NOSONAR
314+
if (this.client != null) { // Could be reset during event handling before
315+
this.client.disconnectForcibly(this.disconnectCompletionTimeout);
316+
try {
317+
this.client.setCallback(null);
318+
this.client.close();
319+
}
320+
catch (MqttException e1) {
321+
// NOSONAR
322+
}
323+
this.client = null;
321324
}
322-
this.client = null;
323325
throw ex;
324326
}
325327
finally {
@@ -355,25 +357,27 @@ private synchronized void cancelReconnect() {
355357

356358
private synchronized void scheduleReconnect() {
357359
cancelReconnect();
358-
try {
359-
this.reconnectFuture = getTaskScheduler().schedule(() -> {
360-
try {
361-
logger.debug("Attempting reconnect");
362-
synchronized (MqttPahoMessageDrivenChannelAdapter.this) {
363-
if (!MqttPahoMessageDrivenChannelAdapter.this.connected) {
364-
connectAndSubscribe();
365-
MqttPahoMessageDrivenChannelAdapter.this.reconnectFuture = null;
360+
if (isActive()) {
361+
try {
362+
this.reconnectFuture = getTaskScheduler().schedule(() -> {
363+
try {
364+
logger.debug("Attempting reconnect");
365+
synchronized (MqttPahoMessageDrivenChannelAdapter.this) {
366+
if (!MqttPahoMessageDrivenChannelAdapter.this.connected) {
367+
connectAndSubscribe();
368+
MqttPahoMessageDrivenChannelAdapter.this.reconnectFuture = null;
369+
}
366370
}
367371
}
368-
}
369-
catch (MqttException ex) {
370-
logger.error(ex, "Exception while connecting and subscribing");
371-
scheduleReconnect();
372-
}
373-
}, new Date(System.currentTimeMillis() + this.recoveryInterval));
374-
}
375-
catch (Exception ex) {
376-
logger.error(ex, "Failed to schedule reconnect");
372+
catch (MqttException ex) {
373+
logger.error(ex, "Exception while connecting and subscribing");
374+
scheduleReconnect();
375+
}
376+
}, new Date(System.currentTimeMillis() + this.recoveryInterval));
377+
}
378+
catch (Exception ex) {
379+
logger.error(ex, "Failed to schedule reconnect");
380+
}
377381
}
378382
}
379383

@@ -412,7 +416,7 @@ public void messageArrived(String topic, MqttMessage mqttMessage) {
412416
sendMessage(message);
413417
}
414418
catch (RuntimeException ex) {
415-
logger.error(ex, () -> "Unhandled exception for " + message.toString());
419+
logger.error(ex, () -> "Unhandled exception for " + message);
416420
throw ex;
417421
}
418422
}

spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/MqttAdapterTests.java

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import static org.mockito.ArgumentMatchers.any;
2323
import static org.mockito.ArgumentMatchers.anyLong;
2424
import static org.mockito.ArgumentMatchers.anyString;
25+
import static org.mockito.ArgumentMatchers.eq;
2526
import static org.mockito.ArgumentMatchers.isNull;
2627
import static org.mockito.BDDMockito.given;
2728
import static org.mockito.BDDMockito.willAnswer;
@@ -53,7 +54,6 @@
5354
import org.assertj.core.api.Condition;
5455
import org.eclipse.paho.client.mqttv3.IMqttAsyncClient;
5556
import org.eclipse.paho.client.mqttv3.IMqttClient;
56-
import org.eclipse.paho.client.mqttv3.IMqttMessageListener;
5757
import org.eclipse.paho.client.mqttv3.IMqttToken;
5858
import org.eclipse.paho.client.mqttv3.MqttAsyncClient;
5959
import org.eclipse.paho.client.mqttv3.MqttCallback;
@@ -65,6 +65,7 @@
6565
import org.eclipse.paho.client.mqttv3.MqttToken;
6666
import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence;
6767
import org.junit.jupiter.api.Test;
68+
import org.mockito.ArgumentCaptor;
6869
import org.mockito.ArgumentMatchers;
6970
import org.mockito.internal.stubbing.answers.CallsRealMethods;
7071

@@ -540,8 +541,7 @@ public void testDifferentQos() throws Exception {
540541
new DirectFieldAccessor(client).setPropertyValue("aClient", aClient);
541542
willAnswer(new CallsRealMethods()).given(client).connect(any(MqttConnectOptions.class));
542543
willAnswer(new CallsRealMethods()).given(client).subscribe(any(String[].class), any(int[].class));
543-
willAnswer(new CallsRealMethods()).given(client).subscribe(any(String[].class), any(int[].class),
544-
(IMqttMessageListener[]) isNull());
544+
willAnswer(new CallsRealMethods()).given(client).subscribe(any(String[].class), any(int[].class), isNull());
545545
willReturn(alwaysComplete).given(aClient).connect(any(MqttConnectOptions.class), any(), any());
546546

547547
IMqttToken token = mock(IMqttToken.class);
@@ -572,8 +572,51 @@ public void testDifferentQos() throws Exception {
572572
verify(client).disconnectForcibly(5_000L);
573573
}
574574

575+
@Test
576+
public void testNoNPEOnReconnectAndStopRaceCondition() throws Exception {
577+
final IMqttClient client = mock(IMqttClient.class);
578+
MqttPahoMessageDrivenChannelAdapter adapter = buildAdapterIn(client, null, ConsumerStopAction.UNSUBSCRIBE_NEVER);
579+
adapter.setRecoveryInterval(10);
580+
581+
MqttException mqttException = new MqttException(MqttException.REASON_CODE_SUBSCRIBE_FAILED);
582+
583+
willThrow(mqttException)
584+
.given(client)
585+
.subscribe(any(), ArgumentMatchers.<int[]>any());
586+
587+
LogAccessor logger = spy(TestUtils.getPropertyValue(adapter, "logger", LogAccessor.class));
588+
new DirectFieldAccessor(adapter).setPropertyValue("logger", logger);
589+
CountDownLatch exceptionLatch = new CountDownLatch(1);
590+
ArgumentCaptor<MqttException> mqttExceptionArgumentCaptor = ArgumentCaptor.forClass(MqttException.class);
591+
willAnswer(i -> {
592+
exceptionLatch.countDown();
593+
return null;
594+
})
595+
.given(logger)
596+
.error(mqttExceptionArgumentCaptor.capture(), eq("Exception while connecting and subscribing"));
597+
598+
ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler();
599+
taskScheduler.initialize();
600+
adapter.setTaskScheduler(taskScheduler);
601+
602+
adapter.setApplicationEventPublisher(event -> {
603+
if (event instanceof MqttConnectionFailedEvent) {
604+
adapter.destroy();
605+
}
606+
});
607+
adapter.start();
608+
609+
assertThat(exceptionLatch.await(10, TimeUnit.SECONDS)).isTrue();
610+
assertThat(mqttExceptionArgumentCaptor.getValue())
611+
.isNotNull()
612+
.isSameAs(mqttException);
613+
614+
taskScheduler.destroy();
615+
}
616+
575617
private MqttPahoMessageDrivenChannelAdapter buildAdapterIn(final IMqttClient client, Boolean cleanSession,
576-
ConsumerStopAction action) throws MqttException {
618+
ConsumerStopAction action) {
619+
577620
DefaultMqttPahoClientFactory factory = new DefaultMqttPahoClientFactory() {
578621

579622
@Override
@@ -604,7 +647,7 @@ private MqttPahoMessageHandler buildAdapterOut(final IMqttAsyncClient client) {
604647
DefaultMqttPahoClientFactory factory = new DefaultMqttPahoClientFactory() {
605648

606649
@Override
607-
public IMqttAsyncClient getAsyncClientInstance(String uri, String clientId) throws MqttException {
650+
public IMqttAsyncClient getAsyncClientInstance(String uri, String clientId) {
608651
return client;
609652
}
610653

0 commit comments

Comments
 (0)