Skip to content

Commit

Permalink
don't drop keys for key phase N before receiving a N+1-protected packet
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Sep 9, 2020
1 parent bed802a commit af96570
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 18 deletions.
30 changes: 19 additions & 11 deletions internal/handshake/updatable_aead.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,13 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer,
}
}

func (a *updatableAEAD) rollKeys(now time.Time) {
func (a *updatableAEAD) rollKeys() {
a.keyPhase++
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
a.numRcvdWithCurrentKey = 0
a.numSentWithCurrentKey = 0
a.prevRcvAEAD = a.rcvAEAD
a.prevRcvAEADExpiry = now.Add(3 * a.rttStats.PTO(true))
a.rcvAEAD = a.nextRcvAEAD
a.sendAEAD = a.nextSendAEAD

Expand All @@ -112,6 +111,10 @@ func (a *updatableAEAD) rollKeys(now time.Time) {
a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret)
}

func (a *updatableAEAD) startKeyDropTimer(now time.Time) {
a.prevRcvAEADExpiry = now.Add(3 * a.rttStats.PTO(true))
}

func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte {
return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size())
}
Expand Down Expand Up @@ -147,7 +150,7 @@ func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret
}

func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
if a.prevRcvAEAD != nil && rcvTime.After(a.prevRcvAEADExpiry) {
if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) {
a.prevRcvAEAD = nil
a.prevRcvAEADExpiry = time.Time{}
}
Expand Down Expand Up @@ -179,7 +182,10 @@ func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.Pac
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
return nil, qerr.NewError(qerr.ProtocolViolation, "keys updated too quickly")
}
a.rollKeys(rcvTime)
a.rollKeys()
// The peer initiated this key update. It's safe to drop the keys for the previous generation now.
// Start a timer to drop the previous key generation.
a.startKeyDropTimer(rcvTime)
a.logger.Debugf("Peer updated keys to %s", a.keyPhase)
if a.tracer != nil {
a.tracer.UpdatedKey(a.keyPhase, true)
Expand All @@ -191,12 +197,14 @@ func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.Pac
// It uses the nonce provided here and XOR it with the IV.
dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err != nil {
err = ErrDecryptionFailed
} else {
a.numRcvdWithCurrentKey++
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber {
a.firstRcvdWithCurrentKey = pn
}
return dec, ErrDecryptionFailed
}
a.numRcvdWithCurrentKey++
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber {
// We initiated the key updated, and now we received the first packet protected with the new key phase.
// Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys.
a.startKeyDropTimer(rcvTime)
a.firstRcvdWithCurrentKey = pn
}
return dec, err
}
Expand Down Expand Up @@ -245,7 +253,7 @@ func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
if a.tracer != nil {
a.tracer.UpdatedKey(a.keyPhase, false)
}
a.rollKeys(time.Now())
a.rollKeys()
}
return a.keyPhase.Bit()
}
Expand Down
42 changes: 35 additions & 7 deletions internal/handshake/updatable_aead_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,14 @@ var _ = Describe("Updatable AEAD", func() {
now := time.Now()
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
encrypted0 := server.Seal(nil, msg, 0x1337, ad)
server.rollKeys(now)
server.rollKeys()
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
encrypted1 := server.Seal(nil, msg, 0x1337, ad)
Expect(encrypted0).ToNot(Equal(encrypted1))
// expect opening to fail. The client didn't roll keys yet
_, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad)
Expect(err).To(MatchError(ErrDecryptionFailed))
client.rollKeys(now)
client.rollKeys()
decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg))
Expand All @@ -142,7 +142,7 @@ var _ = Describe("Updatable AEAD", func() {
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
_ = server.Seal(nil, msg, 0x1, ad)
// now received a message at key phase one
client.rollKeys(now)
client.rollKeys()
encrypted1 := client.Seal(nil, msg, 0x43, ad)
decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -160,7 +160,7 @@ var _ = Describe("Updatable AEAD", func() {
// send one packet at key phase zero
_ = server.Seal(nil, msg, 0x1, ad)
// now receive a packet with key phase 1
client.rollKeys(now)
client.rollKeys()
encrypted1 := client.Seal(nil, msg, 0x44, ad)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
Expand All @@ -185,7 +185,7 @@ var _ = Describe("Updatable AEAD", func() {
// send one packet at key phase zero
_ = server.Seal(nil, msg, 0x1, ad)
// now receive a packet with key phase 1
client.rollKeys(now)
client.rollKeys()
encrypted1 := client.Seal(nil, msg, 0x44, ad)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
Expand All @@ -197,7 +197,7 @@ var _ = Describe("Updatable AEAD", func() {
})

It("errors when the peer starts with key phase 1", func() {
client.rollKeys(time.Now())
client.rollKeys()
encrypted := client.Seal(nil, msg, 0x1337, ad)
_, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad)
Expect(err).To(MatchError("PROTOCOL_VIOLATION: wrong initial keyphase"))
Expand All @@ -209,7 +209,7 @@ var _ = Describe("Updatable AEAD", func() {
_, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
// now receive a packet at key phase one, before having sent any packets
client.rollKeys(time.Now())
client.rollKeys()
encrypted1 := client.Seal(nil, msg, 0x42, ad)
_, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseOne, ad)
Expect(err).To(MatchError("PROTOCOL_VIOLATION: keys updated too quickly"))
Expand Down Expand Up @@ -250,6 +250,34 @@ var _ = Describe("Updatable AEAD", func() {
server.SetLargestAcked(1)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
})

It("drops keys 3 PTOs after a key update", func() {
now := time.Now()
for i := 0; i < keyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad)
server.SetLargestAcked(pn)
}
// Now we've initiated the first key update.
// Decrypt a message sent from the client more than 3 PTO later to make sure the key is still there
threePTO := 3 * rttStats.PTO(false)
dataKeyPhaseZero := client.Seal(nil, msg, 1, ad)
_, err := server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
// Now receive a packet with key phase 1.
// This should start the timer to drop the keys after 3 PTOs.
client.rollKeys()
dataKeyPhaseOne := client.Seal(nil, msg, 10, ad)
t := now.Add(threePTO).Add(time.Second)
_, err = server.Open(nil, dataKeyPhaseOne, t, 10, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
// Make sure the keys are still here.
_, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO*9/10), 1, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
_, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO).Add(time.Nanosecond), 1, protocol.KeyPhaseZero, ad)
Expect(err).To(MatchError(ErrKeysDropped))
})
})

Context("reading the key update env", func() {
Expand Down

0 comments on commit af96570

Please sign in to comment.