Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ProveClosest method to SMT and SMST #18

Merged
merged 15 commits into from
Sep 25, 2023
16 changes: 16 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package smt

import "hash"

// Option is a function that configures SparseMerkleTree.
type Option func(*TreeSpec)

Expand All @@ -12,3 +14,17 @@ func WithPathHasher(ph PathHasher) Option {
func WithValueHasher(vh ValueHasher) Option {
return func(ts *TreeSpec) { ts.vh = vh }
}

// NoPrehashSpec returns a new TreeSpec that has a nil Value Hasher and a nil
// Path Hasher
// NOTE: This should only be used when values are already hashed and a path is
Olshansk marked this conversation as resolved.
Show resolved Hide resolved
// used instead of a key during proof verification, otherwise these will be
// double hashed and produce an incorrect leaf digest invalidating the proof.
func NoPrehashSpec(hasher hash.Hash, sumTree bool) *TreeSpec {
spec := newTreeSpec(hasher, sumTree)
opt := WithPathHasher(newNilPathHasher(hasher.Size()))
opt(&spec)
opt = WithValueHasher(nil)
opt(&spec)
return &spec
}
1 change: 1 addition & 0 deletions proofs.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func init() {
var ErrBadProof = errors.New("bad proof")

// SparseMerkleProof is a Merkle proof for an element in a SparseMerkleTree.
// TODO: Research whether the SiblingData is required and remove it if not
type SparseMerkleProof struct {
// SideNodes is an array of the sibling nodes leading up to the leaf of the proof.
SideNodes [][]byte
Expand Down
21 changes: 21 additions & 0 deletions smst.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,27 @@
return smst.SMT.Prove(key)
}

// ProveClosest generates a SparseMerkleProof of inclusion for the key
// with the most common bits as the path provided
func (smst *SMST) ProveClosest(path []byte) (
closestPath, closestValueHash []byte,
closestSum uint64,
proof *SparseMerkleProof,
err error,
) {
closestPath, valueHash, proof, err := smst.SMT.ProveClosest(path)
if err != nil {
return nil, nil, 0, nil, err
}

Check warning on line 98 in smst.go

View check run for this annotation

Codecov / codecov/patch

smst.go#L97-L98

Added lines #L97 - L98 were not covered by tests
if valueHash == nil {
return closestPath, nil, 0, proof, nil
}
closestValueHash = valueHash[:len(valueHash)-sumSize]
sumBz := valueHash[len(valueHash)-sumSize:]
closestSum = binary.BigEndian.Uint64(sumBz)
return closestPath, closestValueHash, closestSum, proof, nil
}

// Commit persists all dirty nodes in the tree, deletes all orphaned
// nodes from the database and then computes and saves the root hash
func (smst *SMST) Commit() error {
Expand Down
101 changes: 101 additions & 0 deletions smst_proofs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,104 @@ func TestSMST_ProofsSanityCheck(t *testing.T) {
require.NoError(t, smn.Stop())
require.NoError(t, smv.Stop())
}

// ProveClosest test against a visual representation of the tree
// See: https://github.com/pokt-network/smt/assets/53987565/2a2f33e0-f81f-41c5-bd76-af0cd1cd8f15
func TestSMST_ProveClosest(t *testing.T) {
var smn KVStore
var smst *SMST
var proof *SparseMerkleProof
var result bool
var root, closestKey, closestValueHash []byte
var closestSum uint64
var err error

smn, err = NewKVStore("")
require.NoError(t, err)
smst = NewSparseMerkleSumTree(smn, sha256.New(), WithValueHasher(nil))

// insert some unrelated values to populate the tree
require.NoError(t, smst.Update([]byte("foo"), []byte("oof"), 3))
require.NoError(t, smst.Update([]byte("bar"), []byte("rab"), 6))
require.NoError(t, smst.Update([]byte("baz"), []byte("zab"), 9))
require.NoError(t, smst.Update([]byte("bin"), []byte("nib"), 12))
require.NoError(t, smst.Update([]byte("fiz"), []byte("zif"), 15))
require.NoError(t, smst.Update([]byte("fob"), []byte("bof"), 18))
require.NoError(t, smst.Update([]byte("testKey"), []byte("testValue"), 21))
require.NoError(t, smst.Update([]byte("testKey2"), []byte("testValue2"), 24))
require.NoError(t, smst.Update([]byte("testKey3"), []byte("testValue3"), 27))
require.NoError(t, smst.Update([]byte("testKey4"), []byte("testValue4"), 30))

root = smst.Root()

// testKey2 is the child of an inner node which is the child of an extension node
// the extension node has the path bounds of [3, 7] by flipping these bits we force
h5law marked this conversation as resolved.
Show resolved Hide resolved
// a double backstep to return to avoid nil nodes and find the closest key which is
// then testKey2
path := sha256.Sum256([]byte("testKey2"))
flipPathBit(path[:], 3)
flipPathBit(path[:], 6)
closestKey, closestValueHash, closestSum, proof, err = smst.ProveClosest(path[:])
require.NoError(t, err)
require.NotEqual(t, proof, &SparseMerkleProof{})

result = VerifySumProof(proof, root, closestKey, closestValueHash, closestSum, NoPrehashSpec(sha256.New(), true))
require.True(t, result)
closestPath := sha256.Sum256([]byte("testKey2"))
require.Equal(t, closestPath[:], closestKey)
Olshansk marked this conversation as resolved.
Show resolved Hide resolved
require.Equal(t, []byte("testValue2"), closestValueHash)
require.Equal(t, closestSum, uint64(24))

// testKey4 is the neighbour of testKey2, by flipping the final bit of the
// extension node we change the longest common prefix to that of testKey4
path2 := sha256.Sum256([]byte("testKey2"))
flipPathBit(path2[:], 3)
flipPathBit(path2[:], 7)
closestKey, closestValueHash, closestSum, proof, err = smst.ProveClosest(path2[:])
require.NoError(t, err)
require.NotEqual(t, proof, &SparseMerkleProof{})

result = VerifySumProof(proof, root, closestKey, closestValueHash, closestSum, NoPrehashSpec(sha256.New(), true))
require.True(t, result)
closestPath = sha256.Sum256([]byte("testKey4"))
require.Equal(t, closestPath[:], closestKey)
require.Equal(t, []byte("testValue4"), closestValueHash)
require.Equal(t, closestSum, uint64(30))

require.NoError(t, smn.Stop())
}

func TestSMST_ProveClosestEmptyAndOneNode(t *testing.T) {
var smn KVStore
var smst *SMST
var proof *SparseMerkleProof
var err error
var closestPath, closestValueHash []byte
var closestSum uint64

smn, err = NewKVStore("")
require.NoError(t, err)
smst = NewSparseMerkleSumTree(smn, sha256.New(), WithValueHasher(nil))

path := sha256.Sum256([]byte("testKey2"))
flipPathBit(path[:], 3)
flipPathBit(path[:], 6)
closestPath, closestValueHash, closestSum, proof, err = smst.ProveClosest(path[:])
require.NoError(t, err)
require.Equal(t, proof, &SparseMerkleProof{})

result := VerifySumProof(proof, smst.Root(), closestPath, closestValueHash, closestSum, NoPrehashSpec(sha256.New(), true))
require.True(t, result)

require.NoError(t, smst.Update([]byte("foo"), []byte("bar"), 5))
closestPath, closestValueHash, closestSum, proof, err = smst.ProveClosest(path[:])
require.NoError(t, err)
require.Equal(t, proof, &SparseMerkleProof{})
require.Equal(t, closestValueHash, []byte("bar"))
require.Equal(t, closestSum, uint64(5))

result = VerifySumProof(proof, smst.Root(), closestPath, closestValueHash, closestSum, NoPrehashSpec(sha256.New(), true))
require.True(t, result)

require.NoError(t, smn.Stop())
}
133 changes: 133 additions & 0 deletions smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,137 @@
return proof, nil
}

// ProveClosest generates a SparseMerkleProof of inclusion for the first
// key with the most common bits as the path provided.
//
// This method will follow the path provided until it hits a leaf node and then
// exit. If the leaf is along the path it will produce an inclusion proof for
// the key (and return the key-value internal pair) as they share a common
// prefix. If however, during the tree traversal according to the path, a nil
// node is encountered, the traversal backsteps and flips the path bit for that
// depth (ie tries left if it tried right and vice versa). This guarentees that
// a proof of inclusion is found that has the most common bits with the path
// provided, biased to the longest common prefix
func (smt *SMT) ProveClosest(path []byte) (
Olshansk marked this conversation as resolved.
Show resolved Hide resolved
closestPath, closestValueHash []byte, // the closest leaf info for the key provided
proof *SparseMerkleProof, // proof of the key-value pair found
err error, // the error value encountered
) {
workingPath := make([]byte, len(path))
copy(workingPath, path)
var siblings []treeNode
var sib treeNode
var parent treeNode
// depthDelta is used to track the depth increase when traversing down the tree
h5law marked this conversation as resolved.
Show resolved Hide resolved
// it is used when back-stepping to go back to the correct depth in the path
// if we hit a nil node during tree traversal
var depthDelta int

node := smt.tree
depth := 0
// continuously traverse the tree until we hit a leaf node
for depth < smt.depth() {
h5law marked this conversation as resolved.
Show resolved Hide resolved
// save current node information as "parent" info
if node != nil {
parent = node
}
// resolve current node
node, err = smt.resolveLazy(node)
if err != nil {
return nil, nil, nil, err
}

Check warning on line 418 in smt.go

View check run for this annotation

Codecov / codecov/patch

smt.go#L417-L418

Added lines #L417 - L418 were not covered by tests
if node != nil {
// reset depthDelta if node is non nil
depthDelta = 0
} else {
// if we hit a nil node we backstep to the parent node and flip the
// path bit at the parent depth and select the other child
node, err = smt.resolveLazy(parent)
if err != nil {
return nil, nil, nil, err
}

Check warning on line 428 in smt.go

View check run for this annotation

Codecov / codecov/patch

smt.go#L427-L428

Added lines #L427 - L428 were not covered by tests
// trim the last sibling node added as it is no longer relevant
h5law marked this conversation as resolved.
Show resolved Hide resolved
// due to back-stepping we are now going to traverse to the
// most recent sibling, including it here would result in an
// incorrect root hash when calculated
if len(siblings) > 0 {
siblings = siblings[:len(siblings)-1]
h5law marked this conversation as resolved.
Show resolved Hide resolved
}
depth -= depthDelta
// flip the path bit at the parent depth
flipPathBit(workingPath, depth)
}
// end traversal when we hit a leaf node
if _, ok := node.(*leafNode); ok {
h5law marked this conversation as resolved.
Show resolved Hide resolved
break
}
if ext, ok := node.(*extensionNode); ok {
length, match := ext.match(workingPath, depth)
// workingPath from depth to end of extension node's path bounds
// is a perfect match
if !match {
node = ext.expand()
} else {
// extension nodes represent a singly linked list of inner nodes
// add nil siblings to represent the empty neighbours
for i := 0; i < length; i++ {
h5law marked this conversation as resolved.
Show resolved Hide resolved
siblings = append(siblings, nil)
}
depth += length
depthDelta += length
node = ext.child
node, err = smt.resolveLazy(node)
if err != nil {
return nil, nil, nil, err
}

Check warning on line 462 in smt.go

View check run for this annotation

Codecov / codecov/patch

smt.go#L461-L462

Added lines #L461 - L462 were not covered by tests
}
}
inner, ok := node.(*innerNode)
if !ok { // this can only happen for an empty tree
break
Olshansk marked this conversation as resolved.
Show resolved Hide resolved
}
if GetPathBit(workingPath, depth) == left {
node, sib = inner.leftChild, inner.rightChild
} else {
node, sib = inner.rightChild, inner.leftChild
}
siblings = append(siblings, sib)
depth += 1
depthDelta += 1
}

// Retrieve the closest path and value hash if found
h5law marked this conversation as resolved.
Show resolved Hide resolved
if node == nil { // tree was empty
Olshansk marked this conversation as resolved.
Show resolved Hide resolved
return placeholder(smt.Spec()), nil, &SparseMerkleProof{}, nil
}
leaf, ok := node.(*leafNode)
if !ok {
// if no leaf was found and the tree is not empty something went wrong
panic("expected leaf node")

Check warning on line 486 in smt.go

View check run for this annotation

Codecov / codecov/patch

smt.go#L485-L486

Added lines #L485 - L486 were not covered by tests
}
closestPath, closestValueHash = leaf.path, leaf.valueHash
// Hash siblings from bottom up.
var sideNodes [][]byte
for i := range siblings {
Olshansk marked this conversation as resolved.
Show resolved Hide resolved
var sideNode []byte
sibling := siblings[len(siblings)-i-1]
sideNode = hashNode(smt.Spec(), sibling)
sideNodes = append(sideNodes, sideNode)
}
proof = &SparseMerkleProof{
SideNodes: sideNodes,
}
if sib != nil {
h5law marked this conversation as resolved.
Show resolved Hide resolved
sib, err = smt.resolveLazy(sib)
if err != nil {
return nil, nil, nil, err
}

Check warning on line 504 in smt.go

View check run for this annotation

Codecov / codecov/patch

smt.go#L503-L504

Added lines #L503 - L504 were not covered by tests
proof.SiblingData = serialize(smt.Spec(), sib)
}

return closestPath, closestValueHash, proof, nil
}

//nolint:unused
func (smt *SMT) recursiveLoad(hash []byte) (treeNode, error) {
return smt.resolve(hash, smt.recursiveLoad)
Expand Down Expand Up @@ -636,6 +767,8 @@
return head, &b, index
}

// expand returns the inner node that represents the start of the singly
// linked list that this extension node represents
func (ext *extensionNode) expand() treeNode {
last := ext.child
for i := ext.pathEnd() - 1; i >= ext.pathStart(); i-- {
Expand Down
Loading
Loading