Skip to content

Commit ac4b8a4

Browse files
authored
Add lame duck mode handling (#466)
* Add lame duck mode handling * fix: add missing config for #439 to work * fix: bad formatting for nil errors
1 parent 4f5c74b commit ac4b8a4

File tree

8 files changed

+202
-31
lines changed

8 files changed

+202
-31
lines changed

cluster/nats_rpc_client.go

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,14 +243,36 @@ func (ns *NatsRPCClient) Call(
243243
return res, nil
244244
}
245245

246+
// replaceConnection replaces the NATS connection, draining the old one
247+
func (ns *NatsRPCClient) replaceConnection() error {
248+
return replaceNatsConnection(
249+
ns.conn,
250+
nil, // client doesn't have subscriptions
251+
func() error { return ns.initConnection(true) },
252+
"client",
253+
)
254+
}
255+
246256
// Init inits nats rpc client
247257
func (ns *NatsRPCClient) Init() error {
248-
ns.running = true
249-
logger.Log.Debugf("connecting to nats (client) with timeout of %s", ns.connectionTimeout)
258+
return ns.initConnection(false)
259+
}
260+
261+
// initConnection initializes or replaces the NATS connection
262+
func (ns *NatsRPCClient) initConnection(isReplacement bool) error {
263+
264+
if !isReplacement {
265+
ns.running = true
266+
logger.Log.Debugf("connecting to nats (client) with timeout of %s", ns.connectionTimeout)
267+
} else {
268+
logger.Log.Debugf("re-initializing nats client connection")
269+
}
270+
250271
conn, err := setupNatsConn(
251272
ns.connString,
252273
ns.appDieChan,
253-
nats.RetryOnFailedConnect(false),
274+
ns.replaceConnection,
275+
nats.RetryOnFailedConnect(true),
254276
nats.MaxReconnects(ns.maxReconnectionRetries),
255277
nats.Timeout(ns.connectionTimeout),
256278
nats.Compression(ns.websocketCompression),
@@ -263,6 +285,10 @@ func (ns *NatsRPCClient) Init() error {
263285
return err
264286
}
265287
ns.conn = conn
288+
289+
if isReplacement {
290+
logger.Log.Infof("successfully replaced nats client connection")
291+
}
266292
return nil
267293
}
268294

cluster/nats_rpc_client_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ func TestNatsRPCClientCall(t *testing.T) {
440440
for _, table := range tables {
441441
t.Run(table.name, func(t *testing.T) {
442442
ctrl := gomock.NewController(t)
443-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
443+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil, nil)
444444
defer conn.Close()
445445
assert.NoError(t, err)
446446

cluster/nats_rpc_common.go

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,18 @@ func drainAndClose(nc *nats.Conn) error {
3737
if nc == nil {
3838
return nil
3939
}
40+
// If connection is already closed, just return
41+
if nc.IsClosed() {
42+
return nil
43+
}
4044
// Drain connection (this will flush any pending messages and prevent new ones)
4145
err := nc.Drain()
4246
if err != nil {
4347
logger.Log.Warnf("error draining nats connection: %v", err)
44-
// Even if drain fails, try to close
45-
nc.Close()
48+
// Even if drain fails, try to close (but only if not already closed)
49+
if !nc.IsClosed() {
50+
nc.Close()
51+
}
4652
return err
4753
}
4854

@@ -57,7 +63,9 @@ func drainAndClose(nc *nats.Conn) error {
5763
continue
5864
case <-timeout:
5965
logger.Log.Warn("drain timeout exceeded, forcing close")
60-
nc.Close()
66+
if !nc.IsClosed() {
67+
nc.Close()
68+
}
6169
return fmt.Errorf("drain timeout exceeded")
6270
}
6371
}
@@ -66,13 +74,49 @@ func drainAndClose(nc *nats.Conn) error {
6674
return nil
6775
}
6876

69-
func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.Option) (*nats.Conn, error) {
77+
// replaceNatsConnection handles the common logic for replacing NATS connections
78+
// It stores old connection/subscription references, calls initFunc to set up the new connection,
79+
// and then drains the old resources after the new connection is ready.
80+
func replaceNatsConnection(
81+
oldConn *nats.Conn,
82+
oldSub *nats.Subscription,
83+
initFunc func() error,
84+
componentName string,
85+
) error {
86+
logger.Log.Infof("replacing nats %s connection due to lame duck mode", componentName)
87+
88+
// Re-initialize connection (pass true to indicate this is a replacement)
89+
if err := initFunc(); err != nil {
90+
return err
91+
}
92+
93+
// Drain and close old connection and subscription after new one is set up
94+
if oldSub != nil {
95+
go func() {
96+
if err := oldSub.Drain(); err != nil {
97+
logger.Log.Warnf("error draining old %s subscription: %v", componentName, err)
98+
}
99+
}()
100+
}
101+
102+
if oldConn != nil {
103+
go func() {
104+
if err := drainAndClose(oldConn); err != nil {
105+
logger.Log.Warnf("error draining old nats %s connection: %v", componentName, err)
106+
}
107+
}()
108+
}
109+
110+
return nil
111+
}
112+
113+
func setupNatsConn(connectString string, appDieChan chan bool, lameDuckReplacement func() error, options ...nats.Option) (*nats.Conn, error) {
70114
connectedCh := make(chan bool)
71115
initialConnectErrorCh := make(chan error)
72116
natsOptions := append(
73117
options,
74118
nats.DisconnectErrHandler(func(nc *nats.Conn, err error) {
75-
logger.Log.Warnf("disconnected from nats (%s)! Reason: %q\n", nc.ConnectedAddr(), err)
119+
logger.Log.Warnf("disconnected from nats (%s)! Reason: %v", nc.ConnectedAddr(), err)
76120
}),
77121
nats.ReconnectHandler(func(nc *nats.Conn) {
78122
logger.Log.Warnf("reconnected to nats server %s with address %s in cluster %s!", nc.ConnectedServerName(), nc.ConnectedAddr(), nc.ConnectedClusterName())
@@ -85,12 +129,28 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O
85129
}
86130

87131
logger.Log.Errorf("nats connection closed. reason: %q", nc.LastError())
132+
133+
// If connection was never successfully established, prioritize initialConnectErrorCh
134+
// to allow setupNatsConn to return quickly with an error
135+
wasConnected := nc.ConnectedAddr() != ""
136+
137+
if !wasConnected {
138+
// During initial connection, send error to initialConnectErrorCh first
139+
select {
140+
case initialConnectErrorCh <- nc.LastError():
141+
return
142+
default:
143+
// If channel is not ready, fall through to appDieChan handling
144+
}
145+
}
146+
88147
if appDieChan != nil {
89148
select {
90149
case appDieChan <- true:
91150
return
92151
case initialConnectErrorCh <- nc.LastError():
93152
logger.Log.Warnf("appDieChan not ready, sending error in initialConnectCh")
153+
return
94154
default:
95155
logger.Log.Warnf("no termination channel available, sending termination signal to app")
96156

@@ -108,6 +168,14 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O
108168
os.Exit(1)
109169
}
110170
}
171+
} else if !wasConnected {
172+
// If no appDieChan and connection was never established, try initialConnectErrorCh again
173+
select {
174+
case initialConnectErrorCh <- nc.LastError():
175+
return
176+
default:
177+
// Channel not ready, but we've already logged the error
178+
}
111179
}
112180
}),
113181
nats.ErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) {
@@ -123,6 +191,18 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O
123191
logger.Log.Infof("connected to nats on %s", nc.ConnectedAddr())
124192
connectedCh <- true
125193
}),
194+
nats.LameDuckModeHandler(func(nc *nats.Conn) {
195+
logger.Log.Warnf("nats connection entered lame duck mode")
196+
if lameDuckReplacement != nil {
197+
go func() {
198+
if err := lameDuckReplacement(); err != nil {
199+
logger.Log.Errorf("failed to replace connection: %v", err)
200+
// The old connection will eventually close (it's in lame duck mode),
201+
// which will trigger ClosedHandler and appDieChan
202+
}
203+
}()
204+
}
205+
}),
126206
)
127207

128208
nc, err := nats.Connect(connectString, natsOptions...)

cluster/nats_rpc_common_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ func TestNatsRPCCommonSetupNatsConn(t *testing.T) {
5353
s.Shutdown()
5454
s.WaitForShutdown()
5555
}()
56-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
56+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil, nil)
5757
assert.NoError(t, err)
5858
assert.NotNil(t, conn)
5959
}
6060

6161
func TestNatsRPCCommonSetupNatsConnShouldError(t *testing.T) {
6262
t.Parallel()
63-
conn, err := setupNatsConn("nats://invalid:1234", nil)
63+
conn, err := setupNatsConn("nats://invalid:1234", nil, nil)
6464
assert.Error(t, err)
6565
assert.Nil(t, conn)
6666
}
@@ -83,7 +83,7 @@ func TestNatsRPCCommonCloseHandler(t *testing.T) {
8383
assert.True(t, value)
8484
}()
8585

86-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), dieChan, nats.MaxReconnects(1),
86+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), dieChan, nil, nats.MaxReconnects(1),
8787
nats.ReconnectWait(1*time.Millisecond))
8888
assert.NoError(t, err)
8989
assert.NotNil(t, conn)
@@ -108,6 +108,7 @@ func TestNatsRPCCommonWaitReconnections(t *testing.T) {
108108
conn, err := setupNatsConn(
109109
urls,
110110
appDieCh,
111+
nil,
111112
nats.ReconnectWait(10*time.Millisecond),
112113
nats.MaxReconnects(5),
113114
nats.RetryOnFailedConnect(true),
@@ -135,6 +136,7 @@ func TestNatsRPCCommonDoNotBlockOnConnectionFail(t *testing.T) {
135136
conn, err := setupNatsConn(
136137
invalidAddr,
137138
appDieCh,
139+
nil,
138140
nats.ReconnectWait(10*time.Millisecond),
139141
nats.MaxReconnects(2),
140142
nats.RetryOnFailedConnect(true),
@@ -168,7 +170,7 @@ func TestNatsRPCCommonFailWithoutAppDieChan(t *testing.T) {
168170
}()
169171

170172
go func() {
171-
conn, err := setupNatsConn(invalidAddr, appDieCh)
173+
conn, err := setupNatsConn(invalidAddr, appDieCh, nil)
172174
assert.Error(t, err)
173175
assert.Nil(t, conn)
174176
close(done)

cluster/nats_rpc_server.go

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -333,16 +333,37 @@ func (ns *NatsRPCServer) processKick() {
333333
}
334334
}
335335

336+
// replaceConnection replaces the NATS connection, draining the old one and re-subscribing
337+
func (ns *NatsRPCServer) replaceConnection() error {
338+
return replaceNatsConnection(
339+
ns.conn,
340+
ns.sub,
341+
func() error { return ns.initConnection(true) },
342+
"server",
343+
)
344+
}
345+
336346
// Init inits nats rpc server
337347
func (ns *NatsRPCServer) Init() error {
338-
// TODO should we have concurrency here? it feels like we should
339-
go ns.handleMessages()
348+
return ns.initConnection(false)
349+
}
350+
351+
// initConnection initializes or replaces the NATS connection
352+
func (ns *NatsRPCServer) initConnection(isReplacement bool) error {
353+
354+
if !isReplacement {
355+
// TODO should we have concurrency here? it feels like we should
356+
go ns.handleMessages()
357+
logger.Log.Debugf("connecting to nats (server) with timeout of %s", ns.connectionTimeout)
358+
} else {
359+
logger.Log.Debugf("re-initializing nats server connection")
360+
}
340361

341-
logger.Log.Debugf("connecting to nats (server) with timeout of %s", ns.connectionTimeout)
342362
conn, err := setupNatsConn(
343363
ns.connString,
344364
ns.appDieChan,
345-
nats.RetryOnFailedConnect(false),
365+
ns.replaceConnection,
366+
nats.RetryOnFailedConnect(true),
346367
nats.MaxReconnects(ns.maxReconnectionRetries),
347368
nats.Timeout(ns.connectionTimeout),
348369
nats.Compression(ns.websocketCompression),
@@ -363,17 +384,37 @@ func (ns *NatsRPCServer) Init() error {
363384
if err != nil {
364385
return err
365386
}
366-
// this handles remote messages
367-
for i := 0; i < ns.service; i++ {
368-
go ns.processMessages(i)
387+
388+
// Re-subscribe to all session subscriptions if this is a replacement
389+
// The onSessionBind callback is already set up, we just need to trigger it for existing sessions
390+
if isReplacement && ns.server.Frontend && ns.sessionPool != nil {
391+
ns.sessionPool.ForEachSession(func(s session.Session) {
392+
if s.GetIsFrontend() && s.UID() != "" {
393+
// Re-use the same subscription logic as onSessionBind
394+
if err := ns.onSessionBind(context.Background(), s); err != nil {
395+
logger.Log.Errorf("failed to re-subscribe session for user %s: %v", s.UID(), err)
396+
}
397+
}
398+
})
369399
}
370400

371-
ns.sessionPool.OnSessionBind(ns.onSessionBind)
401+
if !isReplacement {
402+
// this handles remote messages
403+
for i := 0; i < ns.service; i++ {
404+
go ns.processMessages(i)
405+
}
406+
407+
ns.sessionPool.OnSessionBind(ns.onSessionBind)
372408

373-
// this should be so fast that we shoudn't need concurrency
374-
go ns.processPushes()
375-
go ns.processSessionBindings()
376-
go ns.processKick()
409+
// this should be so fast that we shoudn't need concurrency
410+
go ns.processPushes()
411+
go ns.processSessionBindings()
412+
go ns.processKick()
413+
}
414+
415+
if isReplacement {
416+
logger.Log.Infof("successfully replaced nats server connection")
417+
}
377418

378419
return nil
379420
}

cluster/nats_rpc_server_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ func TestNatsRPCServerOnSessionBind(t *testing.T) {
160160
s.Shutdown()
161161
s.WaitForShutdown()
162162
}()
163-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
163+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil, nil)
164164
assert.NoError(t, err)
165165
rpcServer.conn = conn
166166
err = rpcServer.onSessionBind(context.Background(), mockSession)
@@ -179,7 +179,7 @@ func TestNatsRPCServerSubscribeToBindingsChannel(t *testing.T) {
179179
s.Shutdown()
180180
s.WaitForShutdown()
181181
}()
182-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
182+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil, nil)
183183
assert.NoError(t, err)
184184
rpcServer.conn = conn
185185
err = rpcServer.subscribeToBindingsChannel()
@@ -201,7 +201,7 @@ func TestNatsRPCServerSubscribeUserKickChannel(t *testing.T) {
201201
s.Shutdown()
202202
s.WaitForShutdown()
203203
}()
204-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
204+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil, nil)
205205
assert.NoError(t, err)
206206
rpcServer.conn = conn
207207
sub, err := rpcServer.subscribeToUserKickChannel("someuid", sv.Type)
@@ -244,7 +244,7 @@ func TestNatsRPCServerSubscribeToUserMessages(t *testing.T) {
244244
s.Shutdown()
245245
s.WaitForShutdown()
246246
}()
247-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
247+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil, nil)
248248
assert.NoError(t, err)
249249
rpcServer.conn = conn
250250
tables := []struct {
@@ -278,7 +278,7 @@ func TestNatsRPCServerSubscribe(t *testing.T) {
278278
s.Shutdown()
279279
s.WaitForShutdown()
280280
}()
281-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
281+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil, nil)
282282
assert.NoError(t, err)
283283
rpcServer.conn = conn
284284
tables := []struct {
@@ -317,7 +317,7 @@ func TestNatsRPCServerHandleMessages(t *testing.T) {
317317
s.Shutdown()
318318
s.WaitForShutdown()
319319
}()
320-
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
320+
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil, nil)
321321
assert.NoError(t, err)
322322
rpcServer.conn = conn
323323
tables := []struct {

0 commit comments

Comments
 (0)