1414package io .streamnative .pulsar .handlers .mqtt .common ;
1515
1616import static io .streamnative .pulsar .handlers .mqtt .common .systemtopic .EventType .CONNECT ;
17+ import static io .streamnative .pulsar .handlers .mqtt .common .systemtopic .EventType .DISCONNECT ;
1718import io .netty .util .HashedWheelTimer ;
1819import io .netty .util .Timeout ;
1920import io .netty .util .concurrent .DefaultThreadFactory ;
2021import io .streamnative .pulsar .handlers .mqtt .common .systemtopic .ConnectEvent ;
2122import io .streamnative .pulsar .handlers .mqtt .common .systemtopic .EventListener ;
2223import io .streamnative .pulsar .handlers .mqtt .common .systemtopic .MqttEvent ;
24+ import java .util .ArrayList ;
25+ import java .util .Collection ;
2326import java .util .concurrent .ConcurrentHashMap ;
2427import java .util .concurrent .ConcurrentMap ;
2528import java .util .concurrent .TimeUnit ;
3336@ Slf4j
3437public class MQTTConnectionManager {
3538
36- private final ConcurrentMap <String , Connection > connections ;
39+ private final ConcurrentMap <String , Connection > localConnections ;
40+
41+ private final ConcurrentMap <String , Connection > eventConnections ;
3742
3843 @ Getter
3944 private static final HashedWheelTimer sessionExpireInterval =
4045 new HashedWheelTimer (
4146 new DefaultThreadFactory ("session-expire-interval" ), 1 , TimeUnit .SECONDS );
4247
4348 @ Getter
44- private final EventListener eventListener ;
49+ private final EventListener connectListener ;
50+
51+ @ Getter
52+ private final EventListener disconnectListener ;
4553
4654 private final String advertisedAddress ;
4755
4856 public MQTTConnectionManager (String advertisedAddress ) {
4957 this .advertisedAddress = advertisedAddress ;
50- this .connections = new ConcurrentHashMap <>(2048 );
51- this .eventListener = new ConnectEventListener ();
58+ this .localConnections = new ConcurrentHashMap <>(2048 );
59+ this .eventConnections = new ConcurrentHashMap <>(2048 );
60+ this .connectListener = new ConnectEventListener ();
61+ this .disconnectListener = new DisconnectEventListener ();
5262 }
5363
5464 public void addConnection (Connection connection ) {
55- Connection existing = connections .put (connection .getClientId (), connection );
65+ Connection existing = localConnections .put (connection .getClientId (), connection );
5666 if (existing != null ) {
5767 log .warn ("The clientId is existed. Close existing connection. CId={}" , existing .getClientId ());
5868 existing .disconnect ();
@@ -68,7 +78,7 @@ public void addConnection(Connection connection) {
6878 */
6979 public void newSessionExpireInterval (Consumer <Timeout > task , String clientId , int interval ) {
7080 sessionExpireInterval .newTimeout (timeout -> {
71- Connection connection = connections .get (clientId );
81+ Connection connection = localConnections .get (clientId );
7282 if (connection != null
7383 && connection .getState () != Connection .ConnectionState .DISCONNECTED ) {
7484 return ;
@@ -80,16 +90,28 @@ public void newSessionExpireInterval(Consumer<Timeout> task, String clientId, in
8090 // Must use connections.remove(key, value).
8191 public void removeConnection (Connection connection ) {
8292 if (connection != null ) {
83- connections .remove (connection .getClientId (), connection );
93+ localConnections .remove (connection .getClientId (), connection );
8494 }
8595 }
8696
8797 public Connection getConnection (String clientId ) {
88- return connections .get (clientId );
98+ return localConnections .get (clientId );
99+ }
100+
101+ public Collection <Connection > getLocalConnections () {
102+ return this .localConnections .values ();
103+ }
104+
105+ public Collection <Connection > getAllConnections () {
106+ Collection <Connection > connections = new ArrayList <>(this .localConnections .values ().size ()
107+ + this .eventConnections .values ().size ());
108+ connections .addAll (this .localConnections .values ());
109+ connections .addAll (eventConnections .values ());
110+ return connections ;
89111 }
90112
91113 public void close () {
92- connections .values ().forEach (connection -> connection .getChannel ().close ());
114+ localConnections .values ().forEach (connection -> connection .getChannel ().close ());
93115 }
94116
95117 class ConnectEventListener implements EventListener {
@@ -103,9 +125,25 @@ public void onChange(MqttEvent event) {
103125 if (connection != null ) {
104126 log .warn ("[ConnectEvent] close existing connection : {}" , connection );
105127 connection .disconnect ();
128+ } else {
129+ eventConnections .put (connectEvent .getClientId (), connection );
106130 }
107131 }
108132 }
109133 }
110134 }
135+
136+ //TODO
137+ class DisconnectEventListener implements EventListener {
138+
139+ @ Override
140+ public void onChange (MqttEvent event ) {
141+ if (event .getEventType () == DISCONNECT ) {
142+ ConnectEvent connectEvent = (ConnectEvent ) event .getSourceEvent ();
143+ if (!connectEvent .getAddress ().equals (advertisedAddress )) {
144+ eventConnections .remove (connectEvent .getClientId ());
145+ }
146+ }
147+ }
148+ }
111149}
0 commit comments