Skip to content

Commit

Permalink
[Minor] Expose GetPathBit() (#15)
Browse files Browse the repository at this point in the history
## Overview

This PR exposes `GetPathBit()` for use outside of the SMT repo
  • Loading branch information
h5law authored Jun 29, 2023
1 parent dd8ae60 commit f5ef955
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 16 deletions.
4 changes: 2 additions & 2 deletions proofs.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v
node := make([]byte, hashSize(spec))
copy(node, proof.SideNodes[i])

if getPathBit(path, len(proof.SideNodes)-1-i) == left {
if GetPathBit(path, len(proof.SideNodes)-1-i) == left {
currentHash, currentData = digestNode(spec, currentHash, node)
} else {
currentHash, currentData = digestNode(spec, node, currentHash)
Expand Down Expand Up @@ -235,7 +235,7 @@ func DecompactProof(proof *SparseCompactMerkleProof, spec *TreeSpec) (*SparseMer
decompactedSideNodes := make([][]byte, proof.NumSideNodes)
position := 0
for i := 0; i < proof.NumSideNodes; i++ {
if getPathBit(proof.BitMask, i) == 1 {
if GetPathBit(proof.BitMask, i) == 1 {
decompactedSideNodes[i] = placeholder(spec)
} else {
decompactedSideNodes[i] = proof.SideNodes[position]
Expand Down
20 changes: 10 additions & 10 deletions smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (smt *SMT) Get(key []byte) ([]byte, error) {
}
}
inner := (*node).(*innerNode)
if getPathBit(path, depth) == left {
if GetPathBit(path, depth) == left {
node = &inner.leftChild
} else {
node = &inner.rightChild
Expand Down Expand Up @@ -172,7 +172,7 @@ func (smt *SMT) update(
*last = &ext
last = &ext.child
}
if getPathBit(path, prefixlen) == left {
if GetPathBit(path, prefixlen) == left {
*last = &innerNode{leftChild: newLeaf, rightChild: leaf}
} else {
*last = &innerNode{leftChild: leaf, rightChild: newLeaf}
Expand All @@ -195,7 +195,7 @@ func (smt *SMT) update(

inner := node.(*innerNode)
var child *treeNode
if getPathBit(path, depth) == left {
if GetPathBit(path, depth) == left {
child = &inner.leftChild
} else {
child = &inner.rightChild
Expand Down Expand Up @@ -266,7 +266,7 @@ func (smt *SMT) delete(node treeNode, depth int, path []byte, orphans *orphanNod

inner := node.(*innerNode)
var child, sib *treeNode
if getPathBit(path, depth) == left {
if GetPathBit(path, depth) == left {
child, sib = &inner.leftChild, &inner.rightChild
} else {
child, sib = &inner.rightChild, &inner.leftChild
Expand Down Expand Up @@ -335,7 +335,7 @@ func (smt *SMT) Prove(key []byte) (proof *SparseMerkleProof, err error) {
}
}
inner := node.(*innerNode)
if getPathBit(path, depth) == left {
if GetPathBit(path, depth) == left {
node, sib = inner.leftChild, inner.rightChild
} else {
node, sib = inner.rightChild, inner.leftChild
Expand Down Expand Up @@ -558,7 +558,7 @@ func (ext *extensionNode) match(path []byte, depth int) (int, bool) {
panic("depth != path_begin")
}
for i := ext.pathStart(); i < ext.pathEnd(); i++ {
if getPathBit(ext.path, i) != getPathBit(path, i) {
if GetPathBit(ext.path, i) != GetPathBit(path, i) {
return i - ext.pathStart(), false
}
}
Expand All @@ -569,7 +569,7 @@ func (ext *extensionNode) match(path []byte, depth int) (int, bool) {
func (ext *extensionNode) commonPrefix(path []byte) int {
count := 0
for i := ext.pathStart(); i < ext.pathEnd(); i++ {
if getPathBit(ext.path, i) != getPathBit(path, i) {
if GetPathBit(ext.path, i) != GetPathBit(path, i) {
break
}
count++
Expand All @@ -588,8 +588,8 @@ func (ext *extensionNode) split(path []byte, depth int) (treeNode, *treeNode, in
index := ext.pathStart()
var myBit, branchBit int
for ; index < ext.pathEnd(); index++ {
myBit = getPathBit(ext.path, index)
branchBit = getPathBit(path, index)
myBit = GetPathBit(ext.path, index)
branchBit = GetPathBit(path, index)
if myBit != branchBit {
break
}
Expand Down Expand Up @@ -640,7 +640,7 @@ func (ext *extensionNode) expand() treeNode {
last := ext.child
for i := ext.pathEnd() - 1; i >= ext.pathStart(); i-- {
var next innerNode
if getPathBit(ext.path, i) == left {
if GetPathBit(ext.path, i) == left {
next.leftChild = last
} else {
next.rightChild = last
Expand Down
19 changes: 15 additions & 4 deletions utils.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
package smt

// getPathBit gets the bit at an offset from the most significant bit
func getPathBit(data []byte, position int) int {
// GetPathBit gets the bit at an offset from the most significant bit
func GetPathBit(data []byte, position int) int {
// get the byte at the position and then left shift one by the offset of the position
// from the leftmost bit in the byte. Check if the bitwise AND is the same
// Path: []byte{ {0 1 0 1 1 0 1 0}, {0 1 1 0 1 1 0 1}, {1 0 0 1 0 0 1 0} } (length = 24 bits / 3 bytes)
// Position: 13 - 13/8=1
// Path[1] = {0 1 1 0 1 1 0 1}
// uint(13)%8 = 5, 8-1-5=2
// 00000001 << 2 = 00000100
// {0 1 1 0 1 1 0 1}
// & {0 0 0 0 0 1 0 0}
// = {0 0 0 0 0 1 0 0}
// > 0 so Path is on the right at position 13
if int(data[position/8])&(1<<(8-1-uint(position)%8)) > 0 {
return 1
}
Expand All @@ -18,7 +29,7 @@ func setPathBit(data []byte, position int) {
func countSetBits(data []byte) int {
count := 0
for i := 0; i < len(data)*8; i++ {
if getPathBit(data, i) == 1 {
if GetPathBit(data, i) == 1 {
count++
}
}
Expand All @@ -29,7 +40,7 @@ func countSetBits(data []byte) int {
func countCommonPrefix(data1, data2 []byte, from int) int {
count := 0
for i := from; i < len(data1)*8; i++ {
if getPathBit(data1, i) == getPathBit(data2, i) {
if GetPathBit(data1, i) == GetPathBit(data2, i) {
count++
} else {
break
Expand Down

0 comments on commit f5ef955

Please sign in to comment.