diff --git a/Makefile b/Makefile index d655c81..130afc8 100644 --- a/Makefile +++ b/Makefile @@ -1,13 +1,13 @@ -all: format vet test +all: test format: go fmt ./... vet: - go vet ./rabbitmq_amqp + go vet ./pkg/rabbitmq_amqp -test: - cd rabbitmq_amqp && go run -mod=mod github.com/onsi/ginkgo/v2/ginkgo \ +test: format vet + cd ./pkg/rabbitmq_amqp && go run -mod=mod github.com/onsi/ginkgo/v2/ginkgo \ --randomize-all --randomize-suites \ --cover --coverprofile=coverage.txt --covermode=atomic \ --race diff --git a/docs/examples/getting_started/main.go b/docs/examples/getting_started/main.go index a99fa9c..1a9954b 100644 --- a/docs/examples/getting_started/main.go +++ b/docs/examples/getting_started/main.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" "github.com/Azure/go-amqp" - "github.com/rabbitmq/rabbitmq-amqp-go-client/rabbitmq_amqp" + "github.com/rabbitmq/rabbitmq-amqp-go-client/pkg/rabbitmq_amqp" "time" ) @@ -20,7 +20,7 @@ func main() { stateChanged := make(chan *rabbitmq_amqp.StateChanged, 1) go func(ch chan *rabbitmq_amqp.StateChanged) { for statusChanged := range ch { - rabbitmq_amqp.Info("[Connection]", "Status changed", statusChanged) + rabbitmq_amqp.Info("[connection]", "Status changed", statusChanged) } }(stateChanged) @@ -33,7 +33,7 @@ func main() { // Register the channel to receive status change notifications amqpConnection.NotifyStatusChange(stateChanged) - fmt.Printf("AMQP Connection opened.\n") + fmt.Printf("AMQP connection opened.\n") // Create the management interface for the connection // so we can declare exchanges, queues, and bindings management := amqpConnection.Management() @@ -86,16 +86,16 @@ func main() { deliveryContext, err := consumer.Receive(ctx) if errors.Is(err, context.Canceled) { // The consumer was closed correctly - rabbitmq_amqp.Info("[Consumer]", "consumer closed. Context", err) + rabbitmq_amqp.Info("[NewConsumer]", "consumer closed. Context", err) return } if err != nil { // An error occurred receiving the message - rabbitmq_amqp.Error("[Consumer]", "Error receiving message", err) + rabbitmq_amqp.Error("[NewConsumer]", "Error receiving message", err) return } - rabbitmq_amqp.Info("[Consumer]", "Received message", + rabbitmq_amqp.Info("[NewConsumer]", "Received message", fmt.Sprintf("%s", deliveryContext.Message().Data)) err = deliveryContext.Accept(context.Background()) @@ -115,26 +115,26 @@ func main() { return } - for i := 0; i < 10; i++ { - + for i := 0; i < 1_000; i++ { // Publish a message to the exchange publishResult, err := publisher.Publish(context.Background(), amqp.NewMessage([]byte("Hello, World!"+fmt.Sprintf("%d", i)))) if err != nil { - rabbitmq_amqp.Error("Error publishing message", err) - return + rabbitmq_amqp.Error("Error publishing message", "error", err) + time.Sleep(1 * time.Second) + continue } switch publishResult.Outcome.(type) { case *amqp.StateAccepted: - rabbitmq_amqp.Info("[Publisher]", "Message accepted", publishResult.Message.Data[0]) + rabbitmq_amqp.Info("[NewPublisher]", "Message accepted", publishResult.Message.Data[0]) break case *amqp.StateReleased: - rabbitmq_amqp.Warn("[Publisher]", "Message was not routed", publishResult.Message.Data[0]) + rabbitmq_amqp.Warn("[NewPublisher]", "Message was not routed", publishResult.Message.Data[0]) break case *amqp.StateRejected: - rabbitmq_amqp.Warn("[Publisher]", "Message rejected", publishResult.Message.Data[0]) + rabbitmq_amqp.Warn("[NewPublisher]", "Message rejected", publishResult.Message.Data[0]) stateType := publishResult.Outcome.(*amqp.StateRejected) if stateType.Error != nil { - rabbitmq_amqp.Warn("[Publisher]", "Message rejected with error: %v", stateType.Error) + rabbitmq_amqp.Warn("[NewPublisher]", "Message rejected with error: %v", stateType.Error) } break default: @@ -153,13 +153,13 @@ func main() { //Close the consumer err = consumer.Close(context.Background()) if err != nil { - rabbitmq_amqp.Error("[Consumer]", err) + rabbitmq_amqp.Error("[NewConsumer]", err) return } // Close the publisher err = publisher.Close(context.Background()) if err != nil { - rabbitmq_amqp.Error("[Publisher]", err) + rabbitmq_amqp.Error("[NewPublisher]", err) return } @@ -197,7 +197,7 @@ func main() { return } - fmt.Printf("AMQP Connection closed.\n") + fmt.Printf("AMQP connection closed.\n") // not necessary. It waits for the status change to be printed time.Sleep(100 * time.Millisecond) close(stateChanged) diff --git a/docs/examples/reliable/reliable.go b/docs/examples/reliable/reliable.go new file mode 100644 index 0000000..6710a4b --- /dev/null +++ b/docs/examples/reliable/reliable.go @@ -0,0 +1,213 @@ +package main + +import ( + "context" + "errors" + "fmt" + "github.com/Azure/go-amqp" + "github.com/rabbitmq/rabbitmq-amqp-go-client/pkg/rabbitmq_amqp" + "sync" + "sync/atomic" + "time" +) + +func main() { + queueName := "reliable-amqp10-go-queue" + var stateAccepted int32 + var stateReleased int32 + var stateRejected int32 + + var received int32 + var failed int32 + + startTime := time.Now() + go func() { + for { + time.Sleep(5 * time.Second) + total := stateAccepted + stateReleased + stateRejected + messagesPerSecond := float64(total) / time.Since(startTime).Seconds() + rabbitmq_amqp.Info("[Stats]", "sent", total, "received", received, "failed", failed, "messagesPerSecond", messagesPerSecond) + + } + }() + + rabbitmq_amqp.Info("How to deal with network disconnections") + signalBlock := sync.Cond{L: &sync.Mutex{}} + /// Create a channel to receive state change notifications + stateChanged := make(chan *rabbitmq_amqp.StateChanged, 1) + go func(ch chan *rabbitmq_amqp.StateChanged) { + for statusChanged := range ch { + rabbitmq_amqp.Info("[connection]", "Status changed", statusChanged) + switch statusChanged.To.(type) { + case *rabbitmq_amqp.StateOpen: + signalBlock.Broadcast() + } + } + }(stateChanged) + + // Open a connection to the AMQP 1.0 server + amqpConnection, err := rabbitmq_amqp.Dial(context.Background(), []string{"amqp://"}, &rabbitmq_amqp.AmqpConnOptions{ + SASLType: amqp.SASLTypeAnonymous(), + ContainerID: "reliable-amqp10-go", + RecoveryConfiguration: &rabbitmq_amqp.RecoveryConfiguration{ + ActiveRecovery: true, + BackOffReconnectInterval: 2 * time.Second, // we reduce the reconnect interval to speed up the test. The default is 5 seconds + // In production, you should avoid BackOffReconnectInterval with low values since it can cause a high number of reconnection attempts + MaxReconnectAttempts: 5, + }, + }) + if err != nil { + rabbitmq_amqp.Error("Error opening connection", err) + return + } + // Register the channel to receive status change notifications + amqpConnection.NotifyStatusChange(stateChanged) + + fmt.Printf("AMQP connection opened.\n") + // Create the management interface for the connection + // so we can declare exchanges, queues, and bindings + management := amqpConnection.Management() + + // Declare a Quorum queue + queueInfo, err := management.DeclareQueue(context.TODO(), &rabbitmq_amqp.QuorumQueueSpecification{ + Name: queueName, + }) + if err != nil { + rabbitmq_amqp.Error("Error declaring queue", err) + return + } + + consumer, err := amqpConnection.NewConsumer(context.Background(), &rabbitmq_amqp.QueueAddress{ + Queue: queueName, + }, "reliable-consumer") + if err != nil { + rabbitmq_amqp.Error("Error creating consumer", err) + return + } + + consumerContext, cancel := context.WithCancel(context.Background()) + + // Consume messages from the queue + go func(ctx context.Context) { + for { + deliveryContext, err := consumer.Receive(ctx) + if errors.Is(err, context.Canceled) { + // The consumer was closed correctly + return + } + if err != nil { + // An error occurred receiving the message + // here the consumer could be disconnected from the server due to a network error + signalBlock.L.Lock() + rabbitmq_amqp.Info("[Consumer]", "Consumer is blocked, queue", queueName, "error", err) + signalBlock.Wait() + rabbitmq_amqp.Info("[Consumer]", "Consumer is unblocked, queue", queueName) + + signalBlock.L.Unlock() + continue + } + + atomic.AddInt32(&received, 1) + err = deliveryContext.Accept(context.Background()) + if err != nil { + // same here the delivery could not be accepted due to a network error + // we wait for 2_500 ms and try again + time.Sleep(2500 * time.Millisecond) + continue + } + } + }(consumerContext) + + publisher, err := amqpConnection.NewPublisher(context.Background(), &rabbitmq_amqp.QueueAddress{ + Queue: queueName, + }, "reliable-publisher") + if err != nil { + rabbitmq_amqp.Error("Error creating publisher", err) + return + } + + wg := &sync.WaitGroup{} + for i := 0; i < 1; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 500_000; i++ { + publishResult, err := publisher.Publish(context.Background(), amqp.NewMessage([]byte("Hello, World!"+fmt.Sprintf("%d", i)))) + if err != nil { + // here you need to deal with the error. You can store the message in a local in memory/persistent storage + // then retry to send the message as soon as the connection is reestablished + + atomic.AddInt32(&failed, 1) + // block signalBlock until the connection is reestablished + signalBlock.L.Lock() + rabbitmq_amqp.Info("[Publisher]", "Publisher is blocked, queue", queueName, "error", err) + signalBlock.Wait() + rabbitmq_amqp.Info("[Publisher]", "Publisher is unblocked, queue", queueName) + signalBlock.L.Unlock() + + } else { + switch publishResult.Outcome.(type) { + case *amqp.StateAccepted: + atomic.AddInt32(&stateAccepted, 1) + break + case *amqp.StateReleased: + atomic.AddInt32(&stateReleased, 1) + break + case *amqp.StateRejected: + atomic.AddInt32(&stateRejected, 1) + break + default: + // these status are not supported. Leave it for AMQP 1.0 compatibility + // see: https://www.rabbitmq.com/docs/next/amqp#outcomes + rabbitmq_amqp.Warn("Message state: %v", publishResult.Outcome) + } + } + } + }() + } + wg.Wait() + + println("press any key to close the connection") + + var input string + _, _ = fmt.Scanln(&input) + + cancel() + //Close the consumer + err = consumer.Close(context.Background()) + if err != nil { + rabbitmq_amqp.Error("[NewConsumer]", err) + return + } + // Close the publisher + err = publisher.Close(context.Background()) + if err != nil { + rabbitmq_amqp.Error("[NewPublisher]", err) + return + } + + // Purge the queue + purged, err := management.PurgeQueue(context.TODO(), queueInfo.Name()) + if err != nil { + fmt.Printf("Error purging queue: %v\n", err) + return + } + fmt.Printf("Purged %d messages from the queue.\n", purged) + + err = management.DeleteQueue(context.TODO(), queueInfo.Name()) + if err != nil { + fmt.Printf("Error deleting queue: %v\n", err) + return + } + + err = amqpConnection.Close(context.Background()) + if err != nil { + fmt.Printf("Error closing connection: %v\n", err) + return + } + + fmt.Printf("AMQP connection closed.\n") + // not necessary. It waits for the status change to be printed + time.Sleep(100 * time.Millisecond) + close(stateChanged) +} diff --git a/rabbitmq_amqp/address.go b/pkg/rabbitmq_amqp/address.go similarity index 100% rename from rabbitmq_amqp/address.go rename to pkg/rabbitmq_amqp/address.go diff --git a/rabbitmq_amqp/address_test.go b/pkg/rabbitmq_amqp/address_test.go similarity index 100% rename from rabbitmq_amqp/address_test.go rename to pkg/rabbitmq_amqp/address_test.go diff --git a/rabbitmq_amqp/amqp_binding.go b/pkg/rabbitmq_amqp/amqp_binding.go similarity index 100% rename from rabbitmq_amqp/amqp_binding.go rename to pkg/rabbitmq_amqp/amqp_binding.go diff --git a/rabbitmq_amqp/amqp_binding_test.go b/pkg/rabbitmq_amqp/amqp_binding_test.go similarity index 100% rename from rabbitmq_amqp/amqp_binding_test.go rename to pkg/rabbitmq_amqp/amqp_binding_test.go diff --git a/pkg/rabbitmq_amqp/amqp_connection.go b/pkg/rabbitmq_amqp/amqp_connection.go new file mode 100644 index 0000000..3b7cd79 --- /dev/null +++ b/pkg/rabbitmq_amqp/amqp_connection.go @@ -0,0 +1,349 @@ +package rabbitmq_amqp + +import ( + "context" + "crypto/tls" + "fmt" + "github.com/Azure/go-amqp" + "github.com/google/uuid" + "math/rand" + "sync" + "sync/atomic" + "time" +) + +//func (c *ConnUrlHelper) UseSsl(value bool) { +// c.UseSsl = value +// if value { +// c.Scheme = "amqps" +// } else { +// c.Scheme = "amqp" +// } +//} + +type AmqpConnOptions struct { + // wrapper for amqp.ConnOptions + ContainerID string + // wrapper for amqp.ConnOptions + HostName string + // wrapper for amqp.ConnOptions + IdleTimeout time.Duration + + // wrapper for amqp.ConnOptions + MaxFrameSize uint32 + + // wrapper for amqp.ConnOptions + MaxSessions uint16 + + // wrapper for amqp.ConnOptions + Properties map[string]any + + // wrapper for amqp.ConnOptions + SASLType amqp.SASLType + + // wrapper for amqp.ConnOptions + TLSConfig *tls.Config + + // wrapper for amqp.ConnOptions + WriteTimeout time.Duration + + // RecoveryConfiguration is used to configure the recovery behavior of the connection. + // when the connection is closed unexpectedly. + RecoveryConfiguration *RecoveryConfiguration + + // copy the addresses for reconnection + addresses []string +} + +type AmqpConnection struct { + azureConnection *amqp.Conn + id string + management *AmqpManagement + lifeCycle *LifeCycle + amqpConnOptions *AmqpConnOptions + session *amqp.Session + refMap *sync.Map + entitiesTracker *entitiesTracker +} + +// NewPublisher creates a new Publisher that sends messages to the provided destination. +// The destination is a TargetAddress that can be a Queue or an Exchange with a routing key. +// See QueueAddress and ExchangeAddress for more information. +func (a *AmqpConnection) NewPublisher(ctx context.Context, destination TargetAddress, linkName string) (*Publisher, error) { + destinationAdd := "" + err := error(nil) + if destination != nil { + destinationAdd, err = destination.toAddress() + if err != nil { + return nil, err + } + err = validateAddress(destinationAdd) + if err != nil { + return nil, err + } + } + + return newPublisher(ctx, a, destinationAdd, linkName) +} + +// NewConsumer creates a new Consumer that listens to the provided destination. Destination is a QueueAddress. +func (a *AmqpConnection) NewConsumer(ctx context.Context, destination *QueueAddress, linkName string) (*Consumer, error) { + destinationAdd, err := destination.toAddress() + if err != nil { + return nil, err + } + err = validateAddress(destinationAdd) + + return newConsumer(ctx, a, destinationAdd, linkName) +} + +// Dial connect to the AMQP 1.0 server using the provided connectionSettings +// Returns a pointer to the new AmqpConnection if successful else an error. +// addresses is a list of addresses to connect to. It picks one randomly. +// It is enough that one of the addresses is reachable. +func Dial(ctx context.Context, addresses []string, connOptions *AmqpConnOptions, args ...string) (*AmqpConnection, error) { + if connOptions == nil { + connOptions = &AmqpConnOptions{ + // RabbitMQ requires SASL security layer + // to be enabled for AMQP 1.0 connections. + // So this is mandatory and default in case not defined. + SASLType: amqp.SASLTypeAnonymous(), + } + } + + if connOptions.RecoveryConfiguration == nil { + connOptions.RecoveryConfiguration = NewRecoveryConfiguration() + } + + // validate the RecoveryConfiguration options + if connOptions.RecoveryConfiguration.MaxReconnectAttempts <= 0 && connOptions.RecoveryConfiguration.ActiveRecovery { + return nil, fmt.Errorf("MaxReconnectAttempts should be greater than 0") + } + if connOptions.RecoveryConfiguration.BackOffReconnectInterval <= 1*time.Second && connOptions.RecoveryConfiguration.ActiveRecovery { + return nil, fmt.Errorf("BackOffReconnectInterval should be greater than 1 second") + } + + // create the connection + + conn := &AmqpConnection{ + management: NewAmqpManagement(), + lifeCycle: NewLifeCycle(), + amqpConnOptions: connOptions, + entitiesTracker: newEntitiesTracker(), + } + tmp := make([]string, len(addresses)) + copy(tmp, addresses) + + err := conn.open(ctx, addresses, connOptions, args...) + if err != nil { + return nil, err + } + conn.amqpConnOptions = connOptions + conn.amqpConnOptions.addresses = addresses + conn.lifeCycle.SetState(&StateOpen{}) + return conn, nil + +} + +// Open opens a connection to the AMQP 1.0 server. +// using the provided connectionSettings and the AMQPLite library. +// Setups the connection and the management interface. +func (a *AmqpConnection) open(ctx context.Context, addresses []string, connOptions *AmqpConnOptions, args ...string) error { + + amqpLiteConnOptions := &amqp.ConnOptions{ + ContainerID: connOptions.ContainerID, + HostName: connOptions.HostName, + IdleTimeout: connOptions.IdleTimeout, + MaxFrameSize: connOptions.MaxFrameSize, + MaxSessions: connOptions.MaxSessions, + Properties: connOptions.Properties, + SASLType: connOptions.SASLType, + TLSConfig: connOptions.TLSConfig, + WriteTimeout: connOptions.WriteTimeout, + } + tmp := make([]string, len(addresses)) + copy(tmp, addresses) + + // random pick and extract one address to use for connection + var azureConnection *amqp.Conn + for len(tmp) > 0 { + idx := random(len(tmp)) + addr := tmp[idx] + //connOptions.HostName is the way to set the virtual host + // so we need to pre-parse the URI to get the virtual host + // the PARSE is copied from go-amqp091 library + // the URI will be parsed is parsed again in the amqp lite library + uri, err := ParseURI(addr) + if err != nil { + return err + } + connOptions.HostName = fmt.Sprintf("vhost:%s", uri.Vhost) + // remove the index from the tmp list + tmp = append(tmp[:idx], tmp[idx+1:]...) + azureConnection, err = amqp.Dial(ctx, addr, amqpLiteConnOptions) + if err != nil { + Error("Failed to open connection", ExtractWithoutPassword(addr), err) + continue + } + Debug("Connected to", ExtractWithoutPassword(addr)) + break + } + if azureConnection == nil { + return fmt.Errorf("failed to connect to any of the provided addresses") + } + + if len(args) > 0 { + a.id = args[0] + } else { + a.id = uuid.New().String() + } + + a.azureConnection = azureConnection + var err error + a.session, err = a.azureConnection.NewSession(ctx, nil) + go func() { + select { + case <-azureConnection.Done(): + { + a.lifeCycle.SetState(&StateClosed{error: azureConnection.Err()}) + if azureConnection.Err() != nil { + Error("connection closed unexpectedly", "error", azureConnection.Err()) + a.maybeReconnect() + + return + } + Debug("connection closed successfully") + } + } + }() + + if err != nil { + return err + } + err = a.management.Open(ctx, a) + if err != nil { + // TODO close connection? + return err + } + + return nil +} +func (a *AmqpConnection) maybeReconnect() { + + if !a.amqpConnOptions.RecoveryConfiguration.ActiveRecovery { + Info("Recovery is disabled, closing connection") + return + } + a.lifeCycle.SetState(&StateReconnecting{}) + numberOfAttempts := 1 + waitTime := a.amqpConnOptions.RecoveryConfiguration.BackOffReconnectInterval + reconnected := false + for numberOfAttempts <= a.amqpConnOptions.RecoveryConfiguration.MaxReconnectAttempts { + ///wait for before reconnecting + // add some random milliseconds to the wait time to avoid thundering herd + // the random time is between 0 and 500 milliseconds + waitTime = waitTime + time.Duration(rand.Intn(500))*time.Millisecond + + Info("Waiting before reconnecting", "in", waitTime, "attempt", numberOfAttempts) + time.Sleep(waitTime) + // context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + // try to createSender + err := a.open(ctx, a.amqpConnOptions.addresses, a.amqpConnOptions) + cancel() + + if err != nil { + numberOfAttempts++ + waitTime = waitTime * 2 + Error("Failed to connection. ", "id", a.Id(), "error", err) + } else { + reconnected = true + break + } + } + + if reconnected { + var fails int32 + Info("Reconnected successfully, restarting publishers and consumers") + a.entitiesTracker.publishers.Range(func(key, value any) bool { + publisher := value.(*Publisher) + // try to createSender + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + err := publisher.createSender(ctx) + if err != nil { + atomic.AddInt32(&fails, 1) + Error("Failed to createSender publisher", "ID", publisher.Id(), "error", err) + } + cancel() + return true + }) + Info("Restarted publishers", "number of fails", fails) + fails = 0 + a.entitiesTracker.consumers.Range(func(key, value any) bool { + consumer := value.(*Consumer) + // try to createSender + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + err := consumer.createReceiver(ctx) + if err != nil { + atomic.AddInt32(&fails, 1) + Error("Failed to createReceiver consumer", "ID", consumer.Id(), "error", err) + } + cancel() + return true + }) + Info("Restarted consumers", "number of fails", fails) + + a.lifeCycle.SetState(&StateOpen{}) + } + +} + +func (a *AmqpConnection) close() { + if a.refMap != nil { + a.refMap.Delete(a.Id()) + } + a.entitiesTracker.CleanUp() +} + +/* +Close closes the connection to the AMQP 1.0 server and the management interface. +All the publishers and consumers are closed as well. +*/ +func (a *AmqpConnection) Close(ctx context.Context) error { + // the status closed (lifeCycle.SetState(&StateClosed{error: nil})) is not set here + // it is set in the connection.Done() channel + // the channel is called anyway + // see the open(...) function with a.lifeCycle.SetState(&StateClosed{error: connection.Err()}) + + err := a.management.Close(ctx) + if err != nil { + Error("Failed to close management", "error:", err) + } + err = a.azureConnection.Close() + a.close() + return err +} + +// NotifyStatusChange registers a channel to receive getState change notifications +// from the connection. +func (a *AmqpConnection) NotifyStatusChange(channel chan *StateChanged) { + a.lifeCycle.chStatusChanged = channel +} + +func (a *AmqpConnection) State() LifeCycleState { + return a.lifeCycle.State() +} + +func (a *AmqpConnection) Id() string { + return a.id +} + +// *** management section *** + +// Management returns the management interface for the connection. +// The management interface is used to declare and delete exchanges, queues, and bindings. +func (a *AmqpConnection) Management() *AmqpManagement { + return a.management +} + +//*** end management section *** diff --git a/pkg/rabbitmq_amqp/amqp_connection_recovery.go b/pkg/rabbitmq_amqp/amqp_connection_recovery.go new file mode 100644 index 0000000..3869f12 --- /dev/null +++ b/pkg/rabbitmq_amqp/amqp_connection_recovery.go @@ -0,0 +1,93 @@ +package rabbitmq_amqp + +import ( + "sync" + "time" +) + +type RecoveryConfiguration struct { + /* + ActiveRecovery Define if the recovery is activated. + If is not activated the connection will not try to createSender. + */ + ActiveRecovery bool + + /* + BackOffReconnectInterval The time to wait before trying to createSender after a connection is closed. + time will be increased exponentially with each attempt. + Default is 5 seconds, each attempt will double the time. + The minimum value is 1 second. Avoid setting a value low values since it can cause a high + number of reconnection attempts. + */ + BackOffReconnectInterval time.Duration + + /* + MaxReconnectAttempts The maximum number of reconnection attempts. + Default is 5. + The minimum value is 1. + */ + MaxReconnectAttempts int +} + +func NewRecoveryConfiguration() *RecoveryConfiguration { + return &RecoveryConfiguration{ + ActiveRecovery: true, + BackOffReconnectInterval: 5 * time.Second, + MaxReconnectAttempts: 5, + } +} + +type entitiesTracker struct { + publishers sync.Map + consumers sync.Map +} + +func newEntitiesTracker() *entitiesTracker { + return &entitiesTracker{ + publishers: sync.Map{}, + consumers: sync.Map{}, + } +} + +func (e *entitiesTracker) storeOrReplaceProducer(entity entityIdentifier) { + e.publishers.Store(entity.Id(), entity) +} + +func (e *entitiesTracker) getProducer(id string) (*Publisher, bool) { + producer, ok := e.publishers.Load(id) + if !ok { + return nil, false + } + return producer.(*Publisher), true +} + +func (e *entitiesTracker) removeProducer(entity entityIdentifier) { + e.publishers.Delete(entity.Id()) +} + +func (e *entitiesTracker) storeOrReplaceConsumer(entity entityIdentifier) { + e.consumers.Store(entity.Id(), entity) +} + +func (e *entitiesTracker) getConsumer(id string) (*Consumer, bool) { + consumer, ok := e.consumers.Load(id) + if !ok { + return nil, false + } + return consumer.(*Consumer), true +} + +func (e *entitiesTracker) removeConsumer(entity entityIdentifier) { + e.consumers.Delete(entity.Id()) +} + +func (e *entitiesTracker) CleanUp() { + e.publishers.Range(func(key, value interface{}) bool { + e.publishers.Delete(key) + return true + }) + e.consumers.Range(func(key, value interface{}) bool { + e.consumers.Delete(key) + return true + }) +} diff --git a/pkg/rabbitmq_amqp/amqp_connection_recovery_test.go b/pkg/rabbitmq_amqp/amqp_connection_recovery_test.go new file mode 100644 index 0000000..b5126f6 --- /dev/null +++ b/pkg/rabbitmq_amqp/amqp_connection_recovery_test.go @@ -0,0 +1,183 @@ +package rabbitmq_amqp + +import ( + "context" + "github.com/Azure/go-amqp" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + testhelper "github.com/rabbitmq/rabbitmq-amqp-go-client/pkg/test-helper" + "time" +) + +var _ = Describe("Recovery connection test", func() { + It("connection should reconnect producers and consumers if dropped by via REST API", func() { + /* + The test is a bit complex since it requires to drop the connection by REST API + Then wait for the connection to be reconnected. + The scope of the test is to verify that the connection is reconnected and the + producers and consumers are able to send and receive messages. + It is more like an integration test. + This kind of the tests requires time in terms of execution it has to wait for the + connection to be reconnected, so to speed up the test I aggregated the tests in one. + */ + + name := "connection should reconnect producers and consumers if dropped by via REST API" + connection, err := Dial(context.Background(), []string{"amqp://"}, &AmqpConnOptions{ + SASLType: amqp.SASLTypeAnonymous(), + ContainerID: name, + // reduced the reconnect interval to speed up the test + RecoveryConfiguration: &RecoveryConfiguration{ + ActiveRecovery: true, + BackOffReconnectInterval: 2 * time.Second, + MaxReconnectAttempts: 5, + }, + }) + Expect(err).To(BeNil()) + ch := make(chan *StateChanged, 1) + connection.NotifyStatusChange(ch) + + qName := generateName(name) + queueInfo, err := connection.Management().DeclareQueue(context.Background(), &QuorumQueueSpecification{ + Name: qName, + }) + Expect(err).To(BeNil()) + Expect(queueInfo).NotTo(BeNil()) + + consumer, err := connection.NewConsumer(context.Background(), &QueueAddress{ + Queue: qName, + }, "test") + + publisher, err := connection.NewPublisher(context.Background(), &QueueAddress{ + Queue: qName, + }, "test") + + Expect(err).To(BeNil()) + Expect(publisher).NotTo(BeNil()) + for i := 0; i < 5; i++ { + publishResult, err := publisher.Publish(context.Background(), amqp.NewMessage([]byte("Hello"))) + Expect(err).To(BeNil()) + Expect(publishResult).NotTo(BeNil()) + Expect(publishResult.Outcome).To(Equal(&amqp.StateAccepted{})) + } + + Eventually(func() bool { + err := testhelper.DropConnectionContainerID(name) + return err == nil + }).WithTimeout(5 * time.Second).WithPolling(400 * time.Millisecond).Should(BeTrue()) + st1 := <-ch + Expect(st1.From).To(Equal(&StateOpen{})) + Expect(st1.To).To(BeAssignableToTypeOf(&StateClosed{})) + /// Closed state should have an error + // Since it is forced closed by the REST API + err = st1.To.(*StateClosed).GetError() + Expect(err).NotTo(BeNil()) + Expect(err.Error()).To(ContainSubstring("Connection forced")) + + time.Sleep(1 * time.Second) + Eventually(func() bool { + conn, err := testhelper.GetConnectionByContainerID(name) + return err == nil && conn != nil + }).WithTimeout(5 * time.Second).WithPolling(400 * time.Millisecond).Should(BeTrue()) + st2 := <-ch + Expect(st2.From).To(BeAssignableToTypeOf(&StateClosed{})) + Expect(st2.To).To(Equal(&StateReconnecting{})) + + st3 := <-ch + Expect(st3.From).To(BeAssignableToTypeOf(&StateReconnecting{})) + Expect(st3.To).To(Equal(&StateOpen{})) + + for i := 0; i < 5; i++ { + publishResult, err := publisher.Publish(context.Background(), amqp.NewMessage([]byte("Hello"))) + Expect(err).To(BeNil()) + Expect(publishResult).NotTo(BeNil()) + Expect(publishResult.Outcome).To(Equal(&amqp.StateAccepted{})) + } + + /// after the connection is reconnected the consumer should be able to receive the messages + for i := 0; i < 10; i++ { + deliveryContext, err := consumer.Receive(context.Background()) + Expect(err).To(BeNil()) + Expect(deliveryContext).NotTo(BeNil()) + } + + Expect(connection.Management().DeleteQueue(context.Background(), qName)).To(BeNil()) + + err = connection.Close(context.Background()) + Expect(err).To(BeNil()) + st4 := <-ch + Expect(st4.From).To(Equal(&StateOpen{})) + Expect(st4.To).To(BeAssignableToTypeOf(&StateClosed{})) + err = st4.To.(*StateClosed).GetError() + // the flow status should be: + // from open to closed (with error) + // from closed to reconnecting + // from reconnecting to open + // from open to closed (without error) + Expect(err).To(BeNil()) + }) + + It("connection should not reconnect producers and consumers if the auto-recovery is disabled", func() { + name := "connection should reconnect producers and consumers if dropped by via REST API" + connection, err := Dial(context.Background(), []string{"amqp://"}, &AmqpConnOptions{ + SASLType: amqp.SASLTypeAnonymous(), + ContainerID: name, + // reduced the reconnect interval to speed up the test + RecoveryConfiguration: &RecoveryConfiguration{ + ActiveRecovery: false, // disabled + }, + }) + Expect(err).To(BeNil()) + ch := make(chan *StateChanged, 1) + connection.NotifyStatusChange(ch) + + Eventually(func() bool { + err := testhelper.DropConnectionContainerID(name) + return err == nil + }).WithTimeout(5 * time.Second).WithPolling(400 * time.Millisecond).Should(BeTrue()) + st1 := <-ch + Expect(st1.From).To(Equal(&StateOpen{})) + Expect(st1.To).To(BeAssignableToTypeOf(&StateClosed{})) + + err = st1.To.(*StateClosed).GetError() + Expect(err).NotTo(BeNil()) + Expect(err.Error()).To(ContainSubstring("Connection forced")) + + time.Sleep(1 * time.Second) + + // the connection should not be reconnected + Consistently(func() bool { + conn, err := testhelper.GetConnectionByContainerID(name) + return err == nil && conn != nil + }).WithTimeout(5 * time.Second).WithPolling(400 * time.Millisecond).Should(BeFalse()) + + err = connection.Close(context.Background()) + Expect(err).NotTo(BeNil()) + }) + + It("validate the Recovery connection parameters", func() { + + _, err := Dial(context.Background(), []string{"amqp://"}, &AmqpConnOptions{ + SASLType: amqp.SASLTypeAnonymous(), + // reduced the reconnect interval to speed up the test + RecoveryConfiguration: &RecoveryConfiguration{ + ActiveRecovery: true, + BackOffReconnectInterval: 500 * time.Millisecond, + MaxReconnectAttempts: 5, + }, + }) + Expect(err).NotTo(BeNil()) + Expect(err.Error()).To(ContainSubstring("BackOffReconnectInterval should be greater than")) + + _, err = Dial(context.Background(), []string{"amqp://"}, &AmqpConnOptions{ + SASLType: amqp.SASLTypeAnonymous(), + RecoveryConfiguration: &RecoveryConfiguration{ + ActiveRecovery: true, + MaxReconnectAttempts: 0, + }, + }) + + Expect(err).NotTo(BeNil()) + Expect(err.Error()).To(ContainSubstring("MaxReconnectAttempts should be greater than")) + }) + +}) diff --git a/rabbitmq_amqp/amqp_connection_test.go b/pkg/rabbitmq_amqp/amqp_connection_test.go similarity index 77% rename from rabbitmq_amqp/amqp_connection_test.go rename to pkg/rabbitmq_amqp/amqp_connection_test.go index fc24327..acfd54d 100644 --- a/rabbitmq_amqp/amqp_connection_test.go +++ b/pkg/rabbitmq_amqp/amqp_connection_test.go @@ -8,19 +8,19 @@ import ( "time" ) -var _ = Describe("AMQP Connection Test", func() { - It("AMQP SASLTypeAnonymous Connection should succeed", func() { +var _ = Describe("AMQP connection Test", func() { + It("AMQP SASLTypeAnonymous connection should succeed", func() { - connection, err := Dial(context.Background(), []string{"amqp://"}, &amqp.ConnOptions{ + connection, err := Dial(context.Background(), []string{"amqp://"}, &AmqpConnOptions{ SASLType: amqp.SASLTypeAnonymous()}) Expect(err).To(BeNil()) err = connection.Close(context.Background()) Expect(err).To(BeNil()) }) - It("AMQP SASLTypePlain Connection should succeed", func() { + It("AMQP SASLTypePlain connection should succeed", func() { - connection, err := Dial(context.Background(), []string{"amqp://"}, &amqp.ConnOptions{ + connection, err := Dial(context.Background(), []string{"amqp://"}, &AmqpConnOptions{ SASLType: amqp.SASLTypePlain("guest", "guest")}) Expect(err).To(BeNil()) @@ -28,35 +28,35 @@ var _ = Describe("AMQP Connection Test", func() { Expect(err).To(BeNil()) }) - It("AMQP Connection connect to the one correct uri and fails the others", func() { + It("AMQP connection connect to the one correct uri and fails the others", func() { conn, err := Dial(context.Background(), []string{"amqp://localhost:1234", "amqp://nohost:555", "amqp://"}, nil) Expect(err).To(BeNil()) Expect(conn.Close(context.Background())) }) - It("AMQP Connection should fail due of wrong Port", func() { + It("AMQP connection should fail due of wrong Port", func() { _, err := Dial(context.Background(), []string{"amqp://localhost:1234"}, nil) Expect(err).NotTo(BeNil()) }) - It("AMQP Connection should fail due of wrong Host", func() { + It("AMQP connection should fail due of wrong Host", func() { _, err := Dial(context.Background(), []string{"amqp://wrong_host:5672"}, nil) Expect(err).NotTo(BeNil()) }) - It("AMQP Connection should fails with all the wrong uris", func() { + It("AMQP connection should fails with all the wrong uris", func() { _, err := Dial(context.Background(), []string{"amqp://localhost:1234", "amqp://nohost:555", "amqp://nono"}, nil) Expect(err).NotTo(BeNil()) }) - It("AMQP Connection should fail due to context cancellation", func() { + It("AMQP connection should fail due to context cancellation", func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) cancel() _, err := Dial(ctx, []string{"amqp://"}, nil) Expect(err).NotTo(BeNil()) }) - It("AMQP Connection should receive events", func() { + It("AMQP connection should receive events", func() { ch := make(chan *StateChanged, 1) connection, err := Dial(context.Background(), []string{"amqp://"}, nil) Expect(err).To(BeNil()) @@ -70,7 +70,7 @@ var _ = Describe("AMQP Connection Test", func() { Expect(recv.To).To(Equal(&StateClosed{})) }) - //It("AMQP TLS Connection should success with SASLTypeAnonymous ", func() { + //It("AMQP TLS connection should success with SASLTypeAnonymous ", func() { // amqpConnection := NewAmqpConnection() // Expect(amqpConnection).NotTo(BeNil()) // Expect(amqpConnection).To(BeAssignableToTypeOf(&AmqpConnection{})) diff --git a/rabbitmq_amqp/amqp_consumer.go b/pkg/rabbitmq_amqp/amqp_consumer.go similarity index 59% rename from rabbitmq_amqp/amqp_consumer.go rename to pkg/rabbitmq_amqp/amqp_consumer.go index 6311919..22f702c 100644 --- a/rabbitmq_amqp/amqp_consumer.go +++ b/pkg/rabbitmq_amqp/amqp_consumer.go @@ -2,7 +2,10 @@ package rabbitmq_amqp import ( "context" + "fmt" "github.com/Azure/go-amqp" + "github.com/google/uuid" + "sync/atomic" ) type DeliveryContext struct { @@ -28,8 +31,8 @@ func (dc *DeliveryContext) DiscardWithAnnotations(ctx context.Context, annotatio } // copy the rabbitmq annotations to amqp annotations destination := make(amqp.Annotations) - for key, value := range annotations { - destination[key] = value + for keyA, value := range annotations { + destination[keyA] = value } @@ -62,21 +65,49 @@ func (dc *DeliveryContext) RequeueWithAnnotations(ctx context.Context, annotatio } type Consumer struct { - receiver *amqp.Receiver + receiver atomic.Pointer[amqp.Receiver] + connection *AmqpConnection + linkName string + destinationAdd string + id string +} + +func (c *Consumer) Id() string { + return c.id } -func newConsumer(receiver *amqp.Receiver) *Consumer { - return &Consumer{receiver: receiver} +func newConsumer(ctx context.Context, connection *AmqpConnection, destinationAdd string, linkName string, args ...string) (*Consumer, error) { + id := fmt.Sprintf("consumer-%s", uuid.New().String()) + if len(args) > 0 { + id = args[0] + } + r := &Consumer{connection: connection, linkName: linkName, destinationAdd: destinationAdd, id: id} + connection.entitiesTracker.storeOrReplaceConsumer(r) + err := r.createReceiver(ctx) + if err != nil { + return nil, err + } + return r, nil +} + +func (c *Consumer) createReceiver(ctx context.Context) error { + receiver, err := c.connection.session.NewReceiver(ctx, c.destinationAdd, createReceiverLinkOptions(c.destinationAdd, c.linkName, AtLeastOnce)) + if err != nil { + return err + } + + c.receiver.Swap(receiver) + return nil } func (c *Consumer) Receive(ctx context.Context) (*DeliveryContext, error) { - msg, err := c.receiver.Receive(ctx, nil) + msg, err := c.receiver.Load().Receive(ctx, nil) if err != nil { return nil, err } - return &DeliveryContext{receiver: c.receiver, message: msg}, nil + return &DeliveryContext{receiver: c.receiver.Load(), message: msg}, nil } func (c *Consumer) Close(ctx context.Context) error { - return c.receiver.Close(ctx) + return c.receiver.Load().Close(ctx) } diff --git a/rabbitmq_amqp/amqp_consumer_test.go b/pkg/rabbitmq_amqp/amqp_consumer_test.go similarity index 100% rename from rabbitmq_amqp/amqp_consumer_test.go rename to pkg/rabbitmq_amqp/amqp_consumer_test.go diff --git a/pkg/rabbitmq_amqp/amqp_environment.go b/pkg/rabbitmq_amqp/amqp_environment.go new file mode 100644 index 0000000..aac04b7 --- /dev/null +++ b/pkg/rabbitmq_amqp/amqp_environment.go @@ -0,0 +1,67 @@ +package rabbitmq_amqp + +import ( + "context" + "fmt" + "sync" +) + +type Environment struct { + connections sync.Map + addresses []string + connOptions *AmqpConnOptions +} + +func NewEnvironment(addresses []string, connOptions *AmqpConnOptions) *Environment { + return &Environment{ + connections: sync.Map{}, + addresses: addresses, + connOptions: connOptions, + } +} + +// NewConnection get a new connection from the environment. +// If the connection id is provided, it will be used as the connection id. +// If the connection id is not provided, a new connection id will be generated. +// The connection id is unique in the environment. +// The Environment will keep track of the connection and close it when the environment is closed. +func (e *Environment) NewConnection(ctx context.Context, args ...string) (*AmqpConnection, error) { + if len(args) > 0 && len(args[0]) > 0 { + // check if connection already exists + if _, ok := e.connections.Load(args[0]); ok { + return nil, fmt.Errorf("connection with id %s already exists", args[0]) + } + } + + connection, err := Dial(ctx, e.addresses, e.connOptions, args...) + if err != nil { + return nil, err + } + e.connections.Store(connection.Id(), connection) + connection.refMap = &e.connections + return connection, nil +} + +// Connections gets the active connections in the environment + +func (e *Environment) Connections() []*AmqpConnection { + connections := make([]*AmqpConnection, 0) + e.connections.Range(func(key, value interface{}) bool { + connections = append(connections, value.(*AmqpConnection)) + return true + }) + return connections +} + +// CloseConnections closes all the connections in the environment with all the publishers and consumers. +func (e *Environment) CloseConnections(ctx context.Context) error { + var err error + e.connections.Range(func(key, value any) bool { + connection := value.(*AmqpConnection) + if cerr := connection.Close(ctx); cerr != nil { + err = cerr + } + return true + }) + return err +} diff --git a/pkg/rabbitmq_amqp/amqp_environment_test.go b/pkg/rabbitmq_amqp/amqp_environment_test.go new file mode 100644 index 0000000..0e102a8 --- /dev/null +++ b/pkg/rabbitmq_amqp/amqp_environment_test.go @@ -0,0 +1,57 @@ +package rabbitmq_amqp + +import ( + "context" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("AMQP Environment Test", func() { + It("AMQP Environment connection should succeed", func() { + env := NewEnvironment([]string{"amqp://"}, nil) + Expect(env).NotTo(BeNil()) + Expect(env.Connections()).NotTo(BeNil()) + Expect(len(env.Connections())).To(Equal(0)) + + connection, err := env.NewConnection(context.Background()) + Expect(err).To(BeNil()) + Expect(connection).NotTo(BeNil()) + Expect(len(env.Connections())).To(Equal(1)) + Expect(connection.Close(context.Background())).To(BeNil()) + Expect(len(env.Connections())).To(Equal(0)) + }) + + It("AMQP Environment CloseConnections should remove all the elements form the list", func() { + env := NewEnvironment([]string{"amqp://"}, nil) + Expect(env).NotTo(BeNil()) + Expect(env.Connections()).NotTo(BeNil()) + Expect(len(env.Connections())).To(Equal(0)) + + connection, err := env.NewConnection(context.Background()) + Expect(err).To(BeNil()) + Expect(connection).NotTo(BeNil()) + Expect(len(env.Connections())).To(Equal(1)) + + Expect(env.CloseConnections(context.Background())).To(BeNil()) + Expect(len(env.Connections())).To(Equal(0)) + }) + + It("AMQP Environment connection ID should be unique", func() { + env := NewEnvironment([]string{"amqp://"}, nil) + Expect(env).NotTo(BeNil()) + Expect(env.Connections()).NotTo(BeNil()) + Expect(len(env.Connections())).To(Equal(0)) + connection, err := env.NewConnection(context.Background(), "myConnectionId") + Expect(err).To(BeNil()) + Expect(connection).NotTo(BeNil()) + Expect(len(env.Connections())).To(Equal(1)) + connectionShouldBeNil, err := env.NewConnection(context.Background(), "myConnectionId") + Expect(err).NotTo(BeNil()) + Expect(err.Error()).To(ContainSubstring("connection with id myConnectionId already exists")) + Expect(connectionShouldBeNil).To(BeNil()) + Expect(len(env.Connections())).To(Equal(1)) + Expect(connection.Close(context.Background())).To(BeNil()) + Expect(len(env.Connections())).To(Equal(0)) + + }) +}) diff --git a/rabbitmq_amqp/amqp_exchange.go b/pkg/rabbitmq_amqp/amqp_exchange.go similarity index 100% rename from rabbitmq_amqp/amqp_exchange.go rename to pkg/rabbitmq_amqp/amqp_exchange.go diff --git a/rabbitmq_amqp/amqp_exchange_test.go b/pkg/rabbitmq_amqp/amqp_exchange_test.go similarity index 100% rename from rabbitmq_amqp/amqp_exchange_test.go rename to pkg/rabbitmq_amqp/amqp_exchange_test.go diff --git a/rabbitmq_amqp/amqp_management.go b/pkg/rabbitmq_amqp/amqp_management.go similarity index 87% rename from rabbitmq_amqp/amqp_management.go rename to pkg/rabbitmq_amqp/amqp_management.go index 8696fec..39aa42d 100644 --- a/rabbitmq_amqp/amqp_management.go +++ b/pkg/rabbitmq_amqp/amqp_management.go @@ -13,6 +13,10 @@ import ( var ErrPreconditionFailed = errors.New("precondition Failed") var ErrDoesNotExist = errors.New("does not exist") +/* +AmqpManagement is the interface to the RabbitMQ /management endpoint +The management interface is used to declare/delete exchanges, queues, and bindings +*/ type AmqpManagement struct { session *amqp.Session sender *amqp.Sender @@ -28,34 +32,28 @@ func NewAmqpManagement() *AmqpManagement { } func (a *AmqpManagement) ensureReceiverLink(ctx context.Context) error { - if a.receiver == nil { - opts := createReceiverLinkOptions(managementNodeAddress, linkPairName, AtMostOnce) - receiver, err := a.session.NewReceiver(ctx, managementNodeAddress, opts) - if err != nil { - return err - } - a.receiver = receiver - return nil + opts := createReceiverLinkOptions(managementNodeAddress, linkPairName, AtMostOnce) + receiver, err := a.session.NewReceiver(ctx, managementNodeAddress, opts) + if err != nil { + return err } + a.receiver = receiver return nil } func (a *AmqpManagement) ensureSenderLink(ctx context.Context) error { - if a.sender == nil { - sender, err := a.session.NewSender(ctx, managementNodeAddress, - createSenderLinkOptions(managementNodeAddress, linkPairName, AtMostOnce)) - if err != nil { - return err - } - - a.sender = sender - return nil + sender, err := a.session.NewSender(ctx, managementNodeAddress, + createSenderLinkOptions(managementNodeAddress, linkPairName, AtMostOnce)) + if err != nil { + return err } + + a.sender = sender return nil } func (a *AmqpManagement) Open(ctx context.Context, connection *AmqpConnection) error { - session, err := connection.Connection.NewSession(ctx, nil) + session, err := connection.azureConnection.NewSession(ctx, nil) if err != nil { return err } @@ -89,6 +87,11 @@ func (a *AmqpManagement) Close(ctx context.Context) error { return err } +/* +Request sends a request to the /management endpoint. +It is a generic method that can be used to send any request to the management endpoint. +In most of the cases you don't need to use this method directly, instead use the standard methods +*/ func (a *AmqpManagement) Request(ctx context.Context, body any, path string, method string, expectedResponseCodes []int) (map[string]any, error) { return a.request(ctx, uuid.New().String(), body, path, method, expectedResponseCodes) diff --git a/rabbitmq_amqp/amqp_management_test.go b/pkg/rabbitmq_amqp/amqp_management_test.go similarity index 92% rename from rabbitmq_amqp/amqp_management_test.go rename to pkg/rabbitmq_amqp/amqp_management_test.go index a00c430..b05ccde 100644 --- a/rabbitmq_amqp/amqp_management_test.go +++ b/pkg/rabbitmq_amqp/amqp_management_test.go @@ -22,8 +22,13 @@ var _ = Describe("Management tests", func() { }) It("AMQP Management should receive events", func() { - ch := make(chan *StateChanged, 1) - connection, err := Dial(context.Background(), []string{"amqp://"}, nil) + ch := make(chan *StateChanged, 2) + connection, err := Dial(context.Background(), []string{"amqp://"}, &AmqpConnOptions{ + SASLType: amqp.SASLTypeAnonymous(), + RecoveryConfiguration: &RecoveryConfiguration{ + ActiveRecovery: false, + }, + }) Expect(err).To(BeNil()) connection.NotifyStatusChange(ch) err = connection.Close(context.Background()) diff --git a/rabbitmq_amqp/amqp_publisher.go b/pkg/rabbitmq_amqp/amqp_publisher.go similarity index 65% rename from rabbitmq_amqp/amqp_publisher.go rename to pkg/rabbitmq_amqp/amqp_publisher.go index 86c7e96..de4b38c 100644 --- a/rabbitmq_amqp/amqp_publisher.go +++ b/pkg/rabbitmq_amqp/amqp_publisher.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "github.com/Azure/go-amqp" + "github.com/google/uuid" + "sync/atomic" ) type PublishResult struct { @@ -13,12 +15,39 @@ type PublishResult struct { // Publisher is a publisher that sends messages to a specific destination address. type Publisher struct { - sender *amqp.Sender - staticTargetAddress bool + sender atomic.Pointer[amqp.Sender] + connection *AmqpConnection + linkName string + destinationAdd string + id string } -func newPublisher(sender *amqp.Sender, staticTargetAddress bool) *Publisher { - return &Publisher{sender: sender, staticTargetAddress: staticTargetAddress} +func (m *Publisher) Id() string { + return m.id +} + +func newPublisher(ctx context.Context, connection *AmqpConnection, destinationAdd string, linkName string, args ...string) (*Publisher, error) { + id := fmt.Sprintf("publisher-%s", uuid.New().String()) + if len(args) > 0 { + id = args[0] + } + + r := &Publisher{connection: connection, linkName: linkName, destinationAdd: destinationAdd, id: id} + connection.entitiesTracker.storeOrReplaceProducer(r) + err := r.createSender(ctx) + if err != nil { + return nil, err + } + return r, nil +} + +func (m *Publisher) createSender(ctx context.Context) error { + sender, err := m.connection.session.NewSender(ctx, m.destinationAdd, createSenderLinkOptions(m.destinationAdd, m.linkName, AtLeastOnce)) + if err != nil { + return err + } + m.sender.Swap(sender) + return nil } /* @@ -58,7 +87,7 @@ Create a new publisher that sends messages based on message destination address: */ func (m *Publisher) Publish(ctx context.Context, message *amqp.Message) (*PublishResult, error) { - if !m.staticTargetAddress { + if m.destinationAdd == "" { if message.Properties == nil || message.Properties.To == nil { return nil, fmt.Errorf("message properties TO is required to send a message to a dynamic target address") } @@ -68,7 +97,7 @@ func (m *Publisher) Publish(ctx context.Context, message *amqp.Message) (*Publis return nil, err } } - r, err := m.sender.SendWithReceipt(ctx, message, nil) + r, err := m.sender.Load().SendWithReceipt(ctx, message, nil) if err != nil { return nil, err } @@ -76,14 +105,14 @@ func (m *Publisher) Publish(ctx context.Context, message *amqp.Message) (*Publis if err != nil { return nil, err } - publishResult := &PublishResult{ + return &PublishResult{ Message: message, Outcome: state, - } - return publishResult, err + }, err } // Close closes the publisher. func (m *Publisher) Close(ctx context.Context) error { - return m.sender.Close(ctx) + m.connection.entitiesTracker.removeProducer(m) + return m.sender.Load().Close(ctx) } diff --git a/rabbitmq_amqp/amqp_publisher_test.go b/pkg/rabbitmq_amqp/amqp_publisher_test.go similarity index 93% rename from rabbitmq_amqp/amqp_publisher_test.go rename to pkg/rabbitmq_amqp/amqp_publisher_test.go index fc02455..d45b6d5 100644 --- a/rabbitmq_amqp/amqp_publisher_test.go +++ b/pkg/rabbitmq_amqp/amqp_publisher_test.go @@ -95,14 +95,14 @@ var _ = Describe("AMQP publisher ", func() { Expect(connection.Close(context.Background())) }) - It("Multi Targets Publisher should fail with StateReleased when the destination does not exist", func() { + It("Multi Targets NewPublisher should fail with StateReleased when the destination does not exist", func() { connection, err := Dial(context.Background(), []string{"amqp://"}, nil) Expect(err).To(BeNil()) Expect(connection).NotTo(BeNil()) publisher, err := connection.NewPublisher(context.Background(), nil, "test") Expect(err).To(BeNil()) Expect(publisher).NotTo(BeNil()) - qName := generateNameWithDateTime("Targets Publisher should fail when the destination does not exist") + qName := generateNameWithDateTime("Targets NewPublisher should fail when the destination does not exist") msg := amqp.NewMessage([]byte("hello")) Expect(MessageToAddressHelper(msg, &QueueAddress{Queue: qName})).To(BeNil()) @@ -113,7 +113,7 @@ var _ = Describe("AMQP publisher ", func() { Expect(connection.Close(context.Background())).To(BeNil()) }) - It("Multi Targets Publisher should success with StateReceived when the destination exists", func() { + It("Multi Targets NewPublisher should success with StateReceived when the destination exists", func() { connection, err := Dial(context.Background(), []string{"amqp://"}, nil) Expect(err).To(BeNil()) Expect(connection).NotTo(BeNil()) @@ -121,7 +121,7 @@ var _ = Describe("AMQP publisher ", func() { publisher, err := connection.NewPublisher(context.Background(), nil, "test") Expect(err).To(BeNil()) Expect(publisher).NotTo(BeNil()) - name := generateNameWithDateTime("Targets Publisher should success with StateReceived when the destination exists") + name := generateNameWithDateTime("Targets NewPublisher should success with StateReceived when the destination exists") _, err = connection.Management().DeclareQueue(context.Background(), &QuorumQueueSpecification{ Name: name, }) @@ -167,7 +167,7 @@ var _ = Describe("AMQP publisher ", func() { Expect(connection.Close(context.Background())).To(BeNil()) }) - It("Multi Targets Publisher should fail it TO is not set or not valid", func() { + It("Multi Targets NewPublisher should fail it TO is not set or not valid", func() { connection, err := Dial(context.Background(), []string{"amqp://"}, nil) Expect(err).To(BeNil()) Expect(connection).NotTo(BeNil()) diff --git a/rabbitmq_amqp/amqp_queue.go b/pkg/rabbitmq_amqp/amqp_queue.go similarity index 100% rename from rabbitmq_amqp/amqp_queue.go rename to pkg/rabbitmq_amqp/amqp_queue.go diff --git a/rabbitmq_amqp/amqp_queue_test.go b/pkg/rabbitmq_amqp/amqp_queue_test.go similarity index 100% rename from rabbitmq_amqp/amqp_queue_test.go rename to pkg/rabbitmq_amqp/amqp_queue_test.go diff --git a/rabbitmq_amqp/amqp_utils.go b/pkg/rabbitmq_amqp/amqp_utils.go similarity index 100% rename from rabbitmq_amqp/amqp_utils.go rename to pkg/rabbitmq_amqp/amqp_utils.go diff --git a/rabbitmq_amqp/common.go b/pkg/rabbitmq_amqp/common.go similarity index 100% rename from rabbitmq_amqp/common.go rename to pkg/rabbitmq_amqp/common.go diff --git a/rabbitmq_amqp/converters.go b/pkg/rabbitmq_amqp/converters.go similarity index 100% rename from rabbitmq_amqp/converters.go rename to pkg/rabbitmq_amqp/converters.go diff --git a/rabbitmq_amqp/converters_test.go b/pkg/rabbitmq_amqp/converters_test.go similarity index 100% rename from rabbitmq_amqp/converters_test.go rename to pkg/rabbitmq_amqp/converters_test.go diff --git a/rabbitmq_amqp/entities.go b/pkg/rabbitmq_amqp/entities.go similarity index 94% rename from rabbitmq_amqp/entities.go rename to pkg/rabbitmq_amqp/entities.go index 197d0b2..b7dfde8 100644 --- a/rabbitmq_amqp/entities.go +++ b/pkg/rabbitmq_amqp/entities.go @@ -1,5 +1,9 @@ package rabbitmq_amqp +type entityIdentifier interface { + Id() string +} + type TQueueType string const ( @@ -16,6 +20,9 @@ func (e QueueType) String() string { return string(e.Type) } +/* +QueueSpecification represents the specification of a queue +*/ type QueueSpecification interface { name() string isAutoDelete() bool @@ -24,8 +31,6 @@ type QueueSpecification interface { buildArguments() map[string]any } -// QuorumQueueSpecification represents the specification of the quorum queue - type OverflowStrategy interface { overflowStrategy() string } @@ -69,6 +74,10 @@ func (r *ClientLocalLeaderLocator) leaderLocator() string { return "client-local" } +/* +QuorumQueueSpecification represents the specification of the quorum queue +*/ + type QuorumQueueSpecification struct { Name string AutoExpire int64 @@ -150,7 +159,9 @@ func (q *QuorumQueueSpecification) buildArguments() map[string]any { return result } -// ClassicQueueSpecification represents the specification of the classic queue +/* +ClassicQueueSpecification represents the specification of the classic queue +*/ type ClassicQueueSpecification struct { Name string IsAutoDelete bool @@ -231,6 +242,11 @@ func (q *ClassicQueueSpecification) buildArguments() map[string]any { return result } +/* +AutoGeneratedQueueSpecification represents the specification of the auto-generated queue. +It is a classic queue with auto-generated name. +It is useful in context like RPC or when you need a temporary queue. +*/ type AutoGeneratedQueueSpecification struct { IsAutoDelete bool IsExclusive bool diff --git a/rabbitmq_amqp/life_cycle.go b/pkg/rabbitmq_amqp/life_cycle.go similarity index 84% rename from rabbitmq_amqp/life_cycle.go rename to pkg/rabbitmq_amqp/life_cycle.go index 6de8c51..3c47e9d 100644 --- a/rabbitmq_amqp/life_cycle.go +++ b/pkg/rabbitmq_amqp/life_cycle.go @@ -31,6 +31,11 @@ func (c *StateClosing) getState() int { } type StateClosed struct { + error error +} + +func (c *StateClosed) GetError() error { + return c.error } func (c *StateClosed) getState() int { @@ -65,7 +70,18 @@ type StateChanged struct { } func (s StateChanged) String() string { + switch s.From.(type) { + case *StateClosed: + + } + + switch s.To.(type) { + case *StateClosed: + return fmt.Sprintf("From: %s, To: %s, Error: %s", statusToString(s.From), statusToString(s.To), s.To.(*StateClosed).error) + + } return fmt.Sprintf("From: %s, To: %s", statusToString(s.From), statusToString(s.To)) + } type LifeCycle struct { @@ -100,6 +116,7 @@ func (l *LifeCycle) SetState(value LifeCycleState) { if l.chStatusChanged == nil { return } + l.chStatusChanged <- &StateChanged{ From: oldState, To: value, diff --git a/rabbitmq_amqp/log.go b/pkg/rabbitmq_amqp/log.go similarity index 100% rename from rabbitmq_amqp/log.go rename to pkg/rabbitmq_amqp/log.go diff --git a/rabbitmq_amqp/pkg_suite_test.go b/pkg/rabbitmq_amqp/pkg_suite_test.go similarity index 100% rename from rabbitmq_amqp/pkg_suite_test.go rename to pkg/rabbitmq_amqp/pkg_suite_test.go diff --git a/rabbitmq_amqp/test_utils.go b/pkg/rabbitmq_amqp/test_utils.go similarity index 100% rename from rabbitmq_amqp/test_utils.go rename to pkg/rabbitmq_amqp/test_utils.go diff --git a/rabbitmq_amqp/uri.go b/pkg/rabbitmq_amqp/uri.go similarity index 100% rename from rabbitmq_amqp/uri.go rename to pkg/rabbitmq_amqp/uri.go diff --git a/rabbitmq_amqp/uri_test.go b/pkg/rabbitmq_amqp/uri_test.go similarity index 100% rename from rabbitmq_amqp/uri_test.go rename to pkg/rabbitmq_amqp/uri_test.go diff --git a/pkg/test-helper/http_utils.go b/pkg/test-helper/http_utils.go new file mode 100644 index 0000000..2d8b5aa --- /dev/null +++ b/pkg/test-helper/http_utils.go @@ -0,0 +1,115 @@ +package test_helper + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "strconv" +) + +type Connection struct { + Name string `json:"name"` + ContainerId string `json:"container_id"` +} + +func Connections() ([]Connection, error) { + bodyString, err := httpGet("http://localhost:15672/api/connections/", "guest", "guest") + if err != nil { + return nil, err + } + + var data []Connection + err = json.Unmarshal([]byte(bodyString), &data) + if err != nil { + return nil, err + } + return data, nil +} + +func GetConnectionByContainerID(Id string) (*Connection, error) { + connections, err := Connections() + if err != nil { + return nil, err + } + for _, conn := range connections { + if conn.ContainerId == Id { + return &conn, nil + } + } + + return nil, errors.New("connection not found") +} + +func DropConnectionContainerID(Id string) error { + connections, err := Connections() + if err != nil { + return err + } + connectionToDrop := "" + for _, conn := range connections { + if conn.ContainerId == Id { + connectionToDrop = conn.Name + break + } + } + + if connectionToDrop == "" { + return errors.New("connection not found") + } + + err = DropConnection(connectionToDrop, "15672") + if err != nil { + return err + } + + return nil +} + +func DropConnection(name string, port string) error { + _, err := httpDelete("http://localhost:"+port+"/api/connections/"+name, "guest", "guest") + if err != nil { + return err + } + + return nil +} +func httpGet(url, username, password string) (string, error) { + return baseCall(url, username, password, "GET") +} + +func httpDelete(url, username, password string) (string, error) { + return baseCall(url, username, password, "DELETE") +} + +func baseCall(url, username, password string, method string) (string, error) { + var client http.Client + req, err := http.NewRequest(method, url, nil) + if err != nil { + return "", err + } + req.SetBasicAuth(username, password) + + resp, err3 := client.Do(req) + + if err3 != nil { + return "", err3 + } + + defer resp.Body.Close() + + if resp.StatusCode == 200 { // OK + bodyBytes, err2 := io.ReadAll(resp.Body) + if err2 != nil { + return "", err2 + } + return string(bodyBytes), nil + } + + if resp.StatusCode == 204 { // No Content + return "", nil + } + + return "", errors.New(strconv.Itoa(resp.StatusCode)) + +} diff --git a/rabbitmq_amqp/amqp_connection.go b/rabbitmq_amqp/amqp_connection.go deleted file mode 100644 index a7e0ff3..0000000 --- a/rabbitmq_amqp/amqp_connection.go +++ /dev/null @@ -1,168 +0,0 @@ -package rabbitmq_amqp - -import ( - "context" - "fmt" - "github.com/Azure/go-amqp" -) - -//func (c *ConnUrlHelper) UseSsl(value bool) { -// c.UseSsl = value -// if value { -// c.Scheme = "amqps" -// } else { -// c.Scheme = "amqp" -// } -//} - -type AmqpConnection struct { - Connection *amqp.Conn - management *AmqpManagement - lifeCycle *LifeCycle - session *amqp.Session -} - -// NewPublisher creates a new Publisher that sends messages to the provided destination. -// The destination is a TargetAddress that can be a Queue or an Exchange with a routing key. -// See QueueAddress and ExchangeAddress for more information. -func (a *AmqpConnection) NewPublisher(ctx context.Context, destination TargetAddress, linkName string) (*Publisher, error) { - destinationAdd := "" - err := error(nil) - if destination != nil { - destinationAdd, err = destination.toAddress() - if err != nil { - return nil, err - } - err = validateAddress(destinationAdd) - if err != nil { - return nil, err - } - } - - sender, err := a.session.NewSender(ctx, destinationAdd, createSenderLinkOptions(destinationAdd, linkName, AtLeastOnce)) - if err != nil { - return nil, err - } - return newPublisher(sender, destinationAdd != ""), nil -} - -// NewConsumer creates a new Consumer that listens to the provided destination. Destination is a QueueAddress. -func (a *AmqpConnection) NewConsumer(ctx context.Context, destination *QueueAddress, linkName string) (*Consumer, error) { - destinationAdd, err := destination.toAddress() - if err != nil { - return nil, err - } - err = validateAddress(destinationAdd) - - if err != nil { - return nil, err - } - receiver, err := a.session.NewReceiver(ctx, destinationAdd, createReceiverLinkOptions(destinationAdd, linkName, AtLeastOnce)) - if err != nil { - return nil, err - } - return newConsumer(receiver), nil -} - -// Dial connect to the AMQP 1.0 server using the provided connectionSettings -// Returns a pointer to the new AmqpConnection if successful else an error. -// addresses is a list of addresses to connect to. It picks one randomly. -// It is enough that one of the addresses is reachable. -func Dial(ctx context.Context, addresses []string, connOptions *amqp.ConnOptions) (*AmqpConnection, error) { - conn := &AmqpConnection{ - management: NewAmqpManagement(), - lifeCycle: NewLifeCycle(), - } - tmp := make([]string, len(addresses)) - copy(tmp, addresses) - - // random pick and extract one address to use for connection - for len(tmp) > 0 { - idx := random(len(tmp)) - addr := tmp[idx] - // remove the index from the tmp list - tmp = append(tmp[:idx], tmp[idx+1:]...) - err := conn.open(ctx, addr, connOptions) - if err != nil { - Error("Failed to open connection", ExtractWithoutPassword(addr), err) - continue - } - Debug("Connected to", ExtractWithoutPassword(addr)) - return conn, nil - } - return nil, fmt.Errorf("no address to connect to") -} - -// Open opens a connection to the AMQP 1.0 server. -// using the provided connectionSettings and the AMQPLite library. -// Setups the connection and the management interface. -func (a *AmqpConnection) open(ctx context.Context, addr string, connOptions *amqp.ConnOptions) error { - - if connOptions == nil { - connOptions = &amqp.ConnOptions{ - // RabbitMQ requires SASL security layer - // to be enabled for AMQP 1.0 connections. - // So this is mandatory and default in case not defined. - SASLType: amqp.SASLTypeAnonymous(), - } - } - - //connOptions.HostName is the way to set the virtual host - // so we need to pre-parse the URI to get the virtual host - // the PARSE is copied from go-amqp091 library - // the URI will be parsed is parsed again in the amqp lite library - uri, err := ParseURI(addr) - if err != nil { - return err - } - connOptions.HostName = fmt.Sprintf("vhost:%s", uri.Vhost) - - conn, err := amqp.Dial(ctx, addr, connOptions) - if err != nil { - return err - } - a.Connection = conn - a.session, err = a.Connection.NewSession(ctx, nil) - - if err != nil { - return err - } - err = a.management.Open(ctx, a) - if err != nil { - // TODO close connection? - return err - } - - a.lifeCycle.SetState(&StateOpen{}) - return nil -} - -func (a *AmqpConnection) Close(ctx context.Context) error { - err := a.management.Close(ctx) - if err != nil { - return err - } - err = a.Connection.Close() - a.lifeCycle.SetState(&StateClosed{}) - return err -} - -// NotifyStatusChange registers a channel to receive getState change notifications -// from the connection. -func (a *AmqpConnection) NotifyStatusChange(channel chan *StateChanged) { - a.lifeCycle.chStatusChanged = channel -} - -func (a *AmqpConnection) State() LifeCycleState { - return a.lifeCycle.State() -} - -// *** management section *** - -// Management returns the management interface for the connection. -// The management interface is used to declare and delete exchanges, queues, and bindings. -func (a *AmqpConnection) Management() *AmqpManagement { - return a.management -} - -//*** end management section ***