diff --git a/store/rootmulti/proof.go b/store/rootmulti/proof.go index fc8925b7f20..247e3867b80 100644 --- a/store/rootmulti/proof.go +++ b/store/rootmulti/proof.go @@ -3,6 +3,7 @@ package rootmulti import ( "github.com/tendermint/tendermint/crypto/merkle" + "github.com/cosmos/cosmos-sdk/store/smt" storetypes "github.com/cosmos/cosmos-sdk/store/types" ) @@ -25,3 +26,9 @@ func DefaultProofRuntime() (prt *merkle.ProofRuntime) { prt.RegisterOpDecoder(storetypes.ProofOpSimpleMerkleCommitment, storetypes.CommitmentOpDecoder) return } + +func SMTProofRuntime() (prt *merkle.ProofRuntime) { + prt = merkle.NewProofRuntime() + prt.RegisterOpDecoder(smt.ProofType, smt.ProofDecoder) + return prt +} diff --git a/store/rootmulti/store.go b/store/rootmulti/store.go index 471a24efe2c..c8df42ce3ce 100644 --- a/store/rootmulti/store.go +++ b/store/rootmulti/store.go @@ -586,7 +586,7 @@ func (rs *Store) SetInitialVersion(version int64) error { // If the store is wrapped with an inter-block cache, we must first unwrap // it to get the underlying IAVL store. store = rs.GetCommitKVStore(key) - store.(*iavl.Store).SetInitialVersion(version) + store.(types.StoreWithInitialVersion).SetInitialVersion(version) } } diff --git a/store/smt/iterator.go b/store/smt/iterator.go new file mode 100644 index 00000000000..459460d7753 --- /dev/null +++ b/store/smt/iterator.go @@ -0,0 +1,101 @@ +package smt + +import ( + "bytes" + + dbm "github.com/tendermint/tm-db" +) + +type Iterator struct { + store *Store + iter dbm.Iterator +} + +func indexKey(key []byte) []byte { + return append(indexPrefix, key...) +} + +func plainKey(key []byte) []byte { + return key[prefixLen:] +} + +func startKey(key []byte) []byte { + if key == nil { + return dataPrefix + } + return dataKey(key) +} + +func endKey(key []byte) []byte { + if key == nil { + return indexPrefix + } + return dataKey(key) +} + +func newIterator(s *Store, start, end []byte, reverse bool) (*Iterator, error) { + start = startKey(start) + end = endKey(end) + var i dbm.Iterator + var err error + if reverse { + i, err = s.db.ReverseIterator(start, end) + } else { + i, err = s.db.Iterator(start, end) + } + if err != nil { + return nil, err + } + return &Iterator{store: s, iter: i}, nil +} + +// Domain returns the start (inclusive) and end (exclusive) limits of the iterator. +// CONTRACT: start, end readonly []byte +func (i *Iterator) Domain() (start []byte, end []byte) { + start, end = i.iter.Domain() + if bytes.Equal(start, dataPrefix) { + start = nil + } else { + start = plainKey(start) + } + if bytes.Equal(end, indexPrefix) { + end = nil + } else { + end = plainKey(end) + } + return start, end +} + +// Valid returns whether the current iterator is valid. Once invalid, the Iterator remains +// invalid forever. +func (i *Iterator) Valid() bool { + return i.iter.Valid() +} + +// Next moves the iterator to the next key in the database, as defined by order of iteration. +// If Valid returns false, this method will panic. +func (i *Iterator) Next() { + i.iter.Next() +} + +// Key returns the key at the current position. Panics if the iterator is invalid. +// CONTRACT: key readonly []byte +func (i *Iterator) Key() (key []byte) { + return plainKey(i.iter.Key()) +} + +// Value returns the value at the current position. Panics if the iterator is invalid. +// CONTRACT: value readonly []byte +func (i *Iterator) Value() (value []byte) { + return i.store.Get(i.Key()) +} + +// Error returns the last error encountered by the iterator, if any. +func (i *Iterator) Error() error { + return i.iter.Error() +} + +// Close closes the iterator, relasing any allocated resources. +func (i *Iterator) Close() error { + return i.iter.Close() +} diff --git a/store/smt/iterator_test.go b/store/smt/iterator_test.go new file mode 100644 index 00000000000..6a724e665a1 --- /dev/null +++ b/store/smt/iterator_test.go @@ -0,0 +1,113 @@ +package smt_test + +import ( + "bytes" + "sort" + "testing" + + "github.com/cosmos/cosmos-sdk/store/smt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dbm "github.com/tendermint/tm-db" +) + +func TestIteration(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + pairs := []struct{ key, val []byte }{ + {[]byte("foo"), []byte("bar")}, + {[]byte("lorem"), []byte("ipsum")}, + {[]byte("alpha"), []byte("beta")}, + {[]byte("gamma"), []byte("delta")}, + {[]byte("epsilon"), []byte("zeta")}, + {[]byte("eta"), []byte("theta")}, + {[]byte("iota"), []byte("kappa")}, + } + + s := smt.NewStore(dbm.NewMemDB()) + + for _, p := range pairs { + s.Set(p.key, p.val) + } + + // sort test data by key, to get "expected" ordering + sort.Slice(pairs, func(i, j int) bool { + return bytes.Compare(pairs[i].key, pairs[j].key) < 0 + }) + + iter := s.Iterator([]byte("alpha"), []byte("omega")) + for _, p := range pairs { + require.True(iter.Valid()) + require.Equal(p.key, iter.Key()) + require.Equal(p.val, iter.Value()) + iter.Next() + } + assert.False(iter.Valid()) + assert.NoError(iter.Error()) + assert.NoError(iter.Close()) + + iter = s.Iterator(nil, nil) + for _, p := range pairs { + require.True(iter.Valid()) + require.Equal(p.key, iter.Key()) + require.Equal(p.val, iter.Value()) + iter.Next() + } + assert.False(iter.Valid()) + assert.NoError(iter.Error()) + assert.NoError(iter.Close()) + + iter = s.Iterator([]byte("epsilon"), []byte("gamma")) + for _, p := range pairs[1:4] { + require.True(iter.Valid()) + require.Equal(p.key, iter.Key()) + require.Equal(p.val, iter.Value()) + iter.Next() + } + assert.False(iter.Valid()) + assert.NoError(iter.Error()) + assert.NoError(iter.Close()) + + rIter := s.ReverseIterator(nil, nil) + for i := len(pairs) - 1; i >= 0; i-- { + require.True(rIter.Valid()) + require.Equal(pairs[i].key, rIter.Key()) + require.Equal(pairs[i].val, rIter.Value()) + rIter.Next() + } + assert.False(rIter.Valid()) + assert.NoError(rIter.Error()) + assert.NoError(rIter.Close()) + + // delete something, and ensure that iteration still works + s.Delete([]byte("eta")) + + iter = s.Iterator(nil, nil) + for _, p := range pairs { + if !bytes.Equal([]byte("eta"), p.key) { + require.True(iter.Valid()) + require.Equal(p.key, iter.Key()) + require.Equal(p.val, iter.Value()) + iter.Next() + } + } + assert.False(iter.Valid()) + assert.NoError(iter.Error()) + assert.NoError(iter.Close()) +} + +func TestDomain(t *testing.T) { + assert := assert.New(t) + s := smt.NewStore(dbm.NewMemDB()) + + iter := s.Iterator(nil, nil) + start, end := iter.Domain() + assert.Nil(start) + assert.Nil(end) + + iter = s.Iterator([]byte("foo"), []byte("bar")) + start, end = iter.Domain() + assert.Equal([]byte("foo"), start) + assert.Equal([]byte("bar"), end) +} diff --git a/store/smt/proof.go b/store/smt/proof.go new file mode 100644 index 00000000000..6663c27cec8 --- /dev/null +++ b/store/smt/proof.go @@ -0,0 +1,92 @@ +package smt + +import ( + "bytes" + "crypto/sha256" + "encoding/gob" + "hash" + + "github.com/cosmos/cosmos-sdk/store/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + "github.com/lazyledger/smt" + "github.com/tendermint/tendermint/crypto/merkle" + tmmerkle "github.com/tendermint/tendermint/proto/tendermint/crypto" +) + +type HasherType byte + +const ( + SHA256 HasherType = iota +) + +const ( + ProofType = "smt" +) + +type ProofOp struct { + Root []byte + Key []byte + Hasher HasherType + Proof smt.SparseMerkleProof +} + +var _ merkle.ProofOperator = &ProofOp{} + +func NewProofOp(root, key []byte, hasher HasherType, proof smt.SparseMerkleProof) *ProofOp { + return &ProofOp{ + Root: root, + Key: key, + Hasher: hasher, + Proof: proof, + } +} + +func (p *ProofOp) Run(args [][]byte) ([][]byte, error) { + switch len(args) { + case 0: // non-membership proof + if !smt.VerifyProof(p.Proof, p.Root, p.Key, []byte{}, getHasher(p.Hasher)) { + return nil, sdkerrors.Wrapf(types.ErrInvalidProof, "proof did not verify absence of key: %s", p.Key) + } + case 1: // membership proof + if !smt.VerifyProof(p.Proof, p.Root, p.Key, args[0], getHasher(p.Hasher)) { + return nil, sdkerrors.Wrapf(types.ErrInvalidProof, "proof did not verify existence of key %s with given value %x", p.Key, args[0]) + } + default: + return nil, sdkerrors.Wrapf(types.ErrInvalidProof, "args must be length 0 or 1, got: %d", len(args)) + } + return [][]byte{p.Root}, nil +} + +func (p *ProofOp) GetKey() []byte { + return p.Key +} + +func (p *ProofOp) ProofOp() tmmerkle.ProofOp { + var data bytes.Buffer + enc := gob.NewEncoder(&data) + enc.Encode(p) + return tmmerkle.ProofOp{ + Type: "smt", + Key: p.Key, + Data: data.Bytes(), + } +} + +func ProofDecoder(pop tmmerkle.ProofOp) (merkle.ProofOperator, error) { + dec := gob.NewDecoder(bytes.NewBuffer(pop.Data)) + var proof ProofOp + err := dec.Decode(&proof) + if err != nil { + return nil, err + } + return &proof, nil +} + +func getHasher(hasher HasherType) hash.Hash { + switch hasher { + case SHA256: + return sha256.New() + default: + return nil + } +} diff --git a/store/smt/proof_test.go b/store/smt/proof_test.go new file mode 100644 index 00000000000..75e7974bc21 --- /dev/null +++ b/store/smt/proof_test.go @@ -0,0 +1,68 @@ +package smt_test + +import ( + "crypto/sha256" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + smtstore "github.com/cosmos/cosmos-sdk/store/smt" + "github.com/lazyledger/smt" + dbm "github.com/tendermint/tm-db" +) + +func TestProofOpInterface(t *testing.T) { + hasher := sha256.New() + tree := smt.NewSparseMerkleTree(dbm.NewMemDB(), hasher) + key := []byte("foo") + value := []byte("bar") + root, err := tree.Update(key, value) + require.NoError(t, err) + require.NotEmpty(t, root) + + proof, err := tree.Prove(key) + require.True(t, smt.VerifyProof(proof, root, key, value, hasher)) + + storeProofOp := smtstore.NewProofOp(root, key, smtstore.SHA256, proof) + require.NotNil(t, storeProofOp) + // inclusion proof + r, err := storeProofOp.Run([][]byte{value}) + assert.NoError(t, err) + assert.NotEmpty(t, r) + assert.Equal(t, root, r[0]) + + // inclusion proof - wrong value - should fail + r, err = storeProofOp.Run([][]byte{key}) + assert.Error(t, err) + assert.Empty(t, r) + + // exclusion proof - should fail + r, err = storeProofOp.Run([][]byte{}) + assert.Error(t, err) + assert.Empty(t, r) + + // invalid request - should fail + r, err = storeProofOp.Run([][]byte{key, key}) + assert.Error(t, err) + assert.Empty(t, r) + + // encode + tmProofOp := storeProofOp.ProofOp() + assert.NotNil(t, tmProofOp) + assert.Equal(t, smtstore.ProofType, tmProofOp.Type) + assert.Equal(t, key, tmProofOp.Key, key) + assert.NotEmpty(t, tmProofOp.Data) + + //decode + decoded, err := smtstore.ProofDecoder(tmProofOp) + assert.NoError(t, err) + assert.NotNil(t, decoded) + assert.Equal(t, key, decoded.GetKey()) + + // run proof after decoding + r, err = decoded.Run([][]byte{value}) + assert.NoError(t, err) + assert.NotEmpty(t, r) + assert.Equal(t, root, r[0]) +} diff --git a/store/smt/store.go b/store/smt/store.go new file mode 100644 index 00000000000..b70b097515d --- /dev/null +++ b/store/smt/store.go @@ -0,0 +1,211 @@ +package smt + +import ( + "crypto/sha256" + "encoding/binary" + "io" + "sync" + "time" + + "github.com/cosmos/cosmos-sdk/store/cachekv" + "github.com/cosmos/cosmos-sdk/store/tracekv" + "github.com/cosmos/cosmos-sdk/store/types" + "github.com/cosmos/cosmos-sdk/telemetry" + abci "github.com/tendermint/tendermint/abci/types" + dbm "github.com/tendermint/tm-db" + + "github.com/lazyledger/smt" +) + +var ( + _ types.KVStore = (*Store)(nil) + _ types.CommitStore = (*Store)(nil) + _ types.CommitKVStore = (*Store)(nil) + _ types.Queryable = (*Store)(nil) + _ types.StoreWithInitialVersion = (*Store)(nil) +) + +var ( + prefixLen = 1 + versionsPrefix = []byte{0} + dataPrefix = []byte{1} + indexPrefix = []byte{2} + afterIndex = []byte{3} +) + +// Store Implements types.KVStore and CommitKVStore. +type Store struct { + tree *smt.SparseMerkleTree + db dbm.DB + + version int64 + + opts struct { + initialVersion int64 + pruningOptions types.PruningOptions + } + + mtx sync.RWMutex +} + +func NewStore(underlyingDB dbm.DB) *Store { + return &Store{ + tree: smt.NewSparseMerkleTree(underlyingDB, sha256.New()), + db: underlyingDB, + } +} + +// KVStore interface below: + +func (s *Store) GetStoreType() types.StoreType { + return types.StoreTypeSMT +} + +// CacheWrap branches a store. +func (s *Store) CacheWrap() types.CacheWrap { + return cachekv.NewStore(s) +} + +// CacheWrapWithTrace branches a store with tracing enabled. +func (s *Store) CacheWrapWithTrace(w io.Writer, tc types.TraceContext) types.CacheWrap { + return cachekv.NewStore(tracekv.NewStore(s, w, tc)) +} + +// Get returns nil iff key doesn't exist. Panics on nil key. +func (s *Store) Get(key []byte) []byte { + defer telemetry.MeasureSince(time.Now(), "store", "smt", "get") + val, err := s.db.Get(dataKey(key)) + if err != nil { + panic(err) + } + return val +} + +// Has checks if a key exists. Panics on nil key. +func (s *Store) Has(key []byte) bool { + defer telemetry.MeasureSince(time.Now(), "store", "smt", "has") + has, err := s.db.Has(dataKey(key)) + return err == nil && has +} + +// Set sets the key. Panics on nil key or value. +func (s *Store) Set(key []byte, value []byte) { + kvHash := sha256.Sum256(append(key, value...)) + + s.mtx.Lock() + defer s.mtx.Unlock() + + err := s.db.Set(dataKey(key), value) + if err != nil { + panic(err.Error()) + } + err = s.db.Set(indexKey(kvHash[:]), key) + if err != nil { + panic(err.Error()) + } + _, err = s.tree.Update(key, kvHash[:]) + if err != nil { + panic(err.Error()) + } +} + +// Delete deletes the key. Panics on nil key. +func (s *Store) Delete(key []byte) { + defer telemetry.MeasureSince(time.Now(), "store", "smt", "delete") + + s.mtx.Lock() + defer s.mtx.Unlock() + + _, _ = s.tree.Delete(key) + + dKey := dataKey(key) + defer func() { + _ = s.db.Delete(dKey) + }() + + value, err := s.db.Get(dKey) + if err != nil { + panic(err.Error()) + } + kvHash := sha256.Sum256(append(key, value...)) + _ = s.db.Delete(indexKey(kvHash[:])) +} + +// Iterator over a domain of keys in ascending order. End is exclusive. +// Start must be less than end, or the Iterator is invalid. +// Iterator must be closed by caller. +// To iterate over entire domain, use store.Iterator(nil, nil) +// CONTRACT: No writes may happen within a domain while an iterator exists over it. +// Exceptionally allowed for cachekv.Store, safe to write in the modules. +func (s *Store) Iterator(start []byte, end []byte) types.Iterator { + iter, err := newIterator(s, start, end, false) + if err != nil { + panic(err.Error()) + } + return iter +} + +// Iterator over a domain of keys in descending order. End is exclusive. +// Start must be less than end, or the Iterator is invalid. +// Iterator must be closed by caller. +// CONTRACT: No writes may happen within a domain while an iterator exists over it. +// Exceptionally allowed for cachekv.Store, safe to write in the modules. +func (s *Store) ReverseIterator(start []byte, end []byte) types.Iterator { + iter, err := newIterator(s, start, end, true) + if err != nil { + panic(err.Error()) + } + return iter +} + +// CommitStore interface below: + +func (s *Store) Commit() types.CommitID { + defer telemetry.MeasureSince(time.Now(), "store", "smt", "commit") + version := s.version + 1 + + if version == 1 && s.opts.initialVersion != 0 { + version = s.opts.initialVersion + } + + s.version = version + + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(version)) + s.db.Set(append(versionsPrefix, b...), s.tree.Root()) + + return s.LastCommitID() +} + +func (s *Store) LastCommitID() types.CommitID { + return types.CommitID{ + Version: s.version, + Hash: s.tree.Root(), + } +} + +func (s *Store) SetPruning(p types.PruningOptions) { + s.opts.pruningOptions = p +} + +func (s *Store) GetPruning() types.PruningOptions { + return s.opts.pruningOptions +} + +// Queryable interface below: + +func (s *Store) Query(_ abci.RequestQuery) abci.ResponseQuery { + panic("not implemented") +} + +// StoreWithInitialVersion interface below: + +// SetInitialVersion sets the initial version of the SMT tree. It is used when +// starting a new chain at an arbitrary height. +func (s *Store) SetInitialVersion(version int64) { + s.opts.initialVersion = version +} + +func dataKey(key []byte) []byte { + return append(dataPrefix, key...) +} diff --git a/store/smt/store_test.go b/store/smt/store_test.go new file mode 100644 index 00000000000..a655e4b6d32 --- /dev/null +++ b/store/smt/store_test.go @@ -0,0 +1,43 @@ +package smt_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cosmos/cosmos-sdk/store/smt" + dbm "github.com/tendermint/tm-db" +) + +func TestVersioning(t *testing.T) { + s := smt.NewStore(dbm.NewMemDB()) + expectedVersion := int64(0) + + s.Set([]byte("foo"), []byte("bar")) + cid1 := s.Commit() + expectedVersion++ + + assert.Equal(t, expectedVersion, cid1.Version) + assert.NotEmpty(t, cid1.Hash) + + s.Set([]byte("foobar"), []byte("baz")) + cid2 := s.Commit() + expectedVersion++ + + assert.Equal(t, expectedVersion, cid2.Version) + assert.NotEmpty(t, cid2.Hash) + assert.NotEqual(t, cid1.Hash, cid2.Hash) +} + +func TestInitialVersion(t *testing.T) { + s := smt.NewStore(dbm.NewMemDB()) + expectedVersion := int64(42) + + s.SetInitialVersion(expectedVersion) + + s.Set([]byte("foo"), []byte("foobar")) + cid := s.Commit() + + assert.Equal(t, expectedVersion, cid.Version) + assert.NotEmpty(t, cid.Hash) +} diff --git a/store/types/store.go b/store/types/store.go index 630cd1d040b..a79b7cd5e67 100644 --- a/store/types/store.go +++ b/store/types/store.go @@ -296,6 +296,7 @@ const ( StoreTypeIAVL StoreTypeTransient StoreTypeMemory + StoreTypeSMT ) func (st StoreType) String() string { @@ -314,6 +315,9 @@ func (st StoreType) String() string { case StoreTypeMemory: return "StoreTypeMemory" + + case StoreTypeSMT: + return "StoreTypeSMT" } return "unknown store type"