Skip to content

Commit

Permalink
quic: fill out connection id handling
Browse files Browse the repository at this point in the history
Add support for sending and receiving NEW_CONNECTION_ID
and RETIRE_CONNECTION_ID frames. Keep the peer supplied
with up to 4 connection IDs. Retire connection IDs as
required by the peer.

Support connection IDs provided in the preferred_address
transport parameter.

RFC 9000, Section 5.1.

For golang/go#58547

Change-Id: I015a69b94c40a6396e9f117a92c88acaf83c594e
Reviewed-on: https://go-review.googlesource.com/c/net/+/513440
TryBot-Result: Gopher Robot <[email protected]>
Run-TryBot: Damien Neil <[email protected]>
Reviewed-by: Jonathan Amsterdam <[email protected]>
  • Loading branch information
neild committed Jul 28, 2023
1 parent 08001cc commit bd8ac9e
Show file tree
Hide file tree
Showing 15 changed files with 998 additions and 75 deletions.
32 changes: 28 additions & 4 deletions internal/quic/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type connListener interface {
type connTestHooks interface {
nextMessage(msgc chan any, nextTimeout time.Time) (now time.Time, message any)
handleTLSEvent(tls.QUICEvent)
newConnID(seq int64) ([]byte, error)
}

func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort, config *Config, l connListener, hooks connTestHooks) (*Conn, error) {
Expand All @@ -90,12 +91,12 @@ func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.
c.msgc = make(chan any, 1)

if c.side == clientSide {
if err := c.connIDState.initClient(newRandomConnID); err != nil {
if err := c.connIDState.initClient(c.newConnIDFunc()); err != nil {
return nil, err
}
initialConnID = c.connIDState.dstConnID()
initialConnID, _ = c.connIDState.dstConnID()
} else {
if err := c.connIDState.initServer(newRandomConnID, initialConnID); err != nil {
if err := c.connIDState.initServer(c.newConnIDFunc(), initialConnID); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -154,11 +155,27 @@ func (c *Conn) discardKeys(now time.Time, space numberSpace) {
}

// receiveTransportParameters applies transport parameters sent by the peer.
func (c *Conn) receiveTransportParameters(p transportParameters) {
func (c *Conn) receiveTransportParameters(p transportParameters) error {
c.peerAckDelayExponent = p.ackDelayExponent
c.loss.setMaxAckDelay(p.maxAckDelay)
if err := c.connIDState.setPeerActiveConnIDLimit(p.activeConnIDLimit, c.newConnIDFunc()); err != nil {
return err
}
if p.preferredAddrConnID != nil {
var (
seq int64 = 1 // sequence number of this conn id is 1
retirePriorTo int64 = 0 // retire nothing
resetToken [16]byte
)
copy(resetToken[:], p.preferredAddrResetToken)
if err := c.connIDState.handleNewConnID(seq, retirePriorTo, p.preferredAddrConnID, resetToken); err != nil {
return err
}
}

// TODO: Many more transport parameters to come.

return nil
}

type timerEvent struct{}
Expand Down Expand Up @@ -295,3 +312,10 @@ func firstTime(a, b time.Time) time.Time {
return b
}
}

func (c *Conn) newConnIDFunc() newConnIDFunc {
if c.testHooks != nil {
return c.testHooks.newConnID
}
return newRandomConnID
}
238 changes: 231 additions & 7 deletions internal/quic/conn_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package quic

import (
"bytes"
"crypto/rand"
)

Expand All @@ -18,8 +19,16 @@ type connIDState struct {
// Local IDs are usually issued by us, and remote IDs by the peer.
// The exception is the transient destination connection ID sent in
// a client's Initial packets, which is chosen by the client.
//
// These are []connID rather than []*connID to minimize allocations.
local []connID
remote []connID

nextLocalSeq int64
retireRemotePriorTo int64 // largest Retire Prior To value sent by the peer
peerActiveConnIDLimit int64 // peer's active_connection_id_limit transport parameter

needSend bool
}

// A connID is a connection ID and associated metadata.
Expand All @@ -32,23 +41,36 @@ type connID struct {
//
// For the transient destination ID in a client's Initial packet, this is -1.
seq int64

// retired is set when the connection ID is retired.
retired bool

// send is set when the connection ID's state needs to be sent to the peer.
//
// For local IDs, this indicates a new ID that should be sent
// in a NEW_CONNECTION_ID frame.
//
// For remote IDs, this indicates a retired ID that should be sent
// in a RETIRE_CONNECTION_ID frame.
send sentVal
}

func (s *connIDState) initClient(newID newConnIDFunc) error {
// Client chooses its initial connection ID, and sends it
// in the Source Connection ID field of the first Initial packet.
locid, err := newID()
locid, err := newID(0)
if err != nil {
return err
}
s.local = append(s.local, connID{
seq: 0,
cid: locid,
})
s.nextLocalSeq = 1

// Client chooses an initial, transient connection ID for the server,
// and sends it in the Destination Connection ID field of the first Initial packet.
remid, err := newID()
remid, err := newID(-1)
if err != nil {
return err
}
Expand All @@ -70,14 +92,15 @@ func (s *connIDState) initServer(newID newConnIDFunc, dstConnID []byte) error {

// Server chooses a connection ID, and sends it in the Source Connection ID of
// the response to the clent.
locid, err := newID()
locid, err := newID(0)
if err != nil {
return err
}
s.local = append(s.local, connID{
seq: 0,
cid: locid,
})
s.nextLocalSeq = 1
return nil
}

Expand All @@ -91,8 +114,44 @@ func (s *connIDState) srcConnID() []byte {
}

// dstConnID is the Destination Connection ID to use in a sent packet.
func (s *connIDState) dstConnID() []byte {
return s.remote[0].cid
func (s *connIDState) dstConnID() (cid []byte, ok bool) {
for i := range s.remote {
if !s.remote[i].retired {
return s.remote[i].cid, true
}
}
return nil, false
}

// setPeerActiveConnIDLimit sets the active_connection_id_limit
// transport parameter received from the peer.
func (s *connIDState) setPeerActiveConnIDLimit(lim int64, newID newConnIDFunc) error {
s.peerActiveConnIDLimit = lim
return s.issueLocalIDs(newID)
}

func (s *connIDState) issueLocalIDs(newID newConnIDFunc) error {
toIssue := min(int(s.peerActiveConnIDLimit), maxPeerActiveConnIDLimit)
for i := range s.local {
if s.local[i].seq != -1 && !s.local[i].retired {
toIssue--
}
}
for toIssue > 0 {
cid, err := newID(s.nextLocalSeq)
if err != nil {
return err
}
s.local = append(s.local, connID{
seq: s.nextLocalSeq,
cid: cid,
})
s.local[len(s.local)-1].send.setUnsent()
s.nextLocalSeq++
s.needSend = true
toIssue--
}
return nil
}

// handlePacket updates the connection ID state during the handshake
Expand Down Expand Up @@ -128,19 +187,184 @@ func (s *connIDState) handlePacket(side connSide, ptype packetType, srcConnID []
}
}

func (s *connIDState) handleNewConnID(seq, retire int64, cid []byte, resetToken [16]byte) error {
if len(s.remote[0].cid) == 0 {
// "An endpoint that is sending packets with a zero-length
// Destination Connection ID MUST treat receipt of a NEW_CONNECTION_ID
// frame as a connection error of type PROTOCOL_VIOLATION."
// https://www.rfc-editor.org/rfc/rfc9000.html#section-19.15-6
return localTransportError(errProtocolViolation)
}

if retire > s.retireRemotePriorTo {
s.retireRemotePriorTo = retire
}

have := false // do we already have this connection ID?
active := 0
for i := range s.remote {
rcid := &s.remote[i]
if !rcid.retired && rcid.seq < s.retireRemotePriorTo {
s.retireRemote(rcid)
}
if !rcid.retired {
active++
}
if rcid.seq == seq {
if !bytes.Equal(rcid.cid, cid) {
return localTransportError(errProtocolViolation)
}
have = true // yes, we've seen this sequence number
}
}

if !have {
// This is a new connection ID that we have not seen before.
//
// We could take steps to keep the list of remote connection IDs
// sorted by sequence number, but there's no particular need
// so we don't bother.
s.remote = append(s.remote, connID{
seq: seq,
cid: cloneBytes(cid),
})
if seq < s.retireRemotePriorTo {
// This ID was already retired by a previous NEW_CONNECTION_ID frame.
s.retireRemote(&s.remote[len(s.remote)-1])
} else {
active++
}
}

if active > activeConnIDLimit {
// Retired connection IDs (including newly-retired ones) do not count
// against the limit.
// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-5
return localTransportError(errConnectionIDLimit)
}

// "An endpoint SHOULD limit the number of connection IDs it has retired locally
// for which RETIRE_CONNECTION_ID frames have not yet been acknowledged."
// https://www.rfc-editor.org/rfc/rfc9000#section-5.1.2-6
//
// Set a limit of four times the active_connection_id_limit for
// the total number of remote connection IDs we keep state for locally.
if len(s.remote) > 4*activeConnIDLimit {
return localTransportError(errConnectionIDLimit)
}

return nil
}

// retireRemote marks a remote connection ID as retired.
func (s *connIDState) retireRemote(rcid *connID) {
rcid.retired = true
rcid.send.setUnsent()
s.needSend = true
}

func (s *connIDState) handleRetireConnID(seq int64, newID newConnIDFunc) error {
if seq >= s.nextLocalSeq {
return localTransportError(errProtocolViolation)
}
for i := range s.local {
if s.local[i].seq == seq {
s.local = append(s.local[:i], s.local[i+1:]...)
break
}
}
s.issueLocalIDs(newID)
return nil
}

func (s *connIDState) ackOrLossNewConnectionID(pnum packetNumber, seq int64, fate packetFate) {
for i := range s.local {
if s.local[i].seq != seq {
continue
}
s.local[i].send.ackOrLoss(pnum, fate)
if fate != packetAcked {
s.needSend = true
}
return
}
}

func (s *connIDState) ackOrLossRetireConnectionID(pnum packetNumber, seq int64, fate packetFate) {
for i := 0; i < len(s.remote); i++ {
if s.remote[i].seq != seq {
continue
}
if fate == packetAcked {
// We have retired this connection ID, and the peer has acked.
// Discard its state completely.
s.remote = append(s.remote[:i], s.remote[i+1:]...)
} else {
// RETIRE_CONNECTION_ID frame was lost, mark for retransmission.
s.needSend = true
s.remote[i].send.ackOrLoss(pnum, fate)
}
return
}
}

// appendFrames appends NEW_CONNECTION_ID and RETIRE_CONNECTION_ID frames
// to the current packet.
//
// It returns true if no more frames need appending,
// false if not everything fit in the current packet.
func (s *connIDState) appendFrames(w *packetWriter, pnum packetNumber, pto bool) bool {
if !s.needSend && !pto {
// Fast path: We don't need to send anything.
return true
}
retireBefore := int64(0)
if s.local[0].seq != -1 {
retireBefore = s.local[0].seq
}
for i := range s.local {
if !s.local[i].send.shouldSendPTO(pto) {
continue
}
if !w.appendNewConnectionIDFrame(
s.local[i].seq,
retireBefore,
s.local[i].cid,
[16]byte{}, // TODO: stateless reset token
) {
return false
}
s.local[i].send.setSent(pnum)
}
for i := range s.remote {
if !s.remote[i].send.shouldSendPTO(pto) {
continue
}
if !w.appendRetireConnectionIDFrame(s.remote[i].seq) {
return false
}
s.remote[i].send.setSent(pnum)
}
s.needSend = false
return true
}

func cloneBytes(b []byte) []byte {
n := make([]byte, len(b))
copy(n, b)
return n
}

type newConnIDFunc func() ([]byte, error)
type newConnIDFunc func(seq int64) ([]byte, error)

func newRandomConnID() ([]byte, error) {
func newRandomConnID(_ int64) ([]byte, error) {
// It is not necessary for connection IDs to be cryptographically secure,
// but it doesn't hurt.
id := make([]byte, connIDLen)
if _, err := rand.Read(id); err != nil {
// TODO: Surface this error as a metric or log event or something.
// rand.Read really shouldn't ever fail, but if it does, we should
// have a way to inform the user.
return nil, err
}
return id, nil
Expand Down
Loading

0 comments on commit bd8ac9e

Please sign in to comment.