From ef5422aed30ddd582c17b7b6966375ade5b4665f Mon Sep 17 00:00:00 2001 From: milen <94537774+taratorio@users.noreply.github.com> Date: Wed, 9 Oct 2024 14:36:41 +0100 Subject: [PATCH] polygon/sync: initialise canonical chain builder correctly (#12246) This PR captures several todos on my list of improving things related to Milestone hash mismatches that I noticed while testing Astrid on the tip: - after restarts, when we initialise the canonical chain builder, we need to set the root to the last finalised waypoint header and connect the latest known tip to it. Previously, we set the root to the latest known tip which may not be the finalised one. This caused unnecessary unwinds on startup (with potentially incorrect unwind num for the bridge.Unwind) and although we auto recover from it it is still better to fix it properly as it may cause unexpected behaviour further down the line - harden the milestone verifier to check for correct header connectivity via ParentHash and Number and also that the number of headers match with the number of headers in the Milestone because the Milestone is simply just the last header hash so a malicious peer may trick us easily - revert https://github.com/erigontech/erigon/pull/11929 which, after taking a deeper look, seems to be an inaccurate interpretation of the error (prune of TD happens only for the blocks T-150_000). I believe this error may have been caused by something else, which may have been fixed by now (after numerous fixes to Astrid related to the parent td errors) - but if not it will be better to troubleshoot more thoroughly if the error ever arises again. In addition, this change is dangerous as it means we may falsely think we have all blocks for a given waypoint inserted in the DB, however that may not be true on restarts in the middle of waypoints, because the blocks inserted before the start may not be the right ones that are part of the waypoint (due to the tip not yet being finalised) --- polygon/heimdall/scraper.go | 24 ++- polygon/heimdall/service.go | 10 +- polygon/heimdall/service_test.go | 6 +- polygon/sync/block_downloader.go | 14 +- polygon/sync/block_downloader_test.go | 38 +---- polygon/sync/sync.go | 145 ++++++++++++------ polygon/sync/waypoint_headers_verifier.go | 27 ++++ .../sync/waypoint_headers_verifier_test.go | 53 +++++-- 8 files changed, 206 insertions(+), 111 deletions(-) diff --git a/polygon/heimdall/scraper.go b/polygon/heimdall/scraper.go index 48a765bc565..77b20a85506 100644 --- a/polygon/heimdall/scraper.go +++ b/polygon/heimdall/scraper.go @@ -18,9 +18,11 @@ package heimdall import ( "context" + "errors" "time" - "github.com/erigontech/erigon-lib/common/errors" + commonerrors "github.com/erigontech/erigon-lib/common/errors" + "github.com/erigontech/erigon-lib/common/generics" "github.com/erigontech/erigon-lib/log/v3" libcommon "github.com/erigontech/erigon-lib/common" @@ -69,7 +71,7 @@ func (s *scraper[TEntity]) Run(ctx context.Context) error { idRange, err := s.fetcher.FetchEntityIdRange(ctx) if err != nil { - if errors.IsOneOf(err, s.transientErrors) { + if commonerrors.IsOneOf(err, s.transientErrors) { s.logger.Warn(heimdallLogPrefix("scraper transient err occurred when fetching id range"), "err", err) continue } @@ -90,7 +92,7 @@ func (s *scraper[TEntity]) Run(ctx context.Context) error { } else { entities, err := s.fetcher.FetchEntitiesRange(ctx, idRange) if err != nil { - if errors.IsOneOf(err, s.transientErrors) { + if commonerrors.IsOneOf(err, s.transientErrors) { // we do not break the scrapping loop when hitting a transient error // we persist the partially fetched range entities before it occurred // and continue scrapping again from there onwards @@ -122,6 +124,18 @@ func (s *scraper[TEntity]) RegisterObserver(observer func([]TEntity)) polygoncom return s.observers.Register(observer) } -func (s *scraper[TEntity]) Synchronize(ctx context.Context) error { - return s.syncEvent.Wait(ctx) +func (s *scraper[TEntity]) Synchronize(ctx context.Context) (TEntity, error) { + if err := s.syncEvent.Wait(ctx); err != nil { + return generics.Zero[TEntity](), err + } + + last, ok, err := s.store.LastEntity(ctx) + if err != nil { + return generics.Zero[TEntity](), err + } + if !ok { + return generics.Zero[TEntity](), errors.New("unexpected last entity not available") + } + + return last, nil } diff --git a/polygon/heimdall/service.go b/polygon/heimdall/service.go index f724e6e441c..060a6d499dd 100644 --- a/polygon/heimdall/service.go +++ b/polygon/heimdall/service.go @@ -46,8 +46,8 @@ type Service interface { Producers(ctx context.Context, blockNum uint64) (*valset.ValidatorSet, error) RegisterMilestoneObserver(callback func(*Milestone), opts ...ObserverOption) polygoncommon.UnregisterFunc Run(ctx context.Context) error - SynchronizeCheckpoints(ctx context.Context) error - SynchronizeMilestones(ctx context.Context) error + SynchronizeCheckpoints(ctx context.Context) (latest *Checkpoint, err error) + SynchronizeMilestones(ctx context.Context) (latest *Milestone, err error) SynchronizeSpans(ctx context.Context, blockNum uint64) error } @@ -182,12 +182,12 @@ func (s *service) Span(ctx context.Context, id uint64) (*Span, bool, error) { return s.reader.Span(ctx, id) } -func (s *service) SynchronizeCheckpoints(ctx context.Context) error { +func (s *service) SynchronizeCheckpoints(ctx context.Context) (*Checkpoint, error) { s.logger.Debug(heimdallLogPrefix("synchronizing checkpoints...")) return s.checkpointScraper.Synchronize(ctx) } -func (s *service) SynchronizeMilestones(ctx context.Context) error { +func (s *service) SynchronizeMilestones(ctx context.Context) (*Milestone, error) { s.logger.Debug(heimdallLogPrefix("synchronizing milestones...")) return s.milestoneScraper.Synchronize(ctx) } @@ -219,7 +219,7 @@ func (s *service) SynchronizeSpans(ctx context.Context, blockNum uint64) error { } func (s *service) synchronizeSpans(ctx context.Context) error { - if err := s.spanScraper.Synchronize(ctx); err != nil { + if _, err := s.spanScraper.Synchronize(ctx); err != nil { return err } diff --git a/polygon/heimdall/service_test.go b/polygon/heimdall/service_test.go index d8774154cf4..376f6eddd80 100644 --- a/polygon/heimdall/service_test.go +++ b/polygon/heimdall/service_test.go @@ -179,10 +179,12 @@ func (suite *ServiceTestSuite) SetupSuite() { return suite.service.Run(suite.ctx) }) - err = suite.service.SynchronizeMilestones(suite.ctx) + lastMilestone, err := suite.service.SynchronizeMilestones(suite.ctx) require.NoError(suite.T(), err) - err = suite.service.SynchronizeCheckpoints(suite.ctx) + require.Equal(suite.T(), suite.expectedLastMilestone, uint64(lastMilestone.Id)) + lastCheckpoint, err := suite.service.SynchronizeCheckpoints(suite.ctx) require.NoError(suite.T(), err) + require.Equal(suite.T(), suite.expectedLastCheckpoint, uint64(lastCheckpoint.Id)) err = suite.service.SynchronizeSpans(suite.ctx, math.MaxInt) require.NoError(suite.T(), err) } diff --git a/polygon/sync/block_downloader.go b/polygon/sync/block_downloader.go index 4b9ebf2b540..624b09e184a 100644 --- a/polygon/sync/block_downloader.go +++ b/polygon/sync/block_downloader.go @@ -118,7 +118,7 @@ func (d *blockDownloader) DownloadBlocksUsingCheckpoints(ctx context.Context, st return nil, err } - return d.downloadBlocksUsingWaypoints(ctx, checkpoints.Waypoints(), d.checkpointVerifier, start) + return d.downloadBlocksUsingWaypoints(ctx, checkpoints.Waypoints(), d.checkpointVerifier) } func (d *blockDownloader) DownloadBlocksUsingMilestones(ctx context.Context, start uint64) (*types.Header, error) { @@ -147,14 +147,13 @@ func (d *blockDownloader) DownloadBlocksUsingMilestones(ctx context.Context, sta milestones[0].Fields.StartBlock = new(big.Int).SetUint64(start) } - return d.downloadBlocksUsingWaypoints(ctx, milestones.Waypoints(), d.milestoneVerifier, start) + return d.downloadBlocksUsingWaypoints(ctx, milestones.Waypoints(), d.milestoneVerifier) } func (d *blockDownloader) downloadBlocksUsingWaypoints( ctx context.Context, waypoints heimdall.Waypoints, verifier WaypointHeadersVerifier, - startBlockNum uint64, ) (*types.Header, error) { if len(waypoints) == 0 { return nil, nil @@ -289,12 +288,9 @@ func (d *blockDownloader) downloadBlocksUsingWaypoints( break } - batchStart := blockBatch[0].Number().Uint64() - batchEnd := blockBatch[len(blockBatch)-1].Number().Uint64() - if batchStart <= startBlockNum && startBlockNum <= batchEnd { - // we do not want to re-insert blocks of the first waypoint if the start block - // falls in the middle of the waypoint range - blockBatch = blockBatch[startBlockNum-batchStart:] + if blockBatch[0].Number().Uint64() == 0 { + // we do not want to insert block 0 (genesis) + blockBatch = blockBatch[1:] } blocks = append(blocks, blockBatch...) diff --git a/polygon/sync/block_downloader_test.go b/polygon/sync/block_downloader_test.go index 83cb655f102..b5763e10892 100644 --- a/polygon/sync/block_downloader_test.go +++ b/polygon/sync/block_downloader_test.go @@ -29,8 +29,9 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" - "github.com/erigontech/erigon-lib/common" "github.com/erigontech/erigon-lib/log/v3" + + "github.com/erigontech/erigon-lib/common" "github.com/erigontech/erigon/core/types" "github.com/erigontech/erigon/polygon/heimdall" "github.com/erigontech/erigon/polygon/p2p" @@ -368,41 +369,6 @@ func TestBlockDownloaderDownloadBlocksUsingCheckpoints(t *testing.T) { require.Equal(t, blocks[len(blocks)-1].Header(), tip) } -func TestBlockDownloaderDownloadBlocksUsingCheckpointsWhenStartIsInMiddleOfCheckpointRange(t *testing.T) { - test := newBlockDownloaderTest(t) - test.waypointReader.EXPECT(). - CheckpointsFromBlock(gomock.Any(), gomock.Any()). - Return(test.fakeCheckpoints(2), nil). - Times(1) - test.p2pService.EXPECT(). - ListPeersMayHaveBlockNum(gomock.Any()). - Return(test.fakePeers(2)). - Times(1) - test.p2pService.EXPECT(). - FetchHeaders(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(test.defaultFetchHeadersMock()). - Times(2) - test.p2pService.EXPECT(). - FetchBodies(gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(test.defaultFetchBodiesMock()). - Times(2) - var blocks []*types.Block - test.store.EXPECT(). - InsertBlocks(gomock.Any(), gomock.Any()). - DoAndReturn(test.defaultInsertBlocksMock(&blocks)). - Times(1) - - tip, err := test.blockDownloader.DownloadBlocksUsingCheckpoints(context.Background(), 513) - require.NoError(t, err) - require.Len(t, blocks, 1536) // [513,1024] = 512 blocks + 1024 blocks from 2nd checkpoint - // check blocks are written in order - require.Equal(t, uint64(513), blocks[0].Header().Number.Uint64()) - require.Equal(t, uint64(1024), blocks[511].Header().Number.Uint64()) - require.Equal(t, uint64(1025), blocks[512].Header().Number.Uint64()) - require.Equal(t, uint64(2048), blocks[1535].Header().Number.Uint64()) - require.Equal(t, blocks[len(blocks)-1].Header(), tip) -} - func TestBlockDownloaderDownloadBlocksWhenInvalidHeadersThenPenalizePeerAndReDownload(t *testing.T) { var firstTimeInvalidReturned bool firstTimeInvalidReturnedPtr := &firstTimeInvalidReturned diff --git a/polygon/sync/sync.go b/polygon/sync/sync.go index 38ce264b848..20c53324cfa 100644 --- a/polygon/sync/sync.go +++ b/polygon/sync/sync.go @@ -24,14 +24,14 @@ import ( "github.com/erigontech/erigon-lib/common" "github.com/erigontech/erigon-lib/log/v3" - "github.com/erigontech/erigon/core/types" + "github.com/erigontech/erigon/polygon/heimdall" "github.com/erigontech/erigon/polygon/p2p" ) type heimdallSynchronizer interface { - SynchronizeCheckpoints(ctx context.Context) error - SynchronizeMilestones(ctx context.Context) error + SynchronizeCheckpoints(ctx context.Context) (latest *heimdall.Checkpoint, err error) + SynchronizeMilestones(ctx context.Context) (latest *heimdall.Milestone, err error) SynchronizeSpans(ctx context.Context, blockNum uint64) error } @@ -119,31 +119,30 @@ func (s *Sync) handleMilestoneTipMismatch( // the milestone doesn't correspond to the tip of the chain // unwind to the previous verified milestone // and download the blocks of the new milestone - oldTip := ccBuilder.Root() - oldTipNum := oldTip.Number.Uint64() + rootNum := ccBuilder.Root().Number.Uint64() s.logger.Debug( - syncLogPrefix("local chain tip does not match the milestone, unwinding to the previous verified milestone"), - "oldTipNum", oldTipNum, + syncLogPrefix("local chain tip does not match the milestone, unwinding to the previous verified root"), + "rootNum", rootNum, "milestoneId", milestone.Id, "milestoneStart", milestone.StartBlock(), "milestoneEnd", milestone.EndBlock(), "milestoneRootHash", milestone.RootHash(), ) - if err := s.bridgeSync.Unwind(ctx, oldTipNum); err != nil { + if err := s.bridgeSync.Unwind(ctx, rootNum); err != nil { return err } - newTip, err := s.blockDownloader.DownloadBlocksUsingMilestones(ctx, oldTipNum) + newTip, err := s.blockDownloader.DownloadBlocksUsingMilestones(ctx, rootNum+1) if err != nil { return err } if newTip == nil { err = errors.New("unexpected empty headers from p2p since new milestone") return fmt.Errorf( - "%w: oldTipNum=%d, milestoneId=%d, milestoneStart=%d, milestoneEnd=%d, milestoneRootHash=%s", - err, oldTipNum, milestone.Id, milestone.StartBlock(), milestone.EndBlock(), milestone.RootHash(), + "%w: rootNum=%d, milestoneId=%d, milestoneStart=%d, milestoneEnd=%d, milestoneRootHash=%s", + err, rootNum, milestone.Id, milestone.StartBlock(), milestone.EndBlock(), milestone.RootHash(), ) } @@ -173,13 +172,9 @@ func (s *Sync) applyNewMilestoneOnTip( ) milestoneHeaders := ccBuilder.HeadersInRange(milestone.StartBlock().Uint64(), milestone.Length()) - err := s.milestoneVerifier(milestone, milestoneHeaders) - if errors.Is(err, ErrBadHeadersRootHash) { + if err := s.milestoneVerifier(milestone, milestoneHeaders); err != nil { return s.handleMilestoneTipMismatch(ctx, ccBuilder, milestone) } - if err != nil { - return err - } return ccBuilder.Prune(milestone.EndBlock().Uint64()) } @@ -347,12 +342,15 @@ func (s *Sync) applyNewBlockHashesOnTip( func (s *Sync) Run(ctx context.Context) error { s.logger.Debug(syncLogPrefix("running sync component")) - tip, err := s.syncToTip(ctx) + result, err := s.syncToTip(ctx) if err != nil { return err } - ccBuilder := s.ccBuilderFactory(tip) + ccBuilder, err := s.initialiseCcb(ctx, result) + if err != nil { + return err + } for { select { @@ -377,24 +375,68 @@ func (s *Sync) Run(ctx context.Context) error { } } -func (s *Sync) syncToTip(ctx context.Context) (*types.Header, error) { - startTime := time.Now() - start, err := s.execution.CurrentHeader(ctx) +// initialiseCcb populates the canonical chain builder with the latest finalized root header and with latest known +// canonical chain tip. +func (s *Sync) initialiseCcb(ctx context.Context, result syncToTipResult) (CanonicalChainBuilder, error) { + tip := result.latestTip + tipNum := tip.Number.Uint64() + rootNum := result.latestWaypoint.EndBlock().Uint64() + if rootNum > tipNum { + return nil, fmt.Errorf("unexpected rootNum > tipNum: %d > %d", rootNum, tipNum) + } + + s.logger.Debug(syncLogPrefix("initialising canonical chain builder"), "rootNum", rootNum, "tipNum", tipNum) + + var root *types.Header + var err error + if rootNum == tipNum { + root = tip + } else { + root, err = s.execution.GetHeader(ctx, rootNum) + } if err != nil { return nil, err } - tip, err := s.syncToTipUsingCheckpoints(ctx, start) + ccb := s.ccBuilderFactory(root) + for blockNum := rootNum + 1; blockNum <= tipNum; blockNum++ { + header, err := s.execution.GetHeader(ctx, blockNum) + if err != nil { + return nil, err + } + + _, err = ccb.Connect(ctx, []*types.Header{header}) + if err != nil { + return nil, err + } + } + + return ccb, nil +} + +type syncToTipResult struct { + latestTip *types.Header + latestWaypoint heimdall.Waypoint +} + +func (s *Sync) syncToTip(ctx context.Context) (syncToTipResult, error) { + startTime := time.Now() + latestTipOnStart, err := s.execution.CurrentHeader(ctx) if err != nil { - return nil, err + return syncToTipResult{}, err } - tip, err = s.syncToTipUsingMilestones(ctx, tip) + result, err := s.syncToTipUsingCheckpoints(ctx, latestTipOnStart) if err != nil { - return nil, err + return syncToTipResult{}, err + } + + result, err = s.syncToTipUsingMilestones(ctx, result.latestTip) + if err != nil { + return syncToTipResult{}, err } - blocks := tip.Number.Uint64() - start.Number.Uint64() + blocks := result.latestTip.Number.Uint64() - latestTipOnStart.Number.Uint64() s.logger.Info( syncLogPrefix("sync to tip finished"), "time", common.PrettyAge(startTime), @@ -402,52 +444,65 @@ func (s *Sync) syncToTip(ctx context.Context) (*types.Header, error) { "blk/sec", uint64(float64(blocks)/time.Since(startTime).Seconds()), ) - return tip, nil + return result, nil } -func (s *Sync) syncToTipUsingCheckpoints(ctx context.Context, tip *types.Header) (*types.Header, error) { - return s.sync(ctx, tip, func(ctx context.Context, startBlockNum uint64) (*types.Header, error) { - err := s.heimdallSync.SynchronizeCheckpoints(ctx) +func (s *Sync) syncToTipUsingCheckpoints(ctx context.Context, tip *types.Header) (syncToTipResult, error) { + return s.sync(ctx, tip, func(ctx context.Context, startBlockNum uint64) (syncToTipResult, error) { + latestCheckpoint, err := s.heimdallSync.SynchronizeCheckpoints(ctx) if err != nil { - return nil, err + return syncToTipResult{}, err + } + + tip, err := s.blockDownloader.DownloadBlocksUsingCheckpoints(ctx, startBlockNum) + if err != nil { + return syncToTipResult{}, err } - return s.blockDownloader.DownloadBlocksUsingCheckpoints(ctx, startBlockNum) + return syncToTipResult{latestTip: tip, latestWaypoint: latestCheckpoint}, nil }) } -func (s *Sync) syncToTipUsingMilestones(ctx context.Context, tip *types.Header) (*types.Header, error) { - return s.sync(ctx, tip, func(ctx context.Context, startBlockNum uint64) (*types.Header, error) { - err := s.heimdallSync.SynchronizeMilestones(ctx) +func (s *Sync) syncToTipUsingMilestones(ctx context.Context, tip *types.Header) (syncToTipResult, error) { + return s.sync(ctx, tip, func(ctx context.Context, startBlockNum uint64) (syncToTipResult, error) { + latestMilestone, err := s.heimdallSync.SynchronizeMilestones(ctx) if err != nil { - return nil, err + return syncToTipResult{}, err } - return s.blockDownloader.DownloadBlocksUsingMilestones(ctx, startBlockNum) + tip, err := s.blockDownloader.DownloadBlocksUsingMilestones(ctx, startBlockNum) + if err != nil { + return syncToTipResult{}, err + } + + return syncToTipResult{latestTip: tip, latestWaypoint: latestMilestone}, nil }) } -type tipDownloaderFunc func(ctx context.Context, startBlockNum uint64) (*types.Header, error) +type tipDownloaderFunc func(ctx context.Context, startBlockNum uint64) (syncToTipResult, error) -func (s *Sync) sync(ctx context.Context, tip *types.Header, tipDownloader tipDownloaderFunc) (*types.Header, error) { +func (s *Sync) sync(ctx context.Context, tip *types.Header, tipDownloader tipDownloaderFunc) (syncToTipResult, error) { + var latestWaypoint heimdall.Waypoint for { - newTip, err := tipDownloader(ctx, tip.Number.Uint64()+1) + newResult, err := tipDownloader(ctx, tip.Number.Uint64()+1) if err != nil { - return nil, err + return syncToTipResult{}, err } - if newTip == nil { + latestWaypoint = newResult.latestWaypoint + + if newResult.latestTip == nil { // we've reached the tip break } - tip = newTip + tip = newResult.latestTip if err = s.commitExecution(ctx, tip, tip); err != nil { - return nil, err + return syncToTipResult{}, err } } - return tip, nil + return syncToTipResult{latestTip: tip, latestWaypoint: latestWaypoint}, nil } func (s *Sync) ignoreFetchBlocksErrOnTipEvent(err error) bool { diff --git a/polygon/sync/waypoint_headers_verifier.go b/polygon/sync/waypoint_headers_verifier.go index ef999ea8af0..6014b19c839 100644 --- a/polygon/sync/waypoint_headers_verifier.go +++ b/polygon/sync/waypoint_headers_verifier.go @@ -30,6 +30,8 @@ import ( var ( ErrFailedToComputeHeadersRootHash = errors.New("failed to compute headers root hash") ErrBadHeadersRootHash = errors.New("bad headers root hash") + ErrIncorrectHeadersLength = errors.New("incorrect headers length") + ErrDisconnectedHeaders = errors.New("disconnected headers") ) type WaypointHeadersVerifier func(waypoint heimdall.Waypoint, headers []*types.Header) error @@ -39,19 +41,44 @@ func VerifyCheckpointHeaders(waypoint heimdall.Waypoint, headers []*types.Header if err != nil { return fmt.Errorf("VerifyCheckpointHeaders: %w: %w", ErrFailedToComputeHeadersRootHash, err) } + if !bytes.Equal(rootHash, waypoint.RootHash().Bytes()) { return fmt.Errorf("VerifyCheckpointHeaders: %w", ErrBadHeadersRootHash) } + return nil } func VerifyMilestoneHeaders(waypoint heimdall.Waypoint, headers []*types.Header) error { + if uint64(len(headers)) != waypoint.Length() || len(headers) == 0 { + return fmt.Errorf( + "VerifyMilestoneHeaders: %w: headers=%d, waypoint=%d", + ErrIncorrectHeadersLength, len(headers), waypoint.Length(), + ) + } + + prevHeader := headers[0] + for _, header := range headers[1:] { + prevNum, prevHash := prevHeader.Number.Uint64(), prevHeader.Hash() + num, hash, parentHash := header.Number.Uint64(), header.Hash(), header.ParentHash + if num != prevNum+1 || parentHash != prevHash { + return fmt.Errorf( + "VerifyMilestoneHeaders: %w: prevNum=%d, prevHash=%s, num=%d, parentHash=%s, hash=%s", + ErrDisconnectedHeaders, prevNum, prevHash, num, parentHash, hash, + ) + } + + prevHeader = header + } + var hash common.Hash if len(headers) > 0 { hash = headers[len(headers)-1].Hash() } + if hash != waypoint.RootHash() { return fmt.Errorf("VerifyMilestoneHeaders: %w", ErrBadHeadersRootHash) } + return nil } diff --git a/polygon/sync/waypoint_headers_verifier_test.go b/polygon/sync/waypoint_headers_verifier_test.go index d6599464b88..9356f4a5f88 100644 --- a/polygon/sync/waypoint_headers_verifier_test.go +++ b/polygon/sync/waypoint_headers_verifier_test.go @@ -56,24 +56,59 @@ func TestVerifyCheckpointHeaders(t *testing.T) { } func TestVerifyMilestoneHeaders(t *testing.T) { - header := &types.Header{ - Root: common.HexToHash("0x01"), + header1 := &types.Header{ + Number: big.NewInt(1), + GasLimit: 123, + Root: common.HexToHash("0x01"), + } + header2 := &types.Header{ + Number: big.NewInt(2), + GasLimit: 456, + Root: common.HexToHash("0x02"), + ParentHash: header1.Hash(), } milestone := &heimdall.Milestone{ Fields: heimdall.WaypointFields{ - RootHash: header.Hash(), + RootHash: header2.Hash(), + StartBlock: big.NewInt(1), + EndBlock: big.NewInt(2), }, } - err := VerifyMilestoneHeaders(milestone, []*types.Header{header}) + err := VerifyMilestoneHeaders(milestone, []*types.Header{header1, header2}) require.NoError(t, err) - diffHeader := &types.Header{ - Root: common.HexToHash("0x02"), + header2DiffHash := &types.Header{ + Number: big.NewInt(2), + GasLimit: 999, + Root: common.HexToHash("0x02-diff"), + ParentHash: header1.Hash(), } - err = VerifyMilestoneHeaders(milestone, []*types.Header{diffHeader}) + err = VerifyMilestoneHeaders(milestone, []*types.Header{header1, header2DiffHash}) require.ErrorIs(t, err, ErrBadHeadersRootHash) - err = VerifyMilestoneHeaders(milestone, []*types.Header{}) - require.ErrorIs(t, err, ErrBadHeadersRootHash) + header3DisconnectedNums := &types.Header{ + Number: big.NewInt(3), + GasLimit: 456, + Root: common.HexToHash("0x02"), + ParentHash: header1.Hash(), + } + err = VerifyMilestoneHeaders(milestone, []*types.Header{header1, header3DisconnectedNums}) + require.ErrorIs(t, err, ErrDisconnectedHeaders) + + header0 := types.Header{Number: big.NewInt(0)} + header3DisconnectedHashes := &types.Header{ + Number: big.NewInt(2), + GasLimit: 456, + Root: common.HexToHash("0x02"), + ParentHash: header0.Hash(), + } + err = VerifyMilestoneHeaders(milestone, []*types.Header{header1, header3DisconnectedHashes}) + require.ErrorIs(t, err, ErrDisconnectedHeaders) + + err = VerifyMilestoneHeaders(milestone, []*types.Header{header1}) + require.ErrorIs(t, err, ErrIncorrectHeadersLength) + + err = VerifyMilestoneHeaders(milestone, nil) + require.ErrorIs(t, err, ErrIncorrectHeadersLength) }