Skip to content

Commit d96670f

Browse files
committed
Fix race condition in AmqpConnection
The automatic reconnection was racy. It could try to open a connection at the same time that `Close()` was called on `AmqpConnection`. This required introducing an internal field to determine when the connection is "closed", and a specific error to signal in the reconnection that the connection is "closed" after calling `Close()`.
1 parent 85930d6 commit d96670f

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

pkg/rabbitmqamqp/amqp_connection.go

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package rabbitmqamqp
33
import (
44
"context"
55
"crypto/tls"
6+
"errors"
67
"fmt"
78
"math/rand"
89
"sync"
@@ -13,6 +14,8 @@ import (
1314
"github.com/google/uuid"
1415
)
1516

17+
var ErrConnectionClosed = errors.New("connection is closed")
18+
1619
type AmqpAddress struct {
1720
// the address of the AMQP server
1821
// it is in the form of amqp://<host>:<port>
@@ -121,6 +124,8 @@ type AmqpConnection struct {
121124
session *amqp.Session
122125
refMap *sync.Map
123126
entitiesTracker *entitiesTracker
127+
mutex sync.RWMutex
128+
closed bool
124129
}
125130

126131
func (a *AmqpConnection) Properties() map[string]any {
@@ -379,6 +384,12 @@ func validateOptions(connOptions *AmqpConnOptions) (*AmqpConnOptions, error) {
379384
// using the provided connectionSettings and the AMQPLite library.
380385
// Setups the connection and the management interface.
381386
func (a *AmqpConnection) open(ctx context.Context, address string, connOptions *AmqpConnOptions) error {
387+
a.mutex.Lock()
388+
defer a.mutex.Unlock()
389+
390+
if a.closed {
391+
return ErrConnectionClosed
392+
}
382393

383394
// random pick and extract one address to use for connection
384395
var azureConnection *amqp.Conn
@@ -456,7 +467,6 @@ func (a *AmqpConnection) open(ctx context.Context, address string, connOptions *
456467
return nil
457468
}
458469
func (a *AmqpConnection) maybeReconnect() {
459-
460470
if !a.amqpConnOptions.RecoveryConfiguration.ActiveRecovery {
461471
Info("Recovery is disabled, closing connection", "ID", a.Id())
462472
return
@@ -467,7 +477,6 @@ func (a *AmqpConnection) maybeReconnect() {
467477
maxDelay := 1 * time.Minute
468478

469479
for attempt := 1; attempt <= a.amqpConnOptions.RecoveryConfiguration.MaxReconnectAttempts; attempt++ {
470-
471480
///wait for before reconnecting
472481
// add some random milliseconds to the wait time to avoid thundering herd
473482
// the random time is between 0 and 500 milliseconds
@@ -491,6 +500,12 @@ func (a *AmqpConnection) maybeReconnect() {
491500
a.lifeCycle.SetState(&StateOpen{})
492501
return
493502
}
503+
504+
if errors.Is(err, ErrConnectionClosed) {
505+
Info("Connection was closed during reconnect, aborting.", "ID", a.Id())
506+
return
507+
}
508+
494509
baseDelay *= 2
495510
Error("Reconnection attempt failed", "attempt", attempt, "error", err, "ID", a.Id())
496511
}
@@ -548,6 +563,13 @@ Close closes the connection to the AMQP 1.0 server and the management interface.
548563
All the publishers and consumers are closed as well.
549564
*/
550565
func (a *AmqpConnection) Close(ctx context.Context) error {
566+
a.mutex.Lock()
567+
if a.closed {
568+
a.mutex.Unlock()
569+
return nil
570+
}
571+
a.closed = true
572+
defer a.mutex.Unlock()
551573
// the status closed (lifeCycle.SetState(&StateClosed{error: nil})) is not set here
552574
// it is set in the connection.Done() channel
553575
// the channel is called anyway

0 commit comments

Comments
 (0)