Skip to content

Commit

Permalink
chore: Backport IAVL Concurrency fix for v0.20 (#828)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattverse authored Sep 4, 2023
1 parent b35e4ff commit df3db2d
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 70 deletions.
21 changes: 15 additions & 6 deletions iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package iavl
import (
"math/rand"
"sort"
"sync"
"testing"

dbm "github.com/cometbft/cometbft-db"
"github.com/cosmos/iavl/fastnode"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -36,7 +36,7 @@ func TestIterator_NewIterator_NilTree_Failure(t *testing.T) {
})

t.Run("Unsaved Fast Iterator", func(t *testing.T) {
itr := NewUnsavedFastIterator(start, end, ascending, nil, map[string]*fastnode.Node{}, map[string]interface{}{})
itr := NewUnsavedFastIterator(start, end, ascending, nil, &sync.Map{}, &sync.Map{})
performTest(t, itr)
require.ErrorIs(t, errFastIteratorNilNdbGiven, itr.Error())
})
Expand Down Expand Up @@ -297,14 +297,14 @@ func setupUnsavedFastIterator(t *testing.T, config *iteratorTestConfig) (dbm.Ite
require.NoError(t, err)

// No unsaved additions or removals should be present after saving
require.Equal(t, 0, len(tree.unsavedFastNodeAdditions))
require.Equal(t, 0, len(tree.unsavedFastNodeRemovals))
require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeAdditions))
require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeRemovals))

// Ensure that there are unsaved additions and removals present
secondHalfMirror := setupMirrorForIterator(t, &secondHalfConfig, tree)

require.True(t, len(tree.unsavedFastNodeAdditions) >= len(secondHalfMirror))
require.Equal(t, 0, len(tree.unsavedFastNodeRemovals))
require.True(t, syncMapCount(tree.unsavedFastNodeAdditions) >= len(secondHalfMirror))
require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeRemovals))

// Merge the two halves
if config.ascending {
Expand All @@ -331,6 +331,15 @@ func setupUnsavedFastIterator(t *testing.T, config *iteratorTestConfig) (dbm.Ite
return itr, mirror
}

func syncMapCount(m *sync.Map) int {
count := 0
m.Range(func(_, _ interface{}) bool {
count++
return true
})
return count
}

func TestNodeIterator_WithEmptyRoot(t *testing.T) {
itr, err := NewNodeIterator(nil, newNodeDB(dbm.NewMemDB(), 0, nil))
require.NoError(t, err)
Expand Down
78 changes: 47 additions & 31 deletions mutable_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ var ErrVersionDoesNotExist = errors.New("version does not exist")
//
// The inner ImmutableTree should not be used directly by callers.
type MutableTree struct {
*ImmutableTree // The current, working tree.
lastSaved *ImmutableTree // The most recently saved tree.
orphans map[string]int64 // Nodes removed by changes to working tree.
versions map[int64]bool // The previous, saved versions of the tree.
allRootLoaded bool // Whether all roots are loaded or not(by LazyLoadVersion)
unsavedFastNodeAdditions map[string]*fastnode.Node // FastNodes that have not yet been saved to disk
unsavedFastNodeRemovals map[string]interface{} // FastNodes that have not yet been removed from disk
*ImmutableTree // The current, working tree.
lastSaved *ImmutableTree // The most recently saved tree.
orphans map[string]int64 // Nodes removed by changes to working tree.
versions map[int64]bool // The previous, saved versions of the tree.
allRootLoaded bool // Whether all roots are loaded or not(by LazyLoadVersion)
unsavedFastNodeAdditions *sync.Map // map[string]*FastNode FastNodes that have not yet been saved to disk
unsavedFastNodeRemovals *sync.Map // map[string]interface{} FastNodes that have not yet been removed from disk
ndb *nodeDB
skipFastStorageUpgrade bool // If true, the tree will work like no fast storage and always not upgrade fast storage

Expand All @@ -59,8 +59,8 @@ func NewMutableTreeWithOpts(db dbm.DB, cacheSize int, opts *Options, skipFastSto
orphans: map[string]int64{},
versions: map[int64]bool{},
allRootLoaded: false,
unsavedFastNodeAdditions: make(map[string]*fastnode.Node),
unsavedFastNodeRemovals: make(map[string]interface{}),
unsavedFastNodeAdditions: &sync.Map{},
unsavedFastNodeRemovals: &sync.Map{},
ndb: ndb,
skipFastStorageUpgrade: skipFastStorageUpgrade,
}, nil
Expand Down Expand Up @@ -152,11 +152,11 @@ func (tree *MutableTree) Get(key []byte) ([]byte, error) {
}

if !tree.skipFastStorageUpgrade {
if fastNode, ok := tree.unsavedFastNodeAdditions[ibytes.UnsafeBytesToStr(key)]; ok {
return fastNode.GetValue(), nil
if fastNode, ok := tree.unsavedFastNodeAdditions.Load(ibytes.UnsafeBytesToStr(key)); ok {
return fastNode.(*fastnode.Node).GetValue(), nil
}
// check if node was deleted
if _, ok := tree.unsavedFastNodeRemovals[string(key)]; ok {
if _, ok := tree.unsavedFastNodeRemovals.Load(string(key)); ok {
return nil, nil
}
}
Expand Down Expand Up @@ -816,8 +816,8 @@ func (tree *MutableTree) Rollback() {
}
tree.orphans = map[string]int64{}
if !tree.skipFastStorageUpgrade {
tree.unsavedFastNodeAdditions = map[string]*fastnode.Node{}
tree.unsavedFastNodeRemovals = map[string]interface{}{}
tree.unsavedFastNodeAdditions = &sync.Map{}
tree.unsavedFastNodeRemovals = &sync.Map{}
}
}

Expand Down Expand Up @@ -936,8 +936,8 @@ func (tree *MutableTree) SaveVersion() ([]byte, int64, error) {
tree.lastSaved = tree.ImmutableTree.clone()
tree.orphans = map[string]int64{}
if !tree.skipFastStorageUpgrade {
tree.unsavedFastNodeAdditions = make(map[string]*fastnode.Node)
tree.unsavedFastNodeRemovals = make(map[string]interface{})
tree.unsavedFastNodeAdditions = &sync.Map{}
tree.unsavedFastNodeRemovals = &sync.Map{}
}

hash, err := tree.Hash()
Expand All @@ -958,48 +958,64 @@ func (tree *MutableTree) saveFastNodeVersion() error {
return tree.ndb.setFastStorageVersionToBatch()
}

// nolint: unused
func (tree *MutableTree) getUnsavedFastNodeAdditions() map[string]*fastnode.Node {
return tree.unsavedFastNodeAdditions
additions := make(map[string]*fastnode.Node)
tree.unsavedFastNodeAdditions.Range(func(key, value interface{}) bool {
additions[key.(string)] = value.(*fastnode.Node)
return true
})
return additions
}

// getUnsavedFastNodeRemovals returns unsaved FastNodes to remove

func (tree *MutableTree) getUnsavedFastNodeRemovals() map[string]interface{} {
return tree.unsavedFastNodeRemovals
removals := make(map[string]interface{})
tree.unsavedFastNodeRemovals.Range(func(key, value interface{}) bool {
removals[key.(string)] = value
return true
})
return removals
}

// addUnsavedAddition stores an addition into the unsaved additions map
func (tree *MutableTree) addUnsavedAddition(key []byte, node *fastnode.Node) {
skey := ibytes.UnsafeBytesToStr(key)
delete(tree.unsavedFastNodeRemovals, skey)
tree.unsavedFastNodeAdditions[skey] = node
tree.unsavedFastNodeRemovals.Delete(skey)
tree.unsavedFastNodeAdditions.Store(skey, node)
}

func (tree *MutableTree) saveFastNodeAdditions() error {
keysToSort := make([]string, 0, len(tree.unsavedFastNodeAdditions))
for key := range tree.unsavedFastNodeAdditions {
keysToSort = append(keysToSort, key)
}
keysToSort := make([]string, 0)
tree.unsavedFastNodeAdditions.Range(func(k, v interface{}) bool {
keysToSort = append(keysToSort, k.(string))
return true
})
sort.Strings(keysToSort)

for _, key := range keysToSort {
if err := tree.ndb.SaveFastNode(tree.unsavedFastNodeAdditions[key]); err != nil {
val, _ := tree.unsavedFastNodeAdditions.Load(key)
if err := tree.ndb.SaveFastNode(val.(*fastnode.Node)); err != nil {
return err
}
}
return nil
}

// addUnsavedRemoval adds a removal to the unsaved removals map
func (tree *MutableTree) addUnsavedRemoval(key []byte) {
skey := ibytes.UnsafeBytesToStr(key)
delete(tree.unsavedFastNodeAdditions, skey)
tree.unsavedFastNodeRemovals[skey] = true
tree.unsavedFastNodeAdditions.Delete(skey)
tree.unsavedFastNodeRemovals.Store(skey, true)
}

func (tree *MutableTree) saveFastNodeRemovals() error {
keysToSort := make([]string, 0, len(tree.unsavedFastNodeRemovals))
for key := range tree.unsavedFastNodeRemovals {
keysToSort = append(keysToSort, key)
}
keysToSort := make([]string, 0)
tree.unsavedFastNodeRemovals.Range(func(k, v interface{}) bool {
keysToSort = append(keysToSort, k.(string))
return true
})
sort.Strings(keysToSort)

for _, key := range keysToSort {
Expand Down
74 changes: 41 additions & 33 deletions unsaved_fast_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"bytes"
"errors"
"sort"
"sync"

dbm "github.com/cometbft/cometbft-db"
"github.com/cosmos/iavl/fastnode"
ibytes "github.com/cosmos/iavl/internal/bytes"

"github.com/cosmos/iavl/fastnode"
)

var (
Expand All @@ -30,14 +32,14 @@ type UnsavedFastIterator struct {
fastIterator dbm.Iterator

nextUnsavedNodeIdx int
unsavedFastNodeAdditions map[string]*fastnode.Node
unsavedFastNodeRemovals map[string]interface{}
unsavedFastNodeAdditions *sync.Map // map[string]*FastNode
unsavedFastNodeRemovals *sync.Map // map[string]interface{}
unsavedFastNodesToSort []string
}

var _ dbm.Iterator = (*UnsavedFastIterator)(nil)

func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsavedFastNodeAdditions map[string]*fastnode.Node, unsavedFastNodeRemovals map[string]interface{}) *UnsavedFastIterator {
func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsavedFastNodeAdditions, unsavedFastNodeRemovals *sync.Map) *UnsavedFastIterator {
iter := &UnsavedFastIterator{
start: start,
end: end,
Expand All @@ -51,28 +53,6 @@ func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsa
fastIterator: NewFastIterator(start, end, ascending, ndb),
}

// We need to ensure that we iterate over saved and unsaved state in order.
// The strategy is to sort unsaved nodes, the fast node on disk are already sorted.
// Then, we keep a pointer to both the unsaved and saved nodes, and iterate over them in order efficiently.
for _, fastNode := range unsavedFastNodeAdditions {
if start != nil && bytes.Compare(fastNode.GetKey(), start) < 0 {
continue
}

if end != nil && bytes.Compare(fastNode.GetKey(), end) >= 0 {
continue
}

iter.unsavedFastNodesToSort = append(iter.unsavedFastNodesToSort, ibytes.UnsafeBytesToStr(fastNode.GetKey()))
}

sort.Slice(iter.unsavedFastNodesToSort, func(i, j int) bool {
if ascending {
return iter.unsavedFastNodesToSort[i] < iter.unsavedFastNodesToSort[j]
}
return iter.unsavedFastNodesToSort[i] > iter.unsavedFastNodesToSort[j]
})

if iter.ndb == nil {
iter.err = errFastIteratorNilNdbGiven
iter.valid = false
Expand All @@ -91,7 +71,33 @@ func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsa
return iter
}

// Move to the first elemenet
// We need to ensure that we iterate over saved and unsaved state in order.
// The strategy is to sort unsaved nodes, the fast node on disk are already sorted.
// Then, we keep a pointer to both the unsaved and saved nodes, and iterate over them in order efficiently.
unsavedFastNodeAdditions.Range(func(k, v interface{}) bool {
fastNode := v.(*fastnode.Node)

if start != nil && bytes.Compare(fastNode.GetKey(), start) < 0 {
return true
}

if end != nil && bytes.Compare(fastNode.GetKey(), end) >= 0 {
return true
}

iter.unsavedFastNodesToSort = append(iter.unsavedFastNodesToSort, k.(string))

return true
})

sort.Slice(iter.unsavedFastNodesToSort, func(i, j int) bool {
if ascending {
return iter.unsavedFastNodesToSort[i] < iter.unsavedFastNodesToSort[j]
}
return iter.unsavedFastNodesToSort[i] > iter.unsavedFastNodesToSort[j]
})

// Move to the first element
iter.Next()

return iter
Expand Down Expand Up @@ -136,16 +142,17 @@ func (iter *UnsavedFastIterator) Next() {

diskKeyStr := ibytes.UnsafeBytesToStr(iter.fastIterator.Key())
if iter.fastIterator.Valid() && iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) {

if iter.unsavedFastNodeRemovals[diskKeyStr] != nil {
value, ok := iter.unsavedFastNodeRemovals.Load(diskKeyStr)
if ok && value != nil {
// If next fast node from disk is to be removed, skip it.
iter.fastIterator.Next()
iter.Next()
return
}

nextUnsavedKey := iter.unsavedFastNodesToSort[iter.nextUnsavedNodeIdx]
nextUnsavedNode := iter.unsavedFastNodeAdditions[nextUnsavedKey]
nextUnsavedNodeVal, _ := iter.unsavedFastNodeAdditions.Load(nextUnsavedKey)
nextUnsavedNode := nextUnsavedNodeVal.(*fastnode.Node)

var isUnsavedNext bool
if iter.ascending {
Expand All @@ -156,7 +163,6 @@ func (iter *UnsavedFastIterator) Next() {

if isUnsavedNext {
// Unsaved node is next

if diskKeyStr == nextUnsavedKey {
// Unsaved update prevails over saved copy so we skip the copy from disk
iter.fastIterator.Next()
Expand All @@ -178,7 +184,8 @@ func (iter *UnsavedFastIterator) Next() {

// if only nodes on disk are left, we return them
if iter.fastIterator.Valid() {
if iter.unsavedFastNodeRemovals[diskKeyStr] != nil {
value, ok := iter.unsavedFastNodeRemovals.Load(diskKeyStr)
if ok && value != nil {
// If next fast node from disk is to be removed, skip it.
iter.fastIterator.Next()
iter.Next()
Expand All @@ -195,7 +202,8 @@ func (iter *UnsavedFastIterator) Next() {
// if only unsaved nodes are left, we can just iterate
if iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) {
nextUnsavedKey := iter.unsavedFastNodesToSort[iter.nextUnsavedNodeIdx]
nextUnsavedNode := iter.unsavedFastNodeAdditions[nextUnsavedKey]
nextUnsavedNodeVal, _ := iter.unsavedFastNodeAdditions.Load(nextUnsavedKey)
nextUnsavedNode := nextUnsavedNodeVal.(*fastnode.Node)

iter.nextKey = nextUnsavedNode.GetKey()
iter.nextVal = nextUnsavedNode.GetValue()
Expand Down

0 comments on commit df3db2d

Please sign in to comment.