Skip to content

Commit

Permalink
chore: address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
h5law committed Sep 24, 2023
1 parent 62073eb commit b4dff71
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 99 deletions.
3 changes: 2 additions & 1 deletion options.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ func WithValueHasher(vh ValueHasher) Option {
return func(ts *TreeSpec) { ts.vh = vh }
}

// NoPrehashSpec returns a new TreeSpec that has a nil Value and Path Hasher
// NoPrehashSpec returns a new TreeSpec that has a nil Value Hasher and a nil
// Path Hasher
// NOTE: This should only be used when values are already hashed and a path is
// used instead of a key during proof verification, otherwise these will be
// double hashed and produce an incorrect leaf digest invalidating the proof.
Expand Down
3 changes: 3 additions & 0 deletions smst.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ func (smst *SMST) ProveClosest(path []byte) (
if err != nil {
return nil, nil, 0, nil, err
}

Check warning on line 98 in smst.go

View check run for this annotation

Codecov / codecov/patch

smst.go#L97-L98

Added lines #L97 - L98 were not covered by tests
if valueHash == nil {
return closestPath, nil, 0, proof, nil
}
closestValueHash = valueHash[:len(valueHash)-sumSize]
sumBz := valueHash[len(valueHash)-sumSize:]
closestSum = binary.BigEndian.Uint64(sumBz)
Expand Down
97 changes: 29 additions & 68 deletions smst_proofs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,73 +175,9 @@ func TestSMST_ProofsSanityCheck(t *testing.T) {
require.NoError(t, smv.Stop())
}

func TestSMST_ProveClosest(t *testing.T) {
var smn KVStore
var smst *SMST
var proof *SparseMerkleProof
var result bool
var root, closestKey, closestValueHash []byte
var closestSum uint64
var err error

smn, err = NewKVStore("")
require.NoError(t, err)
smst = NewSparseMerkleSumTree(smn, sha256.New())

// insert random values
require.NoError(t, smst.Update([]byte("foo"), []byte("bar"), 5))
require.NoError(t, smst.Update([]byte("baz"), []byte("bin"), 5))
require.NoError(t, smst.Update([]byte("testKey"), []byte("testValue"), 5))
require.NoError(t, smst.Update([]byte("testKey2"), []byte("testValue"), 5))
require.NoError(t, smst.Update([]byte("testKey3"), []byte("testValue"), 5))
require.NoError(t, smst.Update([]byte("testKey4"), []byte("testValue"), 5))
// insert testing values that are similar
require.NoError(t, smst.Update([]byte("jackfruit"), []byte("testValue1"), 7))
require.NoError(t, smst.Update([]byte("xwordA188wordB110"), []byte("testValue2"), 9)) // shares 2 bytes with jackfruit
require.NoError(t, smst.Update([]byte("3xwordA250wordB7"), []byte("testValue3"), 11)) // shares 3 bytes with jackfruit

root = smst.Root()

path := sha256.Sum256([]byte("jackfruit"))
flipPathBit(path[:], 245)
closestKey, closestValueHash, closestSum, proof, err = smst.ProveClosest(path[:])
require.NoError(t, err)
require.NotEqual(t, proof, &SparseMerkleProof{})

result = VerifySumProof(proof, root, closestKey, closestValueHash, closestSum, NoPrehashSpec(sha256.New(), true))
require.True(t, result)
closestPath := sha256.Sum256([]byte("jackfruit"))
require.Equal(t, closestPath[:], closestKey)
require.Equal(t, closestSum, uint64(7))

path = sha256.Sum256([]byte("xwordA188wordB110"))
flipPathBit(path[:], 245)
closestKey, closestValueHash, closestSum, proof, err = smst.ProveClosest(path[:])
require.NoError(t, err)
require.NotEqual(t, proof, &SparseMerkleProof{})

result = VerifySumProof(proof, root, closestKey, closestValueHash, closestSum, NoPrehashSpec(sha256.New(), true))
require.True(t, result)
closestPath = sha256.Sum256([]byte("xwordA188wordB110"))
require.Equal(t, closestPath[:], closestKey)
require.Equal(t, closestSum, uint64(9))

path = sha256.Sum256([]byte("3xwordA250wordB7"))
flipPathBit(path[:], 245)
closestKey, closestValueHash, closestSum, proof, err = smst.ProveClosest(path[:])
require.NoError(t, err)
require.NotEqual(t, proof, &SparseMerkleProof{})

result = VerifySumProof(proof, root, closestKey, closestValueHash, closestSum, NoPrehashSpec(sha256.New(), true))
require.True(t, result)
closestPath = sha256.Sum256([]byte("3xwordA250wordB7"))
require.Equal(t, closestPath[:], closestKey)
require.Equal(t, closestSum, uint64(11))
}

// ProveClosest test against a visual representation of the tree
// See: https://i.imgur.com/cPJObIy.png
func TestSMST_ProveClosestFromVisual(t *testing.T) {
func TestSMST_ProveClosest(t *testing.T) {
var smn KVStore
var smst *SMST
var proof *SparseMerkleProof
Expand All @@ -254,7 +190,7 @@ func TestSMST_ProveClosestFromVisual(t *testing.T) {
require.NoError(t, err)
smst = NewSparseMerkleSumTree(smn, sha256.New(), WithValueHasher(nil))

// insert random values
// insert some unrelated values to populate the tree
require.NoError(t, smst.Update([]byte("foo"), []byte("oof"), 3))
require.NoError(t, smst.Update([]byte("bar"), []byte("rab"), 6))
require.NoError(t, smst.Update([]byte("baz"), []byte("zab"), 9))
Expand Down Expand Up @@ -283,10 +219,11 @@ func TestSMST_ProveClosestFromVisual(t *testing.T) {
require.True(t, result)
closestPath := sha256.Sum256([]byte("testKey2"))
require.Equal(t, closestPath[:], closestKey)
require.Equal(t, []byte("testValue2"), closestValueHash)
require.Equal(t, closestSum, uint64(24))

// testValue4 is the neighbour of testValue2, by flipping the final bit of the
// extension node we change the longest common prefix to that of testValue4
// testKey4 is the neighbour of testKey2, by flipping the final bit of the
// extension node we change the longest common prefix to that of testKey4
path2 := sha256.Sum256([]byte("testKey2"))
flipPathBit(path2[:], 3)
flipPathBit(path2[:], 7)
Expand All @@ -298,5 +235,29 @@ func TestSMST_ProveClosestFromVisual(t *testing.T) {
require.True(t, result)
closestPath = sha256.Sum256([]byte("testKey4"))
require.Equal(t, closestPath[:], closestKey)
require.Equal(t, []byte("testValue4"), closestValueHash)
require.Equal(t, closestSum, uint64(30))
}

func TestSMST_ProveClosestEmpty(t *testing.T) {
var smn KVStore
var smst *SMST
var proof *SparseMerkleProof
var err error
var closestPath, closestValueHash []byte
var closestSum uint64

smn, err = NewKVStore("")
require.NoError(t, err)
smst = NewSparseMerkleSumTree(smn, sha256.New(), WithValueHasher(nil))

path := sha256.Sum256([]byte("testKey2"))
flipPathBit(path[:], 3)
flipPathBit(path[:], 6)
closestPath, closestValueHash, closestSum, proof, err = smst.ProveClosest(path[:])
require.NoError(t, err)
require.Equal(t, proof, &SparseMerkleProof{})

result := VerifySumProof(proof, smst.Root(), closestPath, closestValueHash, closestSum, NoPrehashSpec(sha256.New(), true))
require.True(t, result)
}
24 changes: 20 additions & 4 deletions smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,9 @@ func (smt *SMT) ProveClosest(path []byte) (
var siblings []treeNode
var sib treeNode
var parent treeNode
// depthDelta is used to track the depth increase when traversing down the tree
// it is used when back-stepping to go back to the correct depth in the path
// if we hit a nil node during tree traversal
var depthDelta int

node := smt.tree
Expand All @@ -421,13 +424,15 @@ func (smt *SMT) ProveClosest(path []byte) (
return nil, nil, nil, err
}

Check warning on line 425 in smt.go

View check run for this annotation

Codecov / codecov/patch

smt.go#L424-L425

Added lines #L424 - L425 were not covered by tests
// trim the last sibling node added as it is no longer relevant
// due to back-stepping
if len(siblings) > 0 {
siblings = siblings[:len(siblings)-1]
}
depth -= depthDelta
// flip the path bit at the parent depth
flipPathBit(workingPath, depth)
} else {
// reset depthDelta if node is non nil
depthDelta = 0
}
// end traversal when we hit a leaf node
Expand All @@ -436,7 +441,11 @@ func (smt *SMT) ProveClosest(path []byte) (
}
if ext, ok := node.(*extensionNode); ok {
length, match := ext.match(workingPath, depth)
// workingPath from depth to end of extension node's path bounds
// is a perfect match
if match {
// extension nodes represent a singly linked list of inner nodes
// add nil siblings to represent the empty neighbours
for i := 0; i < length; i++ {
siblings = append(siblings, nil)
}
Expand All @@ -451,7 +460,10 @@ func (smt *SMT) ProveClosest(path []byte) (
node = ext.expand()
}
}
inner := node.(*innerNode)
inner, ok := node.(*innerNode)
if !ok { // this can only happen for an empty tree
break
}
if GetPathBit(workingPath, depth) == left {
node, sib = inner.leftChild, inner.rightChild
} else {
Expand All @@ -463,10 +475,14 @@ func (smt *SMT) ProveClosest(path []byte) (
}

// Retrieve the closest path and value hash if found
if node == nil {
panic("no leaf node found")
if node == nil { // tree was empty
return placeholder(smt.Spec()), nil, &SparseMerkleProof{}, nil
}
leaf, ok := node.(*leafNode)
if !ok {
// if no leaf was found and the tree is not empty something went wrong
panic("expected leaf node")

Check warning on line 484 in smt.go

View check run for this annotation

Codecov / codecov/patch

smt.go#L483-L484

Added lines #L483 - L484 were not covered by tests
}
leaf := node.(*leafNode)
closestPath, closestValueHash = leaf.path, leaf.valueHash
// Hash siblings from bottom up.
var sideNodes [][]byte
Expand Down
73 changes: 49 additions & 24 deletions smt_proofs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ func TestSMT_ProofsSanityCheck(t *testing.T) {
require.NoError(t, smv.Stop())
}

// ProveClosest test against a visual representation of the tree
// See: https://i.imgur.com/ux4fchQ.png
func TestSMT_ProveClosest(t *testing.T) {
var smn KVStore
var smt *SMT
Expand All @@ -166,50 +168,73 @@ func TestSMT_ProveClosest(t *testing.T) {

smn, err = NewKVStore("")
require.NoError(t, err)
smt = NewSparseMerkleTree(smn, sha256.New())
smt = NewSparseMerkleTree(smn, sha256.New(), WithValueHasher(nil))

require.NoError(t, smt.Update([]byte("foo"), []byte("bar")))
require.NoError(t, smt.Update([]byte("baz"), []byte("bin")))
// insert some unrelated values to populate the tree
require.NoError(t, smt.Update([]byte("foo"), []byte("oof")))
require.NoError(t, smt.Update([]byte("bar"), []byte("rab")))
require.NoError(t, smt.Update([]byte("baz"), []byte("zab")))
require.NoError(t, smt.Update([]byte("bin"), []byte("nib")))
require.NoError(t, smt.Update([]byte("fiz"), []byte("zif")))
require.NoError(t, smt.Update([]byte("fob"), []byte("bof")))
require.NoError(t, smt.Update([]byte("testKey"), []byte("testValue")))
require.NoError(t, smt.Update([]byte("testKey2"), []byte("testValue")))
require.NoError(t, smt.Update([]byte("testKey3"), []byte("testValue")))
require.NoError(t, smt.Update([]byte("testKey4"), []byte("testValue")))
require.NoError(t, smt.Update([]byte("jackfruit"), []byte("testValue1")))
require.NoError(t, smt.Update([]byte("xwordA188wordB110"), []byte("testValue2"))) // shares 2 bytes with jackfruit
require.NoError(t, smt.Update([]byte("3xwordA250wordB7"), []byte("testValue3"))) // shares 3 bytes with jackfruit
require.NoError(t, smt.Update([]byte("testKey2"), []byte("testValue2")))
require.NoError(t, smt.Update([]byte("testKey3"), []byte("testValue3")))
require.NoError(t, smt.Update([]byte("testKey4"), []byte("testValue4")))

root = smt.Root()

// path = sha256.Sum256([]byte("jackfruit")) change 31st byte
path := []byte{41, 6, 1, 10, 203, 50, 121, 247, 169, 26, 77, 72, 87, 57, 82, 212, 73, 144, 141, 22, 59, 188, 178, 245, 109, 126, 84, 65, 227, 237, 79, 24}
closestKey, closestValueHash, proof, err = smt.ProveClosest(path)
// testKey2 is the child of an inner node which is the child of an extension node
// the extension node has the path bounds of [3, 7] by flipping these bits we force
// a double backstep to return to avoid nil nodes and find the closest key which is
// then testKey2
path := sha256.Sum256([]byte("testKey2"))
flipPathBit(path[:], 3)
flipPathBit(path[:], 6)
closestKey, closestValueHash, proof, err = smt.ProveClosest(path[:])
require.NoError(t, err)
require.NotEqual(t, proof, &SparseMerkleProof{})

result = VerifyProof(proof, root, closestKey, closestValueHash, NoPrehashSpec(sha256.New(), false))
require.True(t, result)
closestPath := sha256.Sum256([]byte("jackfruit"))
closestPath := sha256.Sum256([]byte("testKey2"))
require.Equal(t, closestPath[:], closestKey)
require.Equal(t, []byte("testValue2"), closestValueHash)

// path = sha256.Sum256([]byte("xwordA188wordB110")) change 31st byte
path = []byte{41, 6, 225, 86, 245, 213, 11, 141, 147, 82, 197, 13, 172, 115, 91, 244, 178, 217, 50, 38, 13, 171, 111, 56, 92, 209, 246, 148, 130, 113, 41, 171}
closestKey, closestValueHash, proof, err = smt.ProveClosest(path)
// testKey4 is the neighbour of testKey2, by flipping the final bit of the
// extension node we change the longest common prefix to that of testKey4
path2 := sha256.Sum256([]byte("testKey2"))
flipPathBit(path2[:], 3)
flipPathBit(path2[:], 7)
closestKey, closestValueHash, proof, err = smt.ProveClosest(path2[:])
require.NoError(t, err)
require.NotEqual(t, proof, &SparseMerkleProof{})

result = VerifyProof(proof, root, closestKey, closestValueHash, NoPrehashSpec(sha256.New(), false))
require.True(t, result)
closestPath = sha256.Sum256([]byte("xwordA188wordB110"))
closestPath = sha256.Sum256([]byte("testKey4"))
require.Equal(t, closestPath[:], closestKey)
require.Equal(t, []byte("testValue4"), closestValueHash)
}

func TestSMT_ProveClosestEmpty(t *testing.T) {
var smn KVStore
var smt *SMT
var proof *SparseMerkleProof
var err error
var closestPath, closestValueHash []byte

// path = sha256.Sum256([]byte("3xwordA250wordB7")) change 31st byte
path = []byte{41, 6, 1, 143, 12, 89, 247, 69, 112, 85, 218, 99, 54, 231, 83, 27, 84, 188, 130, 159, 60, 1, 56, 183, 107, 147, 173, 155, 104, 55, 61, 190}
closestKey, closestValueHash, proof, err = smt.ProveClosest(path)
smn, err = NewKVStore("")
require.NoError(t, err)
require.NotEqual(t, proof, &SparseMerkleProof{})
smt = NewSparseMerkleTree(smn, sha256.New(), WithValueHasher(nil))

result = VerifyProof(proof, root, closestKey, closestValueHash, NoPrehashSpec(sha256.New(), false))
path := sha256.Sum256([]byte("testKey2"))
flipPathBit(path[:], 3)
flipPathBit(path[:], 6)
closestPath, closestValueHash, proof, err = smt.ProveClosest(path[:])
require.NoError(t, err)
require.Equal(t, proof, &SparseMerkleProof{})

result := VerifyProof(proof, smt.Root(), closestPath, closestValueHash, NoPrehashSpec(sha256.New(), false))
require.True(t, result)
closestPath = sha256.Sum256([]byte("3xwordA250wordB7"))
require.Equal(t, closestPath[:], closestKey)
}
6 changes: 4 additions & 2 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ func GetPathBit(data []byte, position int) int {
return 0
}

// setPathBit sets the bit at an offset from the most significant bit
// setPathBit sets the bit at an offset (see position) in the data
// provided relative to the most significant bit
func setPathBit(data []byte, position int) {
n := int(data[position/8])
n |= 1 << (8 - 1 - uint(position)%8)
data[position/8] = byte(n)
}

// flipPathBit flips the bit at an offset from the most significant bit
// flipPathBit flips the bit at an offset (see position) in the data
// provided relative to most significant bit
func flipPathBit(data []byte, position int) {
n := int(data[position/8]) // get index of byte containing the position
n ^= 1 << (8 - 1 - uint(position)%8) // XOR the bit within the byte at the position
Expand Down

0 comments on commit b4dff71

Please sign in to comment.