Skip to content

Commit

Permalink
Add SMT store type (cosmos#8507)
Browse files Browse the repository at this point in the history
* Initial SMT store type

SMT (Sparse Merkle Tree) is intended to replace IAVL. New type
implements same interfaces as iavl.Store.

* Add iteration support to SMT

Sparse Merkle Tree does not support iteration over keys in order.
To provide drop-in replacement for IAVL, Iterator and ReverseIterator
has to be implemented.
SMT Store implementation use the underlying KV store to:
 - maintain a list of keys (under a prefix)
 - iterate over a keys
Values are stored only in SMT.

* Migrate to smt v0.1.1

* Extra test for SMT iterator

* CommitStore implementation for SMT store

* Use interface instead of concrete type

* Add telemetry to SMT store

* SMT: version->root mapping, cleanup

* SMT proofs - initial code

* Tests for SMT store ProofOp implementation

* Fix linter errors

* Use simple 1 byte KV-store prefixes

* Improve assertions in tests

* Use mutex properly

* Store data in ADR-040-compatible way

SMT stores:
 * key -> hash(key, value)

KV store stores:
 * key->value in "bucket 1",
 * hash(key, value) -> key in "bucket 2".
  • Loading branch information
tzdybal authored and roysc committed Jun 23, 2021
1 parent 45265b1 commit 1105929
Show file tree
Hide file tree
Showing 9 changed files with 640 additions and 1 deletion.
7 changes: 7 additions & 0 deletions store/rootmulti/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package rootmulti
import (
"github.com/tendermint/tendermint/crypto/merkle"

"github.com/cosmos/cosmos-sdk/store/smt"
storetypes "github.com/cosmos/cosmos-sdk/store/types"
)

Expand All @@ -25,3 +26,9 @@ func DefaultProofRuntime() (prt *merkle.ProofRuntime) {
prt.RegisterOpDecoder(storetypes.ProofOpSimpleMerkleCommitment, storetypes.CommitmentOpDecoder)
return
}

func SMTProofRuntime() (prt *merkle.ProofRuntime) {
prt = merkle.NewProofRuntime()
prt.RegisterOpDecoder(smt.ProofType, smt.ProofDecoder)
return prt
}
2 changes: 1 addition & 1 deletion store/rootmulti/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ func (rs *Store) SetInitialVersion(version int64) error {
// If the store is wrapped with an inter-block cache, we must first unwrap
// it to get the underlying IAVL store.
store = rs.GetCommitKVStore(key)
store.(*iavl.Store).SetInitialVersion(version)
store.(types.StoreWithInitialVersion).SetInitialVersion(version)
}
}

Expand Down
101 changes: 101 additions & 0 deletions store/smt/iterator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package smt

import (
"bytes"

dbm "github.com/tendermint/tm-db"
)

type Iterator struct {
store *Store
iter dbm.Iterator
}

func indexKey(key []byte) []byte {
return append(indexPrefix, key...)
}

func plainKey(key []byte) []byte {
return key[prefixLen:]
}

func startKey(key []byte) []byte {
if key == nil {
return dataPrefix
}
return dataKey(key)
}

func endKey(key []byte) []byte {
if key == nil {
return indexPrefix
}
return dataKey(key)
}

func newIterator(s *Store, start, end []byte, reverse bool) (*Iterator, error) {
start = startKey(start)
end = endKey(end)
var i dbm.Iterator
var err error
if reverse {
i, err = s.db.ReverseIterator(start, end)
} else {
i, err = s.db.Iterator(start, end)
}
if err != nil {
return nil, err
}
return &Iterator{store: s, iter: i}, nil
}

// Domain returns the start (inclusive) and end (exclusive) limits of the iterator.
// CONTRACT: start, end readonly []byte
func (i *Iterator) Domain() (start []byte, end []byte) {
start, end = i.iter.Domain()
if bytes.Equal(start, dataPrefix) {
start = nil
} else {
start = plainKey(start)
}
if bytes.Equal(end, indexPrefix) {
end = nil
} else {
end = plainKey(end)
}
return start, end
}

// Valid returns whether the current iterator is valid. Once invalid, the Iterator remains
// invalid forever.
func (i *Iterator) Valid() bool {
return i.iter.Valid()
}

// Next moves the iterator to the next key in the database, as defined by order of iteration.
// If Valid returns false, this method will panic.
func (i *Iterator) Next() {
i.iter.Next()
}

// Key returns the key at the current position. Panics if the iterator is invalid.
// CONTRACT: key readonly []byte
func (i *Iterator) Key() (key []byte) {
return plainKey(i.iter.Key())
}

// Value returns the value at the current position. Panics if the iterator is invalid.
// CONTRACT: value readonly []byte
func (i *Iterator) Value() (value []byte) {
return i.store.Get(i.Key())
}

// Error returns the last error encountered by the iterator, if any.
func (i *Iterator) Error() error {
return i.iter.Error()
}

// Close closes the iterator, relasing any allocated resources.
func (i *Iterator) Close() error {
return i.iter.Close()
}
113 changes: 113 additions & 0 deletions store/smt/iterator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package smt_test

import (
"bytes"
"sort"
"testing"

"github.com/cosmos/cosmos-sdk/store/smt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dbm "github.com/tendermint/tm-db"
)

func TestIteration(t *testing.T) {
assert := assert.New(t)
require := require.New(t)

pairs := []struct{ key, val []byte }{
{[]byte("foo"), []byte("bar")},
{[]byte("lorem"), []byte("ipsum")},
{[]byte("alpha"), []byte("beta")},
{[]byte("gamma"), []byte("delta")},
{[]byte("epsilon"), []byte("zeta")},
{[]byte("eta"), []byte("theta")},
{[]byte("iota"), []byte("kappa")},
}

s := smt.NewStore(dbm.NewMemDB())

for _, p := range pairs {
s.Set(p.key, p.val)
}

// sort test data by key, to get "expected" ordering
sort.Slice(pairs, func(i, j int) bool {
return bytes.Compare(pairs[i].key, pairs[j].key) < 0
})

iter := s.Iterator([]byte("alpha"), []byte("omega"))
for _, p := range pairs {
require.True(iter.Valid())
require.Equal(p.key, iter.Key())
require.Equal(p.val, iter.Value())
iter.Next()
}
assert.False(iter.Valid())
assert.NoError(iter.Error())
assert.NoError(iter.Close())

iter = s.Iterator(nil, nil)
for _, p := range pairs {
require.True(iter.Valid())
require.Equal(p.key, iter.Key())
require.Equal(p.val, iter.Value())
iter.Next()
}
assert.False(iter.Valid())
assert.NoError(iter.Error())
assert.NoError(iter.Close())

iter = s.Iterator([]byte("epsilon"), []byte("gamma"))
for _, p := range pairs[1:4] {
require.True(iter.Valid())
require.Equal(p.key, iter.Key())
require.Equal(p.val, iter.Value())
iter.Next()
}
assert.False(iter.Valid())
assert.NoError(iter.Error())
assert.NoError(iter.Close())

rIter := s.ReverseIterator(nil, nil)
for i := len(pairs) - 1; i >= 0; i-- {
require.True(rIter.Valid())
require.Equal(pairs[i].key, rIter.Key())
require.Equal(pairs[i].val, rIter.Value())
rIter.Next()
}
assert.False(rIter.Valid())
assert.NoError(rIter.Error())
assert.NoError(rIter.Close())

// delete something, and ensure that iteration still works
s.Delete([]byte("eta"))

iter = s.Iterator(nil, nil)
for _, p := range pairs {
if !bytes.Equal([]byte("eta"), p.key) {
require.True(iter.Valid())
require.Equal(p.key, iter.Key())
require.Equal(p.val, iter.Value())
iter.Next()
}
}
assert.False(iter.Valid())
assert.NoError(iter.Error())
assert.NoError(iter.Close())
}

func TestDomain(t *testing.T) {
assert := assert.New(t)
s := smt.NewStore(dbm.NewMemDB())

iter := s.Iterator(nil, nil)
start, end := iter.Domain()
assert.Nil(start)
assert.Nil(end)

iter = s.Iterator([]byte("foo"), []byte("bar"))
start, end = iter.Domain()
assert.Equal([]byte("foo"), start)
assert.Equal([]byte("bar"), end)
}
92 changes: 92 additions & 0 deletions store/smt/proof.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package smt

import (
"bytes"
"crypto/sha256"
"encoding/gob"
"hash"

"github.com/cosmos/cosmos-sdk/store/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/lazyledger/smt"
"github.com/tendermint/tendermint/crypto/merkle"
tmmerkle "github.com/tendermint/tendermint/proto/tendermint/crypto"
)

type HasherType byte

const (
SHA256 HasherType = iota
)

const (
ProofType = "smt"
)

type ProofOp struct {
Root []byte
Key []byte
Hasher HasherType
Proof smt.SparseMerkleProof
}

var _ merkle.ProofOperator = &ProofOp{}

func NewProofOp(root, key []byte, hasher HasherType, proof smt.SparseMerkleProof) *ProofOp {
return &ProofOp{
Root: root,
Key: key,
Hasher: hasher,
Proof: proof,
}
}

func (p *ProofOp) Run(args [][]byte) ([][]byte, error) {
switch len(args) {
case 0: // non-membership proof
if !smt.VerifyProof(p.Proof, p.Root, p.Key, []byte{}, getHasher(p.Hasher)) {
return nil, sdkerrors.Wrapf(types.ErrInvalidProof, "proof did not verify absence of key: %s", p.Key)
}
case 1: // membership proof
if !smt.VerifyProof(p.Proof, p.Root, p.Key, args[0], getHasher(p.Hasher)) {
return nil, sdkerrors.Wrapf(types.ErrInvalidProof, "proof did not verify existence of key %s with given value %x", p.Key, args[0])
}
default:
return nil, sdkerrors.Wrapf(types.ErrInvalidProof, "args must be length 0 or 1, got: %d", len(args))
}
return [][]byte{p.Root}, nil
}

func (p *ProofOp) GetKey() []byte {
return p.Key
}

func (p *ProofOp) ProofOp() tmmerkle.ProofOp {
var data bytes.Buffer
enc := gob.NewEncoder(&data)
enc.Encode(p)
return tmmerkle.ProofOp{
Type: "smt",
Key: p.Key,
Data: data.Bytes(),
}
}

func ProofDecoder(pop tmmerkle.ProofOp) (merkle.ProofOperator, error) {
dec := gob.NewDecoder(bytes.NewBuffer(pop.Data))
var proof ProofOp
err := dec.Decode(&proof)
if err != nil {
return nil, err
}
return &proof, nil
}

func getHasher(hasher HasherType) hash.Hash {
switch hasher {
case SHA256:
return sha256.New()
default:
return nil
}
}
Loading

0 comments on commit 1105929

Please sign in to comment.