diff --git a/pkg/rabbitmqamqp/amqp_connection_test.go b/pkg/rabbitmqamqp/amqp_connection_test.go index 73bce49..b1274f1 100644 --- a/pkg/rabbitmqamqp/amqp_connection_test.go +++ b/pkg/rabbitmqamqp/amqp_connection_test.go @@ -5,12 +5,12 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "os" + "time" + "github.com/Azure/go-amqp" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "os" - "sync" - "time" ) var _ = Describe("AMQP connection Test", func() { @@ -117,12 +117,10 @@ var _ = Describe("AMQP connection Test", func() { }) Describe("AMQP TLS connection should succeed with in different vhosts with Anonymous and External.", func() { - wg := &sync.WaitGroup{} - wg.Add(4) DescribeTable("TLS connection should success in different vhosts ", func(virtualHost string, sasl amqp.SASLType) { // Load CA cert caCert, err := os.ReadFile("../../.ci/certs/ca_certificate.pem") - Expect(err).To(BeNil()) + Expect(err).ToNot(HaveOccurred()) // Create a CA certificate pool and add the CA certificate to it caCertPool := x509.NewCertPool() @@ -131,7 +129,7 @@ var _ = Describe("AMQP connection Test", func() { // Load client cert clientCert, err := tls.LoadX509KeyPair("../../.ci/certs/client_localhost_certificate.pem", "../../.ci/certs/client_localhost_key.pem") - Expect(err).To(BeNil()) + Expect(err).ToNot(HaveOccurred()) // Create a TLS configuration tlsConfig := &tls.Config{ @@ -146,34 +144,32 @@ var _ = Describe("AMQP connection Test", func() { SASLType: sasl, TLSConfig: tlsConfig, }) - Expect(err).To(BeNil()) + Expect(err).ToNot(HaveOccurred()) Expect(connection).NotTo(BeNil()) // Close the connection err = connection.Close(context.Background()) - Expect(err).To(BeNil()) - wg.Done() + Expect(err).ToNot(HaveOccurred()) }, - Entry("with virtual host. External", "%2F", amqp.SASLTypeExternal("")), - Entry("with a not default virtual host. External", "tls", amqp.SASLTypeExternal("")), - Entry("with virtual host. Anonymous", "%2F", amqp.SASLTypeAnonymous()), - Entry("with a not default virtual host. Anonymous", "tls", amqp.SASLTypeAnonymous()), + Entry("default virtual host + External", "%2F", amqp.SASLTypeExternal("")), + Entry("non-default virtual host + External", "tls", amqp.SASLTypeExternal("")), + Entry("default virtual host + Anonymous", "%2F", amqp.SASLTypeAnonymous()), + Entry("non-default virtual host + Anonymous", "tls", amqp.SASLTypeAnonymous()), ) - go func() { - wg.Wait() - }() }) - Describe("AMQP TLS connection should fail with error.", func() { - tlsConfig := &tls.Config{} + Describe("AMQP TLS connection", func() { + It("should fail with error", func() { + tlsConfig := &tls.Config{} - // Dial the AMQP server with TLS configuration - connection, err := Dial(context.Background(), "amqps://does_not_exist:5671", &AmqpConnOptions{ - TLSConfig: tlsConfig, + // Dial the AMQP server with TLS configuration + connection, err := Dial(context.Background(), "amqps://does_not_exist:5671", &AmqpConnOptions{ + TLSConfig: tlsConfig, + }) + Expect(connection).To(BeNil()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to open TLS connection")) }) - Expect(connection).To(BeNil()) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("failed to open TLS connection")) }) }) diff --git a/pkg/rabbitmqamqp/amqp_consumer_stream_test.go b/pkg/rabbitmqamqp/amqp_consumer_stream_test.go index b846242..3984bba 100644 --- a/pkg/rabbitmqamqp/amqp_consumer_stream_test.go +++ b/pkg/rabbitmqamqp/amqp_consumer_stream_test.go @@ -3,7 +3,6 @@ package rabbitmqamqp import ( "context" "fmt" - "sync" "time" "github.com/Azure/go-amqp" @@ -320,33 +319,43 @@ var _ = Describe("Consumer stream test", func() { }) Describe("consumer should filter messages based on application properties", func() { - qName := generateName("consumer should filter messages based on application properties") - connection, err := Dial(context.Background(), "amqp://", nil) - Expect(err).To(BeNil()) - queueInfo, err := connection.Management().DeclareQueue(context.Background(), &StreamQueueSpecification{ - Name: qName, - }) - Expect(err).To(BeNil()) - Expect(queueInfo).NotTo(BeNil()) + var ( + qName string + connection *AmqpConnection + ) + BeforeEach(func() { + qName = generateName("consumer should filter messages based on application properties") + var err error + connection, err = Dial(context.Background(), "amqp://", nil) + Expect(err).ToNot(HaveOccurred()) + queueInfo, err := connection.Management().DeclareQueue(context.Background(), &StreamQueueSpecification{ + Name: qName, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(queueInfo).NotTo(BeNil()) - publishMessagesWithMessageLogic(qName, "ignoredKey", 7, func(msg *amqp.Message) { - msg.ApplicationProperties = map[string]interface{}{"ignoredKey": "ignoredValue"} - }) + publishMessagesWithMessageLogic(qName, "ignoredKey", 7, func(msg *amqp.Message) { + msg.ApplicationProperties = map[string]interface{}{"ignoredKey": "ignoredValue"} + }) - publishMessagesWithMessageLogic(qName, "key1", 10, func(msg *amqp.Message) { - msg.ApplicationProperties = map[string]interface{}{"key1": "value1", "constFilterKey": "constFilterValue"} - }) + publishMessagesWithMessageLogic(qName, "key1", 10, func(msg *amqp.Message) { + msg.ApplicationProperties = map[string]interface{}{"key1": "value1", "constFilterKey": "constFilterValue"} + }) + + publishMessagesWithMessageLogic(qName, "key2", 10, func(msg *amqp.Message) { + msg.ApplicationProperties = map[string]interface{}{"key2": "value2", "constFilterKey": "constFilterValue"} + }) - publishMessagesWithMessageLogic(qName, "key2", 10, func(msg *amqp.Message) { - msg.ApplicationProperties = map[string]interface{}{"key2": "value2", "constFilterKey": "constFilterValue"} + publishMessagesWithMessageLogic(qName, "key3", 10, func(msg *amqp.Message) { + msg.ApplicationProperties = map[string]interface{}{"key3": "value3", "constFilterKey": "constFilterValue"} + }) }) - publishMessagesWithMessageLogic(qName, "key3", 10, func(msg *amqp.Message) { - msg.ApplicationProperties = map[string]interface{}{"key3": "value3", "constFilterKey": "constFilterValue"} + AfterEach(func() { + Expect(connection.Management().DeleteQueue(context.Background(), qName)).To(Succeed()) + Expect(connection.Close(context.Background())).To(Succeed()) }) - var wg sync.WaitGroup - wg.Add(3) DescribeTable("consumer should filter messages based on application properties", func(key string, value any, label string) { consumer, err := connection.NewConsumer(context.Background(), qName, &StreamConsumerOptions{ @@ -375,93 +384,96 @@ var _ = Describe("Consumer stream test", func() { Expect(dc.Accept(context.Background())).To(BeNil()) } Expect(consumer.Close(context.Background())).To(BeNil()) - wg.Done() }, Entry("key1 value1", "key1", "value1", "key1"), Entry("key2 value2", "key2", "value2", "key2"), Entry("key3 value3", "key3", "value3", "key3"), ) - go func() { - wg.Wait() - Expect(connection.Management().DeleteQueue(context.Background(), qName)).To(BeNil()) - Expect(connection.Close(context.Background())).To(BeNil()) - }() - }) Describe("consumer should filter messages based on properties", func() { /* Test the consumer should filter messages based on properties */ - // TODO: defer cleanup to delete the stream queue - qName := generateName("consumer should filter messages based on properties") - qName += time.Now().String() - connection, err := Dial(context.Background(), "amqp://", nil) - Expect(err).To(BeNil()) - queueInfo, err := connection.Management().DeclareQueue(context.Background(), &StreamQueueSpecification{ - Name: qName, - }) - Expect(err).To(BeNil()) - Expect(queueInfo).NotTo(BeNil()) + var ( + qName string + connection *AmqpConnection + ) - publishMessagesWithMessageLogic(qName, "MessageID", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{MessageID: "MessageID"} - }) + BeforeEach(func() { + qName = generateName("consumer should filter messages based on properties") + qName += time.Now().String() + var err error + connection, err = Dial(context.Background(), "amqp://", nil) + Expect(err).ToNot(HaveOccurred()) + queueInfo, err := connection.Management().DeclareQueue(context.Background(), &StreamQueueSpecification{ + Name: qName, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(queueInfo).NotTo(BeNil()) - publishMessagesWithMessageLogic(qName, "Subject", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{Subject: ptr("Subject")} - }) + publishMessagesWithMessageLogic(qName, "MessageID", 10, func(msg *amqp.Message) { + msg.Properties = &amqp.MessageProperties{MessageID: "MessageID"} + }) - publishMessagesWithMessageLogic(qName, "ReplyTo", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{ReplyTo: ptr("ReplyTo")} - }) + publishMessagesWithMessageLogic(qName, "Subject", 10, func(msg *amqp.Message) { + msg.Properties = &amqp.MessageProperties{Subject: ptr("Subject")} + }) - publishMessagesWithMessageLogic(qName, "ContentType", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{ContentType: ptr("ContentType")} - }) + publishMessagesWithMessageLogic(qName, "ReplyTo", 10, func(msg *amqp.Message) { + msg.Properties = &amqp.MessageProperties{ReplyTo: ptr("ReplyTo")} + }) - publishMessagesWithMessageLogic(qName, "ContentEncoding", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{ContentEncoding: ptr("ContentEncoding")} - }) + publishMessagesWithMessageLogic(qName, "ContentType", 10, func(msg *amqp.Message) { + msg.Properties = &amqp.MessageProperties{ContentType: ptr("ContentType")} + }) - publishMessagesWithMessageLogic(qName, "GroupID", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{GroupID: ptr("GroupID")} - }) + publishMessagesWithMessageLogic(qName, "ContentEncoding", 10, func(msg *amqp.Message) { + msg.Properties = &amqp.MessageProperties{ContentEncoding: ptr("ContentEncoding")} + }) - publishMessagesWithMessageLogic(qName, "ReplyToGroupID", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{ReplyToGroupID: ptr("ReplyToGroupID")} - }) + publishMessagesWithMessageLogic(qName, "GroupID", 10, func(msg *amqp.Message) { + msg.Properties = &amqp.MessageProperties{GroupID: ptr("GroupID")} + }) - // GroupSequence - publishMessagesWithMessageLogic(qName, "GroupSequence", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{GroupSequence: ptr(uint32(137))} - }) + publishMessagesWithMessageLogic(qName, "ReplyToGroupID", 10, func(msg *amqp.Message) { + msg.Properties = &amqp.MessageProperties{ReplyToGroupID: ptr("ReplyToGroupID")} + }) - // ReplyToGroupID - publishMessagesWithMessageLogic(qName, "ReplyToGroupID", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{ReplyToGroupID: ptr("ReplyToGroupID")} - }) + // GroupSequence + publishMessagesWithMessageLogic(qName, "GroupSequence", 10, func(msg *amqp.Message) { + msg.Properties = &amqp.MessageProperties{GroupSequence: ptr(uint32(137))} + }) - // CreationTime + // ReplyToGroupID + publishMessagesWithMessageLogic(qName, "ReplyToGroupID", 10, func(msg *amqp.Message) { + msg.Properties = &amqp.MessageProperties{ReplyToGroupID: ptr("ReplyToGroupID")} + }) - publishMessagesWithMessageLogic(qName, "CreationTime", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{CreationTime: ptr(createDateTime())} - }) + // CreationTime - // AbsoluteExpiryTime + publishMessagesWithMessageLogic(qName, "CreationTime", 10, func(msg *amqp.Message) { + msg.Properties = &amqp.MessageProperties{CreationTime: ptr(createDateTime())} + }) - publishMessagesWithMessageLogic(qName, "AbsoluteExpiryTime", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{AbsoluteExpiryTime: ptr(createDateTime())} - }) + // AbsoluteExpiryTime - // CorrelationID + publishMessagesWithMessageLogic(qName, "AbsoluteExpiryTime", 10, func(msg *amqp.Message) { + msg.Properties = &amqp.MessageProperties{AbsoluteExpiryTime: ptr(createDateTime())} + }) + + // CorrelationID - publishMessagesWithMessageLogic(qName, "CorrelationID", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{CorrelationID: "CorrelationID"} + publishMessagesWithMessageLogic(qName, "CorrelationID", 10, func(msg *amqp.Message) { + msg.Properties = &amqp.MessageProperties{CorrelationID: "CorrelationID"} + }) + }) + + AfterEach(func() { + Expect(connection.Management().DeleteQueue(context.Background(), qName)).To(BeNil()) + Expect(connection.Close(context.Background())).To(BeNil()) }) - var wg sync.WaitGroup - wg.Add(12) DescribeTable("consumer should filter messages based on properties", func(properties *amqp.MessageProperties, label string) { consumer, err := connection.NewConsumer(context.Background(), qName, &StreamConsumerOptions{ @@ -533,7 +545,6 @@ var _ = Describe("Consumer stream test", func() { Expect(dc.Accept(context.Background())).To(BeNil()) } Expect(consumer.Close(context.Background())).To(BeNil()) - wg.Done() }, Entry("MessageID", &amqp.MessageProperties{MessageID: "MessageID"}, "MessageID"), Entry("Subject", &amqp.MessageProperties{Subject: ptr("Subject")}, "Subject"), @@ -548,11 +559,6 @@ var _ = Describe("Consumer stream test", func() { Entry("AbsoluteExpiryTime", &amqp.MessageProperties{AbsoluteExpiryTime: ptr(createDateTime())}, "AbsoluteExpiryTime"), Entry("CorrelationID", &amqp.MessageProperties{CorrelationID: "CorrelationID"}, "CorrelationID"), ) - go func() { - wg.Wait() - Expect(connection.Management().DeleteQueue(context.Background(), qName)).To(BeNil()) - Expect(connection.Close(context.Background())).To(BeNil()) - }() }) It("SQL filter consumer", func() { diff --git a/pkg/rabbitmqamqp/amqp_exchange.go b/pkg/rabbitmqamqp/amqp_exchange.go index 943b4cf..8ada9a6 100644 --- a/pkg/rabbitmqamqp/amqp_exchange.go +++ b/pkg/rabbitmqamqp/amqp_exchange.go @@ -3,6 +3,7 @@ package rabbitmqamqp import ( "context" "errors" + "github.com/Azure/go-amqp" ) @@ -23,14 +24,14 @@ type AmqpExchange struct { management *AmqpManagement arguments map[string]any isAutoDelete bool - exchangeType ExchangeType + exchangeType TExchangeType } func newAmqpExchange(management *AmqpManagement, name string) *AmqpExchange { return &AmqpExchange{management: management, name: name, arguments: make(map[string]any), - exchangeType: ExchangeType{Type: Direct}, + exchangeType: Direct, } } @@ -46,7 +47,7 @@ func (e *AmqpExchange) Declare(ctx context.Context) (*AmqpExchangeInfo, error) { kv := make(map[string]any) kv["auto_delete"] = e.isAutoDelete kv["durable"] = true - kv["type"] = e.exchangeType.String() + kv["type"] = string(e.exchangeType) if e.arguments != nil { kv["arguments"] = e.arguments } @@ -78,14 +79,14 @@ func (e *AmqpExchange) Delete(ctx context.Context) error { return err } -func (e *AmqpExchange) ExchangeType(exchangeType ExchangeType) { - if len(exchangeType.Type) > 0 { +func (e *AmqpExchange) ExchangeType(exchangeType TExchangeType) { + if len(exchangeType) > 0 { e.exchangeType = exchangeType } } func (e *AmqpExchange) GetExchangeType() TExchangeType { - return e.exchangeType.Type + return e.exchangeType } func (e *AmqpExchange) Name() string { diff --git a/pkg/rabbitmqamqp/amqp_queue.go b/pkg/rabbitmqamqp/amqp_queue.go index 38c554f..20a9ed3 100644 --- a/pkg/rabbitmqamqp/amqp_queue.go +++ b/pkg/rabbitmqamqp/amqp_queue.go @@ -92,9 +92,9 @@ func (a *AmqpQueue) Arguments(arguments map[string]any) { a.arguments = arguments } -func (a *AmqpQueue) QueueType(queueType QueueType) { - if len(queueType.String()) != 0 { - a.arguments["x-queue-type"] = queueType.String() +func (a *AmqpQueue) QueueType(queueType TQueueType) { + if len(string(queueType)) != 0 { + a.arguments["x-queue-type"] = string(queueType) } } diff --git a/pkg/rabbitmqamqp/entities.go b/pkg/rabbitmqamqp/entities.go index a3c06ab..020f1c9 100644 --- a/pkg/rabbitmqamqp/entities.go +++ b/pkg/rabbitmqamqp/entities.go @@ -12,14 +12,6 @@ const ( Stream TQueueType = "stream" ) -type QueueType struct { - Type TQueueType -} - -func (e QueueType) String() string { - return string(e.Type) -} - /* IQueueSpecification represents the specification of a queue */ @@ -27,7 +19,7 @@ type IQueueSpecification interface { name() string isAutoDelete() bool isExclusive() bool - queueType() QueueType + queueType() TQueueType buildArguments() map[string]any } @@ -107,8 +99,8 @@ func (q *QuorumQueueSpecification) isExclusive() bool { return false } -func (q *QuorumQueueSpecification) queueType() QueueType { - return QueueType{Type: Quorum} +func (q *QuorumQueueSpecification) queueType() TQueueType { + return Quorum } func (q *QuorumQueueSpecification) buildArguments() map[string]any { @@ -165,7 +157,7 @@ func (q *QuorumQueueSpecification) buildArguments() map[string]any { result["x-quorum-initial-group-size"] = q.QuorumInitialGroupSize } - result["x-queue-type"] = q.queueType().String() + result["x-queue-type"] = string(q.queueType()) return result } @@ -201,8 +193,8 @@ func (q *ClassicQueueSpecification) isExclusive() bool { return q.IsExclusive } -func (q *ClassicQueueSpecification) queueType() QueueType { - return QueueType{Type: Classic} +func (q *ClassicQueueSpecification) queueType() TQueueType { + return Classic } func (q *ClassicQueueSpecification) buildArguments() map[string]any { @@ -251,7 +243,7 @@ func (q *ClassicQueueSpecification) buildArguments() map[string]any { result["x-queue-leader-locator"] = q.LeaderLocator.leaderLocator() } - result["x-queue-type"] = q.queueType().String() + result["x-queue-type"] = string(q.queueType()) return result } @@ -281,8 +273,8 @@ func (a *AutoGeneratedQueueSpecification) isExclusive() bool { return a.IsExclusive } -func (a *AutoGeneratedQueueSpecification) queueType() QueueType { - return QueueType{Classic} +func (a *AutoGeneratedQueueSpecification) queueType() TQueueType { + return Classic } func (a *AutoGeneratedQueueSpecification) buildArguments() map[string]any { @@ -299,7 +291,7 @@ func (a *AutoGeneratedQueueSpecification) buildArguments() map[string]any { result["x-max-length"] = a.MaxLength } - result["x-queue-type"] = a.queueType().String() + result["x-queue-type"] = string(a.queueType()) return result } @@ -323,8 +315,8 @@ func (s *StreamQueueSpecification) isExclusive() bool { return false } -func (s *StreamQueueSpecification) queueType() QueueType { - return QueueType{Type: Stream} +func (s *StreamQueueSpecification) queueType() TQueueType { + return Stream } func (s *StreamQueueSpecification) buildArguments() map[string]any { @@ -341,7 +333,7 @@ func (s *StreamQueueSpecification) buildArguments() map[string]any { result["x-stream-initial-cluster-size"] = s.InitialClusterSize } - result["x-queue-type"] = s.queueType().String() + result["x-queue-type"] = string(s.queueType()) return result } @@ -358,19 +350,11 @@ const ( Headers TExchangeType = "headers" ) -type ExchangeType struct { - Type TExchangeType -} - -func (e ExchangeType) String() string { - return string(e.Type) -} - // IExchangeSpecification represents the specification of an exchange type IExchangeSpecification interface { name() string isAutoDelete() bool - exchangeType() ExchangeType + exchangeType() TExchangeType arguments() map[string]any } @@ -388,8 +372,8 @@ func (d *DirectExchangeSpecification) isAutoDelete() bool { return d.IsAutoDelete } -func (d *DirectExchangeSpecification) exchangeType() ExchangeType { - return ExchangeType{Type: Direct} +func (d *DirectExchangeSpecification) exchangeType() TExchangeType { + return Direct } func (d *DirectExchangeSpecification) arguments() map[string]any { @@ -410,8 +394,8 @@ func (t *TopicExchangeSpecification) isAutoDelete() bool { return t.IsAutoDelete } -func (t *TopicExchangeSpecification) exchangeType() ExchangeType { - return ExchangeType{Type: Topic} +func (t *TopicExchangeSpecification) exchangeType() TExchangeType { + return Topic } func (t *TopicExchangeSpecification) arguments() map[string]any { @@ -432,8 +416,8 @@ func (f *FanOutExchangeSpecification) isAutoDelete() bool { return f.IsAutoDelete } -func (f *FanOutExchangeSpecification) exchangeType() ExchangeType { - return ExchangeType{Type: FanOut} +func (f *FanOutExchangeSpecification) exchangeType() TExchangeType { + return FanOut } func (f *FanOutExchangeSpecification) arguments() map[string]any { @@ -454,8 +438,8 @@ func (h *HeadersExchangeSpecification) isAutoDelete() bool { return h.IsAutoDelete } -func (h *HeadersExchangeSpecification) exchangeType() ExchangeType { - return ExchangeType{Type: Headers} +func (h *HeadersExchangeSpecification) exchangeType() TExchangeType { + return Headers } func (h *HeadersExchangeSpecification) arguments() map[string]any { @@ -477,8 +461,8 @@ func (c *CustomExchangeSpecification) isAutoDelete() bool { return c.IsAutoDelete } -func (c *CustomExchangeSpecification) exchangeType() ExchangeType { - return ExchangeType{Type: TExchangeType(c.ExchangeTypeName)} +func (c *CustomExchangeSpecification) exchangeType() TExchangeType { + return TExchangeType(c.ExchangeTypeName) } func (c *CustomExchangeSpecification) arguments() map[string]any {