Skip to content

Commit

Permalink
fix(dot/network): implement a handshake timeout (ChainSafe#1615)
Browse files Browse the repository at this point in the history
* chore: return the bytes read by leb128

* feat: add handshake timeout

* chore: remove debug network loglevel

* chore: get back to trace

* chore: stop ticker no matter the case

* chore: implement timer and use unbuffered

* chore: using defer

* chore: remove the timer.Stop() when the <-timer.C was called
  • Loading branch information
EclesioMeloJunior authored and timwu20 committed Dec 6, 2021
1 parent 005932b commit 4ab856e
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 26 deletions.
84 changes: 59 additions & 25 deletions dot/network/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package network
import (
"errors"
"sync"
"time"
"unsafe"

libp2pnetwork "github.com/libp2p/go-libp2p-core/network"
Expand All @@ -29,6 +30,7 @@ import (
var errCannotValidateHandshake = errors.New("failed to validate handshake")

const maxHandshakeSize = unsafe.Sizeof(BlockAnnounceHandshake{}) //nolint
const handshakeTimeout = time.Second * 10

// Handshake is the interface all handshakes for notifications protocols must implement
type Handshake interface {
Expand All @@ -53,6 +55,11 @@ type (
NotificationsMessageHandler = func(peer peer.ID, msg NotificationsMessage) (propagate bool, err error)
)

type handshakeReader struct {
hs Handshake
err error
}

type notificationsProtocol struct {
protocolID protocol.ID
getHandshake HandshakeGetter
Expand All @@ -63,16 +70,17 @@ type notificationsProtocol struct {
}

func (n *notificationsProtocol) getHandshakeData(pid peer.ID, inbound bool) (handshakeData, bool) {
if inbound {
data, has := n.inboundHandshakeData.Load(pid)
if !has {
return handshakeData{}, false
}
var (
data interface{}
has bool
)

return data.(handshakeData), true
if inbound {
data, has = n.inboundHandshakeData.Load(pid)
} else {
data, has = n.outboundHandshakeData.Load(pid)
}

data, has := n.outboundHandshakeData.Load(pid)
if !has {
return handshakeData{}, false
}
Expand Down Expand Up @@ -174,7 +182,7 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
return nil
}

logger.Debug("received message on notifications sub-protocol", "protocol", info.protocolID,
logger.Trace("received message on notifications sub-protocol", "protocol", info.protocolID,
"message", msg,
"peer", stream.Conn().RemotePeer(),
)
Expand Down Expand Up @@ -226,14 +234,29 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc
return
}

hs, err := s.readHandshake(stream, decodeBlockAnnounceHandshake)
if err != nil {
logger.Trace("failed to read handshake", "protocol", info.protocolID, "peer", peer, "error", err)
hsTimer := time.NewTimer(handshakeTimeout)

var hs Handshake
select {
case <-hsTimer.C:
logger.Trace("handshake timeout reached", "protocol", info.protocolID, "peer", peer)
_ = stream.Close()
info.outboundHandshakeData.Delete(peer)
return
}

hsData.received = true
case hsResponse := <-s.readHandshake(stream, decodeBlockAnnounceHandshake):
hsTimer.Stop()
if hsResponse.err != nil {
logger.Trace("failed to read handshake", "protocol", info.protocolID, "peer", peer, "error", err)
_ = stream.Close()

info.outboundHandshakeData.Delete(peer)
return
}

hs = hsResponse.hs
hsData.received = true
}

err = info.handshakeValidator(peer, hs)
if err != nil {
Expand Down Expand Up @@ -294,19 +317,30 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer
}
}

func (s *Service) readHandshake(stream libp2pnetwork.Stream, decoder HandshakeDecoder) (Handshake, error) {
msgBytes := s.bufPool.get()
defer s.bufPool.put(&msgBytes)
func (s *Service) readHandshake(stream libp2pnetwork.Stream, decoder HandshakeDecoder) <-chan *handshakeReader {
hsC := make(chan *handshakeReader)

tot, err := readStream(stream, msgBytes[:])
if err != nil {
return nil, err
}
go func() {
msgBytes := s.bufPool.get()
defer func() {
s.bufPool.put(&msgBytes)
close(hsC)
}()

hs, err := decoder(msgBytes[:tot])
if err != nil {
return nil, err
}
tot, err := readStream(stream, msgBytes[:])
if err != nil {
hsC <- &handshakeReader{hs: nil, err: err}
return
}

hs, err := decoder(msgBytes[:tot])
if err != nil {
hsC <- &handshakeReader{hs: nil, err: err}
return
}

hsC <- &handshakeReader{hs: hs, err: nil}
}()

return hs, nil
return hsC
}
82 changes: 82 additions & 0 deletions dot/network/notifications_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package network

import (
"context"
"fmt"
"math/big"
"sync"
"testing"
Expand All @@ -25,7 +27,10 @@ import (
"github.com/ChainSafe/gossamer/dot/types"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/utils"
ma "github.com/multiformats/go-multiaddr"

"github.com/libp2p/go-libp2p"
libp2pnetwork "github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -240,3 +245,80 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T)
require.True(t, data.received)
require.True(t, data.validated)
}

func Test_HandshakeTimeout(t *testing.T) {
// create service A
config := &Config{
BasePath: utils.NewTestBasePath(t, "nodeA"),
Port: 7001,
RandSeed: 1,
NoBootstrap: true,
NoMDNS: true,
}
ha := createTestService(t, config)

// create info and handler
info := &notificationsProtocol{
protocolID: ha.host.protocolID + blockAnnounceID,
getHandshake: ha.getBlockAnnounceHandshake,
handshakeValidator: ha.validateBlockAnnounceHandshake,
inboundHandshakeData: new(sync.Map),
outboundHandshakeData: new(sync.Map),
}

// creating host b with will never respond to a handshake
addrB, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", 7002))
require.NoError(t, err)

hb, err := libp2p.New(
context.Background(), libp2p.ListenAddrs(addrB),
)
require.NoError(t, err)

testHandshakeMsg := &BlockAnnounceHandshake{
Roles: 4,
BestBlockNumber: 77,
BestBlockHash: common.Hash{1},
GenesisHash: common.Hash{2},
}

hb.SetStreamHandler(info.protocolID, func(stream libp2pnetwork.Stream) {
fmt.Println("never respond a handshake message")
})

addrBInfo := peer.AddrInfo{
ID: hb.ID(),
Addrs: hb.Addrs(),
}

err = ha.host.connect(addrBInfo)
if failedToDial(err) {
time.Sleep(TestBackoffTimeout)
err = ha.host.connect(addrBInfo)
}
require.NoError(t, err)

go ha.sendData(hb.ID(), testHandshakeMsg, info, nil)

time.Sleep(handshakeTimeout / 2)
// peer should be stored in handshake data until timeout
_, ok := info.outboundHandshakeData.Load(hb.ID())
require.True(t, ok)

// a stream should be open until timeout
connAToB := ha.host.h.Network().ConnsToPeer(hb.ID())
require.Len(t, connAToB, 1)
require.Len(t, connAToB[0].GetStreams(), 1)

// after the timeout
time.Sleep(handshakeTimeout)

// handshake data should be removed
_, ok = info.outboundHandshakeData.Load(hb.ID())
require.False(t, ok)

// stream should be closed
connAToB = ha.host.h.Network().ConnsToPeer(hb.ID())
require.Len(t, connAToB, 1)
require.Len(t, connAToB[0].GetStreams(), 0)
}
2 changes: 1 addition & 1 deletion dot/network/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) {
}

if tot != int(length) {
return tot, fmt.Errorf("failed to read entire message: expected %d bytes", length)
return tot, fmt.Errorf("failed to read entire message: expected %d bytes, received %d bytes", length, tot)
}

return tot, nil
Expand Down

0 comments on commit 4ab856e

Please sign in to comment.