Skip to content

Commit 88752c5

Browse files
feat: move message polling to the client and improve error handling (#24)
2 parents d506328 + 4aee6ae commit 88752c5

File tree

7 files changed

+180
-100
lines changed

7 files changed

+180
-100
lines changed

cmd/hook/hook.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,12 @@ func checkStatus(ghaRunnerName string) bool {
166166
}
167167

168168
func sendMessageToSqs(message string) {
169-
client, err := sqs.NewClient(cfg.AwsRegion, cfg.AwsRunnerSQSUrl)
169+
client, err := sqs.NewClient(
170+
&sqs.SQSConfig{
171+
QueueURL: cfg.AwsRunnerSQSUrl,
172+
AWSRegion: cfg.AwsRegion,
173+
},
174+
)
170175
if err != nil {
171176
log.Fatalf("Failed to create sqs client: %v", err)
172177
}

cmd/tarter/app/server.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"os"
2424
"time"
2525

26+
"github.com/caarlos0/env"
2627
"github.com/gorilla/mux"
2728
"github.com/spf13/cobra"
2829
coreApi "github.com/tomtom-international/macos-actions-runner-controller/pkg/core/api"
@@ -48,10 +49,12 @@ func NewTarterCommand() *cobra.Command {
4849
if cmd.Name() == "version" {
4950
return nil
5051
}
51-
err := config.LoadTarterConfiguration(&tarterConfig)
52-
if err != nil {
53-
return errors.New("error: failed to load Tarter configuration. " + err.Error())
52+
53+
// Load configuration from environment variables
54+
if err := env.Parse(&tarterConfig); err != nil {
55+
return errors.New("error: failed to load Tarter configuration from environment. " + err.Error())
5456
}
57+
5558
if configFile == "" {
5659
return errors.New("error: --config argument is required")
5760
}

pkg/clients/sqs/client.go

Lines changed: 135 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ package sqs
1818

1919
import (
2020
"context"
21+
"errors"
22+
"fmt"
2123
"math/rand"
2224
"time"
2325

@@ -27,29 +29,75 @@ import (
2729
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
2830
)
2931

32+
const (
33+
defaultMinBackoff = 1 * time.Second
34+
defaultMaxBackoff = 60 * time.Second
35+
)
36+
3037
type SQSClient struct {
31-
service *sqs.Client
32-
QueueURL string
38+
service *sqs.Client
39+
config *SQSConfig
40+
currentBackoff time.Duration
41+
errHandle func(error)
3342
}
3443

35-
func NewClient(region, queueURL string) (*SQSClient, error) {
36-
cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region))
44+
type SQSConfig struct {
45+
QueueURL string
46+
AWSRegion string
47+
MaxBackoff time.Duration
48+
MinBackoff time.Duration
49+
ErrHandle func(error)
50+
}
51+
52+
func NewClient(sqsCfg *SQSConfig) (*SQSClient, error) {
53+
err := validateSQSConfig(sqsCfg)
54+
if err != nil {
55+
return nil, err
56+
}
57+
cfg, err := config.LoadDefaultConfig(context.TODO(),
58+
config.WithRegion(sqsCfg.AWSRegion),
59+
config.WithRetryMaxAttempts(3),
60+
config.WithRetryMode(aws.RetryModeStandard),
61+
)
3762
if err != nil {
3863
return nil, err
3964
}
4065

4166
svc := sqs.NewFromConfig(cfg)
4267

4368
return &SQSClient{
44-
service: svc,
45-
QueueURL: queueURL,
69+
service: svc,
70+
config: sqsCfg,
71+
currentBackoff: sqsCfg.MinBackoff,
72+
errHandle: sqsCfg.ErrHandle,
4673
}, nil
4774
}
4875

76+
func validateSQSConfig(cfg *SQSConfig) error {
77+
if cfg.QueueURL == "" {
78+
return errors.New("queueURL is required")
79+
}
80+
if cfg.AWSRegion == "" {
81+
return errors.New("region is required")
82+
}
83+
if cfg.MaxBackoff <= 0 {
84+
cfg.MaxBackoff = defaultMaxBackoff
85+
}
86+
if cfg.MinBackoff <= 0 {
87+
cfg.MinBackoff = defaultMinBackoff
88+
}
89+
if cfg.ErrHandle == nil {
90+
cfg.ErrHandle = func(err error) {
91+
return
92+
}
93+
}
94+
return nil
95+
}
96+
4997
func (c *SQSClient) SendMessage(ctx context.Context, messageBody string) (*sqs.SendMessageOutput, error) {
5098
result, err := c.service.SendMessage(ctx, &sqs.SendMessageInput{
5199
MessageBody: aws.String(messageBody),
52-
QueueUrl: aws.String(c.QueueURL),
100+
QueueUrl: aws.String(c.config.QueueURL),
53101
})
54102

55103
if err != nil {
@@ -61,7 +109,7 @@ func (c *SQSClient) SendMessage(ctx context.Context, messageBody string) (*sqs.S
61109

62110
func (c *SQSClient) DeleteMessage(ctx context.Context, receiptHandle string) (*sqs.DeleteMessageOutput, error) {
63111
result, err := c.service.DeleteMessage(ctx, &sqs.DeleteMessageInput{
64-
QueueUrl: aws.String(c.QueueURL),
112+
QueueUrl: aws.String(c.config.QueueURL),
65113
ReceiptHandle: aws.String(receiptHandle),
66114
})
67115

@@ -74,7 +122,7 @@ func (c *SQSClient) DeleteMessage(ctx context.Context, receiptHandle string) (*s
74122

75123
func (c *SQSClient) ReceiveMessages(ctx context.Context, maxMessages int32, waitTimeSeconds int32) ([]types.Message, error) {
76124
result, err := c.service.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{
77-
QueueUrl: aws.String(c.QueueURL),
125+
QueueUrl: aws.String(c.config.QueueURL),
78126
MaxNumberOfMessages: maxMessages,
79127
WaitTimeSeconds: waitTimeSeconds,
80128
})
@@ -89,10 +137,6 @@ func (c *SQSClient) ReceiveMessages(ctx context.Context, maxMessages int32, wait
89137
// CalculateNextBackoff computes the next backoff duration using an exponential
90138
// strategy with jitter to prevent synchronized polling.
91139
//
92-
// Parameters:
93-
// - current: The current backoff duration.
94-
// - max: The maximum allowable backoff duration.
95-
//
96140
// Returns:
97141
//
98142
// A time.Duration representing the next backoff period.
@@ -101,18 +145,91 @@ func (c *SQSClient) ReceiveMessages(ctx context.Context, maxMessages int32, wait
101145
// (subtracting up to 25% of the doubled value) to prevent multiple instances
102146
// from synchronizing their polling cycles. The result is capped at the specified
103147
// maximum duration.
104-
func (c *SQSClient) CalculateNextBackoff(current, maxDuration time.Duration) time.Duration {
148+
func (c *SQSClient) calculateNextBackoff() time.Duration {
105149
// Double the current backoff
106-
next := current * 2
150+
next := c.currentBackoff * 2
107151

108152
// Apply jitter (randomness) to prevent synchronized polling
109153
jitter := time.Duration(rand.Int63n(int64(next / 4)))
110154
next -= jitter
111155

112156
// Ensure we don't exceed the maximum
113-
if next > maxDuration {
114-
return maxDuration
157+
if next > c.config.MaxBackoff {
158+
c.currentBackoff = c.config.MaxBackoff
159+
} else {
160+
c.currentBackoff = next
115161
}
116162

117-
return next
163+
return c.currentBackoff
164+
}
165+
166+
func (c *SQSClient) resetBackoff() {
167+
c.currentBackoff = c.config.MinBackoff
168+
}
169+
170+
func (c *SQSClient) PollForMessages(ctx context.Context, handler func(message *types.Message) error) {
171+
errChan := make(chan error, 100)
172+
pollCtx, cancel := context.WithCancel(ctx)
173+
defer cancel()
174+
go func() {
175+
defer close(errChan)
176+
for {
177+
select {
178+
case <-pollCtx.Done():
179+
return
180+
case err := <-errChan:
181+
if err != nil && c.errHandle != nil {
182+
c.errHandle(err)
183+
}
184+
}
185+
}
186+
}()
187+
188+
var backoff time.Duration
189+
190+
for {
191+
select {
192+
case <-ctx.Done():
193+
// Context was canceled - exit gracefully without error
194+
return
195+
default:
196+
messages, err := c.ReceiveMessages(ctx, 10, 20)
197+
if err != nil {
198+
// Check if the error is due to context cancellation
199+
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
200+
// Context canceled while receiving messages, stopping polling
201+
return
202+
}
203+
errChan <- fmt.Errorf("failed to receive messages: %w", err)
204+
205+
// Increase backoff on errors
206+
backoff = c.calculateNextBackoff()
207+
time.Sleep(backoff)
208+
continue
209+
}
210+
211+
if len(messages) == 0 {
212+
// No messages found, increase the backoff
213+
backoff = c.calculateNextBackoff()
214+
time.Sleep(backoff)
215+
} else {
216+
// Messages found, reset backoff
217+
c.resetBackoff()
218+
219+
for _, msg := range messages {
220+
err = handler(&msg)
221+
if err != nil {
222+
continue
223+
}
224+
225+
_, err = c.DeleteMessage(ctx, *msg.ReceiptHandle)
226+
if err != nil {
227+
errChan <- fmt.Errorf("failed to delete message %s from queue %s: %w", *msg.MessageId,
228+
c.config.QueueURL, err)
229+
continue
230+
}
231+
}
232+
}
233+
}
234+
}
118235
}

pkg/controller/controller.go

Lines changed: 27 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ package controller
1919
import (
2020
"context"
2121
"encoding/json"
22-
"errors"
2322
"fmt"
2423
"net"
2524
"sync"
2625
"time"
2726

27+
sqsTypes "github.com/aws/aws-sdk-go-v2/service/sqs/types"
2828
"github.com/gorilla/websocket"
2929
"github.com/tomtom-international/macos-actions-runner-controller/pkg/clients/etcd"
3030
"github.com/tomtom-international/macos-actions-runner-controller/pkg/clients/sqs"
@@ -37,11 +37,6 @@ import (
3737
"github.com/tomtom-international/macos-actions-runner-controller/pkg/utils"
3838
)
3939

40-
const (
41-
minBackoff = 4 * time.Second
42-
maxBackoff = 60 * time.Second
43-
)
44-
4540
type Controller struct {
4641
EtcdClient *etcd.EtcdClient
4742
nodePoolManager *np.Manager
@@ -58,7 +53,15 @@ func NewController(configuration config.ControllerConfig) (*Controller, error) {
5853
if err != nil {
5954
return nil, err
6055
}
61-
sqsClient, err := sqs.NewClient(configuration.AwsRegion, configuration.AwsRunnerRequestSQSUrl)
56+
sqsClient, err := sqs.NewClient(
57+
&sqs.SQSConfig{
58+
QueueURL: configuration.AwsRunnerRequestSQSUrl,
59+
AWSRegion: configuration.AwsRegion,
60+
ErrHandle: func(err error) {
61+
logger.Errorf("Sqs error: %s", err.Error())
62+
},
63+
},
64+
)
6265
if err != nil {
6366
return nil, err
6467
}
@@ -87,7 +90,7 @@ func (c *Controller) Run(ctx context.Context, wg *sync.WaitGroup) {
8790
wg.Add(2)
8891
go func() {
8992
defer wg.Done()
90-
c.listenForNewRunners(ctx, c.createRunner)
93+
c.listenForNewRunners(ctx, c.readSQSMessages)
9194
}()
9295
go func() {
9396
defer wg.Done()
@@ -345,9 +348,9 @@ func (c *Controller) GetRunnerList() ([]types.Runner, error) {
345348
return runners, nil
346349
}
347350

348-
func (c *Controller) listenForNewRunners(ctx context.Context, handlerFunc func(runner *types.Runner) error) {
349-
logger.Debugf("Listening for new messages from SQS queue: %s", c.sqsRunnerRequestClient.QueueURL)
350-
pollForMessages(ctx, c.sqsRunnerRequestClient, handlerFunc)
351+
func (c *Controller) listenForNewRunners(ctx context.Context, handlerFunc func(message *sqsTypes.Message) error) {
352+
logger.Debugf("Listening for new messages from SQS queue")
353+
c.sqsRunnerRequestClient.PollForMessages(ctx, handlerFunc)
351354
}
352355

353356
func (c *Controller) checkRegisteredNodes() {
@@ -373,65 +376,21 @@ func (c *Controller) checkRegisteredNodes() {
373376
}
374377
}
375378

376-
func pollForMessages(ctx context.Context, sqsClient *sqs.SQSClient, handler func(runner *types.Runner) error) {
377-
currentBackoff := minBackoff
378-
379-
for {
380-
select {
381-
case <-ctx.Done():
382-
// Context was canceled - exit gracefully without error
383-
logger.Debugf("Context canceled, stopping message polling from %s", sqsClient.QueueURL)
384-
return
385-
default:
386-
messages, err := sqsClient.ReceiveMessages(ctx, 10, 20)
387-
if err != nil {
388-
// Check if the error is due to context cancellation
389-
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
390-
logger.Debugf("Context canceled while receiving messages, stopping polling")
391-
return
392-
}
393-
394-
logger.Errorf("Failed to receive messages: %s", err.Error())
395-
396-
// Increase backoff on errors
397-
currentBackoff = sqsClient.CalculateNextBackoff(currentBackoff, maxBackoff)
398-
time.Sleep(currentBackoff)
399-
continue
400-
}
379+
func (c *Controller) readSQSMessages(msg *sqsTypes.Message) error {
380+
runner := types.Runner{}
381+
logger.Debugf("Received message with id: %s from queue", *msg.MessageId)
382+
err := json.Unmarshal([]byte(*msg.Body), &runner)
383+
if err != nil {
384+
logger.Errorf("Failed to unmarshal message %s from queue: %s", *msg.MessageId, err.Error())
385+
return err
386+
}
387+
runner.Condition.CreateRequestID = *msg.MessageId
401388

402-
if len(messages) == 0 {
403-
// No messages found, increase the backoff
404-
currentBackoff = sqsClient.CalculateNextBackoff(currentBackoff, maxBackoff)
405-
time.Sleep(currentBackoff)
406-
} else {
407-
// Messages found, reset backoff
408-
currentBackoff = minBackoff
409-
410-
for _, msg := range messages {
411-
runner := types.Runner{}
412-
logger.Debugf("Received message with id: %s from queue %s", *msg.MessageId, sqsClient.QueueURL)
413-
err := json.Unmarshal([]byte(*msg.Body), &runner)
414-
if err != nil {
415-
// message will be sent to DLQ after maximum receives
416-
logger.Errorf("Failed to unmarshal message %s from queue %s: %s", *msg.MessageId, sqsClient.QueueURL, err.Error())
417-
continue
418-
}
419-
runner.Condition.CreateRequestID = *msg.MessageId
420-
421-
err = handler(&runner)
422-
if err != nil {
423-
continue
424-
}
425-
426-
_, err = sqsClient.DeleteMessage(ctx, *msg.ReceiptHandle)
427-
if err != nil {
428-
logger.Errorf("Failed to delete message %s from queue %s: %s", *msg.MessageId, sqsClient.QueueURL, err.Error())
429-
continue
430-
}
431-
}
432-
}
433-
}
389+
err = c.createRunner(&runner)
390+
if err != nil {
391+
return err
434392
}
393+
return nil
435394
}
436395

437396
func (c *Controller) AddRunnersWatcher(conn *websocket.Conn, notificationChan chan types.WatcherRunnersUpdate) {

0 commit comments

Comments
 (0)