diff --git a/dot/network/discovery.go b/dot/network/discovery.go index a1e54ddbb4..04d536bb74 100644 --- a/dot/network/discovery.go +++ b/dot/network/discovery.go @@ -60,34 +60,44 @@ func newDiscovery(ctx context.Context, h libp2phost.Host, } } -// start creates the DHT. -func (d *discovery) start() error { - if len(d.bootnodes) == 0 { - // get all currently connected peers and use them to bootstrap the DHT - peers := d.h.Network().Peers() - - t := time.NewTicker(startDHTTimeout) - defer t.Stop() - for { - if len(peers) > 0 { - break - } +func (d *discovery) waitForPeers() (peers []peer.AddrInfo, err error) { + // get all currently connected peers and use them to bootstrap the DHT - select { - case <-t.C: - logger.Debug("no peers yet, waiting to start DHT...") - // wait for peers to connect before starting DHT, otherwise DHT bootstrap nodes - // will be empty and we will fail to fill the routing table - case <-d.ctx.Done(): - return nil - } + currentPeers := d.h.Network().Peers() + + t := time.NewTicker(startDHTTimeout) + defer t.Stop() - peers = d.h.Network().Peers() + for len(currentPeers) == 0 { + select { + case <-t.C: + logger.Debug("no peers yet, waiting to start DHT...") + // wait for peers to connect before starting DHT, otherwise DHT bootstrap nodes + // will be empty and we will fail to fill the routing table + case <-d.ctx.Done(): + return nil, d.ctx.Err() } - for _, p := range peers { - d.bootnodes = append(d.bootnodes, d.h.Peerstore().PeerInfo(p)) + currentPeers = d.h.Network().Peers() + } + + peers = make([]peer.AddrInfo, len(currentPeers)) + for idx, peer := range currentPeers { + peers[idx] = d.h.Peerstore().PeerInfo(peer) + } + + return peers, nil +} + +// start creates the DHT. +func (d *discovery) start() error { + if len(d.bootnodes) == 0 { + peers, err := d.waitForPeers() + if err != nil { + return fmt.Errorf("failed while waiting for peers: %w", err) } + + d.bootnodes = peers } logger.Debugf("starting DHT with bootnodes %v...", d.bootnodes) @@ -141,8 +151,15 @@ func (d *discovery) advertise() { ttl := initialAdvertisementTimeout for { + timer := time.NewTimer(ttl) + select { - case <-time.After(ttl): + case <-d.ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return + case <-timer.C: logger.Debug("advertising ourselves in the DHT...") err := d.dht.Bootstrap(d.ctx) if err != nil { @@ -155,32 +172,29 @@ func (d *discovery) advertise() { logger.Warnf("failed to advertise in the DHT: %s", err) ttl = tryAdvertiseTimeout } - case <-d.ctx.Done(): - return } } } func (d *discovery) checkPeerCount() { - t := time.NewTicker(connectToPeersTimeout) - defer t.Stop() + timer := time.NewTicker(connectToPeersTimeout) + defer timer.Stop() + for { select { case <-d.ctx.Done(): return - case <-t.C: + case <-timer.C: if len(d.h.Network().Peers()) > d.minPeers { continue } - ctx, cancel := context.WithTimeout(d.ctx, findPeersTimeout) - defer cancel() - d.findPeers(ctx) + d.findPeers() } } } -func (d *discovery) findPeers(ctx context.Context) { +func (d *discovery) findPeers() { logger.Debug("attempting to find DHT peers...") peerCh, err := d.rd.FindPeers(d.ctx, string(d.pid)) if err != nil { @@ -188,9 +202,12 @@ func (d *discovery) findPeers(ctx context.Context) { return } + timer := time.NewTimer(findPeersTimeout) + defer timer.Stop() + for { select { - case <-ctx.Done(): + case <-timer.C: return case peer := <-peerCh: if peer.ID == d.h.ID() || peer.ID == "" { @@ -198,10 +215,12 @@ func (d *discovery) findPeers(ctx context.Context) { } logger.Tracef("found new peer %s via DHT", peer.ID) - d.h.Peerstore().AddAddrs(peer.ID, peer.Addrs, peerstore.PermanentAddrTTL) d.handler.AddPeer(0, peer.ID) + if !timer.Stop() { + <-timer.C + } } } }