Skip to content

Commit

Permalink
refactoring + repo structure improvements (#35)
Browse files Browse the repository at this point in the history
* improve structure a bit

* refactor
  • Loading branch information
jleni authored Sep 1, 2024
1 parent 9251957 commit 6981a36
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 100 deletions.
19 changes: 19 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates

version: 2
updates:
# Enable version updates for npm
- package-ecosystem: 'gomod'
# Look for `package.json` and `lock` files in the `root` directory
directory: '/'
# Check the npm registry for updates every day (weekdays)
schedule:
interval: 'daily'
commit-message:
prefix: 'chore'
prefix-development: 'chore'
include: 'scope'
target-branch: dev
35 changes: 35 additions & 0 deletions .github/workflows/checks.golang.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#
# Generated by @zondax/cli
#
name: Checks

on:
push:
branches: [ main, dev ]
pull_request:
branches: [ main, dev ]

jobs:
checks:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: true
- uses: actions/setup-go@v3
with:
go-version: '1.21'
- name: Build
run: |
make build
- name: ModTidy check
run: make check-modtidy
- name: Lint check
run: |
export PATH=$PATH:$(go env GOPATH)/bin
make install_lint
make lint
- name: Run tests
run: |
make test
127 changes: 60 additions & 67 deletions apduWrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,21 @@ import (
"github.com/pkg/errors"
)

const (
MinPacketSize = 3
TagValue = 0x05
)

var codec = binary.BigEndian

var (
ErrPacketSize = errors.New("packet size must be at least 3")
ErrInvalidChannel = errors.New("invalid channel")
ErrInvalidTag = errors.New("invalid tag")
ErrWrongSequenceIdx = errors.New("wrong sequenceIdx")
)

// ErrorMessage returns a human-readable error message for a given APDU error code.
func ErrorMessage(errorCode uint16) string {
switch errorCode {
// FIXME: Code and description don't match for 0x6982 and 0x6983 based on
Expand Down Expand Up @@ -63,164 +76,144 @@ func ErrorMessage(errorCode uint16) string {
}
}

// SerializePacket serializes a command into a packet for transmission.
func SerializePacket(
channel uint16,
command []byte,
packetSize int,
sequenceIdx uint16) (result []byte, offset int, err error) {
sequenceIdx uint16) ([]byte, int, error) {

if packetSize < 3 {
return nil, 0, errors.New("Packet size must be at least 3")
return nil, 0, ErrPacketSize
}

var headerOffset uint8
headerOffset := 5
if sequenceIdx == 0 {
headerOffset += 2
}

result = make([]byte, packetSize)
var buffer = result
result := make([]byte, packetSize)
buffer := result

// Insert channel (2 bytes)
codec.PutUint16(buffer, channel)
headerOffset += 2

// Insert tag (1 byte)
buffer[headerOffset] = 0x05
headerOffset += 1

var commandLength uint16
commandLength = uint16(len(command))
buffer[2] = 0x05

// Insert sequenceIdx (2 bytes)
codec.PutUint16(buffer[headerOffset:], sequenceIdx)
headerOffset += 2
codec.PutUint16(buffer[3:], sequenceIdx)

// Only insert total size of the command in the first package
if sequenceIdx == 0 {
// Insert sequenceIdx (2 bytes)
codec.PutUint16(buffer[headerOffset:], commandLength)
headerOffset += 2
commandLength := uint16(len(command))
codec.PutUint16(buffer[5:], commandLength)
}

buffer = buffer[headerOffset:]
offset = copy(buffer, command)
offset := copy(buffer[headerOffset:], command)
return result, offset, nil
}

// DeserializePacket deserializes a packet into its original command.
func DeserializePacket(
channel uint16,
buffer []byte,
sequenceIdx uint16) (result []byte, totalResponseLength uint16, isSequenceZero bool, err error) {
packet []byte,
sequenceIdx uint16) ([]byte, uint16, bool, error) {

isSequenceZero = false
const (
minFirstPacketSize = 7
minPacketSize = 5
tag = 0x05
)

if (sequenceIdx == 0 && len(buffer) < 7) || (sequenceIdx > 0 && len(buffer) < 5) {
return nil, 0, isSequenceZero, errors.New("Cannot deserialize the packet. Header information is missing.")
if (sequenceIdx == 0 && len(packet) < minFirstPacketSize) || (sequenceIdx > 0 && len(packet) < minPacketSize) {
return nil, 0, false, errors.New("cannot deserialize the packet. header information is missing")
}

var headerOffset uint8
headerOffset := 2

if codec.Uint16(buffer) != channel {
return nil, 0, isSequenceZero, errors.New(fmt.Sprintf("Invalid channel. Expected %d, Got: %d", channel, codec.Uint16(buffer)))
if codec.Uint16(packet) != channel {
return nil, 0, false, fmt.Errorf("%w: expected %d, got %d", ErrInvalidChannel, channel, codec.Uint16(packet))
}
headerOffset += 2

if buffer[headerOffset] != 0x05 {
return nil, 0, isSequenceZero, errors.New(fmt.Sprintf("Invalid tag. Expected %d, Got: %d", 0x05, buffer[headerOffset]))
if packet[headerOffset] != tag {
return nil, 0, false, fmt.Errorf("invalid tag. expected %d, got %d", tag, packet[headerOffset])
}
headerOffset++

foundSequenceIdx := codec.Uint16(buffer[headerOffset:])
if foundSequenceIdx == 0 {
isSequenceZero = true
} else {
isSequenceZero = false
}
foundSequenceIdx := codec.Uint16(packet[headerOffset:])
isSequenceZero := foundSequenceIdx == 0

if foundSequenceIdx != sequenceIdx {
return nil, 0, isSequenceZero, errors.New(fmt.Sprintf("Wrong sequenceIdx. Expected %d, Got: %d", sequenceIdx, foundSequenceIdx))
return nil, 0, isSequenceZero, fmt.Errorf("wrong sequenceIdx: expected %d, got %d", sequenceIdx, foundSequenceIdx)
}
headerOffset += 2

var totalResponseLength uint16
if sequenceIdx == 0 {
totalResponseLength = codec.Uint16(buffer[headerOffset:])
totalResponseLength = codec.Uint16(packet[headerOffset:])
headerOffset += 2
}

result = make([]byte, len(buffer)-int(headerOffset))
copy(result, buffer[headerOffset:])

result := packet[headerOffset:]
return result, totalResponseLength, isSequenceZero, nil
}

// WrapCommandAPDU turns the command into a sequence of 64 byte packets
// WrapCommandAPDU turns the command into a sequence of packets of specified size.
func WrapCommandAPDU(
channel uint16,
command []byte,
packetSize int) (result []byte, err error) {
packetSize int) ([]byte, error) {

var offset int
var totalResult []byte
var sequenceIdx uint16

for len(command) > 0 {
result, offset, err = SerializePacket(channel, command, packetSize, sequenceIdx)
packet, offset, err := SerializePacket(channel, command, packetSize, sequenceIdx)
if err != nil {
return nil, err
}
command = command[offset:]
totalResult = append(totalResult, result...)
totalResult = append(totalResult, packet...)
sequenceIdx++
}

return totalResult, nil
}

// UnwrapResponseAPDU parses a response of 64 byte packets into the real data
// UnwrapResponseAPDU parses a response of 64 byte packets into the real data.
func UnwrapResponseAPDU(channel uint16, pipe <-chan []byte, packetSize int) ([]byte, error) {
var sequenceIdx uint16

var totalResult []byte
var totalSize uint16
var done = false

// return values from DeserializePacket
var result []byte
var responseSize uint16
var err error

foundZeroSequence := false
isSequenceZero := false
var foundZeroSequence bool

for !done {
// Read next packet from the channel
buffer := <-pipe

result, responseSize, isSequenceZero, err = DeserializePacket(channel, buffer, sequenceIdx) // this may fail if the wrong sequence arrives (espeically if left over all 0000 was in the buffer from the last tx)
for buffer := range pipe {
result, responseSize, isSequenceZero, err := DeserializePacket(channel, buffer, sequenceIdx)
if err != nil {
return nil, err
}

// Recover from a known error condition:
// * Discard messages left over from previous exchange until isSequenceZero == true
if foundZeroSequence == false && isSequenceZero == false {
if !foundZeroSequence && !isSequenceZero {
continue
}
foundZeroSequence = true

// Initialize totalSize (previously we did this if sequenceIdx == 0, but sometimes Nano X can provide the first sequenceIdx == 0 packet with all zeros, then a useful packet with sequenceIdx == 1
// Initialize totalSize
if totalSize == 0 {
totalSize = responseSize
}

buffer = buffer[packetSize:]
totalResult = append(totalResult, result...)
sequenceIdx++

if len(totalResult) >= int(totalSize) {
done = true
break
}
}

// Remove trailing zeros
totalResult = totalResult[:totalSize]
return totalResult, nil
return totalResult[:totalSize], nil
}
37 changes: 32 additions & 5 deletions apduWrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package ledger_go

import (
"bytes"
"fmt"
"math"
"testing"
"unsafe"
Expand Down Expand Up @@ -217,7 +218,7 @@ func Test_DeserializePacket_FirstPacket(t *testing.T) {
assert.Equal(t, len(sampleCommand), int(totalSize), "TotalSize is incorrect")
assert.Equal(t, packetSize-firstPacketHeaderSize, len(output), "Size of the deserialized packet is wrong")
assert.Equal(t, true, isSequenceZero, "Test Case Should Find Sequence == 0")
assert.True(t, bytes.Compare(output[:len(sampleCommand)], sampleCommand) == 0, "Deserialized message does not match the original")
assert.True(t, bytes.Equal(output[:len(sampleCommand)], sampleCommand), "Deserialized message does not match the original")
}

func Test_DeserializePacket_SecondMessage(t *testing.T) {
Expand Down Expand Up @@ -250,7 +251,7 @@ func Test_UnwrapApdu_SmokeTest(t *testing.T) {

serialized, _ := WrapCommandAPDU(channel, input, packetSize)

// Allocate enough buffers to keep all the packets
// Allocate enough packets
pipe := make(chan []byte, int(math.Ceil(float64(inputSize)/float64(packetSize))))
// Send all the packets to the pipe
for len(serialized) > 0 {
Expand All @@ -260,12 +261,38 @@ func Test_UnwrapApdu_SmokeTest(t *testing.T) {

output, _ := UnwrapResponseAPDU(channel, pipe, packetSize)

//fmt.Printf("INPUT : %x\n", input)
//fmt.Printf("SERIALIZED: %x\n", serialized)
//fmt.Printf("OUTPUT : %x\n", output)
fmt.Printf("INPUT : %x\n", input)
fmt.Printf("SERIALIZED: %x\n", serialized)
fmt.Printf("OUTPUT : %x\n", output)

assert.Equal(t, len(input), len(output), "Input and output messages have different size")
assert.True(t,
bytes.Equal(input, output),
"Input message does not match message which was serialized and then deserialized")
}

func TestSerializePacketWithInvalidSize(t *testing.T) {
_, _, err := SerializePacket(0x0101, []byte{1, 2}, 2, 0)
assert.ErrorIs(t, err, ErrPacketSize)
}

func TestDeserializePacketWithInvalidChannel(t *testing.T) {
packet := []byte{0x02, 0x02, 0x05, 0x00, 0x00, 0x00, 0x20}
_, _, _, err := DeserializePacket(0x0101, packet, 0)
assert.ErrorIs(t, err, ErrInvalidChannel)
}

func TestSerializeDeserialize(t *testing.T) {
sampleCommand := []byte{0x01, 0x02, 0x03, 0x04, 0x05}
channel := uint16(0x0101)
packetSize := 64
sequenceIdx := uint16(0)

packet, _, err := SerializePacket(channel, sampleCommand, packetSize, sequenceIdx)
assert.NoError(t, err)

output, _, _, err := DeserializePacket(channel, packet, sequenceIdx)
assert.NoError(t, err)

assert.True(t, bytes.Equal(output[:len(sampleCommand)], sampleCommand), "Deserialized message does not match the original")
}
5 changes: 4 additions & 1 deletion ledger_hid.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package ledger_go
import (
"errors"
"fmt"
"log"
"sync"
"time"

Expand Down Expand Up @@ -131,7 +132,7 @@ func (admin *LedgerAdminHID) Connect(requiredIndex int) (LedgerDevice, error) {
}
}

return nil, fmt.Errorf("LedgerHID device (idx %d) not found. Ledger LOCKED OR Other Program/Web Browser may have control of device.", requiredIndex)
return nil, fmt.Errorf("LedgerHID device (idx %d) not found: device may be locked or in use by another application", requiredIndex)
}

func (ledger *LedgerDeviceHID) write(buffer []byte) (int, error) {
Expand Down Expand Up @@ -209,6 +210,7 @@ func (ledger *LedgerDeviceHID) drainRead() {
}

func (ledger *LedgerDeviceHID) Exchange(command []byte) ([]byte, error) {
log.Printf("Sending command: %X", command)
// Purge messages that arrived after previous exchange completed
ledger.drainRead()

Expand Down Expand Up @@ -249,6 +251,7 @@ func (ledger *LedgerDeviceHID) Exchange(command []byte) ([]byte, error) {
return response[:swOffset], errors.New(ErrorMessage(sw))
}

log.Printf("Received response: %X", response)
return response[:swOffset], nil
}

Expand Down
Loading

0 comments on commit 6981a36

Please sign in to comment.