diff --git a/proofs.go b/proofs.go index 78bbb1f..2874657 100644 --- a/proofs.go +++ b/proofs.go @@ -3,10 +3,16 @@ package smt import ( "bytes" "encoding/binary" + "encoding/gob" "errors" "math" ) +func init() { + gob.Register(SparseMerkleProof{}) + gob.Register(SparseCompactMerkleProof{}) +} + // ErrBadProof is returned when an invalid Merkle proof is supplied. var ErrBadProof = errors.New("bad proof") @@ -25,6 +31,23 @@ type SparseMerkleProof struct { SiblingData []byte } +// Marshal serialises the SparseMerkleProof to bytes +func (proof *SparseMerkleProof) Marshal() ([]byte, error) { + buf := bytes.NewBuffer(nil) + enc := gob.NewEncoder(buf) + if err := enc.Encode(proof); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// Unmarshal deserialises the SparseMerkleProof from bytes +func (proof *SparseMerkleProof) Unmarshal(bz []byte) error { + buf := bytes.NewBuffer(bz) + dec := gob.NewDecoder(buf) + return dec.Decode(proof) +} + func (proof *SparseMerkleProof) sanityCheck(spec *TreeSpec) bool { // Do a basic sanity check on the proof, so that a malicious proof cannot // cause the verifier to fatally exit (e.g. due to an index out-of-range @@ -78,6 +101,23 @@ type SparseCompactMerkleProof struct { SiblingData []byte } +// Marshal serialises the SparseCompactMerkleProof to bytes +func (proof *SparseCompactMerkleProof) Marshal() ([]byte, error) { + buf := bytes.NewBuffer(nil) + enc := gob.NewEncoder(buf) + if err := enc.Encode(proof); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// Unmarshal deserialises the SparseCompactMerkleProof from bytes +func (proof *SparseCompactMerkleProof) Unmarshal(bz []byte) error { + buf := bytes.NewBuffer(bz) + dec := gob.NewDecoder(buf) + return dec.Decode(proof) +} + func (proof *SparseCompactMerkleProof) sanityCheck(spec *TreeSpec) bool { // Do a basic sanity check on the proof on the fields of the proof specific to // the compact proof only. diff --git a/proofs_test.go b/proofs_test.go index 5451ab3..4ab9fd7 100644 --- a/proofs_test.go +++ b/proofs_test.go @@ -3,9 +3,168 @@ package smt import ( "bytes" "crypto/rand" + "crypto/sha256" "testing" + + "github.com/stretchr/testify/require" ) +func TestSparseMerkleProof_Marshal(t *testing.T) { + tree := setupTree(t) + + proof, err := tree.Prove([]byte("key")) + require.NoError(t, err) + bz, err := proof.Marshal() + require.NoError(t, err) + require.NotNil(t, bz) + require.Greater(t, len(bz), 0) + + proof2, err := tree.Prove([]byte("key2")) + require.NoError(t, err) + bz2, err := proof2.Marshal() + require.NoError(t, err) + require.NotNil(t, bz2) + require.Greater(t, len(bz2), 0) + require.NotEqual(t, bz, bz2) + + proof3 := randomiseProof(proof) + bz3, err := proof3.Marshal() + require.NoError(t, err) + require.NotNil(t, bz3) + require.Greater(t, len(bz3), 0) + require.NotEqual(t, bz, bz3) +} + +func TestSparseMerkleProof_Unmarshal(t *testing.T) { + tree := setupTree(t) + + proof, err := tree.Prove([]byte("key")) + require.NoError(t, err) + bz, err := proof.Marshal() + require.NoError(t, err) + require.NotNil(t, bz) + require.Greater(t, len(bz), 0) + uproof := new(SparseMerkleProof) + require.NoError(t, uproof.Unmarshal(bz)) + require.Equal(t, proof, uproof) + + proof2, err := tree.Prove([]byte("key2")) + require.NoError(t, err) + bz2, err := proof2.Marshal() + require.NoError(t, err) + require.NotNil(t, bz2) + require.Greater(t, len(bz2), 0) + uproof2 := new(SparseMerkleProof) + require.NoError(t, uproof2.Unmarshal(bz2)) + require.Equal(t, proof2, uproof2) + + proof3 := randomiseProof(proof) + bz3, err := proof3.Marshal() + require.NoError(t, err) + require.NotNil(t, bz3) + require.Greater(t, len(bz3), 0) + uproof3 := new(SparseMerkleProof) + require.NoError(t, uproof3.Unmarshal(bz3)) + require.Equal(t, proof3, uproof3) +} + +func TestSparseCompactMerkletProof_Marshal(t *testing.T) { + tree := setupTree(t) + + proof, err := tree.Prove([]byte("key")) + require.NoError(t, err) + compactProof, err := CompactProof(proof, tree.Spec()) + require.NoError(t, err) + bz, err := compactProof.Marshal() + require.NoError(t, err) + require.NotNil(t, bz) + require.Greater(t, len(bz), 0) + + proof2, err := tree.Prove([]byte("key2")) + require.NoError(t, err) + compactProof2, err := CompactProof(proof2, tree.Spec()) + require.NoError(t, err) + bz2, err := compactProof2.Marshal() + require.NoError(t, err) + require.NotNil(t, bz2) + require.Greater(t, len(bz2), 0) + require.NotEqual(t, bz, bz2) + + proof3 := randomiseProof(proof) + compactProof3, err := CompactProof(proof3, tree.Spec()) + require.NoError(t, err) + bz3, err := compactProof3.Marshal() + require.NoError(t, err) + require.NotNil(t, bz3) + require.Greater(t, len(bz3), 0) + require.NotEqual(t, bz, bz3) +} + +func TestSparseCompactMerkleProof_Unmarshal(t *testing.T) { + tree := setupTree(t) + + proof, err := tree.Prove([]byte("key")) + require.NoError(t, err) + compactProof, err := CompactProof(proof, tree.Spec()) + require.NoError(t, err) + bz, err := compactProof.Marshal() + require.NoError(t, err) + require.NotNil(t, bz) + require.Greater(t, len(bz), 0) + uCproof := new(SparseCompactMerkleProof) + require.NoError(t, uCproof.Unmarshal(bz)) + require.Equal(t, compactProof, uCproof) + uproof, err := DecompactProof(uCproof, tree.Spec()) + require.NoError(t, err) + require.Equal(t, proof, uproof) + + proof2, err := tree.Prove([]byte("key2")) + require.NoError(t, err) + compactProof2, err := CompactProof(proof2, tree.Spec()) + require.NoError(t, err) + bz2, err := compactProof2.Marshal() + require.NoError(t, err) + require.NotNil(t, bz2) + require.Greater(t, len(bz2), 0) + uCproof2 := new(SparseCompactMerkleProof) + require.NoError(t, uCproof2.Unmarshal(bz2)) + require.Equal(t, compactProof2, uCproof2) + uproof2, err := DecompactProof(uCproof2, tree.Spec()) + require.NoError(t, err) + require.Equal(t, proof2, uproof2) + + proof3 := randomiseProof(proof) + compactProof3, err := CompactProof(proof3, tree.Spec()) + require.NoError(t, err) + bz3, err := compactProof3.Marshal() + require.NoError(t, err) + require.NotNil(t, bz3) + require.Greater(t, len(bz3), 0) + uCproof3 := new(SparseCompactMerkleProof) + require.NoError(t, uCproof3.Unmarshal(bz3)) + require.Equal(t, compactProof3, uCproof3) + uproof3, err := DecompactProof(uCproof3, tree.Spec()) + require.NoError(t, err) + require.Equal(t, proof3, uproof3) +} + +func setupTree(t *testing.T) *SMT { + t.Helper() + + db, err := NewKVStore("") + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, db.Stop()) + }) + + tree := NewSparseMerkleTree(db, sha256.New()) + require.NoError(t, tree.Update([]byte("key"), []byte("value"))) + require.NoError(t, tree.Update([]byte("key2"), []byte("value2"))) + require.NoError(t, tree.Update([]byte("key3"), []byte("value3"))) + + return tree +} + func randomiseProof(proof *SparseMerkleProof) *SparseMerkleProof { sideNodes := make([][]byte, len(proof.SideNodes)) for i := range sideNodes {