diff --git a/crypto/statetrie/nibbles/nibbles.go b/crypto/statetrie/nibbles/nibbles.go index 90a385fd30..3eafb6bd26 100644 --- a/crypto/statetrie/nibbles/nibbles.go +++ b/crypto/statetrie/nibbles/nibbles.go @@ -25,29 +25,23 @@ import ( type Nibbles []byte const ( - // evenIndicator for serialization when the last nibble in a byte array - // is part of the nibble array. - evenIndicator = 0x01 - // oddIndicator for when it is not. - oddIndicator = 0x03 + // oddIndicator for serialization when the last nibble in a byte array + // is not part of the nibble array. + oddIndicator = 0x01 + // evenIndicator for when it is. + evenIndicator = 0x03 ) // MakeNibbles returns a nibble array from the byte array. If oddLength is true, -// the last 4 bits of the last byte of the array are ignored. -func MakeNibbles(data []byte, oddLength bool) Nibbles { - return Unpack(data, oddLength) -} - -// Unpack the byte array into a nibble array. If oddLength is true, the last 4 -// bits of the last byte of the array are ignored. Allocates a new byte -// slice. +// the last 4 bits of the last byte of the array are ignored. // // [0x12, 0x30], true -> [0x1, 0x2, 0x3] // [0x12, 0x34], false -> [0x1, 0x2, 0x3, 0x4] // [0x12, 0x34], true -> [0x1, 0x2, 0x3] <-- last byte last 4 bits ignored // [], false -> [] // never to be called with [], true -func Unpack(data []byte, oddLength bool) Nibbles { +// Allocates a new byte slice. +func MakeNibbles(data []byte, oddLength bool) Nibbles { length := len(data) * 2 if oddLength { length = length - 1 @@ -142,11 +136,11 @@ func Serialize(nyb Nibbles) (data []byte) { output := make([]byte, length+1) copy(output, p) if h { - // 0x1 is the arbitrary odd length indicator - output[length] = evenIndicator - } else { - // 0x3 is the arbitrary even length indicator + // 0x01 is the odd length indicator output[length] = oddIndicator + } else { + // 0x03 is the even length indicator + output[length] = evenIndicator } return output @@ -159,10 +153,13 @@ func Deserialize(encoding []byte) (Nibbles, error) { if length == 0 { return nil, errors.New("invalid encoding") } - if encoding[length-1] == evenIndicator { - ns = Unpack(encoding[:length-1], true) - } else if encoding[length-1] == oddIndicator { - ns = Unpack(encoding[:length-1], false) + if encoding[length-1] == oddIndicator { + if length == 1 { + return nil, errors.New("invalid encoding") + } + ns = MakeNibbles(encoding[:length-1], true) + } else if encoding[length-1] == evenIndicator { + ns = MakeNibbles(encoding[:length-1], false) } else { return nil, errors.New("invalid encoding") } diff --git a/crypto/statetrie/nibbles/nibbles_test.go b/crypto/statetrie/nibbles/nibbles_test.go index 91cc59487f..ef8ff32c1b 100644 --- a/crypto/statetrie/nibbles/nibbles_test.go +++ b/crypto/statetrie/nibbles/nibbles_test.go @@ -60,17 +60,25 @@ func TestNibblesRandom(t *testing.T) { packed, odd := Pack(nibbles) require.Equal(t, odd, half) require.Equal(t, packed, data) - unpacked := Unpack(packed, odd) + unpacked := MakeNibbles(packed, odd) require.Equal(t, nibbles, unpacked) packed, odd = Pack(nibbles2) require.Equal(t, odd, half) require.Equal(t, packed, data) - unpacked = Unpack(packed, odd) + unpacked = MakeNibbles(packed, odd) require.Equal(t, nibbles2, unpacked) } } +func TestNibblesDeserialize(t *testing.T) { + partitiontest.PartitionTest(t) + t.Parallel() + enc := []byte{0x01} + _, err := Deserialize(enc) + require.Error(t, err, "should return invalid encoding error") +} + func TestNibbles(t *testing.T) { partitiontest.PartitionTest(t) t.Parallel() @@ -121,7 +129,7 @@ func TestNibbles(t *testing.T) { require.Equal(t, oddLength == (len(n)%2 == 1), true) require.Equal(t, bytes.Equal(b, sampleNibblesPacked[i]), true) - unp := Unpack(b, oddLength) + unp := MakeNibbles(b, oddLength) require.Equal(t, bytes.Equal(unp, n), true) }