diff --git a/examples/autobahn_client.rs b/examples/autobahn_client.rs index 8b8a9f0b..6319ad90 100644 --- a/examples/autobahn_client.rs +++ b/examples/autobahn_client.rs @@ -15,7 +15,7 @@ // See https://github.com/crossbario/autobahn-testsuite for details. use futures::io::{BufReader, BufWriter}; -use soketto::{BoxedError, connection, handshake}; +use soketto::{connection, handshake, BoxedError}; use std::str::FromStr; use tokio::net::TcpStream; use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; @@ -25,7 +25,7 @@ const SOKETTO_VERSION: &str = env!("CARGO_PKG_VERSION"); #[tokio::main] async fn main() -> Result<(), BoxedError> { let n = num_of_cases().await?; - for i in 1 ..= n { + for i in 1..=n { if let Err(e) = run_case(i).await { log::error!("case {}: {:?}", i, e) } @@ -37,7 +37,10 @@ async fn main() -> Result<(), BoxedError> { async fn num_of_cases() -> Result { let socket = TcpStream::connect("127.0.0.1:9001").await?; let mut client = new_client(socket, "/getCaseCount"); - assert!(matches!(client.handshake().await?, handshake::ServerResponse::Accepted {..})); + assert!(matches!( + client.handshake().await?, + handshake::ServerResponse::Accepted { .. } + )); let (_, mut receiver) = client.into_builder().finish(); let mut data = Vec::new(); let kind = receiver.receive_data(&mut data).await?; @@ -52,7 +55,10 @@ async fn run_case(n: usize) -> Result<(), BoxedError> { let resource = format!("/runCase?case={}&agent=soketto-{}", n, SOKETTO_VERSION); let socket = TcpStream::connect("127.0.0.1:9001").await?; let mut client = new_client(socket, &resource); - assert!(matches!(client.handshake().await?, handshake::ServerResponse::Accepted {..})); + assert!(matches!( + client.handshake().await?, + handshake::ServerResponse::Accepted { .. } + )); let (mut sender, mut receiver) = client.into_builder().finish(); let mut message = Vec::new(); loop { @@ -69,7 +75,7 @@ async fn run_case(n: usize) -> Result<(), BoxedError> { sender.flush().await? } Err(connection::Error::Closed) => return Ok(()), - Err(e) => return Err(e.into()) + Err(e) => return Err(e.into()), } } } @@ -79,19 +85,35 @@ async fn update_report() -> Result<(), BoxedError> { let resource = format!("/updateReports?agent=soketto-{}", SOKETTO_VERSION); let socket = TcpStream::connect("127.0.0.1:9001").await?; let mut client = new_client(socket, &resource); - assert!(matches!(client.handshake().await?, handshake::ServerResponse::Accepted {..})); + assert!(matches!( + client.handshake().await?, + handshake::ServerResponse::Accepted { .. } + )); client.into_builder().finish().0.close().await?; Ok(()) } #[cfg(not(feature = "deflate"))] -fn new_client(socket: TcpStream, path: &str) -> handshake::Client<'_, BufReader>>> { - handshake::Client::new(BufReader::new(BufWriter::new(socket.compat())), "127.0.0.1:9001", path) +fn new_client( + socket: TcpStream, + path: &str, +) -> handshake::Client<'_, BufReader>>> { + handshake::Client::new( + BufReader::new(BufWriter::new(socket.compat())), + "127.0.0.1:9001", + path, + ) } #[cfg(feature = "deflate")] -fn new_client(socket: TcpStream, path: &str) -> handshake::Client<'_, BufReader>>> { - let socket = BufReader::with_capacity(8 * 1024, BufWriter::with_capacity(64 * 1024, socket.compat())); +fn new_client( + socket: TcpStream, + path: &str, +) -> handshake::Client<'_, BufReader>>> { + let socket = BufReader::with_capacity( + 8 * 1024, + BufWriter::with_capacity(64 * 1024, socket.compat()), + ); let mut client = handshake::Client::new(socket, "127.0.0.1:9001", path); let deflate = soketto::extension::deflate::Deflate::new(soketto::Mode::Client); client.add_extension(Box::new(deflate)); diff --git a/examples/autobahn_server.rs b/examples/autobahn_server.rs index c0bdc8ad..1b37b420 100644 --- a/examples/autobahn_server.rs +++ b/examples/autobahn_server.rs @@ -15,10 +15,10 @@ // See https://github.com/crossbario/autobahn-testsuite for details. use futures::io::{BufReader, BufWriter}; -use soketto::{BoxedError, connection, handshake}; +use soketto::{connection, handshake, BoxedError}; use tokio::net::{TcpListener, TcpStream}; -use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; +use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; #[tokio::main] async fn main() -> Result<(), BoxedError> { let listener = TcpListener::bind("127.0.0.1:9001").await?; @@ -29,7 +29,10 @@ async fn main() -> Result<(), BoxedError> { let req = server.receive_request().await?; req.key() }; - let accept = handshake::server::Response::Accept { key, protocol: None }; + let accept = handshake::server::Response::Accept { + key, + protocol: None, + }; server.send_response(&accept).await?; let (mut sender, mut receiver) = server.into_builder().finish(); let mut message = Vec::new(); @@ -47,13 +50,13 @@ async fn main() -> Result<(), BoxedError> { sender.send_text(txt).await?; sender.flush().await? } else { - break + break; } } Err(connection::Error::Closed) => break, Err(e) => { log::error!("connection error: {}", e); - break + break; } } } @@ -62,13 +65,20 @@ async fn main() -> Result<(), BoxedError> { } #[cfg(not(feature = "deflate"))] -fn new_server<'a>(socket: TcpStream) -> handshake::Server<'a, BufReader>>> { +fn new_server<'a>( + socket: TcpStream, +) -> handshake::Server<'a, BufReader>>> { handshake::Server::new(BufReader::new(BufWriter::new(socket.compat()))) } #[cfg(feature = "deflate")] -fn new_server<'a>(socket: TcpStream) -> handshake::Server<'a, BufReader>>> { - let socket = BufReader::with_capacity(8 * 1024, BufWriter::with_capacity(16 * 1024, socket.compat())); +fn new_server<'a>( + socket: TcpStream, +) -> handshake::Server<'a, BufReader>>> { + let socket = BufReader::with_capacity( + 8 * 1024, + BufWriter::with_capacity(16 * 1024, socket.compat()), + ); let mut server = handshake::Server::new(socket); let deflate = soketto::extension::deflate::Deflate::new(soketto::Mode::Server); server.add_extension(Box::new(deflate)); diff --git a/src/base.rs b/src/base.rs index d5bfcd82..4944cff3 100644 --- a/src/base.rs +++ b/src/base.rs @@ -62,7 +62,7 @@ pub enum OpCode { /// A reserved op code. Reserved14, /// A reserved op code. - Reserved15 + Reserved15, } impl OpCode { @@ -88,7 +88,7 @@ impl OpCode { | OpCode::Reserved13 | OpCode::Reserved14 | OpCode::Reserved15 => true, - _ => false + _ => false, } } } @@ -111,7 +111,7 @@ impl fmt::Display for OpCode { OpCode::Reserved12 => f.write_str("Reserved:12"), OpCode::Reserved13 => f.write_str("Reserved:13"), OpCode::Reserved14 => f.write_str("Reserved:14"), - OpCode::Reserved15 => f.write_str("Reserved:15") + OpCode::Reserved15 => f.write_str("Reserved:15"), } } } @@ -150,7 +150,7 @@ impl TryFrom for OpCode { 13 => Ok(OpCode::Reserved13), 14 => Ok(OpCode::Reserved14), 15 => Ok(OpCode::Reserved15), - _ => Err(UnknownOpCode(())) + _ => Err(UnknownOpCode(())), } } } @@ -173,7 +173,7 @@ impl From for u8 { OpCode::Reserved12 => 12, OpCode::Reserved13 => 13, OpCode::Reserved14 => 14, - OpCode::Reserved15 => 15 + OpCode::Reserved15 => 15, } } } @@ -190,12 +190,14 @@ pub struct Header { masked: bool, opcode: OpCode, mask: u32, - payload_len: usize + payload_len: usize, } impl fmt::Display for Header { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "({} (fin {}) (rsv {}{}{}) (mask ({} {:x})) (len {}))", + write!( + f, + "({} (fin {}) (rsv {}{}{}) (mask ({} {:x})) (len {}))", self.opcode, self.fin as u8, self.rsv1 as u8, @@ -203,7 +205,8 @@ impl fmt::Display for Header { self.rsv3 as u8, self.masked as u8, self.mask, - self.payload_len) + self.payload_len + ) } } @@ -218,7 +221,7 @@ impl Header { masked: false, opcode: oc, mask: 0, - payload_len: 0 + payload_len: 0, } } @@ -331,7 +334,7 @@ pub struct Codec { /// Bits reserved by an extension. reserved_bits: u8, /// Scratch buffer used during header encoding. - header_buffer: [u8; MAX_HEADER_SIZE] + header_buffer: [u8; MAX_HEADER_SIZE], } impl Default for Codec { @@ -339,7 +342,7 @@ impl Default for Codec { Codec { max_data_size: 256 * 1024 * 1024, reserved_bits: 0, - header_buffer: [0; MAX_HEADER_SIZE] + header_buffer: [0; MAX_HEADER_SIZE], } } } @@ -385,7 +388,7 @@ impl Codec { /// Decode a websocket frame header. pub fn decode_header(&self, bytes: &[u8]) -> Result, Error> { if bytes.len() < 2 { - return Ok(Parsing::NeedMore(2 - bytes.len())) + return Ok(Parsing::NeedMore(2 - bytes.len())); } let first = bytes[0]; @@ -396,11 +399,11 @@ impl Codec { let opcode = OpCode::try_from(first & 0xF)?; if opcode.is_reserved() { - return Err(Error::ReservedOpCode) + return Err(Error::ReservedOpCode); } if opcode.is_control() && !fin { - return Err(Error::FragmentedControl) + return Err(Error::FragmentedControl); } let mut header = Header::new(opcode); @@ -408,19 +411,19 @@ impl Codec { let rsv1 = first & 0x40 != 0; if rsv1 && (self.reserved_bits & 4 == 0) { - return Err(Error::InvalidReservedBit(1)) + return Err(Error::InvalidReservedBit(1)); } header.set_rsv1(rsv1); let rsv2 = first & 0x20 != 0; if rsv2 && (self.reserved_bits & 2 == 0) { - return Err(Error::InvalidReservedBit(2)) + return Err(Error::InvalidReservedBit(2)); } header.set_rsv2(rsv2); let rsv3 = first & 0x10 != 0; if rsv3 && (self.reserved_bits & 1 == 0) { - return Err(Error::InvalidReservedBit(3)) + return Err(Error::InvalidReservedBit(3)); } header.set_rsv3(rsv3); header.set_masked(second & 0x80 != 0); @@ -428,7 +431,7 @@ impl Codec { let len: u64 = match second & 0x7F { TWO_EXT => { if bytes.len() < offset + 2 { - return Ok(Parsing::NeedMore(offset + 2 - bytes.len())) + return Ok(Parsing::NeedMore(offset + 2 - bytes.len())); } let len = u16::from_be_bytes([bytes[offset], bytes[offset + 1]]); offset += 2; @@ -436,43 +439,45 @@ impl Codec { } EIGHT_EXT => { if bytes.len() < offset + 8 { - return Ok(Parsing::NeedMore(offset + 8 - bytes.len())) + return Ok(Parsing::NeedMore(offset + 8 - bytes.len())); } let mut b = [0; 8]; - b.copy_from_slice(&bytes[offset .. offset + 8]); + b.copy_from_slice(&bytes[offset..offset + 8]); offset += 8; u64::from_be_bytes(b) } - n => u64::from(n) + n => u64::from(n), }; if len > MAX_CTRL_BODY_SIZE && header.opcode().is_control() { - return Err(Error::InvalidControlFrameLen) + return Err(Error::InvalidControlFrameLen); } - let len: usize = - if len > as_u64(self.max_data_size) { - return Err(Error::PayloadTooLarge { - actual: len, - maximum: as_u64(self.max_data_size) - }) - } else { - len as usize - }; + let len: usize = if len > as_u64(self.max_data_size) { + return Err(Error::PayloadTooLarge { + actual: len, + maximum: as_u64(self.max_data_size), + }); + } else { + len as usize + }; header.set_payload_len(len); if header.is_masked() { if bytes.len() < offset + 4 { - return Ok(Parsing::NeedMore(offset + 4 - bytes.len())) + return Ok(Parsing::NeedMore(offset + 4 - bytes.len())); } let mut b = [0; 4]; - b.copy_from_slice(&bytes[offset .. offset + 4]); + b.copy_from_slice(&bytes[offset..offset + 4]); offset += 4; header.set_mask(u32::from_be_bytes(b)); } - Ok(Parsing::Done { value: header, offset }) + Ok(Parsing::Done { + value: header, + offset, + }) } /// Encode a websocket frame header. @@ -514,22 +519,22 @@ impl Codec { second_byte |= TWO_EXT; self.header_buffer[offset] = second_byte; offset += 1; - self.header_buffer[offset .. offset + 2].copy_from_slice(&(len as u16).to_be_bytes()); + self.header_buffer[offset..offset + 2].copy_from_slice(&(len as u16).to_be_bytes()); offset += 2; } else { second_byte |= EIGHT_EXT; self.header_buffer[offset] = second_byte; offset += 1; - self.header_buffer[offset .. offset + 8].copy_from_slice(&as_u64(len).to_be_bytes()); + self.header_buffer[offset..offset + 8].copy_from_slice(&as_u64(len).to_be_bytes()); offset += 8; } if header.is_masked() { - self.header_buffer[offset .. offset + 4].copy_from_slice(&header.mask().to_be_bytes()); + self.header_buffer[offset..offset + 4].copy_from_slice(&header.mask().to_be_bytes()); offset += 4; } - &self.header_buffer[.. offset] + &self.header_buffer[..offset] } /// Use the given header's mask and apply it to the data. @@ -560,26 +565,23 @@ pub enum Error { /// The reserved bit is invalid. InvalidReservedBit(u8), /// The payload length of a frame exceeded the configured maximum. - PayloadTooLarge { actual: u64, maximum: u64 } + PayloadTooLarge { actual: u64, maximum: u64 }, } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Error::Io(e) => - write!(f, "i/o error: {}", e), - Error::UnknownOpCode => - f.write_str("unknown opcode"), - Error::ReservedOpCode => - f.write_str("reserved opcode"), - Error::FragmentedControl => - f.write_str("fragmented control frame"), - Error::InvalidControlFrameLen => - f.write_str("invalid control frame length"), - Error::InvalidReservedBit(n) => - write!(f, "invalid reserved bit: {}", n), - Error::PayloadTooLarge { actual, maximum } => - write!(f, "payload too large: len = {}, maximum = {}", actual, maximum) + Error::Io(e) => write!(f, "i/o error: {}", e), + Error::UnknownOpCode => f.write_str("unknown opcode"), + Error::ReservedOpCode => f.write_str("reserved opcode"), + Error::FragmentedControl => f.write_str("fragmented control frame"), + Error::InvalidControlFrameLen => f.write_str("invalid control frame length"), + Error::InvalidReservedBit(n) => write!(f, "invalid reserved bit: {}", n), + Error::PayloadTooLarge { actual, maximum } => write!( + f, + "payload too large: len = {}, maximum = {}", + actual, maximum + ), } } } @@ -593,8 +595,7 @@ impl std::error::Error for Error { | Error::FragmentedControl | Error::InvalidControlFrameLen | Error::InvalidReservedBit(_) - | Error::PayloadTooLarge {..} - => None + | Error::PayloadTooLarge { .. } => None, } } } @@ -611,14 +612,13 @@ impl From for Error { } } - // Tests ////////////////////////////////////////////////////////////////////////////////////////// #[cfg(test)] mod test { + use super::{Codec, Error, OpCode}; use crate::Parsing; use quickcheck::QuickCheck; - use super::{OpCode, Codec, Error}; #[test] fn decode_partial_header() { @@ -665,7 +665,7 @@ mod test { #[test] fn decode_invalid_control_payload_len() { // Payload on control frame must be 125 bytes or less. 2nd byte must be 0xFD or less. - let ctrl_payload_len : &[u8] = &[0x89, 0xFE, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + let ctrl_payload_len: &[u8] = &[0x89, 0xFE, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; assert!(matches! { Codec::new().decode_header(ctrl_payload_len), Err(Error::InvalidControlFrameLen) @@ -742,4 +742,3 @@ mod test { QuickCheck::new().quickcheck(property as fn((bool, bool, bool)) -> bool) } } - diff --git a/src/connection.rs b/src/connection.rs index 1e1110fc..d30a883b 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -9,10 +9,18 @@ //! A persistent websocket connection after the handshake phase, represented //! as a [`Sender`] and [`Receiver`] pair. -use bytes::{Buf, BytesMut}; -use crate::{Storage, Parsing, base::{self, Header, MAX_HEADER_SIZE, OpCode}, extension::Extension}; use crate::data::{ByteSlice125, Data, Incoming}; -use futures::{io::{ReadHalf, WriteHalf}, lock::BiLock, prelude::*}; +use crate::{ + base::{self, Header, OpCode, MAX_HEADER_SIZE}, + extension::Extension, + Parsing, Storage, +}; +use bytes::{Buf, BytesMut}; +use futures::{ + io::{ReadHalf, WriteHalf}, + lock::BiLock, + prelude::*, +}; use std::{fmt, io, str}; /// Accumulated max. size of a complete message. @@ -27,7 +35,7 @@ pub enum Mode { /// Client-side of a connection (implies masking of payload data). Client, /// Server-side of a connection. - Server + Server, } impl Mode { @@ -63,7 +71,7 @@ pub struct Sender { writer: BiLock>, mask_buffer: Vec, extensions: BiLock>>, - has_extensions: bool + has_extensions: bool, } /// The receiving half of a connection. @@ -79,7 +87,7 @@ pub struct Receiver { buffer: BytesMut, ctrl_buffer: BytesMut, max_message_size: usize, - is_closed: bool + is_closed: bool, } /// A connection builder. @@ -95,7 +103,7 @@ pub struct Builder { codec: base::Codec, extensions: Vec>, buffer: BytesMut, - max_message_size: usize + max_message_size: usize, } impl Builder { @@ -117,7 +125,7 @@ impl Builder { codec, extensions: Vec::new(), buffer: BytesMut::new(), - max_message_size: MAX_MESSAGE_SIZE + max_message_size: MAX_MESSAGE_SIZE, } } @@ -131,7 +139,7 @@ impl Builder { /// Only enabled extensions will be considered. pub fn add_extensions(&mut self, extensions: I) where - I: IntoIterator> + I: IntoIterator>, { for e in extensions.into_iter().filter(|e| e.is_enabled()) { log::debug!("{}: using extension: {}", self.id, e.name()); @@ -173,7 +181,7 @@ impl Builder { buffer: self.buffer, ctrl_buffer: BytesMut::new(), max_message_size: self.max_message_size, - is_closed: false + is_closed: false, }; let send = Sender { @@ -183,7 +191,7 @@ impl Builder { mask_buffer: Vec::new(), codec: self.codec, extensions: ext2, - has_extensions + has_extensions, }; (send, recv) @@ -235,7 +243,10 @@ impl Receiver { // Check if total message does not exceed maximum. if length > self.max_message_size { log::warn!("{}: accumulated message length exceeds maximum", self.id); - return Err(Error::MessageTooLarge { current: length, maximum: self.max_message_size }) + return Err(Error::MessageTooLarge { + current: length, + maximum: self.max_message_size, + }); } // Get the frame's payload data bytes from buffer or socket. @@ -261,46 +272,66 @@ impl Receiver { if bytes_to_read > 0 { let n = message.len(); message.resize(n + bytes_to_read, 0u8); - self.reader.read_exact(&mut message[n ..]).await? + self.reader.read_exact(&mut message[n..]).await? } debug_assert_eq!(header.payload_len(), message.len() - old_msg_len); - base::Codec::apply_mask(&header, &mut message[old_msg_len ..]); + base::Codec::apply_mask(&header, &mut message[old_msg_len..]); } match (header.is_fin(), header.opcode()) { - (false, OpCode::Continue) => { // Intermediate message fragment. + (false, OpCode::Continue) => { + // Intermediate message fragment. if first_fragment_opcode.is_none() { - log::debug!("{}: continue frame while not processing message fragments", self.id); - return Err(Error::UnexpectedOpCode(OpCode::Continue)) + log::debug!( + "{}: continue frame while not processing message fragments", + self.id + ); + return Err(Error::UnexpectedOpCode(OpCode::Continue)); } - continue + continue; } - (false, oc) => { // Initial message fragment. + (false, oc) => { + // Initial message fragment. if first_fragment_opcode.is_some() { - log::debug!("{}: initial fragment while processing a fragmented message", self.id); - return Err(Error::UnexpectedOpCode(oc)) + log::debug!( + "{}: initial fragment while processing a fragmented message", + self.id + ); + return Err(Error::UnexpectedOpCode(oc)); } first_fragment_opcode = Some(oc); self.decode_with_extensions(&mut header, message).await?; - continue + continue; } - (true, OpCode::Continue) => { // Last message fragment. + (true, OpCode::Continue) => { + // Last message fragment. if let Some(oc) = first_fragment_opcode.take() { header.set_payload_len(message.len()); - log::trace!("{}: last fragment: total length = {} bytes", self.id, message.len()); + log::trace!( + "{}: last fragment: total length = {} bytes", + self.id, + message.len() + ); self.decode_with_extensions(&mut header, message).await?; header.set_opcode(oc); } else { - log::debug!("{}: last continue frame while not processing message fragments", self.id); - return Err(Error::UnexpectedOpCode(OpCode::Continue)) + log::debug!( + "{}: last continue frame while not processing message fragments", + self.id + ); + return Err(Error::UnexpectedOpCode(OpCode::Continue)); } } - (true, oc) => { // Regular non-fragmented message. + (true, oc) => { + // Regular non-fragmented message. if first_fragment_opcode.is_some() { - log::debug!("{}: regular message while processing fragmented message", self.id); - return Err(Error::UnexpectedOpCode(oc)) + log::debug!( + "{}: regular message while processing fragmented message", + self.id + ); + return Err(Error::UnexpectedOpCode(oc)); } self.decode_with_extensions(&mut header, message).await? } @@ -309,9 +340,9 @@ impl Receiver { let num_bytes = message.len() - message_len; if header.opcode() == OpCode::Text { - return Ok(Incoming::Data(Data::Text(num_bytes))) + return Ok(Incoming::Data(Data::Text(num_bytes))); } else { - return Ok(Incoming::Data(Data::Binary(num_bytes))) + return Ok(Incoming::Data(Data::Binary(num_bytes))); } } } @@ -320,7 +351,7 @@ impl Receiver { pub async fn receive_data(&mut self, message: &mut Vec) -> Result { loop { if let Incoming::Data(d) = self.receive(message).await? { - return Ok(d) + return Ok(d); } } } @@ -329,14 +360,15 @@ impl Receiver { async fn receive_header(&mut self) -> Result { loop { match self.codec.decode_header(&self.buffer)? { - Parsing::Done { value: header, offset } => { + Parsing::Done { + value: header, + offset, + } => { debug_assert!(offset <= MAX_HEADER_SIZE); self.buffer.advance(offset); - return Ok(header) - } - Parsing::NeedMore(n) => { - crate::read(&mut self.reader, &mut self.buffer, n).await? + return Ok(header); } + Parsing::NeedMore(n) => crate::read(&mut self.reader, &mut self.buffer, n).await?, } } } @@ -344,12 +376,12 @@ impl Receiver { /// Read the complete payload data into the read buffer. async fn read_buffer(&mut self, header: &Header) -> Result<(), Error> { if header.payload_len() <= self.buffer.len() { - return Ok(()) + return Ok(()); } let i = self.buffer.len(); let d = header.payload_len() - i; self.buffer.resize(i + d, 0u8); - self.reader.read_exact(&mut self.buffer[i ..]).await?; + self.reader.read_exact(&mut self.buffer[i..]).await?; Ok(()) } @@ -364,7 +396,16 @@ impl Receiver { let mut answer = Header::new(OpCode::Pong); let mut unused = Vec::new(); let mut data = Storage::Unique(&mut self.ctrl_buffer); - write(self.id, self.mode, &mut self.codec, &mut self.writer, &mut answer, &mut data, &mut unused).await?; + write( + self.id, + self.mode, + &mut self.codec, + &mut self.writer, + &mut answer, + &mut data, + &mut unused, + ) + .await?; self.flush().await?; Ok(None) } @@ -378,10 +419,28 @@ impl Receiver { if let Some(CloseReason { code, .. }) = reason { let mut data = code.to_be_bytes(); let mut data = Storage::Unique(&mut data); - let _ = write(self.id, self.mode, &mut self.codec, &mut self.writer, &mut header, &mut data, &mut unused).await; + let _ = write( + self.id, + self.mode, + &mut self.codec, + &mut self.writer, + &mut header, + &mut data, + &mut unused, + ) + .await; } else { let mut data = Storage::Unique(&mut []); - let _ = write(self.id, self.mode, &mut self.codec, &mut self.writer, &mut header, &mut data, &mut unused).await; + let _ = write( + self.id, + self.mode, + &mut self.codec, + &mut self.writer, + &mut header, + &mut data, + &mut unused, + ) + .await; } self.flush().await?; self.writer.lock().await.close().await?; @@ -399,14 +458,18 @@ impl Receiver { | OpCode::Reserved12 | OpCode::Reserved13 | OpCode::Reserved14 - | OpCode::Reserved15 => Err(Error::UnexpectedOpCode(header.opcode())) + | OpCode::Reserved15 => Err(Error::UnexpectedOpCode(header.opcode())), } } /// Apply all extensions to the given header and the internal message buffer. - async fn decode_with_extensions(&mut self, header: &mut Header, message: &mut Vec) -> Result<(), Error> { + async fn decode_with_extensions( + &mut self, + header: &mut Header, + message: &mut Vec, + ) -> Result<(), Error> { if !self.has_extensions { - return Ok(()) + return Ok(()); } for e in self.extensions.lock().await.iter_mut() { log::trace!("{}: decoding with extension: {}", self.id, e.name()); @@ -419,9 +482,14 @@ impl Receiver { async fn flush(&mut self) -> Result<(), Error> { log::trace!("{}: Receiver flushing connection", self.id); if self.is_closed { - return Ok(()) + return Ok(()); } - self.writer.lock().await.flush().await.or(Err(Error::Closed)) + self.writer + .lock() + .await + .flush() + .await + .or(Err(Error::Closed)) } } @@ -429,7 +497,8 @@ impl Sender { /// Send a text value over the websocket connection. pub async fn send_text(&mut self, data: impl AsRef) -> Result<(), Error> { let mut header = Header::new(OpCode::Text); - self.send_frame(&mut header, &mut Storage::Shared(data.as_ref().as_bytes())).await + self.send_frame(&mut header, &mut Storage::Shared(data.as_ref().as_bytes())) + .await } /// Send a text value over the websocket connection. @@ -437,13 +506,15 @@ impl Sender { /// This method performs one copy fewer than [`Sender::send_text`]. pub async fn send_text_owned(&mut self, data: String) -> Result<(), Error> { let mut header = Header::new(OpCode::Text); - self.send_frame(&mut header, &mut Storage::Owned(data.into_bytes())).await + self.send_frame(&mut header, &mut Storage::Owned(data.into_bytes())) + .await } /// Send some binary data over the websocket connection. pub async fn send_binary(&mut self, data: impl AsRef<[u8]>) -> Result<(), Error> { let mut header = Header::new(OpCode::Binary); - self.send_frame(&mut header, &mut Storage::Shared(data.as_ref())).await + self.send_frame(&mut header, &mut Storage::Shared(data.as_ref())) + .await } /// Send some binary data over the websocket connection. @@ -452,25 +523,33 @@ impl Sender { /// The `data` buffer may be modified by this method, e.g. if masking is necessary. pub async fn send_binary_mut(&mut self, mut data: impl AsMut<[u8]>) -> Result<(), Error> { let mut header = Header::new(OpCode::Binary); - self.send_frame(&mut header, &mut Storage::Unique(data.as_mut())).await + self.send_frame(&mut header, &mut Storage::Unique(data.as_mut())) + .await } /// Ping the remote end. pub async fn send_ping(&mut self, data: ByteSlice125<'_>) -> Result<(), Error> { let mut header = Header::new(OpCode::Ping); - self.write(&mut header, &mut Storage::Shared(data.as_ref())).await + self.write(&mut header, &mut Storage::Shared(data.as_ref())) + .await } /// Send an unsolicited Pong to the remote. pub async fn send_pong(&mut self, data: ByteSlice125<'_>) -> Result<(), Error> { let mut header = Header::new(OpCode::Pong); - self.write(&mut header, &mut Storage::Shared(data.as_ref())).await + self.write(&mut header, &mut Storage::Shared(data.as_ref())) + .await } /// Flush the socket buffer. pub async fn flush(&mut self) -> Result<(), Error> { log::trace!("{}: Sender flushing connection", self.id); - self.writer.lock().await.flush().await.or(Err(Error::Closed)) + self.writer + .lock() + .await + .flush() + .await + .or(Err(Error::Closed)) } /// Send a close message and close the connection. @@ -478,17 +557,27 @@ impl Sender { log::trace!("{}: closing connection", self.id); let mut header = Header::new(OpCode::Close); let code = 1000_u16.to_be_bytes(); // 1000 = normal closure - self.write(&mut header, &mut Storage::Shared(&code[..])).await?; + self.write(&mut header, &mut Storage::Shared(&code[..])) + .await?; self.flush().await?; - self.writer.lock().await.close().await.or(Err(Error::Closed)) + self.writer + .lock() + .await + .close() + .await + .or(Err(Error::Closed)) } /// Send arbitrary websocket frames. /// /// Before sending, extensions will be applied to header and payload data. - async fn send_frame(&mut self, header: &mut Header, data: &mut Storage<'_>) -> Result<(), Error> { + async fn send_frame( + &mut self, + header: &mut Header, + data: &mut Storage<'_>, + ) -> Result<(), Error> { if !self.has_extensions { - return self.write(header, data).await + return self.write(header, data).await; } for e in self.extensions.lock().await.iter_mut() { @@ -504,21 +593,29 @@ impl Sender { /// The data will be masked if necessary. /// No extensions will be applied to header and payload data. async fn write(&mut self, header: &mut Header, data: &mut Storage<'_>) -> Result<(), Error> { - write(self.id, self.mode, &mut self.codec, &mut self.writer, header, data, &mut self.mask_buffer).await + write( + self.id, + self.mode, + &mut self.codec, + &mut self.writer, + header, + data, + &mut self.mask_buffer, + ) + .await } } /// Write header and payload data to socket. -async fn write - ( id: Id - , mode: Mode - , codec: &mut base::Codec - , writer: &mut BiLock> - , header: &mut Header - , data: &mut Storage<'_> - , mask_buffer: &mut Vec - ) -> Result<(), Error> -{ +async fn write( + id: Id, + mode: Mode, + codec: &mut base::Codec, + writer: &mut BiLock>, + header: &mut Header, + data: &mut Storage<'_>, + mask_buffer: &mut Vec, +) -> Result<(), Error> { if mode.is_client() { header.set_masked(true); header.set_mask(rand::random()); @@ -532,7 +629,7 @@ async fn write w.write_all(&header_bytes).await.or(Err(Error::Closed))?; if !header.is_masked() { - return w.write_all(data.as_ref()).await.or(Err(Error::Closed)) + return w.write_all(data.as_ref()).await.or(Err(Error::Closed)); } match data { @@ -563,7 +660,10 @@ fn close_answer(data: &[u8]) -> Result<(Header, Option), Error> { // Check that the reason string is properly encoded let descr = std::str::from_utf8(&data[2..])?.into(); let code = u16::from_be_bytes([data[0], data[1]]); - let reason = CloseReason { code, descr: Some(descr) }; + let reason = CloseReason { + code, + descr: Some(descr), + }; // Status codes are defined in // https://tools.ietf.org/html/rfc6455#section-7.4.1 and @@ -612,19 +712,17 @@ pub struct CloseReason { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Error::Io(e) => - write!(f, "i/o error: {}", e), - Error::Codec(e) => - write!(f, "codec error: {}", e), - Error::Extension(e) => - write!(f, "extension error: {}", e), - Error::UnexpectedOpCode(c) => - write!(f, "unexpected opcode: {}", c), - Error::Utf8(e) => - write!(f, "utf-8 error: {}", e), - Error::MessageTooLarge { current, maximum } => - write!(f, "message too large: len >= {}, maximum = {}", current, maximum), - Error::Closed => f.write_str("connection closed") + Error::Io(e) => write!(f, "i/o error: {}", e), + Error::Codec(e) => write!(f, "codec error: {}", e), + Error::Extension(e) => write!(f, "extension error: {}", e), + Error::UnexpectedOpCode(c) => write!(f, "unexpected opcode: {}", c), + Error::Utf8(e) => write!(f, "utf-8 error: {}", e), + Error::MessageTooLarge { current, maximum } => write!( + f, + "message too large: len >= {}, maximum = {}", + current, maximum + ), + Error::Closed => f.write_str("connection closed"), } } } @@ -636,10 +734,7 @@ impl std::error::Error for Error { Error::Codec(e) => Some(e), Error::Extension(e) => Some(&**e), Error::Utf8(e) => Some(e), - Error::UnexpectedOpCode(_) - | Error::MessageTooLarge {..} - | Error::Closed - => None + Error::UnexpectedOpCode(_) | Error::MessageTooLarge { .. } | Error::Closed => None, } } } diff --git a/src/data.rs b/src/data.rs index 827dac27..de6d6679 100644 --- a/src/data.rs +++ b/src/data.rs @@ -26,12 +26,20 @@ pub enum Incoming<'a> { impl Incoming<'_> { /// Is this text or binary data? pub fn is_data(&self) -> bool { - if let Incoming::Data(_) = self { true } else { false } + if let Incoming::Data(_) = self { + true + } else { + false + } } /// Is this a PONG? pub fn is_pong(&self) -> bool { - if let Incoming::Pong(_) = self { true } else { false } + if let Incoming::Pong(_) = self { + true + } else { + false + } } /// Is this text data? @@ -58,25 +66,33 @@ pub enum Data { /// Textual data (number of bytes). Text(usize), /// Binary data (number of bytes). - Binary(usize) + Binary(usize), } impl Data { /// Is this text data? pub fn is_text(&self) -> bool { - if let Data::Text(_) = self { true } else { false } + if let Data::Text(_) = self { + true + } else { + false + } } /// Is this binary data? pub fn is_binary(&self) -> bool { - if let Data::Binary(_) = self { true } else { false } + if let Data::Binary(_) = self { + true + } else { + false + } } /// The length of data (number of bytes). pub fn len(&self) -> usize { match self { Data::Text(n) => *n, - Data::Binary(n) => *n + Data::Binary(n) => *n, } } } diff --git a/src/extension.rs b/src/extension.rs index 47428a73..cdadcdea 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -14,7 +14,7 @@ #[cfg(feature = "deflate")] pub mod deflate; -use crate::{BoxedError, Storage, base::Header}; +use crate::{base::Header, BoxedError, Storage}; use std::{borrow::Cow, fmt}; /// A websocket extension as per RFC 6455, section 9. @@ -108,7 +108,7 @@ impl Extension for Box { #[derive(Debug, Clone, PartialEq, Eq)] pub struct Param<'a> { name: Cow<'a, str>, - value: Option> + value: Option>, } impl<'a> fmt::Display for Param<'a> { @@ -123,10 +123,10 @@ impl<'a> fmt::Display for Param<'a> { impl<'a> Param<'a> { /// Create a new parameter with the given name. - pub fn new(name: impl Into>) -> Self{ + pub fn new(name: impl Into>) -> Self { Param { name: name.into(), - value: None + value: None, } } @@ -150,8 +150,7 @@ impl<'a> Param<'a> { pub fn acquire(self) -> Param<'static> { Param { name: Cow::Owned(self.name.into_owned()), - value: self.value.map(|v| Cow::Owned(v.into_owned())) + value: self.value.map(|v| Cow::Owned(v.into_owned())), } } } - diff --git a/src/extension/deflate.rs b/src/extension/deflate.rs index d3f4aad5..43e80f08 100644 --- a/src/extension/deflate.rs +++ b/src/extension/deflate.rs @@ -12,14 +12,17 @@ use crate::{ as_u64, - BoxedError, - Storage, base::{Header, OpCode}, connection::Mode, - extension::{Extension, Param} + extension::{Extension, Param}, + BoxedError, Storage, +}; +use flate2::{write::DeflateDecoder, Compress, Compression, FlushCompress, Status}; +use std::{ + convert::TryInto, + io::{self, Write}, + mem, }; -use flate2::{Compress, Compression, FlushCompress, Status, write::DeflateDecoder}; -use std::{convert::TryInto, io::{self, Write}, mem}; const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover"; const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits"; @@ -39,7 +42,7 @@ pub struct Deflate { params: Vec>, our_max_window_bits: u8, their_max_window_bits: u8, - await_last_fragment: bool + await_last_fragment: bool, } impl Deflate { @@ -62,7 +65,7 @@ impl Deflate { params, our_max_window_bits: 15, their_max_window_bits: 15, - await_last_fragment: false + await_last_fragment: false, } } @@ -76,8 +79,14 @@ impl Deflate { /// by including the "server_max_window_bits" extension parameter in the /// response with the same or smaller value as the offer. pub fn set_max_server_window_bits(&mut self, max: u8) { - assert!(self.mode == Mode::Client, "setting max. server window bits requires client mode"); - assert!(max > 8 && max <= 15, "max. server window bits have to be within 9 ..= 15"); + assert!( + self.mode == Mode::Client, + "setting max. server window bits requires client mode" + ); + assert!( + max > 8 && max <= 15, + "max. server window bits have to be within 9 ..= 15" + ); self.their_max_window_bits = max; // upper bound of the server's window let mut p = Param::new(SERVER_MAX_WINDOW_BITS); p.set_value(Some(max.to_string())); @@ -97,10 +106,20 @@ impl Deflate { /// The server may also respond with a smaller value which allows the client /// to reduce its sliding window even more. pub fn set_max_client_window_bits(&mut self, max: u8) { - assert!(self.mode == Mode::Client, "setting max. client window bits requires client mode"); - assert!(max > 8 && max <= 15, "max. client window bits have to be within 9 ..= 15"); + assert!( + self.mode == Mode::Client, + "setting max. client window bits requires client mode" + ); + assert!( + max > 8 && max <= 15, + "max. client window bits have to be within 9 ..= 15" + ); self.our_max_window_bits = max; // upper bound of the client's window - if let Some(p) = self.params.iter_mut().find(|p| p.name() == CLIENT_MAX_WINDOW_BITS) { + if let Some(p) = self + .params + .iter_mut() + .find(|p| p.name() == CLIENT_MAX_WINDOW_BITS) + { p.set_value(Some(max.to_string())); } else { let mut p = Param::new(CLIENT_MAX_WINDOW_BITS); @@ -113,12 +132,12 @@ impl Deflate { if let Some(Ok(v)) = p.value().map(|s| s.parse::()) { if v < 8 || v > 15 { log::debug!("invalid {}: {} (expected range: 8 ..= 15)", p.name(), v); - return Err(()) + return Err(()); } if let Some(x) = expected { if v > x { log::debug!("invalid {}: {} (expected: {} <= {})", p.name(), v, v, x); - return Err(()) + return Err(()); } } self.their_max_window_bits = std::cmp::max(9, v); @@ -147,18 +166,19 @@ impl Extension for Deflate { for p in params { log::trace!("configure server with: {}", p); match p.name() { - CLIENT_MAX_WINDOW_BITS => + CLIENT_MAX_WINDOW_BITS => { if self.set_their_max_window_bits(&p, None).is_err() { // we just accept the client's offer as is => no need to reply - return Ok(()) + return Ok(()); } + } SERVER_MAX_WINDOW_BITS => { if let Some(Ok(v)) = p.value().map(|s| s.parse::()) { // The RFC allows 8 to 15 bits, but due to zlib limitations we // only support 9 to 15. if v < 9 || v > 15 { log::debug!("unacceptable server_max_window_bits: {}", v); - return Ok(()) + return Ok(()); } let mut x = Param::new(SERVER_MAX_WINDOW_BITS); x.set_value(Some(v.to_string())); @@ -166,16 +186,18 @@ impl Extension for Deflate { self.our_max_window_bits = v; } else { log::debug!("invalid server_max_window_bits: {:?}", p.value()); - return Ok(()) + return Ok(()); } } - CLIENT_NO_CONTEXT_TAKEOVER => - self.params.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER)), - SERVER_NO_CONTEXT_TAKEOVER => - self.params.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER)), + CLIENT_NO_CONTEXT_TAKEOVER => { + self.params.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER)) + } + SERVER_NO_CONTEXT_TAKEOVER => { + self.params.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER)) + } _ => { log::debug!("{}: unknown parameter: {}", self.name(), p.name()); - return Ok(()) + return Ok(()); } } } @@ -190,29 +212,33 @@ impl Extension for Deflate { SERVER_MAX_WINDOW_BITS => { let expected = Some(self.their_max_window_bits); if self.set_their_max_window_bits(&p, expected).is_err() { - return Ok(()) + return Ok(()); } } - CLIENT_MAX_WINDOW_BITS => + CLIENT_MAX_WINDOW_BITS => { if let Some(Ok(v)) = p.value().map(|s| s.parse::()) { if v < 8 || v > 15 { log::debug!("unacceptable client_max_window_bits: {}", v); - return Ok(()) + return Ok(()); } - use std::cmp::{min, max}; + use std::cmp::{max, min}; // Due to zlib limitations we have to use 9 as a lower bound // here, even if the server allowed us to go down to 8 bits. self.our_max_window_bits = min(self.our_max_window_bits, max(9, v)); } + } _ => { log::debug!("{}: unknown parameter: {}", self.name(), p.name()); - return Ok(()) + return Ok(()); } } } if !server_no_context_takeover { - log::debug!("{}: server did not confirm no context takeover", self.name()); - return Ok(()) + log::debug!( + "{}: server did not confirm no context takeover", + self.name() + ); + return Ok(()); } } } @@ -226,7 +252,7 @@ impl Extension for Deflate { fn decode(&mut self, header: &mut Header, data: &mut Vec) -> Result<(), BoxedError> { if data.is_empty() { - return Ok(()) + return Ok(()); } match header.opcode() { @@ -234,7 +260,7 @@ impl Extension for Deflate { if !header.is_fin() { self.await_last_fragment = true; log::trace!("deflate: not decoding {}; awaiting last fragment", header); - return Ok(()) + return Ok(()); } log::trace!("deflate: decoding {}", header) } @@ -244,7 +270,7 @@ impl Extension for Deflate { } _ => { log::trace!("deflate: not decoding {}", header); - return Ok(()) + return Ok(()); } } @@ -265,14 +291,14 @@ impl Extension for Deflate { fn encode(&mut self, header: &mut Header, data: &mut Storage) -> Result<(), BoxedError> { if data.as_ref().is_empty() { - return Ok(()) + return Ok(()); } if let OpCode::Binary | OpCode::Text = header.opcode() { log::trace!("deflate: encoding {}", header) } else { log::trace!("deflate: not encoding {}", header); - return Ok(()) + return Ok(()); } self.buffer.clear(); @@ -284,10 +310,14 @@ impl Extension for Deflate { // Compress all input bytes. while encoder.total_in() < as_u64(data.as_ref().len()) { let i: usize = encoder.total_in().try_into()?; - match encoder.compress_vec(&data.as_ref()[i ..], &mut self.buffer, FlushCompress::None)? { + match encoder.compress_vec( + &data.as_ref()[i..], + &mut self.buffer, + FlushCompress::None, + )? { Status::BufError => self.buffer.reserve(4096), Status::Ok => continue, - Status::StreamEnd => break + Status::StreamEnd => break, } } @@ -297,14 +327,14 @@ impl Extension for Deflate { match encoder.compress_vec(&[], &mut self.buffer, FlushCompress::Sync)? { Status::Ok => continue, Status::BufError => continue, // more capacity is reserved above - Status::StreamEnd => break + Status::StreamEnd => break, } } // If we still have not seen the empty deflate block appended, something is wrong. if !self.buffer.ends_with(&[0, 0, 0xFF, 0xFF]) { log::error!("missing 00 00 FF FF"); - return Err(io::Error::new(io::ErrorKind::Other, "missing 00 00 FF FF").into()) + return Err(io::Error::new(io::ErrorKind::Other, "missing 00 00 FF FF").into()); } self.buffer.truncate(self.buffer.len() - 4); // Remove 00 00 FF FF; cf. RFC 7692, 7.2.1 @@ -319,4 +349,3 @@ impl Extension for Deflate { Ok(()) } } - diff --git a/src/handshake.rs b/src/handshake.rs index 761654c4..9b6b15f2 100644 --- a/src/handshake.rs +++ b/src/handshake.rs @@ -13,12 +13,12 @@ pub mod client; pub mod server; +use crate::extension::{Extension, Param}; use bytes::BytesMut; -use crate::extension::{Param, Extension}; use std::{fmt, io, str}; pub use client::{Client, ServerResponse}; -pub use server::{Server, ClientRequest}; +pub use server::{ClientRequest, Server}; // Defined in RFC 6455 and used to generate the `Sec-WebSocket-Accept` header // in the server handshake response. @@ -34,38 +34,41 @@ const SEC_WEBSOCKET_PROTOCOL: &str = "Sec-WebSocket-Protocol"; /// Check a set of headers contains a specific one. fn expect_ascii_header(headers: &[httparse::Header], name: &str, ours: &str) -> Result<(), Error> { enum State { - Init, // Start state - Name, // Header name found - Match // Header value matches + Init, // Start state + Name, // Header name found + Match, // Header value matches } - headers.iter() + headers + .iter() .filter(|h| h.name.eq_ignore_ascii_case(name)) .fold(Ok(State::Init), |result, header| { if let Ok(State::Match) = result { - return result + return result; } if str::from_utf8(header.value)? .split(',') .any(|v| v.trim().eq_ignore_ascii_case(ours)) { - return Ok(State::Match) + return Ok(State::Match); } Ok(State::Name) }) - .and_then(|state| { - match state { - State::Init => Err(Error::HeaderNotFound(name.into())), - State::Name => Err(Error::UnexpectedHeader(name.into())), - State::Match => Ok(()) - } + .and_then(|state| match state { + State::Init => Err(Error::HeaderNotFound(name.into())), + State::Name => Err(Error::UnexpectedHeader(name.into())), + State::Match => Ok(()), }) } /// Pick the first header with the given name and apply the given closure to it. -fn with_first_header<'a, F, R>(headers: &[httparse::Header<'a>], name: &str, f: F) -> Result +fn with_first_header<'a, F, R>( + headers: &[httparse::Header<'a>], + name: &str, + f: F, +) -> Result where - F: Fn(&'a [u8]) -> Result + F: Fn(&'a [u8]) -> Result, { if let Some(h) = headers.iter().find(|h| h.name.eq_ignore_ascii_case(name)) { f(h.value) @@ -75,12 +78,18 @@ where } // Configure all extensions with parsed parameters. -fn configure_extensions(extensions: &mut [Box], line: &str) -> Result<(), Error> { +fn configure_extensions( + extensions: &mut [Box], + line: &str, +) -> Result<(), Error> { for e in line.split(',') { let mut ext_parts = e.split(';'); if let Some(name) = ext_parts.next() { let name = name.trim(); - if let Some(ext) = extensions.iter_mut().find(|x| x.name().eq_ignore_ascii_case(name)) { + if let Some(ext) = extensions + .iter_mut() + .find(|x| x.name().eq_ignore_ascii_case(name)) + { let mut params = Vec::new(); for p in ext_parts { let mut key_value = p.split('='); @@ -101,7 +110,7 @@ fn configure_extensions(extensions: &mut [Box], line: &str // Write all extensions to the given buffer. fn append_extensions<'a, I>(extensions: I, bytes: &mut BytesMut) where - I: IntoIterator> + I: IntoIterator>, { let mut iter = extensions.into_iter().peekable(); @@ -154,38 +163,29 @@ pub enum Error { /// The HTTP entity could not be parsed successfully. Http(crate::BoxedError), /// UTF-8 decoding failed. - Utf8(str::Utf8Error) + Utf8(str::Utf8Error), } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Error::Io(e) => - write!(f, "i/o error: {}", e), - Error::UnsupportedHttpVersion => - f.write_str("http version was not 1.1"), - Error::IncompleteHttpRequest => - f.write_str("http request was incomplete"), - Error::SecWebSocketKeyInvalidLength(len) => - write!(f, "Sec-WebSocket-Key header was {} bytes long, expected 24", len), - Error::InvalidRequestMethod => - f.write_str("handshake was not a GET request"), - Error::HeaderNotFound(name) => - write!(f, "header {} not found", name), - Error::UnexpectedHeader(name) => - write!(f, "header {} had an unexpected value", name), - Error::InvalidSecWebSocketAccept => - f.write_str("websocket key mismatch"), - Error::UnsolicitedExtension => - f.write_str("unsolicited extension returned"), - Error::UnsolicitedProtocol => - f.write_str("unsolicited protocol returned"), - Error::Extension(e) => - write!(f, "extension error: {}", e), - Error::Http(e) => - write!(f, "http parser error: {}", e), - Error::Utf8(e) => - write!(f, "utf-8 decoding error: {}", e) + Error::Io(e) => write!(f, "i/o error: {}", e), + Error::UnsupportedHttpVersion => f.write_str("http version was not 1.1"), + Error::IncompleteHttpRequest => f.write_str("http request was incomplete"), + Error::SecWebSocketKeyInvalidLength(len) => write!( + f, + "Sec-WebSocket-Key header was {} bytes long, expected 24", + len + ), + Error::InvalidRequestMethod => f.write_str("handshake was not a GET request"), + Error::HeaderNotFound(name) => write!(f, "header {} not found", name), + Error::UnexpectedHeader(name) => write!(f, "header {} had an unexpected value", name), + Error::InvalidSecWebSocketAccept => f.write_str("websocket key mismatch"), + Error::UnsolicitedExtension => f.write_str("unsolicited extension returned"), + Error::UnsolicitedProtocol => f.write_str("unsolicited protocol returned"), + Error::Extension(e) => write!(f, "extension error: {}", e), + Error::Http(e) => write!(f, "http parser error: {}", e), + Error::Utf8(e) => write!(f, "utf-8 decoding error: {}", e), } } } @@ -205,8 +205,7 @@ impl std::error::Error for Error { | Error::UnexpectedHeader(_) | Error::InvalidSecWebSocketAccept | Error::UnsolicitedExtension - | Error::UnsolicitedProtocol - => None + | Error::UnsolicitedProtocol => None, } } } @@ -243,12 +242,30 @@ mod tests { #[test] fn header_match() { let headers = &[ - httparse::Header { name: "foo", value: b"a,b,c,d" }, - httparse::Header { name: "foo", value: b"x" }, - httparse::Header { name: "foo", value: b"y, z, a" }, - httparse::Header { name: "bar", value: b"xxx" }, - httparse::Header { name: "bar", value: b"sdfsdf 423 42 424" }, - httparse::Header { name: "baz", value: b"123" } + httparse::Header { + name: "foo", + value: b"a,b,c,d", + }, + httparse::Header { + name: "foo", + value: b"x", + }, + httparse::Header { + name: "foo", + value: b"y, z, a", + }, + httparse::Header { + name: "bar", + value: b"xxx", + }, + httparse::Header { + name: "bar", + value: b"sdfsdf 423 42 424", + }, + httparse::Header { + name: "baz", + value: b"123", + }, ]; assert!(expect_ascii_header(headers, "foo", "a").is_ok()); diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 37b215d3..786b40d8 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -10,24 +10,16 @@ //! //! [handshake]: https://tools.ietf.org/html/rfc6455#section-4 -use bytes::{Buf, BytesMut}; -use crate::{Parsing, extension::Extension}; +use super::{ + append_extensions, configure_extensions, expect_ascii_header, with_first_header, Error, + WebSocketKey, KEY, MAX_NUM_HEADERS, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_PROTOCOL, +}; use crate::connection::{self, Mode}; +use crate::{extension::Extension, Parsing}; +use bytes::{Buf, BytesMut}; use futures::prelude::*; use sha1::{Digest, Sha1}; use std::{mem, str}; -use super::{ - WebSocketKey, - Error, - KEY, - MAX_NUM_HEADERS, - SEC_WEBSOCKET_EXTENSIONS, - SEC_WEBSOCKET_PROTOCOL, - append_extensions, - configure_extensions, - expect_ascii_header, - with_first_header -}; const BLOCK_SIZE: usize = 8 * 1024; @@ -49,7 +41,7 @@ pub struct Client<'a, T> { /// The extensions the client wishes to include in the request. extensions: Vec>, /// Encoding/decoding buffer. - buffer: BytesMut + buffer: BytesMut, } impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> { @@ -63,7 +55,7 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> { nonce: [0; 24], protocols: Vec::new(), extensions: Vec::new(), - buffer: BytesMut::new() + buffer: BytesMut::new(), } } @@ -113,7 +105,7 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> { crate::read(&mut self.socket, &mut self.buffer, BLOCK_SIZE).await?; if let Parsing::Done { value, offset } = self.decode_response()? { self.buffer.advance(offset); - return Ok(value) + return Ok(value); } } } @@ -140,7 +132,8 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> { self.buffer.extend_from_slice(b" HTTP/1.1"); self.buffer.extend_from_slice(b"\r\nHost: "); self.buffer.extend_from_slice(self.host.as_bytes()); - self.buffer.extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: Upgrade"); + self.buffer + .extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: Upgrade"); self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Key: "); self.buffer.extend_from_slice(&self.nonce); if let Some(o) = &self.origin { @@ -148,7 +141,8 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> { self.buffer.extend_from_slice(o.as_bytes()) } if let Some((last, prefix)) = self.protocols.split_last() { - self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: "); + self.buffer + .extend_from_slice(b"\r\nSec-WebSocket-Protocol: "); for p in prefix { self.buffer.extend_from_slice(p.as_bytes()); self.buffer.extend_from_slice(b",") @@ -156,7 +150,8 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> { self.buffer.extend_from_slice(last.as_bytes()) } append_extensions(&self.extensions, &mut self.buffer); - self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Version: 13\r\n\r\n") + self.buffer + .extend_from_slice(b"\r\nSec-WebSocket-Version: 13\r\n\r\n") } /// Decode the server response to this client request. @@ -167,25 +162,37 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> { let offset = match response.parse(self.buffer.as_ref()) { Ok(httparse::Status::Complete(off)) => off, Ok(httparse::Status::Partial) => return Ok(Parsing::NeedMore(())), - Err(e) => return Err(Error::Http(Box::new(e))) + Err(e) => return Err(Error::Http(Box::new(e))), }; if response.version != Some(1) { - return Err(Error::UnsupportedHttpVersion) + return Err(Error::UnsupportedHttpVersion); } match response.code { Some(101) => (), - Some(code@(301 ..= 303)) | Some(code@307) | Some(code@308) => { // redirect response + Some(code @ (301..=303)) | Some(code @ 307) | Some(code @ 308) => { + // redirect response let location = with_first_header(response.headers, "Location", |loc| { Ok(String::from(std::str::from_utf8(loc)?)) })?; - let response = ServerResponse::Redirect { status_code: code, location }; - return Ok(Parsing::Done { value: response, offset }) + let response = ServerResponse::Redirect { + status_code: code, + location, + }; + return Ok(Parsing::Done { + value: response, + offset, + }); } other => { - let response = ServerResponse::Rejected { status_code: other.unwrap_or(0) }; - return Ok(Parsing::Done { value: response, offset }) + let response = ServerResponse::Rejected { + status_code: other.unwrap_or(0), + }; + return Ok(Parsing::Done { + value: response, + offset, + }); } } @@ -198,14 +205,16 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> { digest.update(KEY); let ours = base64::encode(&digest.finalize()); if ours.as_bytes() != theirs { - return Err(Error::InvalidSecWebSocketAccept) + return Err(Error::InvalidSecWebSocketAccept); } Ok(()) })?; // Parse `Sec-WebSocket-Extensions` headers. - for h in response.headers.iter() + for h in response + .headers + .iter() .filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) { configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)? @@ -214,18 +223,25 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> { // Match `Sec-WebSocket-Protocol` header. let mut selected_proto = None; - if let Some(tp) = response.headers.iter() + if let Some(tp) = response + .headers + .iter() .find(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL)) { if let Some(&p) = self.protocols.iter().find(|x| x.as_bytes() == tp.value) { selected_proto = Some(String::from(p)) } else { - return Err(Error::UnsolicitedProtocol) + return Err(Error::UnsolicitedProtocol); } } - let response = ServerResponse::Accepted { protocol: selected_proto }; - Ok(Parsing::Done { value: response, offset }) + let response = ServerResponse::Accepted { + protocol: selected_proto, + }; + Ok(Parsing::Done { + value: response, + offset, + }) } } @@ -235,19 +251,18 @@ pub enum ServerResponse { /// The server has accepted our request. Accepted { /// The protocol (if any) the server has selected. - protocol: Option + protocol: Option, }, /// The server is redirecting us to some other location. Redirect { /// The HTTP response status code. status_code: u16, /// The location URL we should go to. - location: String + location: String, }, /// The server rejected our request. Rejected { /// HTTP response status code. - status_code: u16 - } + status_code: u16, + }, } - diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 67e46422..2fcd554d 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -10,24 +10,16 @@ //! //! [handshake]: https://tools.ietf.org/html/rfc6455#section-4 -use bytes::BytesMut; -use crate::extension::Extension; +use super::{ + append_extensions, configure_extensions, expect_ascii_header, with_first_header, Error, + WebSocketKey, KEY, MAX_NUM_HEADERS, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_PROTOCOL, +}; use crate::connection::{self, Mode}; +use crate::extension::Extension; +use bytes::BytesMut; use futures::prelude::*; use sha1::{Digest, Sha1}; use std::{mem, str}; -use super::{ - WebSocketKey, - Error, - KEY, - MAX_NUM_HEADERS, - SEC_WEBSOCKET_EXTENSIONS, - SEC_WEBSOCKET_PROTOCOL, - append_extensions, - configure_extensions, - expect_ascii_header, - with_first_header -}; // Most HTTP servers default to 8KB limit on headers const MAX_HEADERS_SIZE: usize = 8 * 1024; @@ -42,7 +34,7 @@ pub struct Server<'a, T> { /// Extensions the server supports. extensions: Vec>, /// Encoding/decoding buffer. - buffer: BytesMut + buffer: BytesMut, } impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> { @@ -52,7 +44,7 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> { socket, protocols: Vec::new(), extensions: Vec::new(), - buffer: BytesMut::new() + buffer: BytesMut::new(), } } @@ -97,7 +89,11 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> { // We don't expect body, so can search for the CRLF headers tail from // the end of the buffer. - if self.buffer[skip..limit].windows(4).rev().any(|w| w == b"\r\n\r\n") { + if self.buffer[skip..limit] + .windows(4) + .rev() + .any(|w| w == b"\r\n\r\n") + { break; } @@ -147,13 +143,13 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> { match request.parse(self.buffer.as_ref()) { Ok(httparse::Status::Complete(_)) => (), Ok(httparse::Status::Partial) => return Err(Error::IncompleteHttpRequest), - Err(e) => return Err(Error::Http(Box::new(e))) + Err(e) => return Err(Error::Http(Box::new(e))), }; if request.method != Some("GET") { - return Err(Error::InvalidRequestMethod) + return Err(Error::InvalidRequestMethod); } if request.version != Some(1) { - return Err(Error::UnsupportedHttpVersion) + return Err(Error::UnsupportedHttpVersion); } let host = with_first_header(&request.headers, "Host", Ok)?; @@ -177,14 +173,18 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> { WebSocketKey::try_from(k).map_err(|_| Error::SecWebSocketKeyInvalidLength(k.len())) })?; - for h in request.headers.iter() + for h in request + .headers + .iter() .filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) { configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)? } let mut protocols = Vec::new(); - for p in request.headers.iter() + for p in request + .headers + .iter() .filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL)) { if let Some(&p) = self.protocols.iter().find(|x| x.as_bytes() == p.value) { @@ -194,7 +194,12 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> { let path = request.path.unwrap_or("/"); - Ok(ClientRequest { ws_key, protocols, path, headers }) + Ok(ClientRequest { + ws_key, + protocols, + path, + headers, + }) } // Encode server handshake response. @@ -208,21 +213,29 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> { digest.update(KEY); let d = digest.finalize(); let n = base64::encode_config_slice(&d, base64::STANDARD, &mut key_buf); - &key_buf[.. n] + &key_buf[..n] }; - self.buffer.extend_from_slice(concat![ - "HTTP/1.1 101 Switching Protocols", - "\r\nServer: soketto-", env!("CARGO_PKG_VERSION"), - "\r\nUpgrade: websocket", - "\r\nConnection: upgrade", - "\r\nSec-WebSocket-Accept: ", - ].as_bytes()); + self.buffer.extend_from_slice( + concat![ + "HTTP/1.1 101 Switching Protocols", + "\r\nServer: soketto-", + env!("CARGO_PKG_VERSION"), + "\r\nUpgrade: websocket", + "\r\nConnection: upgrade", + "\r\nSec-WebSocket-Accept: ", + ] + .as_bytes(), + ); self.buffer.extend_from_slice(accept_value); if let Some(p) = protocol { - self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: "); + self.buffer + .extend_from_slice(b"\r\nSec-WebSocket-Protocol: "); self.buffer.extend_from_slice(p.as_bytes()) } - append_extensions(self.extensions.iter().filter(|e| e.is_enabled()), &mut self.buffer); + append_extensions( + self.extensions.iter().filter(|e| e.is_enabled()), + &mut self.buffer, + ); self.buffer.extend_from_slice(b"\r\n\r\n") } Response::Reject { status_code } => { @@ -286,12 +299,10 @@ pub enum Response<'a> { /// The server accepts the handshake request. Accept { key: WebSocketKey, - protocol: Option<&'a str> + protocol: Option<&'a str>, }, /// The server rejects the handshake request. - Reject { - status_code: u16 - } + Reject { status_code: u16 }, } /// Known status codes and their reason phrases. @@ -355,5 +366,5 @@ const STATUSCODES: &[(u16, &str)] = &[ (507, "507 Insufficient Storage"), (508, "508 Loop Detected"), (510, "510 Not Extended"), - (511, "511 Network Authentication Required") + (511, "511 Network Authentication Required"), ]; diff --git a/src/lib.rs b/src/lib.rs index b2cc7514..1445263c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -112,10 +112,10 @@ #![forbid(unsafe_code)] pub mod base; +pub mod connection; pub mod data; pub mod extension; pub mod handshake; -pub mod connection; use bytes::BytesMut; use futures::io::{AsyncRead, AsyncReadExt}; @@ -134,10 +134,10 @@ pub enum Parsing { /// The parsed value. value: T, /// The offset into the byte slice that has been consumed. - offset: usize + offset: usize, }, /// Parsing is incomplete and needs more data. - NeedMore(N) + NeedMore(N), } /// A buffer type used for implementing `Extension`s. @@ -148,7 +148,7 @@ pub enum Storage<'a> { /// A mutable byte slice. Unique(&'a mut [u8]), /// An owned byte buffer. - Owned(Vec) + Owned(Vec), } impl AsRef<[u8]> for Storage<'_> { @@ -156,7 +156,7 @@ impl AsRef<[u8]> for Storage<'_> { match self { Storage::Shared(d) => d, Storage::Unique(d) => d, - Storage::Owned(b) => b.as_ref() + Storage::Owned(b) => b.as_ref(), } } } @@ -171,16 +171,15 @@ const fn as_u64(a: usize) -> u64 { /// Fill the buffer from the given `AsyncRead` impl with up to `max` bytes. async fn read(reader: &mut R, dest: &mut BytesMut, max: usize) -> io::Result<()> where - R: AsyncRead + Unpin + R: AsyncRead + Unpin, { let i = dest.len(); dest.resize(i + max, 0u8); - let n = reader.read(&mut dest[i ..]).await?; + let n = reader.read(&mut dest[i..]).await?; dest.truncate(i + n); if n == 0 { - return Err(io::ErrorKind::UnexpectedEof.into()) + return Err(io::ErrorKind::UnexpectedEof.into()); } log::trace!("read {} bytes", n); Ok(()) } -