diff --git a/lib/grandpa/message_tracker.go b/lib/grandpa/message_tracker.go index 474a175b6d..26bb1b42c2 100644 --- a/lib/grandpa/message_tracker.go +++ b/lib/grandpa/message_tracker.go @@ -129,8 +129,9 @@ func (t *tracker) handleTick() { t.mapLock.Lock() defer t.mapLock.Unlock() - var blockHashesDone []common.Hash - t.votes.forEach(func(peerID peer.ID, message *VoteMessage) { + for _, networkVoteMessage := range t.votes.networkVoteMessages() { + peerID := networkVoteMessage.from + message := networkVoteMessage.msg _, err := t.handler.handleMessage(peerID, message) if err != nil { // handleMessage would never error for vote message @@ -138,11 +139,8 @@ func (t *tracker) handleTick() { } if message.Round < t.handler.grandpa.state.round && message.SetID == t.handler.grandpa.state.setID { - blockHashesDone = append(blockHashesDone, message.Message.BlockHash) + t.votes.delete(message.Message.BlockHash) } - }) - for _, blockHashDone := range blockHashesDone { - t.votes.delete(blockHashDone) } for _, cm := range t.commitMessages { diff --git a/lib/grandpa/votes_tracker.go b/lib/grandpa/votes_tracker.go index de380ac7bc..a90ef0968e 100644 --- a/lib/grandpa/votes_tracker.go +++ b/lib/grandpa/votes_tracker.go @@ -158,13 +158,20 @@ func (vt *votesTracker) messages(blockHash common.Hash) ( return messages } -// forEach runs the function `f` on each -// peer id + message stored in the tracker. -func (vt *votesTracker) forEach( - f func(peerID peer.ID, message *VoteMessage)) { +// networkVoteMessages returns all pairs of +// peer id + message stored in the tracker +// as a slice of networkVoteMessages. +func (vt *votesTracker) networkVoteMessages() ( + messages []networkVoteMessage) { + messages = make([]networkVoteMessage, 0, vt.linkedList.Len()) for _, authorityIDToData := range vt.mapping { for _, data := range authorityIDToData { - f(data.peerID, data.message) + message := networkVoteMessage{ + from: data.peerID, + msg: data.message, + } + messages = append(messages, message) } } + return messages } diff --git a/lib/grandpa/votes_tracker_test.go b/lib/grandpa/votes_tracker_test.go index 086cacb492..eb1e2e4734 100644 --- a/lib/grandpa/votes_tracker_test.go +++ b/lib/grandpa/votes_tracker_test.go @@ -4,7 +4,6 @@ package grandpa import ( - "bytes" "container/list" "sort" "testing" @@ -320,7 +319,7 @@ func Test_votesTracker_messages(t *testing.T) { } } -func Test_votesTracker_forEach(t *testing.T) { +func Test_votesTracker_networkVoteMessages(t *testing.T) { t.Parallel() const capacity = 10 @@ -340,40 +339,13 @@ func Test_votesTracker_forEach(t *testing.T) { vt.add("b", messageBlockAAuthB) vt.add("b", messageBlockBAuthA) - type result struct { - peerID peer.ID - message *VoteMessage - } - var results []result - - vt.forEach(func(peerID peer.ID, message *VoteMessage) { - results = append(results, result{ - peerID: peerID, - message: message, - }) - }) - - // Predictable messages order for assertion. - // Sort by block hash then authority id then peer ID. - sort.Slice(results, func(i, j int) bool { - blockHashFirst := results[i].message.Message.BlockHash - blockHashSecond := results[j].message.Message.BlockHash - if blockHashFirst == blockHashSecond { - authIDFirst := results[i].message.Message.AuthorityID - authIDSecond := results[j].message.Message.AuthorityID - if authIDFirst == authIDSecond { - return results[i].peerID < results[j].peerID - } - return bytes.Compare(authIDFirst[:], authIDSecond[:]) < 0 - } - return bytes.Compare(blockHashFirst[:], blockHashSecond[:]) < 0 - }) + networkVoteMessages := vt.networkVoteMessages() - expectedResults := []result{ - {peerID: "a", message: messageBlockAAuthA}, - {peerID: "b", message: messageBlockAAuthB}, - {peerID: "b", message: messageBlockBAuthA}, + expectedNetworkVoteMessages := []networkVoteMessage{ + {from: "a", msg: messageBlockAAuthA}, + {from: "b", msg: messageBlockAAuthB}, + {from: "b", msg: messageBlockBAuthA}, } - assert.Equal(t, expectedResults, results) + assert.ElementsMatch(t, expectedNetworkVoteMessages, networkVoteMessages) }