Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit pings #32

Merged
merged 1 commit into from
Aug 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ type Config struct {
// waiting an accept.
AcceptBacklog int

// PingBacklog is used to limit how many ping acks we can queue.
PingBacklog int

// EnableKeepalive is used to do a period keep alive
// messages using a ping.
EnableKeepAlive bool
Expand Down Expand Up @@ -53,6 +56,7 @@ type Config struct {
func DefaultConfig() *Config {
return &Config{
AcceptBacklog: 256,
PingBacklog: 32,
EnableKeepAlive: true,
KeepAliveInterval: 30 * time.Second,
ConnectionWriteTimeout: 10 * time.Second,
Expand Down Expand Up @@ -81,6 +85,9 @@ func VerifyConfig(config *Config) error {
if config.WriteCoalesceDelay < 0 {
return fmt.Errorf("WriteCoalesceDelay must be >= 0")
}
if config.PingBacklog < 1 {
return fmt.Errorf("PingBacklog must be > 0")
}
return nil
}

Expand Down
34 changes: 34 additions & 0 deletions ping.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package yamux

import "time"

type ping struct {
id uint32
// written to by the session on ping response
pingResponse chan struct{}

// closed by the Ping call that sent the ping when done.
done chan struct{}
// result set before done is closed.
err error
duration time.Duration
}

func newPing(id uint32) *ping {
return &ping{
id: id,
pingResponse: make(chan struct{}, 1),
done: make(chan struct{}),
}
}

func (p *ping) finish(val time.Duration, err error) {
p.err = err
p.duration = val
close(p.done)
}

func (p *ping) wait() (time.Duration, error) {
<-p.done
return p.duration, p.err
}
75 changes: 48 additions & 27 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ type Session struct {
reader io.Reader

// pings is used to track inflight pings
pings map[uint32]chan struct{}
pingID uint32
pingLock sync.Mutex
pingLock sync.Mutex
pingID uint32
activePing *ping

// streams maps a stream id to a stream, and inflight has an entry
// for any outgoing stream that has not yet been established. Both are
Expand All @@ -66,6 +66,8 @@ type Session struct {

// sendCh is used to send messages
sendCh chan []byte
// pingCh is used to send pongs (responses to pings)
pongCh chan uint32

// recvDoneCh is closed when recv() exits to avoid a race
// between stream registration and stream shutdown
Expand Down Expand Up @@ -104,12 +106,12 @@ func newSession(config *Config, conn net.Conn, client bool, readBuf int) *Sessio
logger: log.New(config.LogOutput, "", log.LstdFlags),
conn: conn,
reader: reader,
pings: make(map[uint32]chan struct{}),
streams: make(map[uint32]*Stream),
inflight: make(map[uint32]struct{}),
synCh: make(chan struct{}, config.AcceptBacklog),
acceptCh: make(chan *Stream, config.AcceptBacklog),
sendCh: make(chan []byte, 64),
pongCh: make(chan uint32, config.PingBacklog),
recvDoneCh: make(chan struct{}),
sendDoneCh: make(chan struct{}),
shutdownCh: make(chan struct{}),
Expand Down Expand Up @@ -281,19 +283,33 @@ func (s *Session) goAway(reason uint32) header {
}

// Ping is used to measure the RTT response time
func (s *Session) Ping() (time.Duration, error) {
// Get a channel for the ping
ch := make(chan struct{})

// Get a new ping id, mark as pending
func (s *Session) Ping() (dur time.Duration, err error) {
// Prepare a ping.
s.pingLock.Lock()
id := s.pingID
// If there's an active ping, jump on the bandwagon.
if activePing := s.activePing; activePing != nil {
s.pingLock.Unlock()
return activePing.wait()
}

// Ok, our job to send the ping.
activePing := newPing(s.pingID)
s.pingID++
s.pings[id] = ch
s.activePing = activePing
s.pingLock.Unlock()

defer func() {
// complete ping promise
activePing.finish(dur, err)

// Unset it.
s.pingLock.Lock()
s.activePing = nil
s.pingLock.Unlock()
}()

// Send the ping request
hdr := encode(typePing, flagSYN, 0, id)
hdr := encode(typePing, flagSYN, 0, activePing.id)
if err := s.sendMsg(hdr, nil, nil); err != nil {
return 0, err
}
Expand All @@ -303,11 +319,8 @@ func (s *Session) Ping() (time.Duration, error) {
timer := time.NewTimer(s.config.ConnectionWriteTimeout)
defer timer.Stop()
select {
case <-ch:
case <-activePing.pingResponse:
case <-timer.C:
s.pingLock.Lock()
delete(s.pings, id) // Ignore it if a response comes later.
s.pingLock.Unlock()
return 0, ErrTimeout
case <-s.shutdownCh:
return 0, s.shutdownErr
Expand Down Expand Up @@ -456,6 +469,10 @@ func (s *Session) sendLoop() error {
var buf []byte
select {
case buf = <-s.sendCh:
case pingID := <-s.pongCh:
buf = pool.Get(headerSize)
hdr := encode(typePing, flagACK, 0, pingID)
copy(buf, hdr[:])
case <-s.shutdownCh:
return nil
//default:
Expand Down Expand Up @@ -605,29 +622,33 @@ func (s *Session) handleStreamMessage(hdr header) error {
return nil
}

// handlePing is invokde for a typePing frame
// handlePing is invoked for a typePing frame
func (s *Session) handlePing(hdr header) error {
flags := hdr.Flags()
pingID := hdr.Length()

// Check if this is a query, respond back in a separate context so we
// don't interfere with the receiving thread blocking for the write.
if flags&flagSYN == flagSYN {
go func() {
hdr := encode(typePing, flagACK, 0, pingID)
if err := s.sendMsg(hdr, nil, nil); err != nil {
s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err)
}
}()
select {
case s.pongCh <- pingID:
default:
s.logger.Printf("[WARN] yamux: dropped ping reply")
}
return nil
}

// Handle a response
s.pingLock.Lock()
ch := s.pings[pingID]
if ch != nil {
delete(s.pings, pingID)
close(ch)
// If we have an active ping, and this is a response to that active
// ping, complete the ping.
if s.activePing != nil && s.activePing.id == pingID {
// Don't assume that the peer won't send multiple responses for
// the same ping.
select {
case s.activePing.pingResponse <- struct{}{}:
default:
}
}
s.pingLock.Unlock()
return nil
Expand Down