Skip to content

Commit

Permalink
fix(state): node hashes vs merkle values
Browse files Browse the repository at this point in the history
- Pruners only care about node hashes
- Rename variables, functions and methods only dealing with node hashes
- Do not write or read inlined nodes with a non-hash Merkle value
- Clarify error wrappings and comments
  • Loading branch information
qdm12 committed Jan 26, 2023
1 parent 998c88a commit acf44f4
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 88 deletions.
10 changes: 5 additions & 5 deletions dot/state/offline_pruner.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (p *OfflinePruner) SetBloomFilter() (err error) {
}

latestBlockNum := header.Number
merkleValues := make(map[string]struct{})
nodeHashes := make(map[common.Hash]struct{})

logger.Infof("Latest block number is %d", latestBlockNum)

Expand All @@ -121,7 +121,7 @@ func (p *OfflinePruner) SetBloomFilter() (err error) {
return err
}

trie.PopulateNodeHashes(tr.RootNode(), merkleValues)
trie.PopulateNodeHashes(tr.RootNode(), nodeHashes)

// get parent header of current block
header, err = p.blockState.GetHeader(header.ParentHash)
Expand All @@ -131,14 +131,14 @@ func (p *OfflinePruner) SetBloomFilter() (err error) {
blockNum = header.Number
}

for key := range merkleValues {
err = p.bloom.put([]byte(key))
for key := range nodeHashes {
err = p.bloom.put(key.ToBytes())
if err != nil {
return err
}
}

logger.Infof("Total keys added in bloom filter: %d", len(merkleValues))
logger.Infof("Total keys added in bloom filter: %d", len(nodeHashes))
return nil
}

Expand Down
4 changes: 2 additions & 2 deletions dot/state/pruner/pruner.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ type Config struct {

// Pruner is implemented by FullNode and ArchiveNode.
type Pruner interface {
StoreJournalRecord(deletedMerkleValues, insertedMerkleValues map[string]struct{},
StoreJournalRecord(deletedNodeHashes, insertedNodeHashes map[common.Hash]struct{},
blockHash common.Hash, blockNum int64) error
}

// ArchiveNode is a no-op since we don't prune nodes in archive mode.
type ArchiveNode struct{}

// StoreJournalRecord for archive node doesn't do anything.
func (*ArchiveNode) StoreJournalRecord(_, _ map[string]struct{},
func (*ArchiveNode) StoreJournalRecord(_, _ map[common.Hash]struct{},
_ common.Hash, _ int64) error {
return nil
}
6 changes: 3 additions & 3 deletions dot/state/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ func (s *StorageState) StoreTrie(ts *rtstorage.TrieState, header *types.Header)
s.tries.softSet(root, ts.Trie())

if header != nil {
insertedMerkleValues, deletedMerkleValues, err := ts.GetChangedNodeHashes()
insertedNodeHashes, deletedNodeHashes, err := ts.GetChangedNodeHashes()
if err != nil {
return fmt.Errorf("failed to get state trie inserted keys: block %s %w", header.Hash(), err)
return fmt.Errorf("getting trie changed node hashes for block hash %s: %w", header.Hash(), err)
}

err = s.pruner.StoreJournalRecord(deletedMerkleValues, insertedMerkleValues, header.Hash(), int64(header.Number))
err = s.pruner.StoreJournalRecord(deletedNodeHashes, insertedNodeHashes, header.Hash(), int64(header.Number))
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion lib/runtime/storage/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ func (s *TrieState) LoadCodeHash() (common.Hash, error) {

// GetChangedNodeHashes returns the two sets of hashes for all nodes
// inserted and deleted in the state trie since the last block produced (trie snapshot).
func (s *TrieState) GetChangedNodeHashes() (inserted, deleted map[string]struct{}, err error) {
func (s *TrieState) GetChangedNodeHashes() (inserted, deleted map[common.Hash]struct{}, err error) {
s.lock.RLock()
defer s.lock.RUnlock()
return s.t.GetChangedNodeHashes()
Expand Down
70 changes: 41 additions & 29 deletions lib/trie/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,33 +68,34 @@ func (t *Trie) loadNode(db Getter, n *Node) error {

merkleValue := child.MerkleValue

if len(merkleValue) == 0 {
if len(merkleValue) < 32 {
// node has already been loaded inline
// just set encoding + hash digest
// just set its encoding
_, err := child.CalculateMerkleValue()
if err != nil {
return fmt.Errorf("merkle value: %w", err)
}
continue
}

encodedNode, err := db.Get(merkleValue)
nodeHash := merkleValue
encodedNode, err := db.Get(nodeHash)
if err != nil {
return fmt.Errorf("cannot find child node key 0x%x in database: %w", merkleValue, err)
return fmt.Errorf("cannot find child node key 0x%x in database: %w", nodeHash, err)
}

reader := bytes.NewReader(encodedNode)
decodedNode, err := node.Decode(reader)
if err != nil {
return fmt.Errorf("decoding node with Merkle value 0x%x: %w", merkleValue, err)
return fmt.Errorf("decoding node with hash 0x%x: %w", nodeHash, err)
}

decodedNode.MerkleValue = merkleValue
decodedNode.MerkleValue = nodeHash
branch.Children[i] = decodedNode

err = t.loadNode(db, decodedNode)
if err != nil {
return fmt.Errorf("loading child at index %d with Merkle value 0x%x: %w", i, merkleValue, err)
return fmt.Errorf("loading child at index %d with node hash 0x%x: %w", i, nodeHash, err)
}

if decodedNode.Kind() == node.Branch {
Expand Down Expand Up @@ -132,7 +133,7 @@ func (t *Trie) loadNode(db Getter, n *Node) error {
// all its descendant nodes as keys to the nodeHashes map.
// It is assumed the node and its descendant nodes have their Merkle value already
// computed.
func PopulateNodeHashes(n *Node, nodeHashes map[string]struct{}) {
func PopulateNodeHashes(n *Node, nodeHashes map[common.Hash]struct{}) {
if n == nil {
return
}
Expand All @@ -148,7 +149,8 @@ func PopulateNodeHashes(n *Node, nodeHashes map[string]struct{}) {
return
}

nodeHashes[string(n.MerkleValue)] = struct{}{}
nodeHash := common.NewHash(n.MerkleValue)
nodeHashes[nodeHash] = struct{}{}

if n.Kind() == node.Leaf {
return
Expand Down Expand Up @@ -260,15 +262,15 @@ func getFromDBAtNode(db Getter, n *Node, key []byte) (
encodedChild, err := db.Get(childMerkleValue)
if err != nil {
return nil, fmt.Errorf(
"finding child node with Merkle value 0x%x in database: %w",
"finding child node with hash 0x%x in database: %w",
childMerkleValue, err)
}

reader := bytes.NewReader(encodedChild)
decodedChild, err := node.Decode(reader)
if err != nil {
return nil, fmt.Errorf(
"decoding child node with Merkle value 0x%x: %w",
"decoding child node with hash 0x%x: %w",
childMerkleValue, err)
}

Expand Down Expand Up @@ -305,11 +307,19 @@ func (t *Trie) writeDirtyNode(db Putter, n *Node) (err error) {
n.MerkleValue, err)
}

err = db.Put(merkleValue, encoding)
if len(merkleValue) < 32 {
// Inlined node, there is no need to write it to database.
n.SetClean()
return nil
}

nodeHash := merkleValue

err = db.Put(nodeHash, encoding)
if err != nil {
return fmt.Errorf(
"putting encoding of node with Merkle value 0x%x in database: %w",
merkleValue, err)
"putting encoding of node with node hash 0x%x in database: %w",
nodeHash, err)
}

if n.Kind() != node.Branch {
Expand Down Expand Up @@ -342,25 +352,20 @@ func (t *Trie) writeDirtyNode(db Putter, n *Node) (err error) {

// GetChangedNodeHashes returns the two sets of hashes for all nodes
// inserted and deleted in the state trie since the last snapshot.
// Returned maps are safe for mutation.
func (t *Trie) GetChangedNodeHashes() (inserted, deleted map[string]struct{}, err error) {
inserted = make(map[string]struct{})
// Returned inserted map is safe for mutation, but deleted is not safe for mutation.
func (t *Trie) GetChangedNodeHashes() (inserted, deleted map[common.Hash]struct{}, err error) {
inserted = make(map[common.Hash]struct{})
err = t.getInsertedNodeHashesAtNode(t.root, inserted)
if err != nil {
return nil, nil, fmt.Errorf("getting inserted node hashes: %w", err)
}

deletedNodeHashes := t.deltas.Deleted()
// TODO return deletedNodeHashes directly after changing MerkleValue -> NodeHash
deleted = make(map[string]struct{}, len(deletedNodeHashes))
for nodeHash := range deletedNodeHashes {
deleted[string(nodeHash[:])] = struct{}{}
}
deleted = t.deltas.Deleted()

return inserted, deleted, nil
}

func (t *Trie) getInsertedNodeHashesAtNode(n *Node, merkleValues map[string]struct{}) (err error) {
func (t *Trie) getInsertedNodeHashesAtNode(n *Node, nodeHashes map[common.Hash]struct{}) (err error) {
if n == nil || !n.Dirty {
return nil
}
Expand All @@ -372,12 +377,19 @@ func (t *Trie) getInsertedNodeHashesAtNode(n *Node, merkleValues map[string]stru
merkleValue, err = n.CalculateMerkleValue()
}
if err != nil {
return fmt.Errorf(
"encoding and hashing node with Merkle value 0x%x: %w",
n.MerkleValue, err)
return fmt.Errorf("calculating Merkle value: %w", err)
}

if len(merkleValue) < 32 {
// this is an inlined node and is encoded as part of its parent node.
// Therefore it is not written to disk and the online pruner does not
// need to track it. If the node encodes to less than 32B, it cannot have
// non-inlined children so it's safe to stop here and not recurse further.
return nil
}

merkleValues[string(merkleValue)] = struct{}{}
nodeHash := common.NewHash(merkleValue)
nodeHashes[nodeHash] = struct{}{}

if n.Kind() != node.Branch {
return nil
Expand All @@ -388,7 +400,7 @@ func (t *Trie) getInsertedNodeHashesAtNode(n *Node, merkleValues map[string]stru
continue
}

err := t.getInsertedNodeHashesAtNode(child, merkleValues)
err := t.getInsertedNodeHashesAtNode(child, nodeHashes)
if err != nil {
// Note: do not wrap error since this is called recursively.
return err
Expand Down
47 changes: 27 additions & 20 deletions lib/trie/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

"github.com/ChainSafe/chaindb"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -161,28 +162,34 @@ func Test_Trie_WriteDirty_ClearPrefix(t *testing.T) {
func Test_PopulateNodeHashes(t *testing.T) {
t.Parallel()

const (
merkleValue32Zeroes = "00000000000000000000000000000000"
merkleValue32Ones = "11111111111111111111111111111111"
merkleValue32Twos = "22222222222222222222222222222222"
merkleValue32Threes = "33333333333333333333333333333333"
var (
merkleValue32Zeroes = common.Hash{}
merkleValue32Ones = common.Hash{
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
merkleValue32Twos = common.Hash{
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}
merkleValue32Threes = common.Hash{
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}
)

testCases := map[string]struct {
node *Node
nodeHashes map[string]struct{}
nodeHashes map[common.Hash]struct{}
panicValue interface{}
}{
"nil_node": {
nodeHashes: map[string]struct{}{},
nodeHashes: map[common.Hash]struct{}{},
},
"inlined_leaf_node": {
node: &Node{MerkleValue: []byte("a")},
nodeHashes: map[string]struct{}{},
nodeHashes: map[common.Hash]struct{}{},
},
"leaf_node": {
node: &Node{MerkleValue: []byte(merkleValue32Zeroes)},
nodeHashes: map[string]struct{}{
node: &Node{MerkleValue: merkleValue32Zeroes.ToBytes()},
nodeHashes: map[common.Hash]struct{}{
merkleValue32Zeroes: {},
},
},
Expand All @@ -197,34 +204,34 @@ func Test_PopulateNodeHashes(t *testing.T) {
{MerkleValue: []byte("b")},
}),
},
nodeHashes: map[string]struct{}{},
nodeHashes: map[common.Hash]struct{}{},
},
"branch_node": {
node: &Node{
MerkleValue: []byte(merkleValue32Zeroes),
MerkleValue: merkleValue32Zeroes.ToBytes(),
Children: padRightChildren([]*Node{
{MerkleValue: []byte(merkleValue32Ones)},
{MerkleValue: merkleValue32Ones.ToBytes()},
}),
},
nodeHashes: map[string]struct{}{
nodeHashes: map[common.Hash]struct{}{
merkleValue32Zeroes: {},
merkleValue32Ones: {},
},
},
"nested_branch_node": {
node: &Node{
MerkleValue: []byte(merkleValue32Zeroes),
MerkleValue: merkleValue32Zeroes.ToBytes(),
Children: padRightChildren([]*Node{
{MerkleValue: []byte(merkleValue32Ones)},
{MerkleValue: merkleValue32Ones.ToBytes()},
{
MerkleValue: []byte(merkleValue32Twos),
MerkleValue: merkleValue32Twos.ToBytes(),
Children: padRightChildren([]*Node{
{MerkleValue: []byte(merkleValue32Threes)},
{MerkleValue: merkleValue32Threes.ToBytes()},
}),
},
}),
},
nodeHashes: map[string]struct{}{
nodeHashes: map[common.Hash]struct{}{
merkleValue32Zeroes: {},
merkleValue32Ones: {},
merkleValue32Twos: {},
Expand All @@ -238,7 +245,7 @@ func Test_PopulateNodeHashes(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()

nodeHashes := make(map[string]struct{})
nodeHashes := make(map[common.Hash]struct{})

if testCase.panicValue != nil {
assert.PanicsWithValue(t, testCase.panicValue, func() {
Expand Down
11 changes: 7 additions & 4 deletions lib/trie/proof/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func Generate(rootHash []byte, fullKeys [][]byte, database Database) (
buffer := pools.DigestBuffers.Get().(*bytes.Buffer)
defer pools.DigestBuffers.Put(buffer)

merkleValuesSeen := make(map[string]struct{})
nodeHashesSeen := make(map[common.Hash]struct{})
for _, fullKey := range fullKeys {
fullKeyNibbles := codec.KeyLEToNibbles(fullKey)
newEncodedProofNodes, err := walkRoot(rootNode, fullKeyNibbles)
Expand All @@ -56,13 +56,16 @@ func Generate(rootHash []byte, fullKeys [][]byte, database Database) (
if err != nil {
return nil, fmt.Errorf("blake2b hash: %w", err)
}
merkleValueString := buffer.String()
// Note: all encoded proof nodes are larger than 32B so their
// merkle value is the encoding hash digest (32B) and never the
// encoding itself.
nodeHash := common.NewHash(buffer.Bytes())

_, seen := merkleValuesSeen[merkleValueString]
_, seen := nodeHashesSeen[nodeHash]
if seen {
continue
}
merkleValuesSeen[merkleValueString] = struct{}{}
nodeHashesSeen[nodeHash] = struct{}{}

encodedProofNodes = append(encodedProofNodes, encodedProofNode)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/trie/proof/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func buildTrie(encodedProofNodes [][]byte, rootHash []byte) (t *trie.Trie, err e
buffer.Reset()
err = node.MerkleValueRoot(encodedProofNode, buffer)
if err != nil {
return nil, fmt.Errorf("calculating Merkle value: %w", err)
return nil, fmt.Errorf("calculating node hash: %w", err)
}
digest := buffer.Bytes()

Expand Down
Loading

0 comments on commit acf44f4

Please sign in to comment.