From dce699294ace5ff660ab04abe631ab2c3e055d28 Mon Sep 17 00:00:00 2001 From: iuwqyir Date: Tue, 20 May 2025 12:29:30 +0300 Subject: [PATCH] pass context to rpc and worker --- internal/handlers/search_handlers.go | 15 +-- internal/orchestrator/chain_tracker.go | 2 +- internal/orchestrator/committer.go | 20 ++-- internal/orchestrator/committer_test.go | 18 +-- internal/orchestrator/failure_recoverer.go | 2 +- internal/orchestrator/poller.go | 12 +- internal/orchestrator/reorg_handler.go | 22 ++-- internal/orchestrator/reorg_handler_test.go | 60 +++++----- internal/rpc/batcher.go | 10 +- internal/rpc/rpc.go | 42 +++---- internal/worker/worker.go | 29 +++-- test/mocks/MockIRPCClient.go | 123 +++++++++++--------- 12 files changed, 189 insertions(+), 166 deletions(-) diff --git a/internal/handlers/search_handlers.go b/internal/handlers/search_handlers.go index 1a4139b..59e4fe9 100644 --- a/internal/handlers/search_handlers.go +++ b/internal/handlers/search_handlers.go @@ -1,6 +1,7 @@ package handlers import ( + "context" "encoding/hex" "fmt" "math/big" @@ -76,7 +77,7 @@ func Search(c *gin.Context) { return } - result, err := executeSearch(mainStorage, chainId, searchInput) + result, err := executeSearch(c.Request.Context(), mainStorage, chainId, searchInput) if err != nil { log.Error().Err(err).Msg("Error executing search") api.InternalErrorHandler(c) @@ -124,7 +125,7 @@ func isValidHashWithLength(input string, length int) bool { return false } -func executeSearch(storage storage.IMainStorage, chainId *big.Int, input SearchInput) (SearchResultModel, error) { +func executeSearch(ctx context.Context, storage storage.IMainStorage, chainId *big.Int, input SearchInput) (SearchResultModel, error) { switch { case input.BlockNumber != nil: block, err := searchByBlockNumber(storage, chainId, input.BlockNumber) @@ -134,7 +135,7 @@ func executeSearch(storage storage.IMainStorage, chainId *big.Int, input SearchI return searchByHash(storage, chainId, input.Hash) case input.Address != "": - return searchByAddress(storage, chainId, input.Address) + return searchByAddress(ctx, storage, chainId, input.Address) case input.FunctionSignature != "": transactions, err := searchByFunctionSelectorOptimistically(storage, chainId, input.FunctionSignature) @@ -329,9 +330,9 @@ func searchByHash(mainStorage storage.IMainStorage, chainId *big.Int, hash strin } } -func searchByAddress(mainStorage storage.IMainStorage, chainId *big.Int, address string) (SearchResultModel, error) { +func searchByAddress(ctx context.Context, mainStorage storage.IMainStorage, chainId *big.Int, address string) (SearchResultModel, error) { searchResult := SearchResultModel{Type: SearchResultTypeAddress} - contractCode, err := checkIfContractHasCode(chainId, address) + contractCode, err := checkIfContractHasCode(ctx, chainId, address) if err != nil { return searchResult, err } @@ -437,14 +438,14 @@ const ( ContractCodeDoesNotExist ) -func checkIfContractHasCode(chainId *big.Int, address string) (ContractCodeState, error) { +func checkIfContractHasCode(ctx context.Context, chainId *big.Int, address string) (ContractCodeState, error) { if config.Cfg.API.Thirdweb.ClientId != "" { rpcUrl := fmt.Sprintf("https://%s.rpc.thirdweb.com/%s", chainId.String(), config.Cfg.API.Thirdweb.ClientId) r, err := rpc.InitializeSimpleRPCWithUrl(rpcUrl) if err != nil { return ContractCodeUnknown, err } - hasCode, err := r.HasCode(address) + hasCode, err := r.HasCode(ctx, address) if err != nil { return ContractCodeUnknown, err } diff --git a/internal/orchestrator/chain_tracker.go b/internal/orchestrator/chain_tracker.go index 8e94a50..686b6b5 100644 --- a/internal/orchestrator/chain_tracker.go +++ b/internal/orchestrator/chain_tracker.go @@ -35,7 +35,7 @@ func (ct *ChainTracker) Start(ctx context.Context) { log.Info().Msg("Chain tracker shutting down") return case <-ticker.C: - latestBlockNumber, err := ct.rpc.GetLatestBlockNumber() + latestBlockNumber, err := ct.rpc.GetLatestBlockNumber(ctx) if err != nil { log.Error().Err(err).Msg("Error getting latest block number") continue diff --git a/internal/orchestrator/committer.go b/internal/orchestrator/committer.go index 9c7c122..4ef4206 100644 --- a/internal/orchestrator/committer.go +++ b/internal/orchestrator/committer.go @@ -63,7 +63,7 @@ func (c *Committer) Start(ctx context.Context) { return default: time.Sleep(interval) - blockDataToCommit, err := c.getSequentialBlockDataToCommit() + blockDataToCommit, err := c.getSequentialBlockDataToCommit(ctx) if err != nil { log.Error().Err(err).Msg("Error getting block data to commit") continue @@ -72,7 +72,7 @@ func (c *Committer) Start(ctx context.Context) { log.Debug().Msg("No block data to commit") continue } - if err := c.commit(blockDataToCommit); err != nil { + if err := c.commit(ctx, blockDataToCommit); err != nil { log.Error().Err(err).Msg("Error committing blocks") } } @@ -108,7 +108,7 @@ func (c *Committer) getBlockNumbersToCommit() ([]*big.Int, error) { return blockNumbers, nil } -func (c *Committer) getSequentialBlockDataToCommit() ([]common.BlockData, error) { +func (c *Committer) getSequentialBlockDataToCommit(ctx context.Context) ([]common.BlockData, error) { blocksToCommit, err := c.getBlockNumbersToCommit() if err != nil { return nil, fmt.Errorf("error determining blocks to commit: %v", err) @@ -123,7 +123,7 @@ func (c *Committer) getSequentialBlockDataToCommit() ([]common.BlockData, error) } if len(blocksData) == 0 { log.Warn().Msgf("Committer didn't find the following range in staging: %v - %v", blocksToCommit[0].Int64(), blocksToCommit[len(blocksToCommit)-1].Int64()) - c.handleMissingStagingData(blocksToCommit) + c.handleMissingStagingData(ctx, blocksToCommit) return nil, nil } @@ -133,7 +133,7 @@ func (c *Committer) getSequentialBlockDataToCommit() ([]common.BlockData, error) }) if blocksData[0].Block.Number.Cmp(blocksToCommit[0]) != 0 { - return nil, c.handleGap(blocksToCommit[0], blocksData[0].Block) + return nil, c.handleGap(ctx, blocksToCommit[0], blocksData[0].Block) } var sequentialBlockData []common.BlockData @@ -161,7 +161,7 @@ func (c *Committer) getSequentialBlockDataToCommit() ([]common.BlockData, error) return sequentialBlockData, nil } -func (c *Committer) commit(blockData []common.BlockData) error { +func (c *Committer) commit(ctx context.Context, blockData []common.BlockData) error { blockNumbers := make([]*big.Int, len(blockData)) for i, block := range blockData { blockNumbers[i] = block.Block.Number @@ -199,7 +199,7 @@ func (c *Committer) commit(blockData []common.BlockData) error { return nil } -func (c *Committer) handleGap(expectedStartBlockNumber *big.Int, actualFirstBlock common.Block) error { +func (c *Committer) handleGap(ctx context.Context, expectedStartBlockNumber *big.Int, actualFirstBlock common.Block) error { // increment the gap counter in prometheus metrics.GapCounter.Inc() // record the first missed block number in prometheus @@ -220,11 +220,11 @@ func (c *Committer) handleGap(expectedStartBlockNumber *big.Int, actualFirstBloc } log.Debug().Msgf("Polling %d blocks while handling gap: %v", len(missingBlockNumbers), missingBlockNumbers) - poller.Poll(missingBlockNumbers) + poller.Poll(ctx, missingBlockNumbers) return fmt.Errorf("first block number (%s) in commit batch does not match expected (%s)", actualFirstBlock.Number.String(), expectedStartBlockNumber.String()) } -func (c *Committer) handleMissingStagingData(blocksToCommit []*big.Int) { +func (c *Committer) handleMissingStagingData(ctx context.Context, blocksToCommit []*big.Int) { // Checks if there are any blocks in staging after the current range end lastStagedBlockNumber, err := c.storage.StagingStorage.GetLastStagedBlockNumber(c.rpc.GetChainID(), blocksToCommit[len(blocksToCommit)-1], big.NewInt(0)) if err != nil { @@ -242,6 +242,6 @@ func (c *Committer) handleMissingStagingData(blocksToCommit []*big.Int) { if len(blocksToCommit) > int(poller.blocksPerPoll) { blocksToPoll = blocksToCommit[:int(poller.blocksPerPoll)] } - poller.Poll(blocksToPoll) + poller.Poll(ctx, blocksToPoll) log.Debug().Msgf("Polled %d blocks due to committer detecting them as missing. Range: %s - %s", len(blocksToPoll), blocksToPoll[0].String(), blocksToPoll[len(blocksToPoll)-1].String()) } diff --git a/internal/orchestrator/committer_test.go b/internal/orchestrator/committer_test.go index 5b66ef8..25b0ca3 100644 --- a/internal/orchestrator/committer_test.go +++ b/internal/orchestrator/committer_test.go @@ -254,7 +254,7 @@ func TestGetSequentialBlockDataToCommit(t *testing.T) { BlockNumbers: []*big.Int{big.NewInt(101), big.NewInt(102), big.NewInt(103)}, }).Return(blockData, nil) - result, err := committer.getSequentialBlockDataToCommit() + result, err := committer.getSequentialBlockDataToCommit(context.Background()) assert.NoError(t, err) assert.NotNil(t, result) @@ -290,7 +290,7 @@ func TestGetSequentialBlockDataToCommitWithDuplicateBlocks(t *testing.T) { BlockNumbers: []*big.Int{big.NewInt(101), big.NewInt(102), big.NewInt(103)}, }).Return(blockData, nil) - result, err := committer.getSequentialBlockDataToCommit() + result, err := committer.getSequentialBlockDataToCommit(context.Background()) assert.NoError(t, err) assert.NotNil(t, result) @@ -320,7 +320,7 @@ func TestCommit(t *testing.T) { mockMainStorage.EXPECT().InsertBlockData(blockData).Return(nil) mockStagingStorage.EXPECT().DeleteStagingData(blockData).Return(nil) - err := committer.commit(blockData) + err := committer.commit(context.Background(), blockData) assert.NoError(t, err) } @@ -343,7 +343,7 @@ func TestHandleGap(t *testing.T) { mockRPC.EXPECT().GetBlocksPerRequest().Return(rpc.BlocksPerRequestConfig{ Blocks: 5, }) - mockRPC.EXPECT().GetFullBlocks([]*big.Int{big.NewInt(100), big.NewInt(101), big.NewInt(102), big.NewInt(103), big.NewInt(104)}).Return([]rpc.GetFullBlockResult{ + mockRPC.EXPECT().GetFullBlocks(context.Background(), []*big.Int{big.NewInt(100), big.NewInt(101), big.NewInt(102), big.NewInt(103), big.NewInt(104)}).Return([]rpc.GetFullBlockResult{ {BlockNumber: big.NewInt(100), Data: common.BlockData{Block: common.Block{Number: big.NewInt(100)}}}, {BlockNumber: big.NewInt(101), Data: common.BlockData{Block: common.Block{Number: big.NewInt(101)}}}, {BlockNumber: big.NewInt(102), Data: common.BlockData{Block: common.Block{Number: big.NewInt(102)}}}, @@ -352,7 +352,7 @@ func TestHandleGap(t *testing.T) { }) mockStagingStorage.EXPECT().InsertStagingData(mock.Anything).Return(nil) - err := committer.handleGap(expectedStartBlockNumber, actualFirstBlock) + err := committer.handleGap(context.Background(), expectedStartBlockNumber, actualFirstBlock) assert.Error(t, err) assert.Contains(t, err.Error(), "first block number (105) in commit batch does not match expected (100)") @@ -463,7 +463,7 @@ func TestHandleMissingStagingData(t *testing.T) { mockRPC.EXPECT().GetBlocksPerRequest().Return(rpc.BlocksPerRequestConfig{ Blocks: 100, }) - mockRPC.EXPECT().GetFullBlocks([]*big.Int{big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4)}).Return([]rpc.GetFullBlockResult{ + mockRPC.EXPECT().GetFullBlocks(context.Background(), []*big.Int{big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4)}).Return([]rpc.GetFullBlockResult{ {BlockNumber: big.NewInt(0), Data: common.BlockData{Block: common.Block{Number: big.NewInt(0)}}}, {BlockNumber: big.NewInt(1), Data: common.BlockData{Block: common.Block{Number: big.NewInt(1)}}}, {BlockNumber: big.NewInt(2), Data: common.BlockData{Block: common.Block{Number: big.NewInt(2)}}}, @@ -482,7 +482,7 @@ func TestHandleMissingStagingData(t *testing.T) { BlockNumbers: []*big.Int{big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4)}, }).Return(blockData, nil) - result, err := committer.getSequentialBlockDataToCommit() + result, err := committer.getSequentialBlockDataToCommit(context.Background()) assert.NoError(t, err) assert.Nil(t, result) @@ -509,7 +509,7 @@ func TestHandleMissingStagingDataIsPolledWithCorrectBatchSize(t *testing.T) { mockRPC.EXPECT().GetBlocksPerRequest().Return(rpc.BlocksPerRequestConfig{ Blocks: 3, }) - mockRPC.EXPECT().GetFullBlocks([]*big.Int{big.NewInt(0), big.NewInt(1), big.NewInt(2)}).Return([]rpc.GetFullBlockResult{ + mockRPC.EXPECT().GetFullBlocks(context.Background(), []*big.Int{big.NewInt(0), big.NewInt(1), big.NewInt(2)}).Return([]rpc.GetFullBlockResult{ {BlockNumber: big.NewInt(0), Data: common.BlockData{Block: common.Block{Number: big.NewInt(0)}}}, {BlockNumber: big.NewInt(1), Data: common.BlockData{Block: common.Block{Number: big.NewInt(1)}}}, {BlockNumber: big.NewInt(2), Data: common.BlockData{Block: common.Block{Number: big.NewInt(2)}}}, @@ -526,7 +526,7 @@ func TestHandleMissingStagingDataIsPolledWithCorrectBatchSize(t *testing.T) { BlockNumbers: []*big.Int{big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4)}, }).Return(blockData, nil) - result, err := committer.getSequentialBlockDataToCommit() + result, err := committer.getSequentialBlockDataToCommit(context.Background()) assert.NoError(t, err) assert.Nil(t, result) diff --git a/internal/orchestrator/failure_recoverer.go b/internal/orchestrator/failure_recoverer.go index ab39164..da1ae91 100644 --- a/internal/orchestrator/failure_recoverer.go +++ b/internal/orchestrator/failure_recoverer.go @@ -75,7 +75,7 @@ func (fr *FailureRecoverer) Start(ctx context.Context) { // Trigger worker for recovery log.Debug().Msgf("Triggering Failure Recoverer for blocks: %v", blocksToTrigger) worker := worker.NewWorker(fr.rpc) - results := worker.Run(blocksToTrigger) + results := worker.Run(ctx, blocksToTrigger) fr.handleWorkerResults(blockFailures, results) // Track recovery activity diff --git a/internal/orchestrator/poller.go b/internal/orchestrator/poller.go index 4b2e3f1..a86f1c7 100644 --- a/internal/orchestrator/poller.go +++ b/internal/orchestrator/poller.go @@ -103,7 +103,7 @@ func (p *Poller) Start(ctx context.Context) { return } blockRangeMutex.Lock() - blockNumbers, err := p.getNextBlockRange() + blockNumbers, err := p.getNextBlockRange(pollCtx) blockRangeMutex.Unlock() if pollCtx.Err() != nil { @@ -117,7 +117,7 @@ func (p *Poller) Start(ctx context.Context) { continue } - lastPolledBlock := p.Poll(blockNumbers) + lastPolledBlock := p.Poll(pollCtx, blockNumbers) if p.reachedPollLimit(lastPolledBlock) { log.Debug().Msg("Reached poll limit, exiting poller") cancel() @@ -146,7 +146,7 @@ func (p *Poller) Start(ctx context.Context) { } } -func (p *Poller) Poll(blockNumbers []*big.Int) (lastPolledBlock *big.Int) { +func (p *Poller) Poll(ctx context.Context, blockNumbers []*big.Int) (lastPolledBlock *big.Int) { if len(blockNumbers) < 1 { log.Debug().Msg("No blocks to poll, skipping") return @@ -161,7 +161,7 @@ func (p *Poller) Poll(blockNumbers []*big.Int) (lastPolledBlock *big.Int) { metrics.PollerLastTriggeredBlock.Set(endBlockNumberFloat) worker := worker.NewWorker(p.rpc) - results := worker.Run(blockNumbers) + results := worker.Run(ctx, blockNumbers) p.handleWorkerResults(results) return endBlock } @@ -170,8 +170,8 @@ func (p *Poller) reachedPollLimit(blockNumber *big.Int) bool { return blockNumber == nil || (p.pollUntilBlock.Sign() > 0 && blockNumber.Cmp(p.pollUntilBlock) >= 0) } -func (p *Poller) getNextBlockRange() ([]*big.Int, error) { - latestBlock, err := p.rpc.GetLatestBlockNumber() +func (p *Poller) getNextBlockRange(ctx context.Context) ([]*big.Int, error) { + latestBlock, err := p.rpc.GetLatestBlockNumber(ctx) if err != nil { return nil, err } diff --git a/internal/orchestrator/reorg_handler.go b/internal/orchestrator/reorg_handler.go index 81c123f..2de8b95 100644 --- a/internal/orchestrator/reorg_handler.go +++ b/internal/orchestrator/reorg_handler.go @@ -83,7 +83,7 @@ func (rh *ReorgHandler) Start(ctx context.Context) { rh.publisher.Close() return case <-ticker.C: - mostRecentBlockChecked, err := rh.RunFromBlock(rh.lastCheckedBlock) + mostRecentBlockChecked, err := rh.RunFromBlock(ctx, rh.lastCheckedBlock) if err != nil { log.Error().Err(err).Msgf("Error during reorg handling: %s", err.Error()) continue @@ -99,7 +99,7 @@ func (rh *ReorgHandler) Start(ctx context.Context) { } } -func (rh *ReorgHandler) RunFromBlock(latestCheckedBlock *big.Int) (lastCheckedBlock *big.Int, err error) { +func (rh *ReorgHandler) RunFromBlock(ctx context.Context, latestCheckedBlock *big.Int) (lastCheckedBlock *big.Int, err error) { fromBlock, toBlock, err := rh.getReorgCheckRange(latestCheckedBlock) if err != nil { return nil, err @@ -130,7 +130,7 @@ func (rh *ReorgHandler) RunFromBlock(latestCheckedBlock *big.Int) (lastCheckedBl metrics.ReorgCounter.Inc() reorgedBlockNumbers := make([]*big.Int, 0) - err = rh.findReorgedBlockNumbers(blockHeaders[firstMismatchIndex:], &reorgedBlockNumbers) + err = rh.findReorgedBlockNumbers(ctx, blockHeaders[firstMismatchIndex:], &reorgedBlockNumbers) if err != nil { return nil, fmt.Errorf("error finding reorged block numbers: %w", err) } @@ -140,7 +140,7 @@ func (rh *ReorgHandler) RunFromBlock(latestCheckedBlock *big.Int) (lastCheckedBl return mostRecentBlockHeader.Number, nil } - err = rh.handleReorg(reorgedBlockNumbers) + err = rh.handleReorg(ctx, reorgedBlockNumbers) if err != nil { return nil, fmt.Errorf("error while handling reorg: %w", err) } @@ -190,8 +190,8 @@ func findIndexOfFirstHashMismatch(blockHeadersDescending []common.BlockHeader) ( return -1, nil } -func (rh *ReorgHandler) findReorgedBlockNumbers(blockHeadersDescending []common.BlockHeader, reorgedBlockNumbers *[]*big.Int) error { - newBlocksByNumber, err := rh.getNewBlocksByNumber(blockHeadersDescending) +func (rh *ReorgHandler) findReorgedBlockNumbers(ctx context.Context, blockHeadersDescending []common.BlockHeader, reorgedBlockNumbers *[]*big.Int) error { + newBlocksByNumber, err := rh.getNewBlocksByNumber(ctx, blockHeadersDescending) if err != nil { return err } @@ -219,12 +219,12 @@ func (rh *ReorgHandler) findReorgedBlockNumbers(blockHeadersDescending []common. sort.Slice(nextHeadersBatch, func(i, j int) bool { return nextHeadersBatch[i].Number.Cmp(nextHeadersBatch[j].Number) > 0 }) - return rh.findReorgedBlockNumbers(nextHeadersBatch, reorgedBlockNumbers) + return rh.findReorgedBlockNumbers(ctx, nextHeadersBatch, reorgedBlockNumbers) } return nil } -func (rh *ReorgHandler) getNewBlocksByNumber(blockHeaders []common.BlockHeader) (map[string]common.Block, error) { +func (rh *ReorgHandler) getNewBlocksByNumber(ctx context.Context, blockHeaders []common.BlockHeader) (map[string]common.Block, error) { blockNumbers := make([]*big.Int, 0, len(blockHeaders)) for _, header := range blockHeaders { blockNumbers = append(blockNumbers, header.Number) @@ -241,7 +241,7 @@ func (rh *ReorgHandler) getNewBlocksByNumber(blockHeaders []common.BlockHeader) wg.Add(1) go func(chunk []*big.Int) { defer wg.Done() - resultsCh <- rh.rpc.GetBlocks(chunk) + resultsCh <- rh.rpc.GetBlocks(ctx, chunk) if config.Cfg.RPC.Blocks.BatchDelay > 0 { time.Sleep(time.Duration(config.Cfg.RPC.Blocks.BatchDelay) * time.Millisecond) } @@ -264,9 +264,9 @@ func (rh *ReorgHandler) getNewBlocksByNumber(blockHeaders []common.BlockHeader) return fetchedBlocksByNumber, nil } -func (rh *ReorgHandler) handleReorg(reorgedBlockNumbers []*big.Int) error { +func (rh *ReorgHandler) handleReorg(ctx context.Context, reorgedBlockNumbers []*big.Int) error { log.Debug().Msgf("Handling reorg for blocks %v", reorgedBlockNumbers) - results := rh.worker.Run(reorgedBlockNumbers) + results := rh.worker.Run(ctx, reorgedBlockNumbers) data := make([]common.BlockData, 0, len(results)) blocksToDelete := make([]*big.Int, 0, len(results)) for _, result := range results { diff --git a/internal/orchestrator/reorg_handler_test.go b/internal/orchestrator/reorg_handler_test.go index 5872bec..195ba4a 100644 --- a/internal/orchestrator/reorg_handler_test.go +++ b/internal/orchestrator/reorg_handler_test.go @@ -240,14 +240,14 @@ func TestFindFirstReorgedBlockNumber(t *testing.T) { {Number: big.NewInt(1), Hash: "hash1", ParentHash: "hash0"}, } - mockRPC.EXPECT().GetBlocks([]*big.Int{big.NewInt(3), big.NewInt(2), big.NewInt(1)}).Return([]rpc.GetBlocksResult{ + mockRPC.EXPECT().GetBlocks(context.Background(), []*big.Int{big.NewInt(3), big.NewInt(2), big.NewInt(1)}).Return([]rpc.GetBlocksResult{ {BlockNumber: big.NewInt(3), Data: common.Block{Hash: "hash3", ParentHash: "hash2"}}, {BlockNumber: big.NewInt(2), Data: common.Block{Hash: "hash2", ParentHash: "hash1"}}, {BlockNumber: big.NewInt(1), Data: common.Block{Hash: "hash1", ParentHash: "hash0"}}, }) reorgedBlockNumbers := []*big.Int{} - err := handler.findReorgedBlockNumbers(reversedBlockHeaders, &reorgedBlockNumbers) + err := handler.findReorgedBlockNumbers(context.Background(), reversedBlockHeaders, &reorgedBlockNumbers) assert.NoError(t, err) assert.Equal(t, []*big.Int{big.NewInt(3)}, reorgedBlockNumbers) @@ -274,14 +274,14 @@ func TestFindAllReorgedBlockNumbersWithLastBlockInSliceAsValid(t *testing.T) { {Number: big.NewInt(1), Hash: "hash1", ParentHash: "hash0"}, } - mockRPC.EXPECT().GetBlocks([]*big.Int{big.NewInt(3), big.NewInt(2), big.NewInt(1)}).Return([]rpc.GetBlocksResult{ + mockRPC.EXPECT().GetBlocks(context.Background(), []*big.Int{big.NewInt(3), big.NewInt(2), big.NewInt(1)}).Return([]rpc.GetBlocksResult{ {BlockNumber: big.NewInt(3), Data: common.Block{Hash: "hash3", ParentHash: "hash2"}}, {BlockNumber: big.NewInt(2), Data: common.Block{Hash: "hash2", ParentHash: "hash1"}}, {BlockNumber: big.NewInt(1), Data: common.Block{Hash: "hash1", ParentHash: "hash0"}}, }) reorgedBlockNumbers := []*big.Int{} - err := handler.findReorgedBlockNumbers(reversedBlockHeaders, &reorgedBlockNumbers) + err := handler.findReorgedBlockNumbers(context.Background(), reversedBlockHeaders, &reorgedBlockNumbers) assert.NoError(t, err) assert.Equal(t, []*big.Int{big.NewInt(3), big.NewInt(2)}, reorgedBlockNumbers) @@ -305,7 +305,7 @@ func TestFindManyReorgsInOneScan(t *testing.T) { mockOrchestratorStorage.EXPECT().GetLastReorgCheckedBlockNumber(big.NewInt(1)).Return(big.NewInt(1), nil) handler := NewReorgHandler(mockRPC, mockStorage) - mockRPC.EXPECT().GetBlocks([]*big.Int{big.NewInt(9), big.NewInt(8), big.NewInt(7), big.NewInt(6), big.NewInt(5), big.NewInt(4), big.NewInt(3), big.NewInt(2), big.NewInt(1)}).Return([]rpc.GetBlocksResult{ + mockRPC.EXPECT().GetBlocks(context.Background(), []*big.Int{big.NewInt(9), big.NewInt(8), big.NewInt(7), big.NewInt(6), big.NewInt(5), big.NewInt(4), big.NewInt(3), big.NewInt(2), big.NewInt(1)}).Return([]rpc.GetBlocksResult{ {BlockNumber: big.NewInt(9), Data: common.Block{Hash: "hash9", ParentHash: "hash8"}}, {BlockNumber: big.NewInt(8), Data: common.Block{Hash: "hash8", ParentHash: "hash7"}}, {BlockNumber: big.NewInt(7), Data: common.Block{Hash: "hash7", ParentHash: "hash6"}}, @@ -330,7 +330,7 @@ func TestFindManyReorgsInOneScan(t *testing.T) { } reorgedBlockNumbers := []*big.Int{} - err := handler.findReorgedBlockNumbers(initialBlockHeaders, &reorgedBlockNumbers) + err := handler.findReorgedBlockNumbers(context.Background(), initialBlockHeaders, &reorgedBlockNumbers) assert.NoError(t, err) assert.Equal(t, []*big.Int{big.NewInt(9), big.NewInt(6), big.NewInt(4)}, reorgedBlockNumbers) @@ -354,14 +354,14 @@ func TestFindManyReorgsInOneScanRecursively(t *testing.T) { mockOrchestratorStorage.EXPECT().GetLastReorgCheckedBlockNumber(big.NewInt(1)).Return(big.NewInt(1), nil) handler := NewReorgHandler(mockRPC, mockStorage) - mockRPC.EXPECT().GetBlocks([]*big.Int{big.NewInt(9), big.NewInt(8), big.NewInt(7), big.NewInt(6)}).Return([]rpc.GetBlocksResult{ + mockRPC.EXPECT().GetBlocks(context.Background(), []*big.Int{big.NewInt(9), big.NewInt(8), big.NewInt(7), big.NewInt(6)}).Return([]rpc.GetBlocksResult{ {BlockNumber: big.NewInt(9), Data: common.Block{Hash: "hash9", ParentHash: "hash8"}}, {BlockNumber: big.NewInt(8), Data: common.Block{Hash: "hash8", ParentHash: "hash7"}}, {BlockNumber: big.NewInt(7), Data: common.Block{Hash: "hash7", ParentHash: "hash6"}}, {BlockNumber: big.NewInt(6), Data: common.Block{Hash: "hash6", ParentHash: "hash5"}}, }).Once() - mockRPC.EXPECT().GetBlocks([]*big.Int{big.NewInt(5), big.NewInt(4), big.NewInt(3), big.NewInt(2)}).Return([]rpc.GetBlocksResult{ + mockRPC.EXPECT().GetBlocks(context.Background(), []*big.Int{big.NewInt(5), big.NewInt(4), big.NewInt(3), big.NewInt(2)}).Return([]rpc.GetBlocksResult{ {BlockNumber: big.NewInt(5), Data: common.Block{Hash: "hash5", ParentHash: "hash4"}}, {BlockNumber: big.NewInt(4), Data: common.Block{Hash: "hash4", ParentHash: "hash3"}}, {BlockNumber: big.NewInt(3), Data: common.Block{Hash: "hash3", ParentHash: "hash2"}}, @@ -383,7 +383,7 @@ func TestFindManyReorgsInOneScanRecursively(t *testing.T) { }, nil) reorgedBlockNumbers := []*big.Int{} - err := handler.findReorgedBlockNumbers(initialBlockHeaders, &reorgedBlockNumbers) + err := handler.findReorgedBlockNumbers(context.Background(), initialBlockHeaders, &reorgedBlockNumbers) assert.NoError(t, err) assert.Equal(t, []*big.Int{big.NewInt(9), big.NewInt(6), big.NewInt(4), big.NewInt(3)}, reorgedBlockNumbers) @@ -407,13 +407,13 @@ func TestFindReorgedBlockNumbersRecursively(t *testing.T) { mockOrchestratorStorage.EXPECT().GetLastReorgCheckedBlockNumber(big.NewInt(1)).Return(big.NewInt(3), nil) handler := NewReorgHandler(mockRPC, mockStorage) - mockRPC.EXPECT().GetBlocks([]*big.Int{big.NewInt(6), big.NewInt(5), big.NewInt(4)}).Return([]rpc.GetBlocksResult{ + mockRPC.EXPECT().GetBlocks(context.Background(), []*big.Int{big.NewInt(6), big.NewInt(5), big.NewInt(4)}).Return([]rpc.GetBlocksResult{ {BlockNumber: big.NewInt(6), Data: common.Block{Hash: "hash6", ParentHash: "hash5"}}, {BlockNumber: big.NewInt(5), Data: common.Block{Hash: "hash5", ParentHash: "hash4"}}, {BlockNumber: big.NewInt(4), Data: common.Block{Hash: "hash4", ParentHash: "hash3"}}, }).Once() - mockRPC.EXPECT().GetBlocks([]*big.Int{big.NewInt(3), big.NewInt(2), big.NewInt(1)}).Return([]rpc.GetBlocksResult{ + mockRPC.EXPECT().GetBlocks(context.Background(), []*big.Int{big.NewInt(3), big.NewInt(2), big.NewInt(1)}).Return([]rpc.GetBlocksResult{ {BlockNumber: big.NewInt(3), Data: common.Block{Hash: "hash3", ParentHash: "hash2"}}, {BlockNumber: big.NewInt(2), Data: common.Block{Hash: "hash2", ParentHash: "hash1"}}, {BlockNumber: big.NewInt(1), Data: common.Block{Hash: "hash1", ParentHash: "hash0"}}, @@ -432,7 +432,7 @@ func TestFindReorgedBlockNumbersRecursively(t *testing.T) { }, nil) reorgedBlockNumbers := []*big.Int{} - err := handler.findReorgedBlockNumbers(initialBlockHeaders, &reorgedBlockNumbers) + err := handler.findReorgedBlockNumbers(context.Background(), initialBlockHeaders, &reorgedBlockNumbers) assert.NoError(t, err) assert.Equal(t, []*big.Int{big.NewInt(6), big.NewInt(5), big.NewInt(4), big.NewInt(3)}, reorgedBlockNumbers) @@ -456,17 +456,17 @@ func TestNewBlocksAreFetchedInBatches(t *testing.T) { mockOrchestratorStorage.EXPECT().GetLastReorgCheckedBlockNumber(big.NewInt(1)).Return(big.NewInt(3), nil) handler := NewReorgHandler(mockRPC, mockStorage) - mockRPC.EXPECT().GetBlocks([]*big.Int{big.NewInt(6), big.NewInt(5)}).Return([]rpc.GetBlocksResult{ + mockRPC.EXPECT().GetBlocks(context.Background(), []*big.Int{big.NewInt(6), big.NewInt(5)}).Return([]rpc.GetBlocksResult{ {BlockNumber: big.NewInt(6), Data: common.Block{Hash: "hash6", ParentHash: "hash5"}}, {BlockNumber: big.NewInt(5), Data: common.Block{Hash: "hash5", ParentHash: "hash4"}}, }).Once() - mockRPC.EXPECT().GetBlocks([]*big.Int{big.NewInt(4), big.NewInt(3)}).Return([]rpc.GetBlocksResult{ + mockRPC.EXPECT().GetBlocks(context.Background(), []*big.Int{big.NewInt(4), big.NewInt(3)}).Return([]rpc.GetBlocksResult{ {BlockNumber: big.NewInt(4), Data: common.Block{Hash: "hash4", ParentHash: "hash3"}}, {BlockNumber: big.NewInt(3), Data: common.Block{Hash: "hash3", ParentHash: "hash2"}}, }).Once() - mockRPC.EXPECT().GetBlocks([]*big.Int{big.NewInt(2)}).Return([]rpc.GetBlocksResult{ + mockRPC.EXPECT().GetBlocks(context.Background(), []*big.Int{big.NewInt(2)}).Return([]rpc.GetBlocksResult{ {BlockNumber: big.NewInt(2), Data: common.Block{Hash: "hash2", ParentHash: "hash1"}}, }).Once() @@ -479,7 +479,7 @@ func TestNewBlocksAreFetchedInBatches(t *testing.T) { } reorgedBlockNumbers := []*big.Int{} - err := handler.findReorgedBlockNumbers(initialBlockHeaders, &reorgedBlockNumbers) + err := handler.findReorgedBlockNumbers(context.Background(), initialBlockHeaders, &reorgedBlockNumbers) assert.NoError(t, err) assert.Equal(t, []*big.Int{}, reorgedBlockNumbers) @@ -497,7 +497,7 @@ func TestHandleReorg(t *testing.T) { mockRPC.EXPECT().GetChainID().Return(big.NewInt(1)) mockRPC.EXPECT().GetBlocksPerRequest().Return(rpc.BlocksPerRequestConfig{Blocks: 100}) - mockRPC.EXPECT().GetFullBlocks(mock.Anything).Return([]rpc.GetFullBlockResult{ + mockRPC.EXPECT().GetFullBlocks(context.Background(), mock.Anything).Return([]rpc.GetFullBlockResult{ {BlockNumber: big.NewInt(1), Data: common.BlockData{}}, {BlockNumber: big.NewInt(2), Data: common.BlockData{}}, {BlockNumber: big.NewInt(3), Data: common.BlockData{}}, @@ -507,7 +507,7 @@ func TestHandleReorg(t *testing.T) { mockMainStorage.EXPECT().ReplaceBlockData(mock.Anything).Return([]common.BlockData{}, nil) handler := NewReorgHandler(mockRPC, mockStorage) - err := handler.handleReorg([]*big.Int{big.NewInt(1), big.NewInt(2), big.NewInt(3)}) + err := handler.handleReorg(context.Background(), []*big.Int{big.NewInt(1), big.NewInt(2), big.NewInt(3)}) assert.NoError(t, err) } @@ -560,7 +560,7 @@ func TestReorgHandlingIsSkippedIfMostRecentAndLastCheckedBlockAreSame(t *testing mockMainStorage.EXPECT().GetMaxBlockNumber(big.NewInt(1)).Return(big.NewInt(100), nil) handler := NewReorgHandler(mockRPC, mockStorage) - mostRecentBlockChecked, err := handler.RunFromBlock(big.NewInt(100)) + mostRecentBlockChecked, err := handler.RunFromBlock(context.Background(), big.NewInt(100)) assert.NoError(t, err) assert.Nil(t, mostRecentBlockChecked) @@ -597,7 +597,7 @@ func TestHandleReorgWithSingleBlockReorg(t *testing.T) { {Number: big.NewInt(100), Hash: "hash100", ParentHash: "hash99"}, }, nil) - mockRPC.EXPECT().GetBlocks([]*big.Int{big.NewInt(105), big.NewInt(104), big.NewInt(103), big.NewInt(102), big.NewInt(101), big.NewInt(100)}).Return([]rpc.GetBlocksResult{ + mockRPC.EXPECT().GetBlocks(context.Background(), []*big.Int{big.NewInt(105), big.NewInt(104), big.NewInt(103), big.NewInt(102), big.NewInt(101), big.NewInt(100)}).Return([]rpc.GetBlocksResult{ {BlockNumber: big.NewInt(105), Data: common.Block{Hash: "hash105", ParentHash: "hash104"}}, {BlockNumber: big.NewInt(104), Data: common.Block{Hash: "hash104", ParentHash: "hash103"}}, {BlockNumber: big.NewInt(103), Data: common.Block{Hash: "hash103", ParentHash: "hash102"}}, @@ -606,7 +606,7 @@ func TestHandleReorgWithSingleBlockReorg(t *testing.T) { {BlockNumber: big.NewInt(100), Data: common.Block{Hash: "hash100", ParentHash: "hash99"}}, }) - mockRPC.EXPECT().GetFullBlocks([]*big.Int{big.NewInt(105)}).Return([]rpc.GetFullBlockResult{ + mockRPC.EXPECT().GetFullBlocks(context.Background(), []*big.Int{big.NewInt(105)}).Return([]rpc.GetFullBlockResult{ {BlockNumber: big.NewInt(105), Data: common.BlockData{}}, }) @@ -615,7 +615,7 @@ func TestHandleReorgWithSingleBlockReorg(t *testing.T) { })).Return([]common.BlockData{}, nil) handler := NewReorgHandler(mockRPC, mockStorage) - mostRecentBlockChecked, err := handler.RunFromBlock(big.NewInt(99)) + mostRecentBlockChecked, err := handler.RunFromBlock(context.Background(), big.NewInt(99)) assert.NoError(t, err) assert.Equal(t, big.NewInt(109), mostRecentBlockChecked) @@ -652,7 +652,7 @@ func TestHandleReorgWithLatestBlockReorged(t *testing.T) { {Number: big.NewInt(100), Hash: "hash100", ParentHash: "hash99"}, }, nil) - mockRPC.EXPECT().GetBlocks([]*big.Int{big.NewInt(108), big.NewInt(107), big.NewInt(106), big.NewInt(105), big.NewInt(104), big.NewInt(103), big.NewInt(102), big.NewInt(101), big.NewInt(100)}).Return([]rpc.GetBlocksResult{ + mockRPC.EXPECT().GetBlocks(context.Background(), []*big.Int{big.NewInt(108), big.NewInt(107), big.NewInt(106), big.NewInt(105), big.NewInt(104), big.NewInt(103), big.NewInt(102), big.NewInt(101), big.NewInt(100)}).Return([]rpc.GetBlocksResult{ {BlockNumber: big.NewInt(108), Data: common.Block{Hash: "hash108", ParentHash: "hash107"}}, {BlockNumber: big.NewInt(107), Data: common.Block{Hash: "hash107", ParentHash: "hash106"}}, {BlockNumber: big.NewInt(106), Data: common.Block{Hash: "hash106", ParentHash: "hash105"}}, @@ -664,7 +664,7 @@ func TestHandleReorgWithLatestBlockReorged(t *testing.T) { {BlockNumber: big.NewInt(100), Data: common.Block{Hash: "hash100", ParentHash: "hash99"}}, }) - mockRPC.EXPECT().GetFullBlocks([]*big.Int{big.NewInt(108), big.NewInt(107), big.NewInt(106), big.NewInt(105), big.NewInt(104), big.NewInt(103), big.NewInt(102), big.NewInt(101)}).Return([]rpc.GetFullBlockResult{ + mockRPC.EXPECT().GetFullBlocks(context.Background(), []*big.Int{big.NewInt(108), big.NewInt(107), big.NewInt(106), big.NewInt(105), big.NewInt(104), big.NewInt(103), big.NewInt(102), big.NewInt(101)}).Return([]rpc.GetFullBlockResult{ {BlockNumber: big.NewInt(101), Data: common.BlockData{}}, {BlockNumber: big.NewInt(102), Data: common.BlockData{}}, {BlockNumber: big.NewInt(103), Data: common.BlockData{}}, @@ -680,7 +680,7 @@ func TestHandleReorgWithLatestBlockReorged(t *testing.T) { })).Return([]common.BlockData{}, nil) handler := NewReorgHandler(mockRPC, mockStorage) - mostRecentBlockChecked, err := handler.RunFromBlock(big.NewInt(99)) + mostRecentBlockChecked, err := handler.RunFromBlock(context.Background(), big.NewInt(99)) assert.NoError(t, err) assert.Equal(t, big.NewInt(109), mostRecentBlockChecked) @@ -717,7 +717,7 @@ func TestHandleReorgWithManyBlocks(t *testing.T) { {Number: big.NewInt(100), Hash: "hash100", ParentHash: "hash99"}, }, nil) - mockRPC.EXPECT().GetBlocks([]*big.Int{big.NewInt(107), big.NewInt(106), big.NewInt(105), big.NewInt(104), big.NewInt(103), big.NewInt(102), big.NewInt(101), big.NewInt(100)}).Return([]rpc.GetBlocksResult{ + mockRPC.EXPECT().GetBlocks(context.Background(), []*big.Int{big.NewInt(107), big.NewInt(106), big.NewInt(105), big.NewInt(104), big.NewInt(103), big.NewInt(102), big.NewInt(101), big.NewInt(100)}).Return([]rpc.GetBlocksResult{ {BlockNumber: big.NewInt(107), Data: common.Block{Hash: "hash107", ParentHash: "hash106"}}, {BlockNumber: big.NewInt(106), Data: common.Block{Hash: "hash106", ParentHash: "hash105"}}, {BlockNumber: big.NewInt(105), Data: common.Block{Hash: "hash105", ParentHash: "hash104"}}, @@ -728,7 +728,7 @@ func TestHandleReorgWithManyBlocks(t *testing.T) { {BlockNumber: big.NewInt(100), Data: common.Block{Hash: "hash100", ParentHash: "hash99"}}, }) - mockRPC.EXPECT().GetFullBlocks([]*big.Int{big.NewInt(107), big.NewInt(106), big.NewInt(105), big.NewInt(104), big.NewInt(103)}).Return([]rpc.GetFullBlockResult{ + mockRPC.EXPECT().GetFullBlocks(context.Background(), []*big.Int{big.NewInt(107), big.NewInt(106), big.NewInt(105), big.NewInt(104), big.NewInt(103)}).Return([]rpc.GetFullBlockResult{ {BlockNumber: big.NewInt(107), Data: common.BlockData{}}, {BlockNumber: big.NewInt(106), Data: common.BlockData{}}, {BlockNumber: big.NewInt(105), Data: common.BlockData{}}, @@ -741,7 +741,7 @@ func TestHandleReorgWithManyBlocks(t *testing.T) { })).Return([]common.BlockData{}, nil) handler := NewReorgHandler(mockRPC, mockStorage) - mostRecentBlockChecked, err := handler.RunFromBlock(big.NewInt(99)) + mostRecentBlockChecked, err := handler.RunFromBlock(context.Background(), big.NewInt(99)) assert.NoError(t, err) assert.Equal(t, big.NewInt(109), mostRecentBlockChecked) @@ -778,7 +778,7 @@ func TestHandleReorgWithDuplicateBlocks(t *testing.T) { }, nil) handler := NewReorgHandler(mockRPC, mockStorage) - mostRecentBlockChecked, err := handler.RunFromBlock(big.NewInt(6268162)) + mostRecentBlockChecked, err := handler.RunFromBlock(context.Background(), big.NewInt(6268162)) assert.NoError(t, err) assert.Equal(t, big.NewInt(6268172), mostRecentBlockChecked) @@ -815,7 +815,7 @@ func TestNothingIsDoneForCorrectBlocks(t *testing.T) { }, nil) handler := NewReorgHandler(mockRPC, mockStorage) - mostRecentBlockChecked, err := handler.RunFromBlock(big.NewInt(6268163)) + mostRecentBlockChecked, err := handler.RunFromBlock(context.Background(), big.NewInt(6268163)) assert.NoError(t, err) assert.Equal(t, big.NewInt(6268173), mostRecentBlockChecked) diff --git a/internal/rpc/batcher.go b/internal/rpc/batcher.go index ea82d7b..c34fd13 100644 --- a/internal/rpc/batcher.go +++ b/internal/rpc/batcher.go @@ -16,9 +16,9 @@ type RPCFetchBatchResult[K any, T any] struct { Result T } -func RPCFetchInBatches[K any, T any](rpc *Client, keys []K, batchSize int, batchDelay int, method string, argsFunc func(K) []interface{}) []RPCFetchBatchResult[K, T] { +func RPCFetchInBatches[K any, T any](rpc *Client, ctx context.Context, keys []K, batchSize int, batchDelay int, method string, argsFunc func(K) []interface{}) []RPCFetchBatchResult[K, T] { if len(keys) <= batchSize { - return RPCFetchSingleBatch[K, T](rpc, keys, method, argsFunc) + return RPCFetchSingleBatch[K, T](rpc, ctx, keys, method, argsFunc) } chunks := common.SliceToChunks[K](keys, batchSize) @@ -31,7 +31,7 @@ func RPCFetchInBatches[K any, T any](rpc *Client, keys []K, batchSize int, batch wg.Add(1) go func(chunk []K) { defer wg.Done() - resultsCh <- RPCFetchSingleBatch[K, T](rpc, chunk, method, argsFunc) + resultsCh <- RPCFetchSingleBatch[K, T](rpc, ctx, chunk, method, argsFunc) if batchDelay > 0 { time.Sleep(time.Duration(batchDelay) * time.Millisecond) } @@ -50,7 +50,7 @@ func RPCFetchInBatches[K any, T any](rpc *Client, keys []K, batchSize int, batch return results } -func RPCFetchSingleBatch[K any, T any](rpc *Client, keys []K, method string, argsFunc func(K) []interface{}) []RPCFetchBatchResult[K, T] { +func RPCFetchSingleBatch[K any, T any](rpc *Client, ctx context.Context, keys []K, method string, argsFunc func(K) []interface{}) []RPCFetchBatchResult[K, T] { batch := make([]gethRpc.BatchElem, len(keys)) results := make([]RPCFetchBatchResult[K, T], len(keys)) @@ -63,7 +63,7 @@ func RPCFetchSingleBatch[K any, T any](rpc *Client, keys []K, method string, arg } } - err := rpc.RPCClient.BatchCallContext(context.Background(), batch) + err := rpc.RPCClient.BatchCallContext(ctx, batch) if err != nil { for i := range results { results[i].Error = err diff --git a/internal/rpc/rpc.go b/internal/rpc/rpc.go index c18bd7b..c9931bc 100644 --- a/internal/rpc/rpc.go +++ b/internal/rpc/rpc.go @@ -40,16 +40,16 @@ type BlocksPerRequestConfig struct { } type IRPCClient interface { - GetFullBlocks(blockNumbers []*big.Int) []GetFullBlockResult - GetBlocks(blockNumbers []*big.Int) []GetBlocksResult - GetTransactions(txHashes []string) []GetTransactionsResult - GetLatestBlockNumber() (*big.Int, error) + GetFullBlocks(ctx context.Context, blockNumbers []*big.Int) []GetFullBlockResult + GetBlocks(ctx context.Context, blockNumbers []*big.Int) []GetBlocksResult + GetTransactions(ctx context.Context, txHashes []string) []GetTransactionsResult + GetLatestBlockNumber(ctx context.Context) (*big.Int, error) GetChainID() *big.Int GetURL() string GetBlocksPerRequest() BlocksPerRequestConfig IsWebsocket() bool SupportsTraceBlock() bool - HasCode(address string) (bool, error) + HasCode(ctx context.Context, address string) (bool, error) Close() } @@ -89,7 +89,7 @@ func Initialize() (IRPCClient, error) { return nil, checkErr } - chainIdErr := rpc.setChainID() + chainIdErr := rpc.setChainID(context.Background()) if chainIdErr != nil { return nil, chainIdErr } @@ -208,8 +208,8 @@ func (rpc *Client) checkTraceBlockSupport() error { return nil } -func (rpc *Client) setChainID() error { - chainID, err := rpc.EthClient.ChainID(context.Background()) +func (rpc *Client) setChainID(ctx context.Context) error { + chainID, err := rpc.EthClient.ChainID(ctx) if err != nil { return fmt.Errorf("failed to get chain ID: %v", err) } @@ -218,7 +218,7 @@ func (rpc *Client) setChainID() error { return nil } -func (rpc *Client) GetFullBlocks(blockNumbers []*big.Int) []GetFullBlockResult { +func (rpc *Client) GetFullBlocks(ctx context.Context, blockNumbers []*big.Int) []GetFullBlockResult { var wg sync.WaitGroup var blocks []RPCFetchBatchResult[*big.Int, common.RawBlock] var logs []RPCFetchBatchResult[*big.Int, common.RawLogs] @@ -228,20 +228,20 @@ func (rpc *Client) GetFullBlocks(blockNumbers []*big.Int) []GetFullBlockResult { go func() { defer wg.Done() - result := RPCFetchSingleBatch[*big.Int, common.RawBlock](rpc, blockNumbers, "eth_getBlockByNumber", GetBlockWithTransactionsParams) + result := RPCFetchSingleBatch[*big.Int, common.RawBlock](rpc, ctx, blockNumbers, "eth_getBlockByNumber", GetBlockWithTransactionsParams) blocks = result }() if rpc.supportsBlockReceipts { go func() { defer wg.Done() - result := RPCFetchInBatches[*big.Int, common.RawReceipts](rpc, blockNumbers, rpc.blocksPerRequest.Receipts, config.Cfg.RPC.BlockReceipts.BatchDelay, "eth_getBlockReceipts", GetBlockReceiptsParams) + result := RPCFetchInBatches[*big.Int, common.RawReceipts](rpc, ctx, blockNumbers, rpc.blocksPerRequest.Receipts, config.Cfg.RPC.BlockReceipts.BatchDelay, "eth_getBlockReceipts", GetBlockReceiptsParams) receipts = result }() } else { go func() { defer wg.Done() - result := RPCFetchInBatches[*big.Int, common.RawLogs](rpc, blockNumbers, rpc.blocksPerRequest.Logs, config.Cfg.RPC.Logs.BatchDelay, "eth_getLogs", GetLogsParams) + result := RPCFetchInBatches[*big.Int, common.RawLogs](rpc, ctx, blockNumbers, rpc.blocksPerRequest.Logs, config.Cfg.RPC.Logs.BatchDelay, "eth_getLogs", GetLogsParams) logs = result }() } @@ -250,7 +250,7 @@ func (rpc *Client) GetFullBlocks(blockNumbers []*big.Int) []GetFullBlockResult { wg.Add(1) go func() { defer wg.Done() - result := RPCFetchInBatches[*big.Int, common.RawTraces](rpc, blockNumbers, rpc.blocksPerRequest.Traces, config.Cfg.RPC.Traces.BatchDelay, "trace_block", TraceBlockParams) + result := RPCFetchInBatches[*big.Int, common.RawTraces](rpc, ctx, blockNumbers, rpc.blocksPerRequest.Traces, config.Cfg.RPC.Traces.BatchDelay, "trace_block", TraceBlockParams) traces = result }() } @@ -260,7 +260,7 @@ func (rpc *Client) GetFullBlocks(blockNumbers []*big.Int) []GetFullBlockResult { return SerializeFullBlocks(rpc.chainID, blocks, logs, traces, receipts) } -func (rpc *Client) GetBlocks(blockNumbers []*big.Int) []GetBlocksResult { +func (rpc *Client) GetBlocks(ctx context.Context, blockNumbers []*big.Int) []GetBlocksResult { var wg sync.WaitGroup var blocks []RPCFetchBatchResult[*big.Int, common.RawBlock] @@ -268,14 +268,14 @@ func (rpc *Client) GetBlocks(blockNumbers []*big.Int) []GetBlocksResult { go func() { defer wg.Done() - blocks = RPCFetchSingleBatch[*big.Int, common.RawBlock](rpc, blockNumbers, "eth_getBlockByNumber", GetBlockWithoutTransactionsParams) + blocks = RPCFetchSingleBatch[*big.Int, common.RawBlock](rpc, ctx, blockNumbers, "eth_getBlockByNumber", GetBlockWithoutTransactionsParams) }() wg.Wait() return SerializeBlocks(rpc.chainID, blocks) } -func (rpc *Client) GetTransactions(txHashes []string) []GetTransactionsResult { +func (rpc *Client) GetTransactions(ctx context.Context, txHashes []string) []GetTransactionsResult { var wg sync.WaitGroup var transactions []RPCFetchBatchResult[string, common.RawTransaction] @@ -283,23 +283,23 @@ func (rpc *Client) GetTransactions(txHashes []string) []GetTransactionsResult { go func() { defer wg.Done() - transactions = RPCFetchSingleBatch[string, common.RawTransaction](rpc, txHashes, "eth_getTransactionByHash", GetTransactionParams) + transactions = RPCFetchSingleBatch[string, common.RawTransaction](rpc, ctx, txHashes, "eth_getTransactionByHash", GetTransactionParams) }() wg.Wait() return SerializeTransactions(rpc.chainID, transactions) } -func (rpc *Client) GetLatestBlockNumber() (*big.Int, error) { - blockNumber, err := rpc.EthClient.BlockNumber(context.Background()) +func (rpc *Client) GetLatestBlockNumber(ctx context.Context) (*big.Int, error) { + blockNumber, err := rpc.EthClient.BlockNumber(ctx) if err != nil { return nil, fmt.Errorf("failed to get latest block number: %v", err) } return new(big.Int).SetUint64(blockNumber), nil } -func (rpc *Client) HasCode(address string) (bool, error) { - code, err := rpc.EthClient.CodeAt(context.Background(), gethCommon.HexToAddress(address), nil) +func (rpc *Client) HasCode(ctx context.Context, address string) (bool, error) { + code, err := rpc.EthClient.CodeAt(ctx, gethCommon.HexToAddress(address), nil) if err != nil { return false, err } diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 1fa949e..9d42e05 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -1,6 +1,7 @@ package worker import ( + "context" "math/big" "sort" "sync" @@ -23,13 +24,19 @@ func NewWorker(rpc rpc.IRPCClient) *Worker { } } -func (w *Worker) processChunkWithRetry(chunk []*big.Int, resultsCh chan<- []rpc.GetFullBlockResult) { +func (w *Worker) processChunkWithRetry(ctx context.Context, chunk []*big.Int, resultsCh chan<- []rpc.GetFullBlockResult) { + select { + case <-ctx.Done(): + return + default: + } + defer func() { time.Sleep(time.Duration(config.Cfg.RPC.Blocks.BatchDelay) * time.Millisecond) }() // Try with current chunk size - results := w.rpc.GetFullBlocks(chunk) + results := w.rpc.GetFullBlocks(ctx, chunk) if len(chunk) == 1 { // chunk size 1 is the minimum, so we return whatever we get @@ -61,7 +68,7 @@ func (w *Worker) processChunkWithRetry(chunk []*big.Int, resultsCh chan<- []rpc. // can't split any further, so try one last time if len(failedBlocks) == 1 { - w.processChunkWithRetry(failedBlocks, resultsCh) + w.processChunkWithRetry(ctx, failedBlocks, resultsCh) return } @@ -77,18 +84,18 @@ func (w *Worker) processChunkWithRetry(chunk []*big.Int, resultsCh chan<- []rpc. go func() { defer wg.Done() - w.processChunkWithRetry(leftChunk, resultsCh) + w.processChunkWithRetry(ctx, leftChunk, resultsCh) }() go func() { defer wg.Done() - w.processChunkWithRetry(rightChunk, resultsCh) + w.processChunkWithRetry(ctx, rightChunk, resultsCh) }() wg.Wait() } -func (w *Worker) Run(blockNumbers []*big.Int) []rpc.GetFullBlockResult { +func (w *Worker) Run(ctx context.Context, blockNumbers []*big.Int) []rpc.GetFullBlockResult { blockCount := len(blockNumbers) chunks := common.SliceToChunks(blockNumbers, w.rpc.GetBlocksPerRequest().Blocks) @@ -98,10 +105,18 @@ func (w *Worker) Run(blockNumbers []*big.Int) []rpc.GetFullBlockResult { log.Debug().Msgf("Worker Processing %d blocks in %d chunks of max %d blocks", blockCount, len(chunks), w.rpc.GetBlocksPerRequest().Blocks) for _, chunk := range chunks { + select { + case <-ctx.Done(): + log.Debug().Msg("Context canceled, stopping Worker") + return nil + default: + // continue processing + } + wg.Add(1) go func(chunk []*big.Int) { defer wg.Done() - w.processChunkWithRetry(chunk, resultsCh) + w.processChunkWithRetry(ctx, chunk, resultsCh) }(chunk) } diff --git a/test/mocks/MockIRPCClient.go b/test/mocks/MockIRPCClient.go index 5d64428..8ac63cf 100644 --- a/test/mocks/MockIRPCClient.go +++ b/test/mocks/MockIRPCClient.go @@ -5,9 +5,11 @@ package mocks import ( + context "context" big "math/big" mock "github.com/stretchr/testify/mock" + rpc "github.com/thirdweb-dev/indexer/internal/rpc" ) @@ -56,17 +58,17 @@ func (_c *MockIRPCClient_Close_Call) RunAndReturn(run func()) *MockIRPCClient_Cl return _c } -// GetBlocks provides a mock function with given fields: blockNumbers -func (_m *MockIRPCClient) GetBlocks(blockNumbers []*big.Int) []rpc.GetBlocksResult { - ret := _m.Called(blockNumbers) +// GetBlocks provides a mock function with given fields: ctx, blockNumbers +func (_m *MockIRPCClient) GetBlocks(ctx context.Context, blockNumbers []*big.Int) []rpc.GetBlocksResult { + ret := _m.Called(ctx, blockNumbers) if len(ret) == 0 { panic("no return value specified for GetBlocks") } var r0 []rpc.GetBlocksResult - if rf, ok := ret.Get(0).(func([]*big.Int) []rpc.GetBlocksResult); ok { - r0 = rf(blockNumbers) + if rf, ok := ret.Get(0).(func(context.Context, []*big.Int) []rpc.GetBlocksResult); ok { + r0 = rf(ctx, blockNumbers) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]rpc.GetBlocksResult) @@ -82,14 +84,15 @@ type MockIRPCClient_GetBlocks_Call struct { } // GetBlocks is a helper method to define mock.On call +// - ctx context.Context // - blockNumbers []*big.Int -func (_e *MockIRPCClient_Expecter) GetBlocks(blockNumbers interface{}) *MockIRPCClient_GetBlocks_Call { - return &MockIRPCClient_GetBlocks_Call{Call: _e.mock.On("GetBlocks", blockNumbers)} +func (_e *MockIRPCClient_Expecter) GetBlocks(ctx interface{}, blockNumbers interface{}) *MockIRPCClient_GetBlocks_Call { + return &MockIRPCClient_GetBlocks_Call{Call: _e.mock.On("GetBlocks", ctx, blockNumbers)} } -func (_c *MockIRPCClient_GetBlocks_Call) Run(run func(blockNumbers []*big.Int)) *MockIRPCClient_GetBlocks_Call { +func (_c *MockIRPCClient_GetBlocks_Call) Run(run func(ctx context.Context, blockNumbers []*big.Int)) *MockIRPCClient_GetBlocks_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]*big.Int)) + run(args[0].(context.Context), args[1].([]*big.Int)) }) return _c } @@ -99,7 +102,7 @@ func (_c *MockIRPCClient_GetBlocks_Call) Return(_a0 []rpc.GetBlocksResult) *Mock return _c } -func (_c *MockIRPCClient_GetBlocks_Call) RunAndReturn(run func([]*big.Int) []rpc.GetBlocksResult) *MockIRPCClient_GetBlocks_Call { +func (_c *MockIRPCClient_GetBlocks_Call) RunAndReturn(run func(context.Context, []*big.Int) []rpc.GetBlocksResult) *MockIRPCClient_GetBlocks_Call { _c.Call.Return(run) return _c } @@ -196,17 +199,17 @@ func (_c *MockIRPCClient_GetChainID_Call) RunAndReturn(run func() *big.Int) *Moc return _c } -// GetFullBlocks provides a mock function with given fields: blockNumbers -func (_m *MockIRPCClient) GetFullBlocks(blockNumbers []*big.Int) []rpc.GetFullBlockResult { - ret := _m.Called(blockNumbers) +// GetFullBlocks provides a mock function with given fields: ctx, blockNumbers +func (_m *MockIRPCClient) GetFullBlocks(ctx context.Context, blockNumbers []*big.Int) []rpc.GetFullBlockResult { + ret := _m.Called(ctx, blockNumbers) if len(ret) == 0 { panic("no return value specified for GetFullBlocks") } var r0 []rpc.GetFullBlockResult - if rf, ok := ret.Get(0).(func([]*big.Int) []rpc.GetFullBlockResult); ok { - r0 = rf(blockNumbers) + if rf, ok := ret.Get(0).(func(context.Context, []*big.Int) []rpc.GetFullBlockResult); ok { + r0 = rf(ctx, blockNumbers) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]rpc.GetFullBlockResult) @@ -222,14 +225,15 @@ type MockIRPCClient_GetFullBlocks_Call struct { } // GetFullBlocks is a helper method to define mock.On call +// - ctx context.Context // - blockNumbers []*big.Int -func (_e *MockIRPCClient_Expecter) GetFullBlocks(blockNumbers interface{}) *MockIRPCClient_GetFullBlocks_Call { - return &MockIRPCClient_GetFullBlocks_Call{Call: _e.mock.On("GetFullBlocks", blockNumbers)} +func (_e *MockIRPCClient_Expecter) GetFullBlocks(ctx interface{}, blockNumbers interface{}) *MockIRPCClient_GetFullBlocks_Call { + return &MockIRPCClient_GetFullBlocks_Call{Call: _e.mock.On("GetFullBlocks", ctx, blockNumbers)} } -func (_c *MockIRPCClient_GetFullBlocks_Call) Run(run func(blockNumbers []*big.Int)) *MockIRPCClient_GetFullBlocks_Call { +func (_c *MockIRPCClient_GetFullBlocks_Call) Run(run func(ctx context.Context, blockNumbers []*big.Int)) *MockIRPCClient_GetFullBlocks_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]*big.Int)) + run(args[0].(context.Context), args[1].([]*big.Int)) }) return _c } @@ -239,14 +243,14 @@ func (_c *MockIRPCClient_GetFullBlocks_Call) Return(_a0 []rpc.GetFullBlockResult return _c } -func (_c *MockIRPCClient_GetFullBlocks_Call) RunAndReturn(run func([]*big.Int) []rpc.GetFullBlockResult) *MockIRPCClient_GetFullBlocks_Call { +func (_c *MockIRPCClient_GetFullBlocks_Call) RunAndReturn(run func(context.Context, []*big.Int) []rpc.GetFullBlockResult) *MockIRPCClient_GetFullBlocks_Call { _c.Call.Return(run) return _c } -// GetLatestBlockNumber provides a mock function with no fields -func (_m *MockIRPCClient) GetLatestBlockNumber() (*big.Int, error) { - ret := _m.Called() +// GetLatestBlockNumber provides a mock function with given fields: ctx +func (_m *MockIRPCClient) GetLatestBlockNumber(ctx context.Context) (*big.Int, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for GetLatestBlockNumber") @@ -254,19 +258,19 @@ func (_m *MockIRPCClient) GetLatestBlockNumber() (*big.Int, error) { var r0 *big.Int var r1 error - if rf, ok := ret.Get(0).(func() (*big.Int, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (*big.Int, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() *big.Int); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) *big.Int); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*big.Int) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -280,13 +284,14 @@ type MockIRPCClient_GetLatestBlockNumber_Call struct { } // GetLatestBlockNumber is a helper method to define mock.On call -func (_e *MockIRPCClient_Expecter) GetLatestBlockNumber() *MockIRPCClient_GetLatestBlockNumber_Call { - return &MockIRPCClient_GetLatestBlockNumber_Call{Call: _e.mock.On("GetLatestBlockNumber")} +// - ctx context.Context +func (_e *MockIRPCClient_Expecter) GetLatestBlockNumber(ctx interface{}) *MockIRPCClient_GetLatestBlockNumber_Call { + return &MockIRPCClient_GetLatestBlockNumber_Call{Call: _e.mock.On("GetLatestBlockNumber", ctx)} } -func (_c *MockIRPCClient_GetLatestBlockNumber_Call) Run(run func()) *MockIRPCClient_GetLatestBlockNumber_Call { +func (_c *MockIRPCClient_GetLatestBlockNumber_Call) Run(run func(ctx context.Context)) *MockIRPCClient_GetLatestBlockNumber_Call { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(context.Context)) }) return _c } @@ -296,22 +301,22 @@ func (_c *MockIRPCClient_GetLatestBlockNumber_Call) Return(_a0 *big.Int, _a1 err return _c } -func (_c *MockIRPCClient_GetLatestBlockNumber_Call) RunAndReturn(run func() (*big.Int, error)) *MockIRPCClient_GetLatestBlockNumber_Call { +func (_c *MockIRPCClient_GetLatestBlockNumber_Call) RunAndReturn(run func(context.Context) (*big.Int, error)) *MockIRPCClient_GetLatestBlockNumber_Call { _c.Call.Return(run) return _c } -// GetTransactions provides a mock function with given fields: txHashes -func (_m *MockIRPCClient) GetTransactions(txHashes []string) []rpc.GetTransactionsResult { - ret := _m.Called(txHashes) +// GetTransactions provides a mock function with given fields: ctx, txHashes +func (_m *MockIRPCClient) GetTransactions(ctx context.Context, txHashes []string) []rpc.GetTransactionsResult { + ret := _m.Called(ctx, txHashes) if len(ret) == 0 { panic("no return value specified for GetTransactions") } var r0 []rpc.GetTransactionsResult - if rf, ok := ret.Get(0).(func([]string) []rpc.GetTransactionsResult); ok { - r0 = rf(txHashes) + if rf, ok := ret.Get(0).(func(context.Context, []string) []rpc.GetTransactionsResult); ok { + r0 = rf(ctx, txHashes) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]rpc.GetTransactionsResult) @@ -327,14 +332,15 @@ type MockIRPCClient_GetTransactions_Call struct { } // GetTransactions is a helper method to define mock.On call +// - ctx context.Context // - txHashes []string -func (_e *MockIRPCClient_Expecter) GetTransactions(txHashes interface{}) *MockIRPCClient_GetTransactions_Call { - return &MockIRPCClient_GetTransactions_Call{Call: _e.mock.On("GetTransactions", txHashes)} +func (_e *MockIRPCClient_Expecter) GetTransactions(ctx interface{}, txHashes interface{}) *MockIRPCClient_GetTransactions_Call { + return &MockIRPCClient_GetTransactions_Call{Call: _e.mock.On("GetTransactions", ctx, txHashes)} } -func (_c *MockIRPCClient_GetTransactions_Call) Run(run func(txHashes []string)) *MockIRPCClient_GetTransactions_Call { +func (_c *MockIRPCClient_GetTransactions_Call) Run(run func(ctx context.Context, txHashes []string)) *MockIRPCClient_GetTransactions_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]string)) + run(args[0].(context.Context), args[1].([]string)) }) return _c } @@ -344,7 +350,7 @@ func (_c *MockIRPCClient_GetTransactions_Call) Return(_a0 []rpc.GetTransactionsR return _c } -func (_c *MockIRPCClient_GetTransactions_Call) RunAndReturn(run func([]string) []rpc.GetTransactionsResult) *MockIRPCClient_GetTransactions_Call { +func (_c *MockIRPCClient_GetTransactions_Call) RunAndReturn(run func(context.Context, []string) []rpc.GetTransactionsResult) *MockIRPCClient_GetTransactions_Call { _c.Call.Return(run) return _c } @@ -394,9 +400,9 @@ func (_c *MockIRPCClient_GetURL_Call) RunAndReturn(run func() string) *MockIRPCC return _c } -// HasCode provides a mock function with given fields: address -func (_m *MockIRPCClient) HasCode(address string) (bool, error) { - ret := _m.Called(address) +// HasCode provides a mock function with given fields: ctx, address +func (_m *MockIRPCClient) HasCode(ctx context.Context, address string) (bool, error) { + ret := _m.Called(ctx, address) if len(ret) == 0 { panic("no return value specified for HasCode") @@ -404,17 +410,17 @@ func (_m *MockIRPCClient) HasCode(address string) (bool, error) { var r0 bool var r1 error - if rf, ok := ret.Get(0).(func(string) (bool, error)); ok { - return rf(address) + if rf, ok := ret.Get(0).(func(context.Context, string) (bool, error)); ok { + return rf(ctx, address) } - if rf, ok := ret.Get(0).(func(string) bool); ok { - r0 = rf(address) + if rf, ok := ret.Get(0).(func(context.Context, string) bool); ok { + r0 = rf(ctx, address) } else { r0 = ret.Get(0).(bool) } - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(address) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, address) } else { r1 = ret.Error(1) } @@ -428,14 +434,15 @@ type MockIRPCClient_HasCode_Call struct { } // HasCode is a helper method to define mock.On call +// - ctx context.Context // - address string -func (_e *MockIRPCClient_Expecter) HasCode(address interface{}) *MockIRPCClient_HasCode_Call { - return &MockIRPCClient_HasCode_Call{Call: _e.mock.On("HasCode", address)} +func (_e *MockIRPCClient_Expecter) HasCode(ctx interface{}, address interface{}) *MockIRPCClient_HasCode_Call { + return &MockIRPCClient_HasCode_Call{Call: _e.mock.On("HasCode", ctx, address)} } -func (_c *MockIRPCClient_HasCode_Call) Run(run func(address string)) *MockIRPCClient_HasCode_Call { +func (_c *MockIRPCClient_HasCode_Call) Run(run func(ctx context.Context, address string)) *MockIRPCClient_HasCode_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string)) + run(args[0].(context.Context), args[1].(string)) }) return _c } @@ -445,7 +452,7 @@ func (_c *MockIRPCClient_HasCode_Call) Return(_a0 bool, _a1 error) *MockIRPCClie return _c } -func (_c *MockIRPCClient_HasCode_Call) RunAndReturn(run func(string) (bool, error)) *MockIRPCClient_HasCode_Call { +func (_c *MockIRPCClient_HasCode_Call) RunAndReturn(run func(context.Context, string) (bool, error)) *MockIRPCClient_HasCode_Call { _c.Call.Return(run) return _c }