Skip to content

Commit

Permalink
Mock trie db
Browse files Browse the repository at this point in the history
  • Loading branch information
dimartiro committed Jul 16, 2023
1 parent d1441b7 commit 9fb4d9c
Show file tree
Hide file tree
Showing 14 changed files with 100 additions and 45 deletions.
49 changes: 49 additions & 0 deletions dot/state/db_getter_mocks_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions dot/state/mocks_generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ package state
//go:generate mockgen -destination=mocks_runtime_test.go -package $GOPACKAGE github.com/ChainSafe/gossamer/lib/runtime Instance
//go:generate mockgen -destination=mock_gauge_test.go -package $GOPACKAGE github.com/prometheus/client_golang/prometheus Gauge
//go:generate mockgen -destination=mock_counter_test.go -package $GOPACKAGE github.com/prometheus/client_golang/prometheus Counter
//go:generate mockgen -destination=db_getter_mocks_test.go -package=$GOPACKAGE github.com/ChainSafe/gossamer/lib/trie DBGetter
4 changes: 3 additions & 1 deletion dot/state/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,15 @@ func TestGetStorageChildAndGetStorageFromChild(t *testing.T) {
genHeader.Hash(),
"0",
))
dbGetter := NewMockDBGetter(ctrl)
dbGetter.EXPECT().Get(gomock.Any()).Times(0)

trieRoot := &node.Node{
PartialKey: []byte{1, 2},
StorageValue: []byte{3, 4},
Dirty: true,
}
testChildTrie := trie.NewTrie(trieRoot, nil)
testChildTrie := trie.NewTrie(trieRoot, dbGetter)

testChildTrie.Put([]byte("keyInsidechild"), []byte("voila"))

Expand Down
12 changes: 9 additions & 3 deletions dot/state/tries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@ func Test_Tries_SetEmptyTrie(t *testing.T) {

func Test_Tries_SetTrie(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
dbGetter := NewMockDBGetter(ctrl)
dbGetter.EXPECT().Get(gomock.Any()).Times(0)

tr := trie.NewTrie(&node.Node{PartialKey: []byte{1}}, nil)
tr := trie.NewTrie(&node.Node{PartialKey: []byte{1}}, dbGetter)

tries := NewTries()
tries.SetTrie(tr)
Expand Down Expand Up @@ -188,6 +191,9 @@ func Test_Tries_delete(t *testing.T) {
}
func Test_Tries_get(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
dbGetter := NewMockDBGetter(ctrl)
dbGetter.EXPECT().Get(gomock.Any()).Times(0)

testCases := map[string]struct {
tries *Tries
Expand All @@ -200,14 +206,14 @@ func Test_Tries_get(t *testing.T) {
{1, 2, 3}: trie.NewTrie(&node.Node{
PartialKey: []byte{1, 2, 3},
StorageValue: []byte{1},
}, nil),
}, dbGetter),
},
},
root: common.Hash{1, 2, 3},
trie: trie.NewTrie(&node.Node{
PartialKey: []byte{1, 2, 3},
StorageValue: []byte{1},
}, nil),
}, dbGetter),
},
"not_found_in_map": {
// similar to not found in database
Expand Down
18 changes: 9 additions & 9 deletions lib/trie/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ import (
"github.com/ChainSafe/chaindb"
)

// Getter gets a value corresponding to the given key.
type Getter interface {
// DBGetter gets a value corresponding to the given key.
type DBGetter interface {
Get(key []byte) (value []byte, err error)
}

// Putter puts a value at the given key and returns an error.
type Putter interface {
// DBPutter puts a value at the given key and returns an error.
type DBPutter interface {
Put(key []byte, value []byte) error
}

Expand All @@ -31,7 +31,7 @@ type NewBatcher interface {

// Load reconstructs the trie from the database from the given root hash.
// It is used when restarting the node to load the current state trie.
func (t *Trie) Load(db Getter, rootHash common.Hash) error {
func (t *Trie) Load(db DBGetter, rootHash common.Hash) error {
if rootHash == EmptyHash {
t.root = nil
return nil
Expand All @@ -55,7 +55,7 @@ func (t *Trie) Load(db Getter, rootHash common.Hash) error {
return t.loadNode(db, t.root)
}

func (t *Trie) loadNode(db Getter, n *Node) error {
func (t *Trie) loadNode(db DBGetter, n *Node) error {
if n.Kind() != node.Branch {
return nil
}
Expand Down Expand Up @@ -197,7 +197,7 @@ func recordAllDeleted(n *Node, recorder DeltaRecorder) {
// It recursively descends into the trie using the database starting
// from the root node until it reaches the node with the given key.
// It then reads the value from the database.
func GetFromDB(db Getter, rootHash common.Hash, key []byte) (
func GetFromDB(db DBGetter, rootHash common.Hash, key []byte) (
value []byte, err error) {
if rootHash == EmptyHash {
return nil, nil
Expand All @@ -223,7 +223,7 @@ func GetFromDB(db Getter, rootHash common.Hash, key []byte) (
// for the value corresponding to a key.
// Note it does not copy the value so modifying the value bytes
// slice will modify the value of the node in the trie.
func getFromDBAtNode(db Getter, n *Node, key []byte) (
func getFromDBAtNode(db DBGetter, n *Node, key []byte) (
value []byte, err error) {
if n.Kind() == node.Leaf {
if bytes.Equal(n.PartialKey, key) {
Expand Down Expand Up @@ -290,7 +290,7 @@ func (t *Trie) WriteDirty(db NewBatcher) error {
return batch.Flush()
}

func (t *Trie) writeDirtyNode(db Putter, n *Node) (err error) {
func (t *Trie) writeDirtyNode(db DBPutter, n *Node) (err error) {
if n == nil || !n.Dirty {
return nil
}
Expand Down
6 changes: 0 additions & 6 deletions lib/trie/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@ import (
"github.com/ChainSafe/gossamer/lib/common"
)

// Database defines a key value Get method used
// for proof generation.
type Database interface {
Get(key []byte) (value []byte, err error)
}

type MemoryDB struct {
data map[common.Hash][]byte
}
Expand Down
6 changes: 0 additions & 6 deletions lib/trie/db/mocks_generate_test.go

This file was deleted.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 7 additions & 2 deletions lib/trie/proof/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,23 @@ import (
"github.com/ChainSafe/gossamer/internal/trie/pools"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/trie"
"github.com/ChainSafe/gossamer/lib/trie/db"
)

var (
ErrKeyNotFound = errors.New("key not found")
)

// Database defines a key value Get method used
// for proof generation.
type Database interface {
Get(key []byte) (value []byte, err error)
}

// Generate generates and deduplicates the encoded proof nodes
// for the trie corresponding to the root hash given, and for
// the slice of (Little Endian) full keys given. The database given
// is used to load the trie using the root hash given.
func Generate(rootHash []byte, fullKeys [][]byte, database db.Database) (
func Generate(rootHash []byte, fullKeys [][]byte, database Database) (
encodedProofNodes [][]byte, err error) {
trie := trie.NewEmptyTrie()
if err := trie.Load(database, common.BytesToHash(rootHash)); err != nil {
Expand Down
15 changes: 7 additions & 8 deletions lib/trie/proof/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/ChainSafe/gossamer/internal/trie/codec"
"github.com/ChainSafe/gossamer/internal/trie/node"
"github.com/ChainSafe/gossamer/lib/trie"
"github.com/ChainSafe/gossamer/lib/trie/db"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -32,14 +31,14 @@ func Test_Generate(t *testing.T) {
testCases := map[string]struct {
rootHash []byte
fullKeysNibbles [][]byte
databaseBuilder func(ctrl *gomock.Controller) db.Database
databaseBuilder func(ctrl *gomock.Controller) Database
encodedProofNodes [][]byte
errWrapped error
errMessage string
}{
"failed_loading_trie": {
rootHash: someHash,
databaseBuilder: func(ctrl *gomock.Controller) db.Database {
databaseBuilder: func(ctrl *gomock.Controller) Database {
mockDatabase := NewMockDatabase(ctrl)
mockDatabase.EXPECT().Get(someHash).
Return(nil, errTest)
Expand All @@ -54,7 +53,7 @@ func Test_Generate(t *testing.T) {
"walk_error": {
rootHash: someHash,
fullKeysNibbles: [][]byte{{1}},
databaseBuilder: func(ctrl *gomock.Controller) db.Database {
databaseBuilder: func(ctrl *gomock.Controller) Database {
mockDatabase := NewMockDatabase(ctrl)
encodedRoot := encodeNode(t, node.Node{
PartialKey: []byte{1},
Expand All @@ -70,7 +69,7 @@ func Test_Generate(t *testing.T) {
"leaf_root": {
rootHash: someHash,
fullKeysNibbles: [][]byte{{}},
databaseBuilder: func(ctrl *gomock.Controller) db.Database {
databaseBuilder: func(ctrl *gomock.Controller) Database {
mockDatabase := NewMockDatabase(ctrl)
encodedRoot := encodeNode(t, node.Node{
PartialKey: []byte{1},
Expand All @@ -90,7 +89,7 @@ func Test_Generate(t *testing.T) {
"branch_root": {
rootHash: someHash,
fullKeysNibbles: [][]byte{{}},
databaseBuilder: func(ctrl *gomock.Controller) db.Database {
databaseBuilder: func(ctrl *gomock.Controller) Database {
mockDatabase := NewMockDatabase(ctrl)
encodedRoot := encodeNode(t, node.Node{
PartialKey: []byte{1},
Expand Down Expand Up @@ -126,7 +125,7 @@ func Test_Generate(t *testing.T) {
fullKeysNibbles: [][]byte{
{1, 2, 3, 4},
},
databaseBuilder: func(ctrl *gomock.Controller) db.Database {
databaseBuilder: func(ctrl *gomock.Controller) Database {
mockDatabase := NewMockDatabase(ctrl)

rootNode := node.Node{
Expand Down Expand Up @@ -175,7 +174,7 @@ func Test_Generate(t *testing.T) {
{1, 2, 4, 4},
{1, 2, 5, 5},
},
databaseBuilder: func(ctrl *gomock.Controller) db.Database {
databaseBuilder: func(ctrl *gomock.Controller) Database {
mockDatabase := NewMockDatabase(ctrl)

rootNode := node.Node{
Expand Down
6 changes: 6 additions & 0 deletions lib/trie/proof/mocks_generate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Copyright 2022 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package proof

//go:generate mockgen -destination=database_mocks_test.go -package=$GOPACKAGE . Database
2 changes: 1 addition & 1 deletion lib/trie/proof/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ var (
)

// buildTrie sets a partial trie based on the proof slice of encoded nodes.
func buildTrie(encodedProofNodes [][]byte, rootHash []byte, db db.Database) (t *trie.Trie, err error) {
func buildTrie(encodedProofNodes [][]byte, rootHash []byte, db Database) (t *trie.Trie, err error) {
if len(encodedProofNodes) == 0 {
return nil, fmt.Errorf("%w: for Merkle root hash 0x%x",
ErrEmptyProof, rootHash)
Expand Down
2 changes: 1 addition & 1 deletion lib/trie/proof/verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func Test_buildTrie(t *testing.T) {
encodedProofNodes [][]byte
rootHash []byte
expectedTrie *trie.Trie
db db.Database
db Database
errWrapped error
errMessage string
}
Expand Down
11 changes: 5 additions & 6 deletions lib/trie/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/ChainSafe/gossamer/internal/trie/node"
"github.com/ChainSafe/gossamer/internal/trie/tracking"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/trie/db"
)

// EmptyHash is the empty trie hash.
Expand All @@ -22,7 +21,7 @@ type Trie struct {
generation uint64
root *Node
childTries map[common.Hash]*Trie
db db.Database
db DBGetter
// deltas stores trie deltas since the last trie snapshot.
// For example node hashes that were deleted since
// the last snapshot. These are used by the online
Expand All @@ -37,7 +36,7 @@ func NewEmptyTrie() *Trie {
}

// NewTrie creates a trie with an existing root node
func NewTrie(root *Node, db db.Database) *Trie {
func NewTrie(root *Node, db DBGetter) *Trie {
return &Trie{
root: root,
childTries: make(map[common.Hash]*Trie),
Expand Down Expand Up @@ -721,7 +720,7 @@ func (t *Trie) Get(keyLE []byte) (value []byte) {
return retrieve(t.db, t.root, keyNibbles)
}

func retrieve(db db.Database, parent *Node, key []byte) (value []byte) {
func retrieve(db DBGetter, parent *Node, key []byte) (value []byte) {
if parent == nil {
return nil
}
Expand All @@ -732,7 +731,7 @@ func retrieve(db db.Database, parent *Node, key []byte) (value []byte) {
return retrieveFromBranch(db, parent, key)
}

func retrieveFromLeaf(db db.Database, leaf *Node, key []byte) (value []byte) {
func retrieveFromLeaf(db DBGetter, leaf *Node, key []byte) (value []byte) {
if bytes.Equal(leaf.PartialKey, key) {
if leaf.HashedValue {
// We get the node
Expand All @@ -747,7 +746,7 @@ func retrieveFromLeaf(db db.Database, leaf *Node, key []byte) (value []byte) {
return nil
}

func retrieveFromBranch(db db.Database, branch *Node, key []byte) (value []byte) {
func retrieveFromBranch(db DBGetter, branch *Node, key []byte) (value []byte) {
if len(key) == 0 || bytes.Equal(branch.PartialKey, key) {
return branch.StorageValue
}
Expand Down

0 comments on commit 9fb4d9c

Please sign in to comment.