Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

net/mock: support ConnectionGater in MockNet #2297

Merged
merged 1 commit into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions p2p/net/mock/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io"
"time"

"github.com/libp2p/go-libp2p/core/connmgr"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
Expand All @@ -19,14 +20,24 @@ import (
ma "github.com/multiformats/go-multiaddr"
)

type PeerOptions struct {
// ps is the Peerstore to use when adding peer. If nil, a default peerstore will be created.
ps peerstore.Peerstore

// gater is the ConnectionGater to use when adding a peer. If nil, no connection gater will be used.
gater connmgr.ConnectionGater
}

type Mocknet interface {
// GenPeer generates a peer and its network.Network in the Mocknet
GenPeer() (host.Host, error)
GenPeerWithOptions(PeerOptions) (host.Host, error)

// AddPeer adds an existing peer. we need both a privkey and addr.
// ID is derived from PrivKey
AddPeer(ic.PrivKey, ma.Multiaddr) (host.Host, error)
AddPeerWithPeerstore(peer.ID, peerstore.Peerstore) (host.Host, error)
AddPeerWithOptions(peer.ID, PeerOptions) (host.Host, error)

// retrieve things (with randomized iteration order)
Peers() []peer.ID
Expand Down
72 changes: 62 additions & 10 deletions p2p/net/mock/mock_net.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ func (mn *mocknet) Close() error {
}

func (mn *mocknet) GenPeer() (host.Host, error) {
return mn.GenPeerWithOptions(PeerOptions{})
}

func (mn *mocknet) GenPeerWithOptions(opts PeerOptions) (host.Host, error) {
if err := mn.addDefaults(&opts); err != nil {
return nil, err
}
sk, _, err := ic.GenerateECDSAKeyPair(rand.Reader)
if err != nil {
return nil, err
Expand All @@ -83,7 +90,20 @@ func (mn *mocknet) GenPeer() (host.Host, error) {
return nil, fmt.Errorf("failed to create test multiaddr: %s", err)
}

h, err := mn.AddPeer(sk, a)
var ps peerstore.Peerstore
if opts.ps == nil {
ps, err = pstoremem.NewPeerstore()
if err != nil {
return nil, err
}
} else {
ps = opts.ps
}
p, err := mn.updatePeerstore(sk, a, ps)
if err != nil {
return nil, err
}
h, err := mn.AddPeerWithOptions(p, opts)
if err != nil {
return nil, err
}
Expand All @@ -92,36 +112,39 @@ func (mn *mocknet) GenPeer() (host.Host, error) {
}

func (mn *mocknet) AddPeer(k ic.PrivKey, a ma.Multiaddr) (host.Host, error) {
p, err := peer.IDFromPublicKey(k.GetPublic())
ps, err := pstoremem.NewPeerstore()
if err != nil {
return nil, err
}

ps, err := pstoremem.NewPeerstore()
p, err := mn.updatePeerstore(k, a, ps)
if err != nil {
return nil, err
}
ps.AddAddr(p, a, peerstore.PermanentAddrTTL)
ps.AddPrivKey(p, k)
ps.AddPubKey(p, k.GetPublic())

return mn.AddPeerWithPeerstore(p, ps)
}

func (mn *mocknet) AddPeerWithPeerstore(p peer.ID, ps peerstore.Peerstore) (host.Host, error) {
return mn.AddPeerWithOptions(p, PeerOptions{ps: ps})
}

func (mn *mocknet) AddPeerWithOptions(p peer.ID, opts PeerOptions) (host.Host, error) {
bus := eventbus.NewBus()
n, err := newPeernet(mn, p, ps, bus)
if err := mn.addDefaults(&opts); err != nil {
return nil, err
}
n, err := newPeernet(mn, p, opts, bus)
if err != nil {
return nil, err
}

opts := &bhost.HostOpts{
hostOpts := &bhost.HostOpts{
NegotiationTimeout: -1,
DisableSignedPeerRecord: true,
EventBus: bus,
}

h, err := bhost.NewHost(n, opts)
h, err := bhost.NewHost(n, hostOpts)
if err != nil {
return nil, err
}
Expand All @@ -134,6 +157,35 @@ func (mn *mocknet) AddPeerWithPeerstore(p peer.ID, ps peerstore.Peerstore) (host
return h, nil
}

func (mn *mocknet) addDefaults(opts *PeerOptions) error {
if opts.ps == nil {
ps, err := pstoremem.NewPeerstore()
if err != nil {
return err
}
opts.ps = ps
}
return nil
}

func (mn *mocknet) updatePeerstore(k ic.PrivKey, a ma.Multiaddr, ps peerstore.Peerstore) (peer.ID, error) {
p, err := peer.IDFromPublicKey(k.GetPublic())
if err != nil {
return "", err
}

ps.AddAddr(p, a, peerstore.PermanentAddrTTL)
err = ps.AddPrivKey(p, k)
if err != nil {
return "", err
}
err = ps.AddPubKey(p, k.GetPublic())
if err != nil {
return "", err
}
return p, nil
}

func (mn *mocknet) Peers() []peer.ID {
mn.Lock()
defer mn.Unlock()
Expand Down
56 changes: 49 additions & 7 deletions p2p/net/mock/mock_peernet.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"math/rand"
"sync"

"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
Expand All @@ -28,6 +29,9 @@ type peernet struct {
connsByPeer map[peer.ID]map[*conn]struct{}
connsByLink map[*link]map[*conn]struct{}

// connection gater to check before dialing or accepting connections. May be nil to allow all.
gater connmgr.ConnectionGater

// implement network.Network
streamHandler network.StreamHandler

Expand All @@ -38,7 +42,7 @@ type peernet struct {
}

// newPeernet constructs a new peernet
func newPeernet(m *mocknet, p peer.ID, ps peerstore.Peerstore, bus event.Bus) (*peernet, error) {
func newPeernet(m *mocknet, p peer.ID, opts PeerOptions, bus event.Bus) (*peernet, error) {
emitter, err := bus.Emitter(&event.EvtPeerConnectednessChanged{})
if err != nil {
return nil, err
Expand All @@ -47,7 +51,8 @@ func newPeernet(m *mocknet, p peer.ID, ps peerstore.Peerstore, bus event.Bus) (*
n := &peernet{
mocknet: m,
peer: p,
ps: ps,
ps: opts.ps,
gater: opts.gater,
emitter: emitter,

connsByPeer: map[peer.ID]map[*conn]struct{}{},
Expand Down Expand Up @@ -124,6 +129,10 @@ func (pn *peernet) connect(p peer.ID) (*conn, error) {
}
pn.RUnlock()

if pn.gater != nil && !pn.gater.InterceptPeerDial(p) {
log.Debugf("gater disallowed outbound connection to peer %s", p)
return nil, fmt.Errorf("%v connection gater disallowed connection to %v", pn.peer, p)
}
log.Debugf("%s (newly) dialing %s", pn.peer, p)

// ok, must create a new connection. we need a link
Expand All @@ -139,18 +148,51 @@ func (pn *peernet) connect(p peer.ID) (*conn, error) {

log.Debugf("%s dialing %s openingConn", pn.peer, p)
// create a new connection with link
c := pn.openConn(p, l.(*link))
return c, nil
return pn.openConn(p, l.(*link))
}

func (pn *peernet) openConn(r peer.ID, l *link) *conn {
func (pn *peernet) openConn(r peer.ID, l *link) (*conn, error) {
lc, rc := l.newConnPair(pn)
log.Debugf("%s opening connection to %s", pn.LocalPeer(), lc.RemotePeer())
addConnPair(pn, rc.net, lc, rc)
log.Debugf("%s opening connection to %s", pn.LocalPeer(), lc.RemotePeer())
abort := func() {
_ = lc.Close()
_ = rc.Close()
}
if pn.gater != nil && !pn.gater.InterceptAddrDial(lc.remote, lc.remoteAddr) {
abort()
return nil, fmt.Errorf("%v rejected dial to %v on addr %v", lc.local, lc.remote, lc.remoteAddr)
}
if rc.net.gater != nil && !rc.net.gater.InterceptAccept(rc) {
abort()
return nil, fmt.Errorf("%v rejected connection from %v", rc.local, rc.remote)
}
if err := checkSecureAndUpgrade(network.DirOutbound, pn.gater, lc); err != nil {
abort()
return nil, err
}
if err := checkSecureAndUpgrade(network.DirInbound, rc.net.gater, rc); err != nil {
abort()
return nil, err
}

go rc.net.remoteOpenedConn(rc)
pn.addConn(lc)
return lc
return lc, nil
}

func checkSecureAndUpgrade(dir network.Direction, gater connmgr.ConnectionGater, c *conn) error {
if gater == nil {
return nil
}
if !gater.InterceptSecured(dir, c.remote, c) {
return fmt.Errorf("%v rejected secure handshake with %v", c.local, c.remote)
}
allow, _ := gater.InterceptUpgraded(c)
if !allow {
return fmt.Errorf("%v rejected upgrade with %v", c.local, c.remote)
}
return nil
}

// addConnPair adds connection to both peernets at the same time
Expand Down
68 changes: 68 additions & 0 deletions p2p/net/mock/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ import (

"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/p2p/net/conngater"
manet "github.com/multiformats/go-multiaddr/net"

"github.com/libp2p/go-libp2p-testing/ci"
tetc "github.com/libp2p/go-libp2p-testing/etc"
Expand Down Expand Up @@ -681,3 +684,68 @@ func TestEventBus(t *testing.T) {
}
}
}

func TestBlockByPeerID(t *testing.T) {
m, gater1, host1, _, host2 := WithConnectionGaters(t)

err := gater1.BlockPeer(host2.ID())
if err != nil {
t.Fatal(err)
}

_, err = m.ConnectPeers(host1.ID(), host2.ID())
if err == nil {
t.Fatal("Should have blocked connection to banned peer")
}

_, err = m.ConnectPeers(host2.ID(), host1.ID())
if err == nil {
t.Fatal("Should have blocked connection from banned peer")
}
}

func TestBlockByIP(t *testing.T) {
m, gater1, host1, _, host2 := WithConnectionGaters(t)

ip, err := manet.ToIP(host2.Addrs()[0])
if err != nil {
t.Fatal(err)
}
err = gater1.BlockAddr(ip)
if err != nil {
t.Fatal(err)
}

_, err = m.ConnectPeers(host1.ID(), host2.ID())
if err == nil {
t.Fatal("Should have blocked connection to banned IP")
}

_, err = m.ConnectPeers(host2.ID(), host1.ID())
if err == nil {
t.Fatal("Should have blocked connection from banned IP")
}
}

func WithConnectionGaters(t *testing.T) (Mocknet, *conngater.BasicConnectionGater, host.Host, *conngater.BasicConnectionGater, host.Host) {
m := New()
addPeer := func() (*conngater.BasicConnectionGater, host.Host) {
gater, err := conngater.NewBasicConnectionGater(nil)
if err != nil {
t.Fatal(err)
}
h, err := m.GenPeerWithOptions(PeerOptions{gater: gater})
if err != nil {
t.Fatal(err)
}
return gater, h
}
gater1, host1 := addPeer()
gater2, host2 := addPeer()

err := m.LinkAll()
if err != nil {
t.Fatal(err)
}
return m, gater1, host1, gater2, host2
}