Skip to content

Commit 709e2ec

Browse files
committed
handle sigterm
1 parent cdb704c commit 709e2ec

File tree

8 files changed

+169
-56
lines changed

8 files changed

+169
-56
lines changed

internal/orchestrator/chain_tracker.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package orchestrator
22

33
import (
4+
"context"
45
"time"
56

67
"github.com/rs/zerolog/log"
@@ -22,13 +23,18 @@ func NewChainTracker(rpc rpc.IRPCClient) *ChainTracker {
2223
}
2324
}
2425

25-
func (ct *ChainTracker) Start() {
26+
func (ct *ChainTracker) Start(ctx context.Context) {
2627
interval := time.Duration(ct.triggerIntervalMs) * time.Millisecond
2728
ticker := time.NewTicker(interval)
29+
defer ticker.Stop()
2830

2931
log.Debug().Msgf("Chain tracker running")
30-
go func() {
31-
for range ticker.C {
32+
for {
33+
select {
34+
case <-ctx.Done():
35+
log.Info().Msg("Chain tracker shutting down")
36+
return
37+
case <-ticker.C:
3238
latestBlockNumber, err := ct.rpc.GetLatestBlockNumber()
3339
if err != nil {
3440
log.Error().Err(err).Msg("Error getting latest block number")
@@ -37,8 +43,5 @@ func (ct *ChainTracker) Start() {
3743
latestBlockNumberFloat, _ := latestBlockNumber.Float64()
3844
metrics.ChainHead.Set(latestBlockNumberFloat)
3945
}
40-
}()
41-
42-
// Keep the program running (otherwise it will exit)
43-
select {}
46+
}
4447
}

internal/orchestrator/committer.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package orchestrator
22

33
import (
4+
"context"
45
"fmt"
56
"math/big"
67
"sort"
@@ -44,13 +45,18 @@ func NewCommitter(rpc rpc.IRPCClient, storage storage.IStorage) *Committer {
4445
}
4546
}
4647

47-
func (c *Committer) Start() {
48+
func (c *Committer) Start(ctx context.Context) {
4849
interval := time.Duration(c.triggerIntervalMs) * time.Millisecond
4950
ticker := time.NewTicker(interval)
51+
defer ticker.Stop()
5052

5153
log.Debug().Msgf("Committer running")
52-
go func() {
53-
for range ticker.C {
54+
for {
55+
select {
56+
case <-ctx.Done():
57+
log.Info().Msg("Committer shutting down")
58+
return
59+
case <-ticker.C:
5460
blockDataToCommit, err := c.getSequentialBlockDataToCommit()
5561
if err != nil {
5662
log.Error().Err(err).Msg("Error getting block data to commit")
@@ -64,10 +70,7 @@ func (c *Committer) Start() {
6470
log.Error().Err(err).Msg("Error committing blocks")
6571
}
6672
}
67-
}()
68-
69-
// Keep the program running (otherwise it will exit)
70-
select {}
73+
}
7174
}
7275

7376
func (c *Committer) getBlockNumbersToCommit() ([]*big.Int, error) {

internal/orchestrator/committer_test.go

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package orchestrator
22

33
import (
4+
"context"
45
"math/big"
56
"testing"
67
"time"
@@ -207,12 +208,61 @@ func TestStartCommitter(t *testing.T) {
207208
mockStagingStorage.On("DeleteStagingData", &blockData).Return(nil)
208209

209210
// Start the committer in a goroutine
210-
go committer.Start()
211+
go committer.Start(context.Background())
211212

212213
// Wait for a short time to allow the committer to run
213214
time.Sleep(200 * time.Millisecond)
214215
}
215216

217+
func TestCommitterRespectsSIGTERM(t *testing.T) {
218+
mockRPC := mocks.NewMockIRPCClient(t)
219+
mockMainStorage := mocks.NewMockIMainStorage(t)
220+
mockStagingStorage := mocks.NewMockIStagingStorage(t)
221+
mockStorage := storage.IStorage{
222+
MainStorage: mockMainStorage,
223+
StagingStorage: mockStagingStorage,
224+
}
225+
226+
committer := NewCommitter(mockRPC, mockStorage)
227+
committer.triggerIntervalMs = 100 // Short interval for testing
228+
229+
chainID := big.NewInt(1)
230+
mockRPC.EXPECT().GetChainID().Return(chainID)
231+
mockMainStorage.EXPECT().GetMaxBlockNumber(chainID).Return(big.NewInt(100), nil)
232+
233+
blockData := []common.BlockData{
234+
{Block: common.Block{Number: big.NewInt(101)}},
235+
{Block: common.Block{Number: big.NewInt(102)}},
236+
}
237+
mockStagingStorage.On("GetStagingData", mock.Anything).Return(&blockData, nil)
238+
mockMainStorage.On("InsertBlockData", &blockData).Return(nil)
239+
mockStagingStorage.On("DeleteStagingData", &blockData).Return(nil)
240+
241+
// Create a context that we can cancel
242+
ctx, cancel := context.WithCancel(context.Background())
243+
244+
// Start the committer in a goroutine
245+
done := make(chan struct{})
246+
go func() {
247+
committer.Start(ctx)
248+
close(done)
249+
}()
250+
251+
// Wait a bit to ensure the committer is running
252+
time.Sleep(200 * time.Millisecond)
253+
254+
// Cancel the context (simulating SIGTERM)
255+
cancel()
256+
257+
// Wait for the committer to stop with a timeout
258+
select {
259+
case <-done:
260+
// Success - committer stopped
261+
case <-time.After(2 * time.Second):
262+
t.Fatal("Committer did not stop within timeout period after receiving cancel signal")
263+
}
264+
}
265+
216266
func TestHandleMissingStagingData(t *testing.T) {
217267
defer func() { config.Cfg = config.Config{} }()
218268
config.Cfg.Committer.BlocksPerCommit = 5

internal/orchestrator/failure_recoverer.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package orchestrator
22

33
import (
4+
"context"
45
"fmt"
56
"math/big"
67
"time"
@@ -41,13 +42,19 @@ func NewFailureRecoverer(rpc rpc.IRPCClient, storage storage.IStorage) *FailureR
4142
}
4243
}
4344

44-
func (fr *FailureRecoverer) Start() {
45+
func (fr *FailureRecoverer) Start(ctx context.Context) {
4546
interval := time.Duration(fr.triggerIntervalMs) * time.Millisecond
4647
ticker := time.NewTicker(interval)
48+
defer ticker.Stop()
4749

4850
log.Debug().Msgf("Failure Recovery running")
49-
go func() {
50-
for range ticker.C {
51+
52+
for {
53+
select {
54+
case <-ctx.Done():
55+
log.Info().Msg("Failure recoverer shutting down")
56+
return
57+
case <-ticker.C:
5158
blockFailures, err := fr.storage.OrchestratorStorage.GetBlockFailures(storage.QueryFilter{
5259
ChainId: fr.rpc.GetChainID(),
5360
Limit: fr.failuresPerPoll,
@@ -75,10 +82,7 @@ func (fr *FailureRecoverer) Start() {
7582
metrics.FailureRecovererLastTriggeredBlock.Set(float64(blockFailures[len(blockFailures)-1].BlockNumber.Int64()))
7683
metrics.FirstBlocknumberInFailureRecovererBatch.Set(float64(blockFailures[0].BlockNumber.Int64()))
7784
}
78-
}()
79-
80-
// Keep the program running (otherwise it will exit)
81-
select {}
85+
}
8286
}
8387

8488
func (fr *FailureRecoverer) handleWorkerResults(blockFailures []common.BlockFailure, results []rpc.GetFullBlockResult) {

internal/orchestrator/orchestrator.go

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
package orchestrator
22

33
import (
4+
"context"
5+
"os"
6+
"os/signal"
47
"sync"
8+
"syscall"
59

10+
"github.com/rs/zerolog/log"
611
config "github.com/thirdweb-dev/indexer/configs"
712
"github.com/thirdweb-dev/indexer/internal/rpc"
813
"github.com/thirdweb-dev/indexer/internal/storage"
@@ -15,6 +20,7 @@ type Orchestrator struct {
1520
failureRecovererEnabled bool
1621
committerEnabled bool
1722
reorgHandlerEnabled bool
23+
cancel context.CancelFunc
1824
}
1925

2026
func NewOrchestrator(rpc rpc.IRPCClient) (*Orchestrator, error) {
@@ -34,14 +40,26 @@ func NewOrchestrator(rpc rpc.IRPCClient) (*Orchestrator, error) {
3440
}
3541

3642
func (o *Orchestrator) Start() {
43+
ctx, cancel := context.WithCancel(context.Background())
44+
o.cancel = cancel
45+
3746
var wg sync.WaitGroup
3847

48+
sigChan := make(chan os.Signal, 1)
49+
signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT)
50+
51+
go func() {
52+
sig := <-sigChan
53+
log.Info().Msgf("Received signal %v, initiating graceful shutdown", sig)
54+
o.cancel()
55+
}()
56+
3957
if o.pollerEnabled {
4058
wg.Add(1)
4159
go func() {
4260
defer wg.Done()
4361
poller := NewPoller(o.rpc, o.storage)
44-
poller.Start()
62+
poller.Start(ctx)
4563
}()
4664
}
4765

@@ -50,7 +68,7 @@ func (o *Orchestrator) Start() {
5068
go func() {
5169
defer wg.Done()
5270
failureRecoverer := NewFailureRecoverer(o.rpc, o.storage)
53-
failureRecoverer.Start()
71+
failureRecoverer.Start(ctx)
5472
}()
5573
}
5674

@@ -59,7 +77,7 @@ func (o *Orchestrator) Start() {
5977
go func() {
6078
defer wg.Done()
6179
committer := NewCommitter(o.rpc, o.storage)
62-
committer.Start()
80+
committer.Start(ctx)
6381
}()
6482
}
6583

@@ -68,7 +86,7 @@ func (o *Orchestrator) Start() {
6886
go func() {
6987
defer wg.Done()
7088
reorgHandler := NewReorgHandler(o.rpc, o.storage)
71-
reorgHandler.Start()
89+
reorgHandler.Start(ctx)
7290
}()
7391
}
7492

@@ -77,8 +95,14 @@ func (o *Orchestrator) Start() {
7795
go func() {
7896
defer wg.Done()
7997
chainTracker := NewChainTracker(o.rpc)
80-
chainTracker.Start()
98+
chainTracker.Start(ctx)
8199
}()
82100

83101
wg.Wait()
84102
}
103+
104+
func (o *Orchestrator) Shutdown() {
105+
if o.cancel != nil {
106+
o.cancel()
107+
}
108+
}

internal/orchestrator/poller.go

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package orchestrator
22

33
import (
4+
"context"
45
"fmt"
56
"math/big"
67
"sync"
@@ -76,43 +77,67 @@ func NewPoller(rpc rpc.IRPCClient, storage storage.IStorage) *Poller {
7677
return poller
7778
}
7879

79-
func (p *Poller) Start() {
80+
func (p *Poller) Start(ctx context.Context) {
8081
interval := time.Duration(p.triggerIntervalMs) * time.Millisecond
8182
ticker := time.NewTicker(interval)
83+
defer ticker.Stop()
84+
log.Debug().Msgf("Poller running")
8285

8386
tasks := make(chan struct{}, p.parallelPollers)
8487
var blockRangeMutex sync.Mutex
88+
var wg sync.WaitGroup
89+
ctx, cancel := context.WithCancel(ctx)
90+
defer cancel()
8591

8692
for i := 0; i < p.parallelPollers; i++ {
93+
wg.Add(1)
8794
go func() {
88-
for range tasks {
89-
blockRangeMutex.Lock()
90-
blockNumbers, err := p.getNextBlockRange()
91-
blockRangeMutex.Unlock()
92-
93-
if err != nil {
94-
if err != ErrNoNewBlocks {
95-
log.Error().Err(err).Msg("Failed to get block range to poll")
95+
defer wg.Done()
96+
for {
97+
select {
98+
case <-ctx.Done():
99+
return
100+
case _, ok := <-tasks:
101+
if !ok {
102+
return
96103
}
97-
continue
98-
}
104+
blockRangeMutex.Lock()
105+
blockNumbers, err := p.getNextBlockRange()
106+
blockRangeMutex.Unlock()
99107

100-
lastPolledBlock := p.Poll(blockNumbers)
101-
if p.reachedPollLimit(lastPolledBlock) {
102-
log.Debug().Msg("Reached poll limit, exiting poller")
103-
ticker.Stop()
104-
return
108+
if err != nil {
109+
if err != ErrNoNewBlocks {
110+
log.Error().Err(err).Msg("Failed to get block range to poll")
111+
}
112+
continue
113+
}
114+
115+
lastPolledBlock := p.Poll(blockNumbers)
116+
if p.reachedPollLimit(lastPolledBlock) {
117+
log.Debug().Msg("Reached poll limit, exiting poller")
118+
cancel()
119+
return
120+
}
105121
}
106122
}
107123
}()
108124
}
109125

110-
for range ticker.C {
111-
tasks <- struct{}{}
126+
for {
127+
select {
128+
case <-ctx.Done():
129+
close(tasks)
130+
wg.Wait()
131+
log.Info().Msg("Poller shutting down")
132+
return
133+
case <-ticker.C:
134+
select {
135+
case tasks <- struct{}{}:
136+
default:
137+
// Channel full, skip this tick
138+
}
139+
}
112140
}
113-
114-
// Keep the program running (otherwise it will exit)
115-
select {}
116141
}
117142

118143
func (p *Poller) Poll(blockNumbers []*big.Int) (lastPolledBlock *big.Int) {

0 commit comments

Comments
 (0)