diff --git a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/transport/triple/ReferenceCountManagedChannel.java b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/transport/triple/ReferenceCountManagedChannel.java index 73ec416bf..1ab859022 100644 --- a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/transport/triple/ReferenceCountManagedChannel.java +++ b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/transport/triple/ReferenceCountManagedChannel.java @@ -16,6 +16,8 @@ */ package com.alipay.sofa.rpc.transport.triple; +import com.alipay.sofa.rpc.log.Logger; +import com.alipay.sofa.rpc.log.LoggerFactory; import io.grpc.CallOptions; import io.grpc.ClientCall; import io.grpc.ConnectivityState; @@ -30,9 +32,11 @@ */ public class ReferenceCountManagedChannel extends ManagedChannel { - private final AtomicInteger referenceCount = new AtomicInteger(0); + private final static Logger LOGGER = LoggerFactory.getLogger(ReferenceCountManagedChannel.class); - private ManagedChannel grpcChannel; + private final AtomicInteger referenceCount = new AtomicInteger(0); + + private final ManagedChannel grpcChannel; public ReferenceCountManagedChannel(ManagedChannel delegated) { this.grpcChannel = delegated; @@ -47,8 +51,18 @@ public void incrementAndGetCount() { @Override public ManagedChannel shutdown() { - if (referenceCount.decrementAndGet() <= 0) { - return grpcChannel.shutdown(); + int remainReferenceCount = referenceCount.decrementAndGet(); + try { + if (remainReferenceCount <= 0) { + ManagedChannel shutdown = grpcChannel.shutdown(); + shutdown.awaitTermination(5, TimeUnit.SECONDS); + return shutdown; + } + } catch (InterruptedException e) { + LOGGER.warn("Triple channel shut down interrupted."); + } finally { + LOGGER.info("ReferenceCountManagedChannel {} shutdown remain referenceCount: {}", this, + remainReferenceCount); } return grpcChannel; } diff --git a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/transport/triple/TripleClientTransport.java b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/transport/triple/TripleClientTransport.java index c74fc9d76..02e7f9555 100644 --- a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/transport/triple/TripleClientTransport.java +++ b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/transport/triple/TripleClientTransport.java @@ -35,8 +35,6 @@ import com.alipay.sofa.rpc.event.EventBus; import com.alipay.sofa.rpc.ext.Extension; import com.alipay.sofa.rpc.interceptor.ClientHeaderClientInterceptor; -import com.alipay.sofa.rpc.log.Logger; -import com.alipay.sofa.rpc.log.LoggerFactory; import com.alipay.sofa.rpc.message.ResponseFuture; import com.alipay.sofa.rpc.server.triple.TripleContants; import com.alipay.sofa.rpc.transport.AbstractChannel; @@ -63,8 +61,6 @@ @Extension("tri") public class TripleClientTransport extends ClientTransport { - private final static Logger LOGGER = LoggerFactory.getLogger(TripleClientTransport.class); - protected ProviderInfo providerInfo; protected ManagedChannel channel; @@ -78,7 +74,7 @@ public class TripleClientTransport extends ClientTransport { /* */ protected final static ConcurrentMap channelMap = new ConcurrentHashMap<>(); - protected final Object lock = new Object(); + protected final static Object lock = new Object(); protected static int KEEP_ALIVE_INTERVAL = SofaConfigs.getOrCustomDefault( RpcConfigKeys.TRIPLE_CLIENT_KEEP_ALIVE_INTERVAL, @@ -114,17 +110,11 @@ protected TripleClientInvoker buildClientInvoker() { @Override public void disconnect() { if (channel != null) { - try { - channel.shutdown().awaitTermination(5, TimeUnit.SECONDS); - } catch (InterruptedException e) { - LOGGER.warn("Triple channel shut down interrupted."); - } + channel.shutdown(); if (channel.isShutdown()) { - channel = null; - channelMap.remove(providerInfo.toString()); + channelMap.remove(providerInfo.toString(), (ReferenceCountManagedChannel) channel); } - } else { - channelMap.remove(providerInfo.toString()); + channel = null; } } @@ -269,6 +259,7 @@ private ReferenceCountManagedChannel getSharedChannel(ProviderInfo url) { channel.incrementAndGetCount(); } else { channel = new ReferenceCountManagedChannel(initChannel(url)); + channel.incrementAndGetCount(); channelMap.put(key, channel); } } diff --git a/remoting/remoting-triple/src/test/java/com/alipay/sofa/rpc/transport/triple/TripleClientTransportTest.java b/remoting/remoting-triple/src/test/java/com/alipay/sofa/rpc/transport/triple/TripleClientTransportTest.java index 23ca132d9..e3cbdf0d8 100644 --- a/remoting/remoting-triple/src/test/java/com/alipay/sofa/rpc/transport/triple/TripleClientTransportTest.java +++ b/remoting/remoting-triple/src/test/java/com/alipay/sofa/rpc/transport/triple/TripleClientTransportTest.java @@ -16,6 +16,11 @@ */ package com.alipay.sofa.rpc.transport.triple; +import com.alipay.sofa.rpc.client.ProviderInfo; +import com.alipay.sofa.rpc.config.ConsumerConfig; +import com.alipay.sofa.rpc.server.triple.HelloService; +import com.alipay.sofa.rpc.transport.ClientTransportConfig; +import com.alipay.sofa.rpc.transport.ClientTransportFactory; import org.junit.Assert; import org.junit.Test; @@ -32,4 +37,63 @@ public void testInit() { Assert.assertEquals(TripleClientTransport.KEEP_ALIVE_INTERVAL, 0); } + + @Test + public void test() { + //模拟两个 reference 去消费同一份推送数据 + ProviderInfo providerInfo = new ProviderInfo(); + providerInfo.setHost("127.0.0.1"); + providerInfo.setPort(55555); + + ConsumerConfig consumerConfig1 = new ConsumerConfig<>(); + consumerConfig1.setProtocol("tri"); + consumerConfig1.setInterfaceId(HelloService.class.getName()); + ClientTransportConfig config1 = providerToClientConfig(consumerConfig1, providerInfo); + TripleClientTransport clientTransport1 = (TripleClientTransport) ClientTransportFactory.getClientTransport(config1); + + ConsumerConfig consumerConfig2 = new ConsumerConfig<>(); + consumerConfig2.setProtocol("tri"); + consumerConfig2.setInterfaceId(HelloService.class.getName()); + ClientTransportConfig config2 = providerToClientConfig(consumerConfig2, providerInfo); + TripleClientTransport clientTransport2 = (TripleClientTransport) ClientTransportFactory.getClientTransport(config2); + + Assert.assertNotNull(TripleClientTransport.channelMap.get("127.0.0.1:55555")); + Assert.assertTrue(clientTransport1.isAvailable()); + Assert.assertTrue(clientTransport2.isAvailable()); + Assert.assertEquals(clientTransport1.channel, clientTransport2.channel); + + clientTransport1.destroy(); + Assert.assertNull(clientTransport1.channel); + Assert.assertFalse(clientTransport1.isAvailable()); + Assert.assertTrue(clientTransport2.isAvailable()); + Assert.assertNotNull(TripleClientTransport.channelMap.get("127.0.0.1:55555")); + + clientTransport1.connect(); + Assert.assertTrue(clientTransport1.isAvailable()); + Assert.assertEquals(clientTransport1.channel, clientTransport2.channel); + + clientTransport2.destroy(); + Assert.assertNull(clientTransport2.channel); + Assert.assertTrue(clientTransport1.isAvailable()); + Assert.assertFalse(clientTransport2.isAvailable()); + Assert.assertNotNull(TripleClientTransport.channelMap.get("127.0.0.1:55555")); + + clientTransport1.destroy(); + Assert.assertNull(clientTransport1.channel); + Assert.assertFalse(clientTransport1.isAvailable()); + Assert.assertFalse(clientTransport2.isAvailable()); + Assert.assertNull(TripleClientTransport.channelMap.get("127.0.0.1:55555")); + } + + private ClientTransportConfig providerToClientConfig(ConsumerConfig consumerConfig, ProviderInfo providerInfo) { + return new ClientTransportConfig() + .setConsumerConfig(consumerConfig) + .setProviderInfo(providerInfo) + .setContainer(consumerConfig.getProtocol()) + .setConnectTimeout(consumerConfig.getConnectTimeout()) + .setInvokeTimeout(consumerConfig.getTimeout()) + .setDisconnectTimeout(consumerConfig.getDisconnectTimeout()) + .setConnectionNum(consumerConfig.getConnectionNum()) + .setChannelListeners(consumerConfig.getOnConnect()); + } } \ No newline at end of file