From 584aa618a196265fc734e7467412f06f0f0d7c34 Mon Sep 17 00:00:00 2001 From: Pierre Krieger Date: Fri, 17 Nov 2023 14:08:41 +0100 Subject: [PATCH] Add a `WebRtcFraming` utility for WebRTC messages (#1350) --- lib/src/libp2p/collection/multi_stream.rs | 175 ++------- lib/src/libp2p/connection.rs | 1 + .../connection/established/multi_stream.rs | 251 ++---------- lib/src/libp2p/connection/webrtc_framing.rs | 359 ++++++++++++++++++ 4 files changed, 421 insertions(+), 365 deletions(-) create mode 100644 lib/src/libp2p/connection/webrtc_framing.rs diff --git a/lib/src/libp2p/collection/multi_stream.rs b/lib/src/libp2p/collection/multi_stream.rs index 2e6ed37e37..8b2608d289 100644 --- a/lib/src/libp2p/collection/multi_stream.rs +++ b/lib/src/libp2p/collection/multi_stream.rs @@ -15,11 +15,9 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -use crate::util::{leb128, protobuf}; - use super::{ super::{ - connection::{established, noise}, + connection::{established, noise, webrtc_framing}, read_write::ReadWrite, }, ConnectionToCoordinator, ConnectionToCoordinatorInner, CoordinatorToConnection, @@ -27,11 +25,9 @@ use super::{ SubstreamId, }; -use alloc::{collections::VecDeque, string::ToString as _, sync::Arc, vec::Vec}; +use alloc::{collections::VecDeque, string::ToString as _, sync::Arc}; use core::{ - cmp, hash::Hash, - mem, ops::{Add, Sub}, time::Duration, }; @@ -44,15 +40,11 @@ enum MultiStreamConnectionTaskInner { /// Connection is still in its handshake phase. Handshake { /// Substream that has been opened to perform the handshake, if any. - opened_substream: Option, + opened_substream: Option<(TSubId, webrtc_framing::WebRtcFraming)>, /// Noise handshake in progress. Always `Some`, except to be temporarily extracted. handshake: Option, - /// All incoming data for the handshake substream is first transferred to this buffer. - // TODO: this is very suboptimal code, instead the parsing should be done in a streaming way - handshake_read_buffer: Vec, - /// Other substreams, besides [`MultiStreamConnectionTaskInner::Handshake::opened_substream`], /// that have been opened. For each substream, contains a boolean indicating whether the /// substream is outbound (`true`) or inbound (`false`). @@ -151,7 +143,6 @@ where // TODO: the handshake doesn't have a timeout handshake: Some(handshake), opened_substream: None, - handshake_read_buffer: Vec::new(), extra_open_substreams: hashbrown::HashMap::with_capacity_and_hasher( 0, Default::default(), @@ -683,14 +674,16 @@ where opened_substream: ref mut opened_substream @ None, .. } if outbound => { - *opened_substream = Some(id); + *opened_substream = Some((id, webrtc_framing::WebRtcFraming::new())); } MultiStreamConnectionTaskInner::Handshake { opened_substream, extra_open_substreams, .. } => { - assert!(opened_substream.as_ref().map_or(true, |open| *open != id)); + assert!(opened_substream + .as_ref() + .map_or(true, |(open, _)| *open != id)); // TODO: add a limit to the number allowed? let _was_in = extra_open_substreams.insert(id, outbound); assert!(_was_in.is_none()); @@ -791,12 +784,10 @@ where established.reset_substream(substream_id) } MultiStreamConnectionTaskInner::Handshake { - opened_substream: Some(opened_substream), - handshake_read_buffer, + opened_substream: Some((opened_substream, _)), .. } if opened_substream == substream_id => { // TODO: the handshake has failed, kill the connection? - handshake_read_buffer.clear(); } MultiStreamConnectionTaskInner::Handshake { extra_open_substreams, @@ -843,149 +834,38 @@ where match &mut self.connection { MultiStreamConnectionTaskInner::Handshake { handshake, - opened_substream, - handshake_read_buffer, + opened_substream: Some((opened_handshake_substream, handshake_webrtc_framing)), established, extra_open_substreams, - } if opened_substream - .as_ref() - .map_or(false, |s| s == substream_id) => - { + } if opened_handshake_substream == substream_id => { // TODO: check the handshake timeout - // The Noise data is not directly the data of the substream. Instead, everything - // is wrapped within a Protobuf frame. For this reason, we first transfer the data - // to a buffer. - // - // According to the libp2p WebRTC spec, a frame and its length prefix must not be - // larger than 16kiB, meaning that the read buffer never has to exceed this size. - // - // Try to add data to `handshake_read_buffer`. - // TODO: this is very suboptimal; improve - // TODO: this doesn't properly back-pressure, because we read unconditionally - let protobuf_frame_size = { - let mut parser = - nom::combinator::map_parser::<_, _, _, nom::error::Error<&[u8]>, _, _>( - nom::multi::length_data(crate::util::leb128::nom_leb128_usize), - protobuf::message_decode! { - #[optional] flags = 1 => protobuf::enum_tag_decode, - #[optional] message = 2 => protobuf::bytes_tag_decode, - }, - ); - - match parser(&read_write.incoming_buffer) { - Ok((rest, framed_message)) => { - if let Some(message) = framed_message.message { - handshake_read_buffer.extend_from_slice(message); - } - - let protobuf_frame_size = handshake_read_buffer.len() - rest.len(); - // If the remote has sent a `FIN` or `RESET_STREAM` flag, mark the - // remote writing side as closed. - if framed_message.flags.map_or(false, |f| f == 0 || f == 2) { - // TODO: no, handshake error - return SubstreamFate::Reset; - } - protobuf_frame_size - } - Err(nom::Err::Incomplete(needed)) => { - read_write.expected_incoming_bytes = Some( - handshake_read_buffer.len() - + match needed { - nom::Needed::Size(s) => s.get(), - nom::Needed::Unknown => 1, - }, - ); - 0 - } - Err(_) => { - // Message decoding error. - // TODO: no, handshake error + // Progress the Noise handshake. + let handshake_outcome = { + // The Noise data is not directly the data of the substream. Instead, + // everything is wrapped within a Protobuf frame. + let mut with_framing = match handshake_webrtc_framing.read_write(read_write) { + Ok(f) => f, + Err(_err) => { + // TODO: not great for diagnostic to just ignore the error; also, the connection should just reset entirely return SubstreamFate::Reset; } - } + }; + handshake.take().unwrap().read_write(&mut with_framing) }; - let _ = read_write.incoming_bytes_take(protobuf_frame_size); - - let mut sub_read_write = ReadWrite { - now: read_write.now.clone(), - incoming_buffer: mem::take(handshake_read_buffer), - read_bytes: 0, - expected_incoming_bytes: Some(0), - write_buffers: Vec::new(), - write_bytes_queued: read_write.write_bytes_queued, - // Don't write out more than one frame. - // TODO: this `10` is here for the length and protobuf frame size and is a bit hacky - write_bytes_queueable: Some( - cmp::min(read_write.write_bytes_queueable.unwrap(), 16384) - .saturating_sub(10), - ), - wake_up_after: None, - }; - - let handshake_outcome = handshake.take().unwrap().read_write(&mut sub_read_write); - *handshake_read_buffer = sub_read_write.incoming_buffer; - if let Some(wake_up_after) = &sub_read_write.wake_up_after { - read_write.wake_up_after(wake_up_after) - } - - // Send out the message that the Noise handshake has written - // into `intermediary_write_buffer`. - if sub_read_write.write_bytes_queued != read_write.write_bytes_queued { - let written_bytes = - sub_read_write.write_bytes_queued - read_write.write_bytes_queued; - debug_assert_eq!( - written_bytes, - sub_read_write - .write_buffers - .iter() - .fold(0, |s, b| s + b.len()) - ); - - // TODO: don't do the encoding manually but use the protobuf module? - let tag = protobuf::tag_encode(2, 2).collect::>(); - let data_len = leb128::encode_usize(written_bytes).collect::>(); - let libp2p_prefix = - leb128::encode_usize(tag.len() + data_len.len() + written_bytes) - .collect::>(); - - // The spec mentions that a frame plus its length prefix shouldn't exceed - // 16kiB. This is normally ensured by forbidding the substream from writing - // more data than would fit in 16kiB. - debug_assert!( - libp2p_prefix.len() + tag.len() + data_len.len() + written_bytes <= 16384 - ); - - read_write.write_out(libp2p_prefix); - read_write.write_out(tag); - read_write.write_out(data_len); - for buffer in sub_read_write.write_buffers { - read_write.write_out(buffer); - } - } match handshake_outcome { Ok(noise::NoiseHandshake::InProgress(handshake_update)) => { *handshake = Some(handshake_update); SubstreamFate::Continue } - Err(_err) => todo!("{:?}", _err), // TODO: /!\ + Err(_err) => return SubstreamFate::Reset, // TODO: /!\ Ok(noise::NoiseHandshake::Success { cipher: _, remote_peer_id, }) => { // The handshake has succeeded and we will transition into "established" // mode. - // However the rest of the body of this function still needs to deal with - // the substream used for the handshake. - // We close the writing side. If the reading side is closed, we indicate - // that the substream is dead. If the reading side is still open, we - // indicate that it's not dead and store it in the state machine while - // waiting for it to be closed by the remote. - read_write.close_write(); - let handshake_substream_still_open = - read_write.expected_incoming_bytes.is_some(); - let mut established = established.take().unwrap(); for (substream_id, outbound) in extra_open_substreams.drain() { established.add_substream(substream_id, outbound); @@ -994,11 +874,7 @@ where self.connection = MultiStreamConnectionTaskInner::Established { established, handshake_finished_message_to_send: Some(remote_peer_id), - handshake_substream: if handshake_substream_still_open { - Some(opened_substream.take().unwrap()) - } else { - None - }, + handshake_substream: None, // TODO: do properly outbound_substreams_map: hashbrown::HashMap::with_capacity_and_hasher( 0, Default::default(), @@ -1008,11 +884,8 @@ where inbound_accept_cancel_events: VecDeque::with_capacity(2), }; - if handshake_substream_still_open { - SubstreamFate::Continue - } else { - SubstreamFate::Reset - } + // TODO: hacky + SubstreamFate::Reset } } } diff --git a/lib/src/libp2p/connection.rs b/lib/src/libp2p/connection.rs index 5cefd4558a..d0f48be43d 100644 --- a/lib/src/libp2p/connection.rs +++ b/lib/src/libp2p/connection.rs @@ -83,4 +83,5 @@ pub mod established; pub mod multistream_select; pub mod noise; pub mod single_stream_handshake; +pub mod webrtc_framing; pub mod yamux; diff --git a/lib/src/libp2p/connection/established/multi_stream.rs b/lib/src/libp2p/connection/established/multi_stream.rs index 57024deff5..a74fbee4ce 100644 --- a/lib/src/libp2p/connection/established/multi_stream.rs +++ b/lib/src/libp2p/connection/established/multi_stream.rs @@ -20,13 +20,12 @@ use super::{ super::super::read_write::ReadWrite, substream, Config, Event, SubstreamId, SubstreamIdInner, }; -use crate::util::{self, leb128, protobuf}; +use crate::{libp2p::connection::webrtc_framing, util}; use alloc::{collections::VecDeque, string::String, vec::Vec}; use core::{ - cmp, fmt, + fmt, hash::Hash, - mem, ops::{Add, Index, IndexMut, Sub}, time::Duration, }; @@ -95,11 +94,8 @@ struct Substream { /// Underlying state machine for the substream. Always `Some` while the substream is alive, /// and `None` if it has been reset. inner: Option>, - /// All incoming data is first transferred to this buffer. - // TODO: this is very suboptimal code, instead the parsing should be done in a streaming way - read_buffer: Vec, - remote_writing_side_closed: bool, - local_writing_side_closed: bool, + /// State of the message frames. + framing: webrtc_framing::WebRtcFraming, } const MAX_PENDING_EVENTS: usize = 4; @@ -195,9 +191,7 @@ where id: out_substream_id, inner: Some(substream::Substream::ingoing(self.max_protocol_name_len)), user_data: None, - read_buffer: Vec::new(), - local_writing_side_closed: false, - remote_writing_side_closed: false, + framing: webrtc_framing::WebRtcFraming::new(), } } else if self.ping_substream.is_none() { let out_substream_id = self.next_out_substream_id; @@ -209,9 +203,7 @@ where id: out_substream_id, inner: Some(substream::Substream::ping_out(self.ping_protocol.clone())), user_data: None, - read_buffer: Vec::new(), - local_writing_side_closed: false, - remote_writing_side_closed: false, + framing: webrtc_framing::WebRtcFraming::new(), } } else if let Some(desired) = self.desired_out_substreams.pop_front() { desired @@ -306,207 +298,42 @@ where read_write.wake_up_after(&self.next_ping); } - loop { - // Don't process any more data before events are pulled. - if self.pending_events.len() >= MAX_PENDING_EVENTS { - return SubstreamFate::Continue; - } - - // In the situation where there's not enough space in the outgoing buffer to write an - // outgoing Protobuf frame, we just return immediately. - // This is necessary because calling `substream.read_write` can generate a write - // close message. - // TODO: this is error-prone, as we have no guarantee that the outgoing buffer will ever be > 6 bytes, for example in principle the API user could decide to use only a write buffer of 2 bytes, although that would be a very stupid thing to do - if read_write.write_bytes_queueable.unwrap_or(0) < 6 { - return SubstreamFate::Continue; - } - - // If this flag is still `false` at the end of the loop, we break out of it. - let mut continue_looping = false; - - // The incoming data is not directly the data of the substream. Instead, everything - // is wrapped within a Protobuf frame. For this reason, we first transfer the data to - // a buffer. - // - // According to the libp2p WebRTC spec, a frame and its length prefix must not be - // larger than 16kiB, meaning that the read buffer never has to exceed this size. - // - // Try to add data to `substream.read_buffer`. - // TODO: this is very suboptimal; improve - // TODO: this doesn't properly back-pressure, because we read unconditionally - let must_reset = { - let (protobuf_frame_size, must_reset) = { - let mut parser = - nom::combinator::map_parser::<_, _, _, nom::error::Error<&[u8]>, _, _>( - nom::multi::length_data(crate::util::leb128::nom_leb128_usize), - protobuf::message_decode! { - #[optional] flags = 1 => protobuf::enum_tag_decode, - #[optional] message = 2 => protobuf::bytes_tag_decode, - }, - ); - match parser(&read_write.incoming_buffer) { - Ok((rest, framed_message)) => { - if let Some(message) = framed_message.message { - substream.read_buffer.extend_from_slice(message); - } - - let protobuf_frame_size = read_write.incoming_buffer.len() - rest.len(); - - // If the remote has sent a `FIN` or `RESET_STREAM` flag, mark the remote writing - // side as closed. - if framed_message.flags.map_or(false, |f| f == 0 || f == 2) { - substream.remote_writing_side_closed = true; - } - - // If the remote has sent a `RESET_STREAM` flag, also reset the substream. - let must_reset = framed_message.flags.map_or(false, |f| f == 2); - - (protobuf_frame_size, must_reset) - } - Err(nom::Err::Incomplete(needed)) => { - read_write.expected_incoming_bytes = Some( - read_write.incoming_buffer.len() - + match needed { - nom::Needed::Size(s) => s.get(), - nom::Needed::Unknown => 1, - }, - ); - (0, false) - } - Err(_) => { - // Message decoding error. - // TODO: no, must ask the state machine to reset - return SubstreamFate::Reset; - } - } - }; - - let _ = read_write.incoming_bytes_take(protobuf_frame_size); - must_reset - }; - - let event = if must_reset { - substream.inner.take().unwrap().reset() - } else { - let mut sub_read_write = ReadWrite { - now: read_write.now.clone(), - incoming_buffer: mem::take(&mut substream.read_buffer), - read_bytes: 0, - expected_incoming_bytes: if substream.remote_writing_side_closed { - None - } else { - Some(0) - }, - write_buffers: Vec::new(), - write_bytes_queued: read_write.write_bytes_queued, - // Don't write out more than one frame. - // TODO: this `10` is here for the length and protobuf frame size and is a bit hacky - write_bytes_queueable: if !substream.local_writing_side_closed { - Some( - cmp::min(read_write.write_bytes_queueable.unwrap(), 16384) - .saturating_sub(10), - ) - } else { - None - }, - wake_up_after: None, - }; - - let (substream_update, event) = substream - .inner - .take() - .unwrap() - .read_write(&mut sub_read_write); + // Don't process any more data before events are pulled. + if self.pending_events.len() >= MAX_PENDING_EVENTS { + return SubstreamFate::Continue; + } + // Now process the substream. + let event = match substream.framing.read_write(read_write) { + Ok(mut framing) => { + let (substream_update, event) = + substream.inner.take().unwrap().read_write(&mut framing); substream.inner = substream_update; - substream.read_buffer = sub_read_write.incoming_buffer; - if let Some(wake_up_after) = &sub_read_write.wake_up_after { - read_write.wake_up_after(wake_up_after) - } - - // Determine whether we should send a message on that substream with a specific - // flag. - let flag_to_write_out = if substream.inner.is_none() - && (!substream.remote_writing_side_closed - || !substream.local_writing_side_closed) - { - // Send a `RESET_STREAM` if the state machine has reset while a side was still - // open. - Some(2) - } else if !substream.local_writing_side_closed - && sub_read_write.write_bytes_queueable.is_none() - { - // Send a `FIN` if the state machine has closed the writing side while it - // wasn't closed before. - substream.local_writing_side_closed = true; - Some(0) - } else { - None - }; - - // Send out message. - if flag_to_write_out.is_some() - || sub_read_write.write_bytes_queued != read_write.write_bytes_queued - { - let written_bytes = - sub_read_write.write_bytes_queued - read_write.write_bytes_queued; - - // TODO: flags not written - - // TODO: don't do the encoding manually but use the protobuf module? - let tag = protobuf::tag_encode(2, 2).collect::>(); - let data_len = leb128::encode_usize(written_bytes).collect::>(); - let libp2p_prefix = - leb128::encode_usize(tag.len() + data_len.len() + written_bytes) - .collect::>(); - - // The spec mentions that a frame plus its length prefix shouldn't exceed - // 16kiB. This is normally ensured by forbidding the substream from writing - // more data than would fit in 16kiB. - debug_assert!( - libp2p_prefix.len() + tag.len() + data_len.len() + written_bytes <= 16384 - ); - - read_write.write_out(libp2p_prefix); - read_write.write_out(tag); - read_write.write_out(data_len); - for buffer in sub_read_write.write_buffers { - read_write.write_out(buffer); - } - - // We continue looping because the substream might have more data to send. - continue_looping = true; - } - event - }; - - match event { - None => {} - Some(other) => { - continue_looping = true; - Self::on_substream_event( - &mut self.pending_events, - substream.id, - &mut substream.user_data, - other, - ) - } } + Err(_) => substream.inner.take().unwrap().reset(), + }; - // WebRTC never closes the writing side. - debug_assert!(read_write.write_bytes_queueable.is_none()); + if let Some(event) = event { + read_write.wake_up_asap(); + Self::on_substream_event( + &mut self.pending_events, + substream.id, + &mut substream.user_data, + event, + ) + } - if substream.inner.is_none() { - if Some(substream_id) == self.ping_substream.as_ref() { - self.ping_substream = None; - } - self.out_in_substreams_map.remove(&substream.id); - self.in_substreams.remove(substream_id); - break SubstreamFate::Reset; - } else if !continue_looping { - break SubstreamFate::Continue; + // The substream is `None` if it needs to be reset. + if substream.inner.is_none() { + if Some(substream_id) == self.ping_substream.as_ref() { + self.ping_substream = None; } + self.out_in_substreams_map.remove(&substream.id); + self.in_substreams.remove(substream_id); + SubstreamFate::Reset + } else { + SubstreamFate::Continue } } @@ -624,9 +451,7 @@ where max_response_size, )), user_data: Some(user_data), - read_buffer: Vec::new(), - local_writing_side_closed: false, - remote_writing_side_closed: false, + framing: webrtc_framing::WebRtcFraming::new(), }); // TODO: ? do this? substream.reserve_window(128 * 1024 * 1024 + 128); // TODO: proper max size @@ -687,9 +512,7 @@ where max_handshake_size, )), user_data: Some(user_data), - read_buffer: Vec::new(), - local_writing_side_closed: false, - remote_writing_side_closed: false, + framing: webrtc_framing::WebRtcFraming::new(), }); SubstreamId(SubstreamIdInner::MultiStream(substream_id)) diff --git a/lib/src/libp2p/connection/webrtc_framing.rs b/lib/src/libp2p/connection/webrtc_framing.rs new file mode 100644 index 0000000000..6351f723ae --- /dev/null +++ b/lib/src/libp2p/connection/webrtc_framing.rs @@ -0,0 +1,359 @@ +// Smoldot +// Copyright (C) 2023 Pierre Krieger +// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0 + +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +//! +//! See . + +use crate::{ + libp2p::read_write::ReadWrite, + util::{leb128, protobuf}, +}; + +use alloc::{borrow::ToOwned as _, vec::Vec}; +use core::{cmp, fmt, mem, ops}; + +/// State of the framing. +pub struct WebRtcFraming { + /// Value of [`ReadWrite::expected_incoming_bytes`] of the inner stream the last time that + /// [`WebRtcFraming::read_write`] was called. `None` if unknown. + inner_stream_expected_incoming_bytes: Option, + + /// Buffer containing data from a previous frame, but that doesn't contain enough data for + /// the underlying substream to accept it. + /// + /// In other words, `receive_buffer.len() < inner_stream_expected_incoming_bytes`. + // TODO: shrink_to_fit? + receive_buffer: Vec, + + /// State of the writing side of the remote. + remote_write_state: RemoteWriteState, + + /// State of the local writing side. + local_write_state: LocalWriteState, +} + +enum LocalWriteState { + Open, + FinBuffered, + FinAcked, +} + +enum RemoteWriteState { + Open, + /// The remote has sent a `FIN` in the past. Any data in [`WebRtcFraming::receive_buffer`] + /// is still valid and was received before the remote writing side was closed. + Closed, + ClosedAckBuffered, +} + +const RECEIVE_BUFFER_CAPACITY: usize = 2048; +/// Minimum size in bytes of the protobuf frame surrounding the message. +const PROTOBUF_FRAME_MIN_LEN: usize = 2; +/// Maximum size in bytes of the protobuf frame surrounding the message. +const PROTOBUF_FRAME_MAX_LEN: usize = 8; // TODO: calculate better? +const MAX_PROTOBUF_MESSAGE_LEN: usize = 16384; + +impl WebRtcFraming { + /// Initializes a new [`WebRtcFraming`]. + pub fn new() -> Self { + WebRtcFraming { + inner_stream_expected_incoming_bytes: None, + receive_buffer: Vec::with_capacity(RECEIVE_BUFFER_CAPACITY), + remote_write_state: RemoteWriteState::Open, + local_write_state: LocalWriteState::Open, + } + } + + /// Feeds data coming from a socket and outputs data to write to the socket. + /// + /// Returns an object that implements `Deref`. This object represents the + /// decrypted stream of data. + /// + /// An error is returned if the protocol is being violated by the remote, if the remote wants + /// to reset the substream. + pub fn read_write<'a, TNow: Clone>( + &'a mut self, + outer_read_write: &'a mut ReadWrite, + ) -> Result, Error> { + // Read from the incoming buffer until we have enough data for the underlying substream. + loop { + // Immediately stop looping if there is enough data for the underlying substream. + // Also stop looping if `inner_stream_expected_incoming_bytes` is `None`, as we always + // want to process the inner substream the first time ever. + if self + .inner_stream_expected_incoming_bytes + .map_or(true, |rq_bytes| rq_bytes <= self.receive_buffer.len()) + { + break; + } + + // Try to parse a frame from the incoming buffer. + let bytes_to_discard = { + // TODO: we could in theory demand from the outside just the protobuf header, and then later the data, which would save some copying but might considerably complexifies the code + let mut parser = + nom::combinator::map_parser::<_, _, _, nom::error::Error<&[u8]>, _, _>( + nom::multi::length_data(crate::util::leb128::nom_leb128_usize), + protobuf::message_decode! { + #[optional] flags = 1 => protobuf::enum_tag_decode, + #[optional] message = 2 => protobuf::bytes_tag_decode, + }, + ); + match parser(&outer_read_write.incoming_buffer) { + Ok((rest, framed_message)) => { + // The remote has sent a `RESET_STREAM` flag, immediately stop with an error. + // The specification mentions that the receiver may discard any data already + // received, which we do. + if framed_message.flags.map_or(false, |f| f == 2) { + return Err(Error::RemoteResetDesired); + } + + // Some protocol check. + if framed_message.message.map_or(false, |msg| !msg.is_empty()) + && !matches!(self.remote_write_state, RemoteWriteState::Open) + { + return Err(Error::DataAfterFin); + } + + // Process the `FIN_ACK` flag sent by the remote. + // Note that we don't treat it as an error if the remote sends the + // `FIN_ACK` flag multiple times, although this is opinionated. + if framed_message.flags.map_or(false, |f| f == 3) { + if matches!(self.local_write_state, LocalWriteState::Open) { + return Err(Error::FinAckWithoutFin); + } + self.local_write_state = LocalWriteState::FinAcked; + } + + // Process the `FIN` flag sent by the remote. + if matches!(self.remote_write_state, RemoteWriteState::Open) + && framed_message.flags.map_or(false, |f| f == 0) + { + self.remote_write_state = RemoteWriteState::Closed; + } + + // Note that any `STOP_SENDING` flag sent by the remote is ignored. + + // Copy the message of the remote out from the incoming buffer. + if let Some(message) = framed_message.message { + self.receive_buffer.extend_from_slice(message); + } + + // Number of bytes to discard is the size of the protobuf frame. + outer_read_write.incoming_buffer.len() - rest.len() + } + Err(nom::Err::Incomplete(needed)) => { + // Not enough data in the incoming buffer for a full frame. Requesting + // more. + let Some(expected_incoming_bytes) = + &mut outer_read_write.expected_incoming_bytes + else { + // TODO: is this correct anyway? substreams are never supposed to close? + return Err(Error::EofIncompleteFrame); + }; + *expected_incoming_bytes = outer_read_write.incoming_buffer.len() + + match needed { + nom::Needed::Size(s) => s.get(), + nom::Needed::Unknown => 1, + }; + break; + } + Err(_) => { + // Frame decoding error. + return Err(Error::InvalidFrame); + } + } + }; + + // Discard the frame data. + let _extract_result = outer_read_write.incoming_bytes_take(bytes_to_discard); + debug_assert!(matches!(_extract_result, Ok(Some(_)))); + } + + Ok(InnerReadWrite { + inner_read_write: ReadWrite { + now: outer_read_write.now.clone(), + incoming_buffer: mem::take(&mut self.receive_buffer), + read_bytes: 0, + expected_incoming_bytes: if matches!( + self.remote_write_state, + RemoteWriteState::Open + ) { + Some(0) + } else { + None + }, + write_buffers: Vec::new(), + write_bytes_queued: 0, + write_bytes_queueable: if matches!(self.local_write_state, LocalWriteState::Open) { + outer_read_write + .write_bytes_queueable + .map(|outer_writable| { + cmp::min( + // TODO: what if the outer maximum queueable is <= PROTOBUF_FRAME_MAX_LEN? this will never happen in practice, but in theory it could + outer_writable.saturating_sub(PROTOBUF_FRAME_MAX_LEN), + MAX_PROTOBUF_MESSAGE_LEN - PROTOBUF_FRAME_MAX_LEN, + ) + }) + } else { + None + }, + wake_up_after: outer_read_write.wake_up_after.clone(), + }, + framing: self, + outer_read_write, + }) + } +} + +impl fmt::Debug for WebRtcFraming { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("WebRtcFraming").finish() + } +} + +/// Stream of data without the frames. See [`WebRtcFraming::read_write`]. +pub struct InnerReadWrite<'a, TNow: Clone> { + framing: &'a mut WebRtcFraming, + outer_read_write: &'a mut ReadWrite, + inner_read_write: ReadWrite, +} + +impl<'a, TNow: Clone> ops::Deref for InnerReadWrite<'a, TNow> { + type Target = ReadWrite; + + fn deref(&self) -> &Self::Target { + &self.inner_read_write + } +} + +impl<'a, TNow: Clone> ops::DerefMut for InnerReadWrite<'a, TNow> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner_read_write + } +} + +impl<'a, TNow: Clone> Drop for InnerReadWrite<'a, TNow> { + fn drop(&mut self) { + // It is possible that the inner stream processes some bytes of `self.receive_buffer` + // and expects to be called again while no bytes was pulled from the outer `ReadWrite`. + // If that happens, the API user will not call `read_write` again and we will have a stall. + // For this reason, if the inner stream has read some bytes, we make sure that the outer + // `ReadWrite` wakes up as soon as possible. + // Additionally, we also do a first dummy substream processing without reading anything, + // in order to populate `inner_stream_expected_incoming_bytes`. If this is the case, we + // also immediately wake up again. + if self.framing.inner_stream_expected_incoming_bytes.is_none() + || self.inner_read_write.read_bytes != 0 + { + self.outer_read_write.wake_up_asap(); + } + + // Updating the timer and reading side of things. + self.outer_read_write.wake_up_after = self.inner_read_write.wake_up_after.clone(); + self.framing.receive_buffer = mem::take(&mut self.inner_read_write.incoming_buffer); + self.framing.inner_stream_expected_incoming_bytes = + Some(self.inner_read_write.expected_incoming_bytes.unwrap_or(0)); + if let Some(expected_incoming_bytes) = &mut self.outer_read_write.expected_incoming_bytes { + *expected_incoming_bytes = cmp::max( + *expected_incoming_bytes, + self.inner_read_write.expected_incoming_bytes.unwrap_or(0) + PROTOBUF_FRAME_MIN_LEN, + ); + } + + // Update the local state, and figure out the flag (if any) that we want to send out to + // the remote. + // Note that we never send the `RESET_STREAM` flag. It is unclear to me what purpose this + // flag serves compares to simply closing the substream. + // We also never send `STOP_SENDING`, as it doesn't fit in our API. + let flag_to_send_out: Option = + if matches!(self.framing.local_write_state, LocalWriteState::Open) + && self.inner_read_write.write_bytes_queueable.is_none() + { + self.framing.local_write_state = LocalWriteState::FinBuffered; + Some(0) + } else if matches!(self.framing.remote_write_state, RemoteWriteState::Closed) { + // `FIN_ACK` + self.framing.remote_write_state = RemoteWriteState::ClosedAckBuffered; + Some(3) + } else { + None + }; + + // Write out a message only if there is anything to write. + // TODO: consider buffering data more before flushing, to reduce the overhead of the protobuf frame? + if flag_to_send_out.is_some() || self.inner_read_write.write_bytes_queued != 0 { + // Reserve some space in `write_buffers` to later write the message length prefix. + let message_length_prefix_index = self.outer_read_write.write_buffers.len(); + self.outer_read_write + .write_buffers + .push(Vec::with_capacity(4)); + + // Total number of bytes written below, excluding the length prefix. + let mut length_prefix_value = 0; + + // Write the flags, if any. + if let Some(flag_to_send_out) = flag_to_send_out { + for buffer in protobuf::uint32_tag_encode(1, flag_to_send_out) { + let buffer = buffer.as_ref(); + length_prefix_value += buffer.len(); + self.outer_read_write.write_buffers.push(buffer.to_owned()); + } + } + + // Write the data. This consists in a protobuf tag, a length, and the data itself. + let data_protobuf_tag = protobuf::tag_encode(2, 2).collect::>(); + length_prefix_value += data_protobuf_tag.len(); + self.outer_read_write.write_buffers.push(data_protobuf_tag); + let data_len = + leb128::encode_usize(self.inner_read_write.write_bytes_queued).collect::>(); + length_prefix_value += data_len.len(); + self.outer_read_write.write_buffers.push(data_len); + length_prefix_value += self.inner_read_write.write_bytes_queued; + self.outer_read_write + .write_buffers + .extend(mem::take(&mut self.inner_read_write.write_buffers)); + + // Now write the length prefix. + let length_prefix = leb128::encode_usize(length_prefix_value).collect::>(); + let total_length = length_prefix_value + length_prefix.len(); + self.outer_read_write.write_buffers[message_length_prefix_index] = length_prefix; + + // Properly update the outer `ReadWrite`. + self.outer_read_write.write_bytes_queued += total_length; + *self + .outer_read_write + .write_bytes_queueable + .as_mut() + .unwrap() -= total_length; + } + } +} + +/// Error while decoding data. +#[derive(Debug, derive_more::Display)] +pub enum Error { + /// The remote wants to reset the substream. This is a normal situation. + RemoteResetDesired, + /// Failed to decode the protobuf header. + InvalidFrame, + /// Remote has sent data after having sent a `FIN` flag in the past. + DataAfterFin, + /// Outer substream has closed in the middle of a frame. + EofIncompleteFrame, + /// Received a `FIN_ACK` flag without having sent a `FIN` flag. + FinAckWithoutFin, +}