Skip to content

Commit

Permalink
Fix a lot of WebRTC-related bugs and panics (#1348)
Browse files Browse the repository at this point in the history
* Fix a lot of WebRTC-related bugs and panics

* PR link
  • Loading branch information
tomaka authored Nov 17, 2023
1 parent 8a31bd4 commit fd3f3d3
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 88 deletions.
44 changes: 27 additions & 17 deletions lib/src/libp2p/collection/multi_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -863,16 +863,15 @@ where
// Try to add data to `handshake_read_buffer`.
// TODO: this is very suboptimal; improve
// TODO: this doesn't properly back-pressure, because we read unconditionally
let (protobuf_frame_size, flags) = {
let mut parser = nom::combinator::complete::<_, _, nom::error::Error<&[u8]>, _>(
nom::combinator::map_parser(
let protobuf_frame_size = {
let mut parser =
nom::combinator::map_parser::<_, _, _, nom::error::Error<&[u8]>, _, _>(
nom::multi::length_data(crate::util::leb128::nom_leb128_usize),
protobuf::message_decode! {
#[optional] flags = 1 => protobuf::enum_tag_decode,
#[optional] message = 2 => protobuf::bytes_tag_decode,
},
),
);
);

match parser(&read_write.incoming_buffer) {
Ok((rest, framed_message)) => {
Expand All @@ -881,7 +880,13 @@ where
}

let protobuf_frame_size = handshake_read_buffer.len() - rest.len();
(protobuf_frame_size, framed_message.flags)
// If the remote has sent a `FIN` or `RESET_STREAM` flag, mark the
// remote writing side as closed.
if framed_message.flags.map_or(false, |f| f == 0 || f == 2) {
// TODO: no, handshake error
return SubstreamFate::Reset;
}
protobuf_frame_size
}
Err(nom::Err::Incomplete(needed)) => {
read_write.expected_incoming_bytes = Some(
Expand All @@ -891,7 +896,7 @@ where
nom::Needed::Unknown => 1,
},
);
return SubstreamFate::Continue;
0
}
Err(_) => {
// Message decoding error.
Expand All @@ -900,16 +905,8 @@ where
}
}
};

let _ = read_write.incoming_bytes_take(protobuf_frame_size);

// If the remote has sent a `FIN` or `RESET_STREAM` flag, mark the
// remote writing side as closed.
if flags.map_or(false, |f| f == 0 || f == 2) {
// TODO: no, handshake error
return SubstreamFate::Reset;
}

let mut sub_read_write = ReadWrite {
now: read_write.now.clone(),
incoming_buffer: mem::take(handshake_read_buffer),
Expand Down Expand Up @@ -937,21 +934,34 @@ where
if sub_read_write.write_bytes_queued != read_write.write_bytes_queued {
let written_bytes =
sub_read_write.write_bytes_queued - read_write.write_bytes_queued;
debug_assert_eq!(
written_bytes,
sub_read_write
.write_buffers
.iter()
.fold(0, |s, b| s + b.len())
);

// TODO: don't do the encoding manually but use the protobuf module?
let tag = protobuf::tag_encode(2, 2).collect::<Vec<_>>();
let data_len = leb128::encode_usize(written_bytes).collect::<Vec<_>>();
let libp2p_prefix =
leb128::encode_usize(tag.len() + data_len.len()).collect::<Vec<_>>();
leb128::encode_usize(tag.len() + data_len.len() + written_bytes)
.collect::<Vec<_>>();

// The spec mentions that a frame plus its length prefix shouldn't exceed
// 16kiB. This is normally ensured by forbidding the substream from writing
// more data than would fit in 16kiB.
debug_assert!(libp2p_prefix.len() + tag.len() + data_len.len() <= 16384);
debug_assert!(
libp2p_prefix.len() + tag.len() + data_len.len() + written_bytes <= 16384
);

read_write.write_out(libp2p_prefix);
read_write.write_out(tag);
read_write.write_out(data_len);
for buffer in sub_read_write.write_buffers {
read_write.write_out(buffer);
}
}

match handshake_outcome {
Expand Down
38 changes: 24 additions & 14 deletions lib/src/libp2p/connection/established/multi_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ where
// TODO: this is very suboptimal; improve
// TODO: this doesn't properly back-pressure, because we read unconditionally
let must_reset = {
let (protobuf_frame_size, flags) = {
let (protobuf_frame_size, must_reset) = {
let mut parser =
nom::combinator::map_parser::<_, _, _, nom::error::Error<&[u8]>, _, _>(
nom::multi::length_data(crate::util::leb128::nom_leb128_usize),
Expand All @@ -351,7 +351,17 @@ where
}

let protobuf_frame_size = read_write.incoming_buffer.len() - rest.len();
(protobuf_frame_size, framed_message.flags)

// If the remote has sent a `FIN` or `RESET_STREAM` flag, mark the remote writing
// side as closed.
if framed_message.flags.map_or(false, |f| f == 0 || f == 2) {
substream.remote_writing_side_closed = true;
}

// If the remote has sent a `RESET_STREAM` flag, also reset the substream.
let must_reset = framed_message.flags.map_or(false, |f| f == 2);

(protobuf_frame_size, must_reset)
}
Err(nom::Err::Incomplete(needed)) => {
read_write.expected_incoming_bytes = Some(
Expand All @@ -361,7 +371,7 @@ where
nom::Needed::Unknown => 1,
},
);
return SubstreamFate::Continue;
(0, false)
}
Err(_) => {
// Message decoding error.
Expand All @@ -372,15 +382,7 @@ where
};

let _ = read_write.incoming_bytes_take(protobuf_frame_size);

// If the remote has sent a `FIN` or `RESET_STREAM` flag, mark the remote writing
// side as closed.
if flags.map_or(false, |f| f == 0 || f == 2) {
substream.remote_writing_side_closed = true;
}

// If the remote has sent a `RESET_STREAM` flag, also reset the substream.
flags.map_or(false, |f| f == 2)
must_reset
};

let event = if must_reset {
Expand Down Expand Up @@ -449,20 +451,28 @@ where
let written_bytes =
sub_read_write.write_bytes_queued - read_write.write_bytes_queued;

// TODO: flags not written

// TODO: don't do the encoding manually but use the protobuf module?
let tag = protobuf::tag_encode(2, 2).collect::<Vec<_>>();
let data_len = leb128::encode_usize(written_bytes).collect::<Vec<_>>();
let libp2p_prefix =
leb128::encode_usize(tag.len() + data_len.len()).collect::<Vec<_>>();
leb128::encode_usize(tag.len() + data_len.len() + written_bytes)
.collect::<Vec<_>>();

// The spec mentions that a frame plus its length prefix shouldn't exceed
// 16kiB. This is normally ensured by forbidding the substream from writing
// more data than would fit in 16kiB.
debug_assert!(libp2p_prefix.len() + tag.len() + data_len.len() <= 16384);
debug_assert!(
libp2p_prefix.len() + tag.len() + data_len.len() + written_bytes <= 16384
);

read_write.write_out(libp2p_prefix);
read_write.write_out(tag);
read_write.write_out(data_len);
for buffer in sub_read_write.write_buffers {
read_write.write_out(buffer);
}

// We continue looping because the substream might have more data to send.
continue_looping = true;
Expand Down
13 changes: 10 additions & 3 deletions light-base/src/network_service/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ pub(super) async fn webrtc_multi_stream_connection_task<TPlat: PlatformRef>(
.desired_outbound_substreams()
.saturating_sub(pending_opening_out_substreams)
{
log::trace!(target: "connections", "Connection({address_string}) <= OpenSubstream");
platform.open_out_substream(&mut connection);
pending_opening_out_substreams += 1;
}
Expand Down Expand Up @@ -255,7 +256,9 @@ pub(super) async fn webrtc_multi_stream_connection_task<TPlat: PlatformRef>(
// must be called. Because we only call `read_write_access` when `message_sending`
// is `None`, we also call `wait_read_write_again` only when `message_sending` is
// `None`.
let fut = if message_sending.as_ref().as_pin_ref().is_none() {
let fut = if message_sending.as_ref().as_pin_ref().is_none()
&& !when_substreams_rw_ready.is_empty()
{
Some(when_substreams_rw_ready.select_next_some())
} else {
None
Expand Down Expand Up @@ -344,6 +347,10 @@ pub(super) async fn webrtc_multi_stream_connection_task<TPlat: PlatformRef>(
);
}

if let SubstreamFate::Reset = substream_fate {
log::trace!(target: "connections", "Connection({address_string}) <= ResetSubstream(substream_id={substream_id})");
}

substream_fate
}
Err(err) => {
Expand Down Expand Up @@ -382,7 +389,7 @@ pub(super) async fn webrtc_multi_stream_connection_task<TPlat: PlatformRef>(
when_substreams_rw_ready.push({
let platform = platform.clone();
Box::pin(async move {
platform.wait_read_write_again(socket.as_mut());
platform.wait_read_write_again(socket.as_mut()).await;
(socket, substream_id)
})
});
Expand All @@ -395,13 +402,13 @@ pub(super) async fn webrtc_multi_stream_connection_task<TPlat: PlatformRef>(
connection_task.reset();
}
WakeUpReason::NewSubstream(substream, direction) => {
log::trace!(target: "connections", "Connection({address_string}) => NewSubstream({direction:?})");
let outbound = match direction {
SubstreamDirection::Outbound => true,
SubstreamDirection::Inbound => false,
};
let substream_id = next_substream_id;
next_substream_id += 1;
log::trace!(target: "connections", "Connection({address_string}) => SubstreamOpened(substream_id={substream_id}, direction={direction:?})");
connection_task.add_substream(substream_id, outbound);
if outbound {
pending_opening_out_substreams -= 1;
Expand Down
4 changes: 4 additions & 0 deletions wasm-node/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Unreleased

### Fixed

- Fix several WebRTC-related panics and bugs. ([#1348](https://github.com/smol-dot/smoldot/pull/1348))

## 2.0.9 - 2023-11-16

### Changed
Expand Down
6 changes: 3 additions & 3 deletions wasm-node/javascript/src/internals/remote-instance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ export async function startInstanceServer(config: ServerConfig, initPortToClient
if (!state.connections.has(message.connectionId))
return;
// The stream might have been reset locally in the past.
if (message.streamId && !state.connections.get(message.connectionId)!.has(message.streamId))
if (message.streamId !== undefined && !state.connections.get(message.connectionId)!.has(message.streamId))
return;
state.instance!.streamMessage(message.connectionId, message.message, message.streamId);
break;
Expand All @@ -405,7 +405,7 @@ export async function startInstanceServer(config: ServerConfig, initPortToClient
if (!state.connections.has(message.connectionId))
return;
// The stream might have been reset locally in the past.
if (message.streamId && !state.connections.get(message.connectionId)!.has(message.streamId))
if (message.streamId !== undefined && !state.connections.get(message.connectionId)!.has(message.streamId))
return;
state.instance!.streamWritableBytes(message.connectionId, message.numExtra, message.streamId);
break;
Expand All @@ -415,7 +415,7 @@ export async function startInstanceServer(config: ServerConfig, initPortToClient
if (!state.connections.has(message.connectionId))
return;
// The stream might have been reset locally in the past.
if (message.streamId && !state.connections.get(message.connectionId)!.has(message.streamId))
if (!state.connections.get(message.connectionId)!.has(message.streamId))
return;
state.connections.get(message.connectionId)!.delete(message.streamId);
state.instance!.streamReset(message.connectionId, message.streamId);
Expand Down
Loading

0 comments on commit fd3f3d3

Please sign in to comment.