Skip to content

Commit

Permalink
feat: add ProveClosest method to SMT and SMST (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
h5law authored Sep 25, 2023
1 parent a8abd2d commit c5998f1
Show file tree
Hide file tree
Showing 8 changed files with 400 additions and 1 deletion.
16 changes: 16 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package smt

import "hash"

// Option is a function that configures SparseMerkleTree.
type Option func(*TreeSpec)

Expand All @@ -12,3 +14,17 @@ func WithPathHasher(ph PathHasher) Option {
func WithValueHasher(vh ValueHasher) Option {
return func(ts *TreeSpec) { ts.vh = vh }
}

// NoPrehashSpec returns a new TreeSpec that has a nil Value Hasher and a nil
// Path Hasher
// NOTE: This should only be used when values are already hashed and a path is
// used instead of a key during proof verification, otherwise these will be
// double hashed and produce an incorrect leaf digest invalidating the proof.
func NoPrehashSpec(hasher hash.Hash, sumTree bool) *TreeSpec {
spec := newTreeSpec(hasher, sumTree)
opt := WithPathHasher(newNilPathHasher(hasher.Size()))
opt(&spec)
opt = WithValueHasher(nil)
opt(&spec)
return &spec
}
1 change: 1 addition & 0 deletions proofs.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func init() {
var ErrBadProof = errors.New("bad proof")

// SparseMerkleProof is a Merkle proof for an element in a SparseMerkleTree.
// TODO: Research whether the SiblingData is required and remove it if not
type SparseMerkleProof struct {
// SideNodes is an array of the sibling nodes leading up to the leaf of the proof.
SideNodes [][]byte
Expand Down
21 changes: 21 additions & 0 deletions smst.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,27 @@ func (smst *SMST) Prove(key []byte) (*SparseMerkleProof, error) {
return smst.SMT.Prove(key)
}

// ProveClosest generates a SparseMerkleProof of inclusion for the key
// with the most common bits as the path provided
func (smst *SMST) ProveClosest(path []byte) (
closestPath, closestValueHash []byte,
closestSum uint64,
proof *SparseMerkleProof,
err error,
) {
closestPath, valueHash, proof, err := smst.SMT.ProveClosest(path)
if err != nil {
return nil, nil, 0, nil, err
}
if valueHash == nil {
return closestPath, nil, 0, proof, nil
}
closestValueHash = valueHash[:len(valueHash)-sumSize]
sumBz := valueHash[len(valueHash)-sumSize:]
closestSum = binary.BigEndian.Uint64(sumBz)
return closestPath, closestValueHash, closestSum, proof, nil
}

// Commit persists all dirty nodes in the tree, deletes all orphaned
// nodes from the database and then computes and saves the root hash
func (smst *SMST) Commit() error {
Expand Down
102 changes: 102 additions & 0 deletions smst_proofs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,105 @@ func TestSMST_ProofsSanityCheck(t *testing.T) {
require.NoError(t, smn.Stop())
require.NoError(t, smv.Stop())
}

// ProveClosest test against a visual representation of the tree
// See: https://github.com/pokt-network/smt/assets/53987565/2a2f33e0-f81f-41c5-bd76-af0cd1cd8f15
func TestSMST_ProveClosest(t *testing.T) {
var smn KVStore
var smst *SMST
var proof *SparseMerkleProof
var result bool
var root, closestKey, closestValueHash []byte
var closestSum uint64
var err error

smn, err = NewKVStore("")
require.NoError(t, err)
smst = NewSparseMerkleSumTree(smn, sha256.New(), WithValueHasher(nil))

// insert some unrelated values to populate the tree
require.NoError(t, smst.Update([]byte("foo"), []byte("oof"), 3))
require.NoError(t, smst.Update([]byte("bar"), []byte("rab"), 6))
require.NoError(t, smst.Update([]byte("baz"), []byte("zab"), 9))
require.NoError(t, smst.Update([]byte("bin"), []byte("nib"), 12))
require.NoError(t, smst.Update([]byte("fiz"), []byte("zif"), 15))
require.NoError(t, smst.Update([]byte("fob"), []byte("bof"), 18))
require.NoError(t, smst.Update([]byte("testKey"), []byte("testValue"), 21))
require.NoError(t, smst.Update([]byte("testKey2"), []byte("testValue2"), 24))
require.NoError(t, smst.Update([]byte("testKey3"), []byte("testValue3"), 27))
require.NoError(t, smst.Update([]byte("testKey4"), []byte("testValue4"), 30))

root = smst.Root()

// `testKey2` is the child of an inner node, which is the child of an extension node.
// The extension node has the path bounds of [3, 7]. This means any bits between
// 3-6 can be flipped, and the resulting path would still traverse through the same
// extension node and lead to testKey2 - the closest key. However, flipping bit 7
// will lead to testKey4.
path := sha256.Sum256([]byte("testKey2"))
flipPathBit(path[:], 3)
flipPathBit(path[:], 6)
closestKey, closestValueHash, closestSum, proof, err = smst.ProveClosest(path[:])
require.NoError(t, err)
require.NotEqual(t, proof, &SparseMerkleProof{})

result = VerifySumProof(proof, root, closestKey, closestValueHash, closestSum, NoPrehashSpec(sha256.New(), true))
require.True(t, result)
closestPath := sha256.Sum256([]byte("testKey2"))
require.Equal(t, closestPath[:], closestKey)
require.Equal(t, []byte("testValue2"), closestValueHash)
require.Equal(t, closestSum, uint64(24))

// testKey4 is the neighbour of testKey2, by flipping the final bit of the
// extension node we change the longest common prefix to that of testKey4
path2 := sha256.Sum256([]byte("testKey2"))
flipPathBit(path2[:], 3)
flipPathBit(path2[:], 7)
closestKey, closestValueHash, closestSum, proof, err = smst.ProveClosest(path2[:])
require.NoError(t, err)
require.NotEqual(t, proof, &SparseMerkleProof{})

result = VerifySumProof(proof, root, closestKey, closestValueHash, closestSum, NoPrehashSpec(sha256.New(), true))
require.True(t, result)
closestPath = sha256.Sum256([]byte("testKey4"))
require.Equal(t, closestPath[:], closestKey)
require.Equal(t, []byte("testValue4"), closestValueHash)
require.Equal(t, closestSum, uint64(30))

require.NoError(t, smn.Stop())
}

func TestSMST_ProveClosestEmptyAndOneNode(t *testing.T) {
var smn KVStore
var smst *SMST
var proof *SparseMerkleProof
var err error
var closestPath, closestValueHash []byte
var closestSum uint64

smn, err = NewKVStore("")
require.NoError(t, err)
smst = NewSparseMerkleSumTree(smn, sha256.New(), WithValueHasher(nil))

path := sha256.Sum256([]byte("testKey2"))
flipPathBit(path[:], 3)
flipPathBit(path[:], 6)
closestPath, closestValueHash, closestSum, proof, err = smst.ProveClosest(path[:])
require.NoError(t, err)
require.Equal(t, proof, &SparseMerkleProof{})

result := VerifySumProof(proof, smst.Root(), closestPath, closestValueHash, closestSum, NoPrehashSpec(sha256.New(), true))
require.True(t, result)

require.NoError(t, smst.Update([]byte("foo"), []byte("bar"), 5))
closestPath, closestValueHash, closestSum, proof, err = smst.ProveClosest(path[:])
require.NoError(t, err)
require.Equal(t, proof, &SparseMerkleProof{})
require.Equal(t, closestValueHash, []byte("bar"))
require.Equal(t, closestSum, uint64(5))

result = VerifySumProof(proof, smst.Root(), closestPath, closestValueHash, closestSum, NoPrehashSpec(sha256.New(), true))
require.True(t, result)

require.NoError(t, smn.Stop())
}
133 changes: 133 additions & 0 deletions smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,137 @@ func (smt *SMT) Prove(key []byte) (proof *SparseMerkleProof, err error) {
return proof, nil
}

// ProveClosest generates a SparseMerkleProof of inclusion for the first
// key with the most common bits as the path provided.
//
// This method will follow the path provided until it hits a leaf node and then
// exit. If the leaf is along the path it will produce an inclusion proof for
// the key (and return the key-value internal pair) as they share a common
// prefix. If however, during the tree traversal according to the path, a nil
// node is encountered, the traversal backsteps and flips the path bit for that
// depth (ie tries left if it tried right and vice versa). This guarentees that
// a proof of inclusion is found that has the most common bits with the path
// provided, biased to the longest common prefix
func (smt *SMT) ProveClosest(path []byte) (
closestPath, closestValueHash []byte, // the closest leaf info for the key provided
proof *SparseMerkleProof, // proof of the key-value pair found
err error, // the error value encountered
) {
workingPath := make([]byte, len(path))
copy(workingPath, path)
var siblings []treeNode
var sib treeNode
var parent treeNode
// depthDelta is used to track the depth increase when traversing down the tree
// it is used when back-stepping to go back to the correct depth in the path
// if we hit a nil node during tree traversal
var depthDelta int

node := smt.tree
depth := 0
// continuously traverse the tree until we hit a leaf node
for depth < smt.depth() {
// save current node information as "parent" info
if node != nil {
parent = node
}
// resolve current node
node, err = smt.resolveLazy(node)
if err != nil {
return nil, nil, nil, err
}
if node != nil {
// reset depthDelta if node is non nil
depthDelta = 0
} else {
// if we hit a nil node we backstep to the parent node and flip the
// path bit at the parent depth and select the other child
node, err = smt.resolveLazy(parent)
if err != nil {
return nil, nil, nil, err
}
// trim the last sibling node added as it is no longer relevant
// due to back-stepping we are now going to traverse to the
// most recent sibling, including it here would result in an
// incorrect root hash when calculated
if len(siblings) > 0 {
siblings = siblings[:len(siblings)-1]
}
depth -= depthDelta
// flip the path bit at the parent depth
flipPathBit(workingPath, depth)
}
// end traversal when we hit a leaf node
if _, ok := node.(*leafNode); ok {
break
}
if ext, ok := node.(*extensionNode); ok {
length, match := ext.match(workingPath, depth)
// workingPath from depth to end of extension node's path bounds
// is a perfect match
if !match {
node = ext.expand()
} else {
// extension nodes represent a singly linked list of inner nodes
// add nil siblings to represent the empty neighbours
for i := 0; i < length; i++ {
siblings = append(siblings, nil)
}
depth += length
depthDelta += length
node = ext.child
node, err = smt.resolveLazy(node)
if err != nil {
return nil, nil, nil, err
}
}
}
inner, ok := node.(*innerNode)
if !ok { // this can only happen for an empty tree
break
}
if GetPathBit(workingPath, depth) == left {
node, sib = inner.leftChild, inner.rightChild
} else {
node, sib = inner.rightChild, inner.leftChild
}
siblings = append(siblings, sib)
depth += 1
depthDelta += 1
}

// Retrieve the closest path and value hash if found
if node == nil { // tree was empty
return placeholder(smt.Spec()), nil, &SparseMerkleProof{}, nil
}
leaf, ok := node.(*leafNode)
if !ok {
// if no leaf was found and the tree is not empty something went wrong
panic("expected leaf node")
}
closestPath, closestValueHash = leaf.path, leaf.valueHash
// Hash siblings from bottom up.
var sideNodes [][]byte
for i := range siblings {
var sideNode []byte
sibling := siblings[len(siblings)-i-1]
sideNode = hashNode(smt.Spec(), sibling)
sideNodes = append(sideNodes, sideNode)
}
proof = &SparseMerkleProof{
SideNodes: sideNodes,
}
if sib != nil {
sib, err = smt.resolveLazy(sib)
if err != nil {
return nil, nil, nil, err
}
proof.SiblingData = serialize(smt.Spec(), sib)
}

return closestPath, closestValueHash, proof, nil
}

//nolint:unused
func (smt *SMT) recursiveLoad(hash []byte) (treeNode, error) {
return smt.resolve(hash, smt.recursiveLoad)
Expand Down Expand Up @@ -636,6 +767,8 @@ func (ext *extensionNode) split(path []byte, depth int) (treeNode, *treeNode, in
return head, &b, index
}

// expand returns the inner node that represents the start of the singly
// linked list that this extension node represents
func (ext *extensionNode) expand() treeNode {
last := ext.child
for i := ext.pathEnd() - 1; i >= ext.pathStart(); i-- {
Expand Down
Loading

0 comments on commit c5998f1

Please sign in to comment.