Skip to content

Commit 9e69f33

Browse files
authored
Add lame duck mode handling (#467)
* Add lame duck mode handling
1 parent 3fefa22 commit 9e69f33

File tree

10 files changed

+172
-27
lines changed

10 files changed

+172
-27
lines changed

pkg/cluster/nats_rpc_client.go

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,13 +245,36 @@ func (ns *NatsRPCClient) Call(
245245
return res, nil
246246
}
247247

248+
// replaceConnection replaces the NATS connection, draining the old one
249+
func (ns *NatsRPCClient) replaceConnection() error {
250+
return replaceNatsConnection(
251+
ns.conn,
252+
nil, // client doesn't have subscriptions
253+
func() error { return ns.initConnection(true) },
254+
"client",
255+
)
256+
}
257+
248258
// Init inits nats rpc client
249259
func (ns *NatsRPCClient) Init() error {
250-
ns.running = true
251-
logger.Log.Debugf("connecting to nats (client) with timeout of %s", ns.connectionTimeout)
260+
return ns.initConnection(false)
261+
}
262+
263+
// initConnection initializes or replaces the NATS connection
264+
func (ns *NatsRPCClient) initConnection(isReplacement bool) error {
265+
266+
if !isReplacement {
267+
ns.running = true
268+
logger.Log.Debugf("connecting to nats (client) with timeout of %s", ns.connectionTimeout)
269+
} else {
270+
logger.Log.Debugf("re-initializing nats client connection")
271+
}
272+
252273
conn, err := setupNatsConn(
253274
ns.connString,
254275
ns.appDieChan,
276+
ns.replaceConnection,
277+
nats.RetryOnFailedConnect(true),
255278
nats.MaxReconnects(ns.maxReconnectionRetries),
256279
nats.Timeout(ns.connectionTimeout),
257280
nats.Compression(ns.websocketCompression),
@@ -264,6 +287,10 @@ func (ns *NatsRPCClient) Init() error {
264287
return err
265288
}
266289
ns.conn = conn
290+
291+
if isReplacement {
292+
logger.Log.Infof("successfully replaced nats client connection")
293+
}
267294
return nil
268295
}
269296

pkg/cluster/nats_rpc_client_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ func TestNatsRPCClientCall(t *testing.T) {
440440
for _, table := range tables {
441441
t.Run(table.name, func(t *testing.T) {
442442
ctrl := gomock.NewController(t)
443-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
443+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil, nil)
444444
assert.NoError(t, err)
445445

446446
sv2 := getServer()

pkg/cluster/nats_rpc_common.go

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,43 @@ func drainAndClose(nc *nats.Conn) error {
7474
return nil
7575
}
7676

77-
func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.Option) (*nats.Conn, error) {
77+
// replaceNatsConnection handles the common logic for replacing NATS connections
78+
// It stores old connection/subscription references, calls initFunc to set up the new connection,
79+
// and then drains the old resources after the new connection is ready.
80+
func replaceNatsConnection(
81+
oldConn *nats.Conn,
82+
oldSub *nats.Subscription,
83+
initFunc func() error,
84+
componentName string,
85+
) error {
86+
logger.Log.Infof("replacing nats %s connection due to lame duck mode", componentName)
87+
88+
// Re-initialize connection (pass true to indicate this is a replacement)
89+
if err := initFunc(); err != nil {
90+
return err
91+
}
92+
93+
// Drain and close old connection and subscription after new one is set up
94+
if oldSub != nil {
95+
go func() {
96+
if err := oldSub.Drain(); err != nil {
97+
logger.Log.Warnf("error draining old %s subscription: %v", componentName, err)
98+
}
99+
}()
100+
}
101+
102+
if oldConn != nil {
103+
go func() {
104+
if err := drainAndClose(oldConn); err != nil {
105+
logger.Log.Warnf("error draining old nats %s connection: %v", componentName, err)
106+
}
107+
}()
108+
}
109+
110+
return nil
111+
}
112+
113+
func setupNatsConn(connectString string, appDieChan chan bool, lameDuckReplacement func() error, options ...nats.Option) (*nats.Conn, error) {
78114
connectedCh := make(chan bool)
79115
initialConnectErrorCh := make(chan error)
80116
natsOptions := append(
@@ -165,6 +201,18 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O
165201
logger.Log.Infof("connected to nats on %s", nc.ConnectedAddr())
166202
connectedCh <- true
167203
}),
204+
nats.LameDuckModeHandler(func(nc *nats.Conn) {
205+
logger.Log.Warnf("nats connection entered lame duck mode")
206+
if lameDuckReplacement != nil {
207+
go func() {
208+
if err := lameDuckReplacement(); err != nil {
209+
logger.Log.Errorf("failed to replace connection: %v", err)
210+
// The old connection will eventually close (it's in lame duck mode),
211+
// which will trigger ClosedHandler and appDieChan
212+
}
213+
}()
214+
}
215+
}),
168216
)
169217

170218
nc, err := nats.Connect(connectString, natsOptions...)

pkg/cluster/nats_rpc_common_test.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,14 @@ func TestNatsRPCCommonSetupNatsConn(t *testing.T) {
4949
t.Parallel()
5050
s := helpers.GetTestNatsServer(t)
5151
defer s.Shutdown()
52-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
52+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil, nil)
5353
assert.NoError(t, err)
5454
assert.NotNil(t, conn)
5555
}
5656

5757
func TestNatsRPCCommonSetupNatsConnShouldError(t *testing.T) {
5858
t.Parallel()
59-
conn, err := setupNatsConn("nats://localhost:1234", nil)
59+
conn, err := setupNatsConn("nats://localhost:1234", nil, nil)
6060
assert.Error(t, err)
6161
assert.Nil(t, conn)
6262
}
@@ -67,7 +67,7 @@ func TestNatsRPCCommonCloseHandler(t *testing.T) {
6767

6868
dieChan := make(chan bool)
6969

70-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), dieChan, nats.MaxReconnects(1),
70+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), dieChan, nil, nats.MaxReconnects(1),
7171
nats.ReconnectWait(1*time.Millisecond))
7272
assert.NoError(t, err)
7373
assert.NotNil(t, conn)
@@ -99,6 +99,7 @@ func TestSetupNatsConnReconnection(t *testing.T) {
9999
conn, err := setupNatsConn(
100100
urls,
101101
appDieCh,
102+
nil,
102103
nats.ReconnectWait(10*time.Millisecond),
103104
nats.MaxReconnects(5),
104105
nats.RetryOnFailedConnect(true),
@@ -121,6 +122,7 @@ func TestSetupNatsConnReconnection(t *testing.T) {
121122
conn, err := setupNatsConn(
122123
invalidAddr,
123124
appDieCh,
125+
nil,
124126
nats.ReconnectWait(10*time.Millisecond),
125127
nats.MaxReconnects(2),
126128
nats.RetryOnFailedConnect(true),
@@ -146,7 +148,7 @@ func TestSetupNatsConnReconnection(t *testing.T) {
146148
done := make(chan any)
147149

148150
go func() {
149-
conn, err := setupNatsConn(invalidAddr, appDieCh)
151+
conn, err := setupNatsConn(invalidAddr, appDieCh, nil)
150152
assert.Error(t, err)
151153
assert.Nil(t, conn)
152154
close(done)
@@ -182,6 +184,7 @@ func TestSetupNatsConnReconnection(t *testing.T) {
182184
conn, err := setupNatsConn(
183185
invalidAddr,
184186
appDieCh,
187+
nil,
185188
nats.Timeout(initialConnectionTimeout),
186189
nats.ReconnectWait(reconnectWait),
187190
nats.MaxReconnects(maxReconnectionAtetmpts),

pkg/cluster/nats_rpc_server.go

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -333,15 +333,37 @@ func (ns *NatsRPCServer) processKick() {
333333
}
334334
}
335335

336+
// replaceConnection replaces the NATS connection, draining the old one and re-subscribing
337+
func (ns *NatsRPCServer) replaceConnection() error {
338+
return replaceNatsConnection(
339+
ns.conn,
340+
ns.sub,
341+
func() error { return ns.initConnection(true) },
342+
"server",
343+
)
344+
}
345+
336346
// Init inits nats rpc server
337347
func (ns *NatsRPCServer) Init() error {
338-
// TODO should we have concurrency here? it feels like we should
339-
go ns.handleMessages()
348+
return ns.initConnection(false)
349+
}
350+
351+
// initConnection initializes or replaces the NATS connection
352+
func (ns *NatsRPCServer) initConnection(isReplacement bool) error {
353+
354+
if !isReplacement {
355+
// TODO should we have concurrency here? it feels like we should
356+
go ns.handleMessages()
357+
logger.Log.Debugf("connecting to nats (server) with timeout of %s", ns.connectionTimeout)
358+
} else {
359+
logger.Log.Debugf("re-initializing nats server connection")
360+
}
340361

341-
logger.Log.Debugf("connecting to nats (server) with timeout of %s", ns.connectionTimeout)
342362
conn, err := setupNatsConn(
343363
ns.connString,
344364
ns.appDieChan,
365+
ns.replaceConnection,
366+
nats.RetryOnFailedConnect(true),
345367
nats.MaxReconnects(ns.maxReconnectionRetries),
346368
nats.Timeout(ns.connectionTimeout),
347369
nats.Compression(ns.websocketCompression),
@@ -362,17 +384,37 @@ func (ns *NatsRPCServer) Init() error {
362384
if err != nil {
363385
return err
364386
}
365-
// this handles remote messages
366-
for i := 0; i < ns.service; i++ {
367-
go ns.processMessages(i)
387+
388+
// Re-subscribe to all session subscriptions if this is a replacement
389+
// The onSessionBind callback is already set up, we just need to trigger it for existing sessions
390+
if isReplacement && ns.server.Frontend && ns.sessionPool != nil {
391+
ns.sessionPool.ForEachSession(func(s session.Session) {
392+
if s.GetIsFrontend() && s.UID() != "" {
393+
// Re-use the same subscription logic as onSessionBind
394+
if err := ns.onSessionBind(context.Background(), s); err != nil {
395+
logger.Log.Errorf("failed to re-subscribe session for user %s: %v", s.UID(), err)
396+
}
397+
}
398+
})
368399
}
369400

370-
ns.sessionPool.OnSessionBind(ns.onSessionBind)
401+
if !isReplacement {
402+
// this handles remote messages
403+
for i := 0; i < ns.service; i++ {
404+
go ns.processMessages(i)
405+
}
406+
407+
ns.sessionPool.OnSessionBind(ns.onSessionBind)
371408

372-
// this should be so fast that we shoudn't need concurrency
373-
go ns.processPushes()
374-
go ns.processSessionBindings()
375-
go ns.processKick()
409+
// this should be so fast that we shoudn't need concurrency
410+
go ns.processPushes()
411+
go ns.processSessionBindings()
412+
go ns.processKick()
413+
}
414+
415+
if isReplacement {
416+
logger.Log.Infof("successfully replaced nats server connection")
417+
}
376418

377419
return nil
378420
}

pkg/cluster/nats_rpc_server_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ func TestNatsRPCServerOnSessionBind(t *testing.T) {
157157
rpcServer, _ := NewNatsRPCServer(cfg, sv, nil, nil, nil)
158158
s := helpers.GetTestNatsServer(t)
159159
defer s.Shutdown()
160-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
160+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil, nil)
161161
assert.NoError(t, err)
162162
rpcServer.conn = conn
163163
err = rpcServer.onSessionBind(context.Background(), mockSession)
@@ -172,7 +172,7 @@ func TestNatsRPCServerSubscribeToBindingsChannel(t *testing.T) {
172172
rpcServer, _ := NewNatsRPCServer(cfg, sv, nil, nil, nil)
173173
s := helpers.GetTestNatsServer(t)
174174
defer s.Shutdown()
175-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
175+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil, nil)
176176
assert.NoError(t, err)
177177
rpcServer.conn = conn
178178
err = rpcServer.subscribeToBindingsChannel()
@@ -190,7 +190,7 @@ func TestNatsRPCServerSubscribeUserKickChannel(t *testing.T) {
190190
rpcServer, _ := NewNatsRPCServer(cfg, sv, nil, nil, nil)
191191
s := helpers.GetTestNatsServer(t)
192192
defer s.Shutdown()
193-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
193+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil, nil)
194194
assert.NoError(t, err)
195195
rpcServer.conn = conn
196196
sub, err := rpcServer.subscribeToUserKickChannel("someuid", sv.Type)
@@ -229,7 +229,7 @@ func TestNatsRPCServerSubscribeToUserMessages(t *testing.T) {
229229
rpcServer, _ := NewNatsRPCServer(cfg, sv, nil, nil, nil)
230230
s := helpers.GetTestNatsServer(t)
231231
defer s.Shutdown()
232-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
232+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil, nil)
233233
assert.NoError(t, err)
234234
rpcServer.conn = conn
235235
tables := []struct {
@@ -259,7 +259,7 @@ func TestNatsRPCServerSubscribe(t *testing.T) {
259259
rpcServer, _ := NewNatsRPCServer(cfg, sv, nil, nil, nil)
260260
s := helpers.GetTestNatsServer(t)
261261
defer s.Shutdown()
262-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
262+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil, nil)
263263
assert.NoError(t, err)
264264
rpcServer.conn = conn
265265
tables := []struct {
@@ -294,7 +294,7 @@ func TestNatsRPCServerHandleMessages(t *testing.T) {
294294
rpcServer, _ := NewNatsRPCServer(cfg, sv, mockMetricsReporters, nil, nil)
295295
s := helpers.GetTestNatsServer(t)
296296
defer s.Shutdown()
297-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
297+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil, nil)
298298
assert.NoError(t, err)
299299
rpcServer.conn = conn
300300
tables := []struct {

pkg/groups/memory_group_service.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ var (
1515
memoryOnce sync.Once
1616
globalCtx context.Context
1717
globalCancel context.CancelFunc
18+
cleanupWG sync.WaitGroup
1819
)
1920

2021
// MemoryGroupService base in server memory solution
@@ -35,6 +36,7 @@ func NewMemoryGroupService(config config.MemoryGroupConfig) *MemoryGroupService
3536
memoryOnce.Do(func() {
3637
memoryGroups = make(map[string]*MemoryGroup)
3738
globalCtx, globalCancel = context.WithCancel(context.Background())
39+
cleanupWG.Add(1)
3840
go groupTTLCleanup(globalCtx, config.TickDuration)
3941
})
4042
// All services share the same cancel function
@@ -43,6 +45,7 @@ func NewMemoryGroupService(config config.MemoryGroupConfig) *MemoryGroupService
4345
}
4446

4547
func groupTTLCleanup(ctx context.Context, interval time.Duration) {
48+
defer cleanupWG.Done()
4649
ticker := time.NewTicker(interval)
4750
defer ticker.Stop()
4851

@@ -231,5 +234,7 @@ func (c *MemoryGroupService) Close() {
231234
// The goroutine will exit when the context is cancelled
232235
if globalCancel != nil {
233236
globalCancel()
237+
// Wait for the goroutine to exit
238+
cleanupWG.Wait()
234239
}
235240
}

pkg/groups/memory_group_service_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ func TestMain(m *testing.M) {
3636
memoryGroupService = NewMemoryGroupService(mconfig)
3737
exit := m.Run()
3838
memoryGroupService.Close()
39-
// Wait for the goroutine to exit (give it time to process the context cancellation)
40-
time.Sleep(2 * mconfig.TickDuration)
4139
os.Exit(exit)
4240
}
4341

pkg/session/mocks/session.go

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)