Skip to content

Commit fda81ed

Browse files
committed
feat(PubSub): Add SQS PubSub wrappers
1 parent 0afe635 commit fda81ed

File tree

3 files changed

+240
-0
lines changed

3 files changed

+240
-0
lines changed

pkg/pubsub/sqs/publisher.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package sqs
2+
3+
import (
4+
"context"
5+
6+
"github.com/aws/aws-sdk-go-v2/service/sqs"
7+
)
8+
9+
type (
10+
Publisher struct {
11+
client *sqs.Client
12+
queueURL string
13+
}
14+
)
15+
16+
func NewPublisher(sqsClient *sqs.Client, queueURL string) *Publisher {
17+
return &Publisher{
18+
client: sqsClient,
19+
queueURL: queueURL,
20+
}
21+
}
22+
23+
func (p *Publisher) Publish(ctx context.Context, msg *sqs.SendMessageInput) (*sqs.SendMessageOutput, error) {
24+
if msg.QueueUrl == nil {
25+
msg.QueueUrl = &p.queueURL
26+
}
27+
28+
return p.client.SendMessage(ctx, msg)
29+
}

pkg/pubsub/sqs/subscriber.go

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
package sqs
2+
3+
import (
4+
"context"
5+
"sync"
6+
"time"
7+
8+
"github.com/aws/aws-sdk-go-v2/aws"
9+
"github.com/aws/aws-sdk-go-v2/service/sqs"
10+
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
11+
12+
"github.com/scribd/go-sdk/pkg/pubsub"
13+
"github.com/scribd/go-sdk/pkg/pubsub/pool"
14+
)
15+
16+
type (
17+
SQSClient interface {
18+
ReceiveMessage(
19+
ctx context.Context,
20+
params *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error)
21+
}
22+
23+
Subscriber struct {
24+
client SQSClient
25+
queueURL string
26+
handler MsgHandler
27+
maxMessages int
28+
pool *pool.Pool
29+
waitTime time.Duration
30+
wg sync.WaitGroup // tracks active handlers
31+
stopCh chan struct{}
32+
}
33+
34+
SubscriberConfig struct {
35+
SQSClient *sqs.Client
36+
MsgHandler MsgHandler
37+
SQSConfig pubsub.SQS
38+
}
39+
40+
MsgHandler func(msg types.Message)
41+
)
42+
43+
const (
44+
defaultNumWorkers = 1
45+
maxWaitTime = 20 * time.Second
46+
)
47+
48+
func NewSubscriber(c SubscriberConfig) *Subscriber {
49+
workers := c.SQSConfig.Subscriber.Workers
50+
if workers == 0 {
51+
workers = defaultNumWorkers
52+
}
53+
54+
waitTime := c.SQSConfig.Subscriber.WaitTime
55+
if waitTime > maxWaitTime {
56+
waitTime = maxWaitTime
57+
}
58+
59+
return &Subscriber{
60+
client: c.SQSClient,
61+
handler: c.MsgHandler,
62+
maxMessages: c.SQSConfig.Subscriber.MaxMessages,
63+
queueURL: c.SQSConfig.Subscriber.QueueURL,
64+
pool: pool.New(workers),
65+
waitTime: waitTime,
66+
stopCh: make(chan struct{}),
67+
}
68+
}
69+
70+
func (s *Subscriber) Subscribe(ctx context.Context) chan error {
71+
ch := make(chan error)
72+
73+
req := &sqs.ReceiveMessageInput{
74+
QueueUrl: aws.String(s.queueURL),
75+
MaxNumberOfMessages: int32(s.maxMessages),
76+
MessageAttributeNames: []string{"All"},
77+
MessageSystemAttributeNames: []types.MessageSystemAttributeName{
78+
types.MessageSystemAttributeNameAll,
79+
},
80+
}
81+
if s.waitTime > 0 {
82+
req.WaitTimeSeconds = int32(s.waitTime.Seconds())
83+
}
84+
85+
go func() {
86+
defer close(ch)
87+
88+
for {
89+
select {
90+
case <-s.stopCh:
91+
return
92+
default:
93+
response, err := s.client.ReceiveMessage(ctx, req)
94+
if err != nil {
95+
ch <- err
96+
97+
return
98+
}
99+
100+
for _, message := range response.Messages {
101+
s.wg.Add(1)
102+
s.pool.Schedule(func() {
103+
s.handler(message)
104+
s.wg.Done()
105+
})
106+
}
107+
}
108+
}
109+
}()
110+
111+
return ch
112+
}
113+
114+
func (s *Subscriber) Unsubscribe() error {
115+
close(s.stopCh)
116+
s.wg.Wait()
117+
118+
return nil
119+
}

pkg/pubsub/sqs/subscriber_test.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package sqs
2+
3+
import (
4+
"context"
5+
"sync/atomic"
6+
"testing"
7+
"time"
8+
9+
"github.com/aws/aws-sdk-go-v2/aws"
10+
"github.com/aws/aws-sdk-go-v2/service/sqs"
11+
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
12+
13+
"github.com/scribd/go-sdk/pkg/pubsub/pool"
14+
)
15+
16+
type (
17+
mockSQSClient struct {
18+
msgs []types.Message
19+
}
20+
)
21+
22+
func (m *mockSQSClient) ReceiveMessage(
23+
ctx context.Context,
24+
params *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
25+
return &sqs.ReceiveMessageOutput{
26+
Messages: m.msgs,
27+
}, nil
28+
}
29+
30+
func Test_Subscriber_Subscribe(t *testing.T) {
31+
t.Run("all subscribers finished", func(t *testing.T) {
32+
var nHandlers int64 // atomic
33+
var executedTimes int64
34+
35+
c := make(chan int, 6)
36+
37+
sub := &Subscriber{
38+
pool: pool.New(2),
39+
stopCh: make(chan struct{}),
40+
client: &mockSQSClient{
41+
msgs: []types.Message{
42+
{
43+
Body: aws.String("1"),
44+
},
45+
{
46+
Body: aws.String("2"),
47+
},
48+
{
49+
Body: aws.String("3"),
50+
},
51+
{
52+
Body: aws.String("4"),
53+
},
54+
{
55+
Body: aws.String("5"),
56+
},
57+
{
58+
Body: aws.String("6"),
59+
},
60+
},
61+
},
62+
handler: func(msg types.Message) {
63+
c <- 0
64+
65+
atomic.AddInt64(&nHandlers, 1)
66+
defer atomic.AddInt64(&nHandlers, -1)
67+
atomic.AddInt64(&executedTimes, 1)
68+
69+
time.Sleep(time.Millisecond * 10)
70+
},
71+
}
72+
73+
_ = sub.Subscribe(context.Background())
74+
// Make sure all goroutines have started.
75+
for i := 0; i < cap(c); i++ {
76+
<-c
77+
}
78+
79+
err := sub.Unsubscribe()
80+
if err != nil {
81+
t.Errorf("expected nil, got %v", err)
82+
}
83+
84+
if got := atomic.LoadInt64(&nHandlers); got != 0 {
85+
t.Errorf("expected 0, got %d", got)
86+
}
87+
88+
if got := atomic.LoadInt64(&executedTimes); got != 6 {
89+
t.Errorf("expected 6, got %d", got)
90+
}
91+
})
92+
}

0 commit comments

Comments
 (0)