Skip to content

Commit

Permalink
net/mock: support ConnectionGater in MockNet
Browse files Browse the repository at this point in the history
  • Loading branch information
ajsutton committed May 17, 2023
1 parent 8719fc4 commit f235c5c
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 13 deletions.
2 changes: 2 additions & 0 deletions p2p/net/mock/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io"
"time"

"github.com/libp2p/go-libp2p/core/connmgr"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
Expand All @@ -22,6 +23,7 @@ import (
type Mocknet interface {
// GenPeer generates a peer and its network.Network in the Mocknet
GenPeer() (host.Host, error)
GenPeerWithConnGater(connmgr.ConnectionGater) (host.Host, error)

// AddPeer adds an existing peer. we need both a privkey and addr.
// ID is derived from PrivKey
Expand Down
36 changes: 30 additions & 6 deletions p2p/net/mock/mock_net.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"sort"
"sync"

"github.com/libp2p/go-libp2p/core/connmgr"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
Expand Down Expand Up @@ -64,6 +65,10 @@ func (mn *mocknet) Close() error {
}

func (mn *mocknet) GenPeer() (host.Host, error) {
return mn.GenPeerWithConnGater(nil)
}

func (mn *mocknet) GenPeerWithConnGater(gater connmgr.ConnectionGater) (host.Host, error) {
sk, _, err := ic.GenerateECDSAKeyPair(rand.Reader)
if err != nil {
return nil, err
Expand All @@ -83,7 +88,14 @@ func (mn *mocknet) GenPeer() (host.Host, error) {
return nil, fmt.Errorf("failed to create test multiaddr: %s", err)
}

h, err := mn.AddPeer(sk, a)
p, ps, err := mn.createPeerstore(sk, a)
if err != nil {
return nil, err
}
h, err := mn.AddPeerWithOptions(p, &peerOptions{
ps: ps,
gater: gater,
})
if err != nil {
return nil, err
}
Expand All @@ -92,25 +104,37 @@ func (mn *mocknet) GenPeer() (host.Host, error) {
}

func (mn *mocknet) AddPeer(k ic.PrivKey, a ma.Multiaddr) (host.Host, error) {
p, err := peer.IDFromPublicKey(k.GetPublic())
p, ps, err := mn.createPeerstore(k, a)
if err != nil {
return nil, err
}

return mn.AddPeerWithPeerstore(p, ps)
}

func (mn *mocknet) createPeerstore(k ic.PrivKey, a ma.Multiaddr) (peer.ID, peerstore.Peerstore, error) {
p, err := peer.IDFromPublicKey(k.GetPublic())
if err != nil {
return "", nil, err
}

ps, err := pstoremem.NewPeerstore()
if err != nil {
return nil, err
return "", nil, err
}
ps.AddAddr(p, a, peerstore.PermanentAddrTTL)
ps.AddPrivKey(p, k)
ps.AddPubKey(p, k.GetPublic())

return mn.AddPeerWithPeerstore(p, ps)
return p, ps, nil
}

func (mn *mocknet) AddPeerWithPeerstore(p peer.ID, ps peerstore.Peerstore) (host.Host, error) {
return mn.AddPeerWithOptions(p, &peerOptions{ps: ps})
}

func (mn *mocknet) AddPeerWithOptions(p peer.ID, netOpts *peerOptions) (host.Host, error) {
bus := eventbus.NewBus()
n, err := newPeernet(mn, p, ps, bus)
n, err := newPeernet(mn, p, netOpts, bus)
if err != nil {
return nil, err
}
Expand Down
61 changes: 54 additions & 7 deletions p2p/net/mock/mock_peernet.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@ import (
"math/rand"
"sync"

"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
ma "github.com/multiformats/go-multiaddr"
)

type peerOptions struct {
ps peerstore.Peerstore
gater connmgr.ConnectionGater
}

// peernet implements network.Network
type peernet struct {
mocknet *mocknet // parent
Expand All @@ -28,6 +34,9 @@ type peernet struct {
connsByPeer map[peer.ID]map[*conn]struct{}
connsByLink map[*link]map[*conn]struct{}

// connection gater to check before dialing or accepting connections. May be nil to allow all.
gater connmgr.ConnectionGater

// implement network.Network
streamHandler network.StreamHandler

Expand All @@ -38,7 +47,7 @@ type peernet struct {
}

// newPeernet constructs a new peernet
func newPeernet(m *mocknet, p peer.ID, ps peerstore.Peerstore, bus event.Bus) (*peernet, error) {
func newPeernet(m *mocknet, p peer.ID, opts *peerOptions, bus event.Bus) (*peernet, error) {
emitter, err := bus.Emitter(&event.EvtPeerConnectednessChanged{})
if err != nil {
return nil, err
Expand All @@ -47,7 +56,8 @@ func newPeernet(m *mocknet, p peer.ID, ps peerstore.Peerstore, bus event.Bus) (*
n := &peernet{
mocknet: m,
peer: p,
ps: ps,
ps: opts.ps,
gater: opts.gater,
emitter: emitter,

connsByPeer: map[peer.ID]map[*conn]struct{}{},
Expand Down Expand Up @@ -124,6 +134,10 @@ func (pn *peernet) connect(p peer.ID) (*conn, error) {
}
pn.RUnlock()

if pn.gater != nil && !pn.gater.InterceptPeerDial(p) {
log.Debugf("gater disallowed outbound connection to peer %s", p)
return nil, fmt.Errorf("%v connection gater disallowed connection to %v", pn.peer, p)
}
log.Debugf("%s (newly) dialing %s", pn.peer, p)

// ok, must create a new connection. we need a link
Expand All @@ -139,18 +153,51 @@ func (pn *peernet) connect(p peer.ID) (*conn, error) {

log.Debugf("%s dialing %s openingConn", pn.peer, p)
// create a new connection with link
c := pn.openConn(p, l.(*link))
return c, nil
return pn.openConn(p, l.(*link))
}

func (pn *peernet) openConn(r peer.ID, l *link) *conn {
func (pn *peernet) openConn(r peer.ID, l *link) (*conn, error) {
lc, rc := l.newConnPair(pn)
log.Debugf("%s opening connection to %s", pn.LocalPeer(), lc.RemotePeer())
addConnPair(pn, rc.net, lc, rc)
log.Debugf("%s opening connection to %s", pn.LocalPeer(), lc.RemotePeer())
abort := func() {
_ = lc.Close()
_ = rc.Close()
}
if pn.gater != nil && !pn.gater.InterceptAddrDial(lc.remote, lc.remoteAddr) {
abort()
return nil, fmt.Errorf("%v rejected dial to %v on addr %v", lc.local, lc.remote, lc.remoteAddr)
}
if rc.net.gater != nil && !rc.net.gater.InterceptAccept(rc) {
abort()
return nil, fmt.Errorf("%v rejected connection from %v", rc.local, rc.remote)
}
if err := checkSecureAndUpgrade(network.DirOutbound, pn.gater, lc); err != nil {
abort()
return nil, err
}
if err := checkSecureAndUpgrade(network.DirInbound, rc.net.gater, rc); err != nil {
abort()
return nil, err
}

go rc.net.remoteOpenedConn(rc)
pn.addConn(lc)
return lc
return lc, nil
}

func checkSecureAndUpgrade(dir network.Direction, gater connmgr.ConnectionGater, c *conn) error {
if gater == nil {
return nil
}
if !gater.InterceptSecured(dir, c.remote, c) {
return fmt.Errorf("%v rejected secure handshake with %v", c.local, c.remote)
}
allow, _ := gater.InterceptUpgraded(c)
if !allow {
return fmt.Errorf("%v rejected upgrade with %v", c.local, c.remote)
}
return nil
}

// addConnPair adds connection to both peernets at the same time
Expand Down
68 changes: 68 additions & 0 deletions p2p/net/mock/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ import (

"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/p2p/net/conngater"
manet "github.com/multiformats/go-multiaddr/net"

"github.com/libp2p/go-libp2p-testing/ci"
tetc "github.com/libp2p/go-libp2p-testing/etc"
Expand Down Expand Up @@ -681,3 +684,68 @@ func TestEventBus(t *testing.T) {
}
}
}

func TestBlockByPeerID(t *testing.T) {
m, gater1, host1, _, host2 := WithConnectionGaters(t)

err := gater1.BlockPeer(host2.ID())
if err != nil {
t.Fatal(err)
}

_, err = m.ConnectPeers(host1.ID(), host2.ID())
if err == nil {
t.Fatal("Should have blocked connection to banned peer")
}

_, err = m.ConnectPeers(host2.ID(), host1.ID())
if err == nil {
t.Fatal("Should have blocked connection from banned peer")
}
}

func TestBlockByIP(t *testing.T) {
m, gater1, host1, _, host2 := WithConnectionGaters(t)

ip, err := manet.ToIP(host2.Addrs()[0])
if err != nil {
t.Fatal(err)
}
err = gater1.BlockAddr(ip)
if err != nil {
t.Fatal(err)
}

_, err = m.ConnectPeers(host1.ID(), host2.ID())
if err == nil {
t.Fatal("Should have blocked connection to banned IP")
}

_, err = m.ConnectPeers(host2.ID(), host1.ID())
if err == nil {
t.Fatal("Should have blocked connection from banned IP")
}
}

func WithConnectionGaters(t *testing.T) (Mocknet, *conngater.BasicConnectionGater, host.Host, *conngater.BasicConnectionGater, host.Host) {
m := New()
addPeer := func() (*conngater.BasicConnectionGater, host.Host) {
gater, err := conngater.NewBasicConnectionGater(nil)
if err != nil {
t.Fatal(err)
}
h, err := m.GenPeerWithConnGater(gater)
if err != nil {
t.Fatal(err)
}
return gater, h
}
gater1, host1 := addPeer()
gater2, host2 := addPeer()

err := m.LinkAll()
if err != nil {
t.Fatal(err)
}
return m, gater1, host1, gater2, host2
}

0 comments on commit f235c5c

Please sign in to comment.