diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 85722040..334fc01b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,7 +2,7 @@ name: Rust on: push: - # Run jobs when commits are pushed to + # Run jobs when commits are pushed to # develop or release-like branches: branches: - develop @@ -40,7 +40,7 @@ jobs: uses: actions-rs/cargo@v1.0.3 with: command: check - args: --all-targets + args: --all-targets --all-features fmt: name: Run rustfmt diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ca9ebb6..3de880a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ The format is based on [Keep a Changelog]. [Keep a Changelog]: http://keepachangelog.com/en/1.0.0/ +## 0.7.0 + +- [added] Added the `handshake::http` module and example usage at `examples/hyper_server.rs` to make using Soketto in conjunction with libraries that use the `http` types (like Hyper) simpler [#45](https://github.com/paritytech/soketto/pull/45) [#48](https://github.com/paritytech/soketto/pull/48) +- [added] Allow setting custom headers on the client to be sent to WebSocket servers when the opening handshake is performed [#47](https://github.com/paritytech/soketto/pull/47) + ## 0.6.0 - [changed] Expose the `Origin` headers from the client handshake on `ClientRequest` [#35](https://github.com/paritytech/soketto/pull/35) diff --git a/Cargo.toml b/Cargo.toml index 65bf7cdd..46c2c206 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "soketto" -version = "0.6.0" +version = "0.7.0" authors = ["Parity Technologies ", "Jason Ozias "] description = "A websocket protocol implementation." keywords = ["websocket", "codec", "async", "futures"] @@ -26,9 +26,17 @@ httparse = { default-features = false, features = ["std"], version = "1.3.4" } log = { default-features = false, version = "0.4.8" } rand = { default-features = false, features = ["std", "std_rng"], version = "0.8" } sha-1 = { default-features = false, version = "0.9" } +http = { default-features = false, version = "0.2", optional = true } [dev-dependencies] quickcheck = "0.9" tokio = { version = "1", features = ["full"] } tokio-util = { version = "0.6", features = ["compat"] } tokio-stream = { version = "0.1", features = ["net"] } +hyper = { version = "0.14.10", features = ["full"] } +env_logger = "0.9.0" + +[[example]] +name = "hyper_server" +required-features = ["http"] + diff --git a/RELEASING.md b/RELEASING.md index 297fd377..65947253 100644 --- a/RELEASING.md +++ b/RELEASING.md @@ -11,7 +11,7 @@ These steps assume that you've checked out the Soketto repository and are in the 3. Check that you're happy with the current documentation. ``` - cargo doc --open + cargo doc --open --all-features ``` CI checks for broken internal links at the moment. Optionally you can also confirm that any external links @@ -19,7 +19,7 @@ These steps assume that you've checked out the Soketto repository and are in the ``` cargo install cargo-deadlinks - cargo deadlinks --check-http + cargo deadlinks --check-http -- --all-features ``` If there are minor issues with the documentation, they can be fixed in the release branch. @@ -65,5 +65,9 @@ These steps assume that you've checked out the Soketto repository and are in the git push --tags ``` + Once this is pushed, go along to [the releases page on GitHub](https://github.com/paritytech/soketto/releases) + and draft a new release which points to the tag you just pushed to `master` above. Copy the changelog comments + for the current release into the release description. + 10. Merge the `master` branch back to develop so that we keep track of any changes that we made on the release branch. diff --git a/examples/hyper_server.rs b/examples/hyper_server.rs new file mode 100644 index 00000000..9190da34 --- /dev/null +++ b/examples/hyper_server.rs @@ -0,0 +1,128 @@ +// Copyright (c) 2021 Parity Technologies (UK) Ltd. +// +// Licensed under the Apache License, Version 2.0 +// or the MIT +// license , at your +// option. All files in the project carrying such notice may not be copied, +// modified, or distributed except according to those terms. + +// An example of how to use of Soketto alongside Hyper, so that we can handle +// standard HTTP traffic with Hyper, and WebSocket connections with Soketto, on +// the same port. +// +// To try this, start up the example (`cargo run --example hyper_server`) and then +// navigate to localhost:3000 and, in the browser JS console, run: +// +// ``` +// var socket = new WebSocket("ws://localhost:3000"); +// socket.onmessage = function(msg) { console.log(msg) }; +// socket.send("Hello!"); +// ``` +// +// You'll see any messages you send echoed back. + +use futures::io::{BufReader, BufWriter}; +use hyper::{Body, Request, Response}; +use soketto::{ + handshake::http::{is_upgrade_request, Server}, + BoxedError, +}; +use tokio_util::compat::TokioAsyncReadCompatExt; + +/// Start up a hyper server. +#[tokio::main] +async fn main() -> Result<(), BoxedError> { + env_logger::init(); + + let addr = ([127, 0, 0, 1], 3000).into(); + + let service = + hyper::service::make_service_fn(|_| async { Ok::<_, hyper::Error>(hyper::service::service_fn(handler)) }); + let server = hyper::Server::bind(&addr).serve(service); + + println!("Listening on http://{} — connect and I'll echo back anything you send!", server.local_addr()); + server.await?; + + Ok(()) +} + +/// Handle incoming HTTP Requests. +async fn handler(req: Request) -> Result, BoxedError> { + if is_upgrade_request(&req) { + // Create a new handshake server. + let mut server = Server::new(); + + // Add any extensions that we want to use. + #[cfg(feature = "deflate")] + { + let deflate = soketto::extension::deflate::Deflate::new(soketto::Mode::Server); + server.add_extension(Box::new(deflate)); + } + + // Attempt the handshake. + match server.receive_request(&req) { + // The handshake has been successful so far; return the response we're given back + // and spawn a task to handle the long-running WebSocket server: + Ok(response) => { + tokio::spawn(async move { + if let Err(e) = websocket_echo_messages(server, req).await { + log::error!("Error upgrading to websocket connection: {}", e); + } + }); + Ok(response.map(|()| Body::empty())) + } + // We tried to upgrade and failed early on; tell the client about the failure however we like: + Err(e) => { + log::error!("Could not upgrade connection: {}", e); + Ok(Response::new(Body::from("Something went wrong upgrading!"))) + } + } + } else { + // The request wasn't an upgrade request; let's treat it as a standard HTTP request: + Ok(Response::new(Body::from("Hello HTTP!"))) + } +} + +/// Echo any messages we get from the client back to them +async fn websocket_echo_messages(server: Server, req: Request) -> Result<(), BoxedError> { + // The negotiation to upgrade to a WebSocket connection has been successful so far. Next, we get back the underlying + // stream using `hyper::upgrade::on`, and hand this to a Soketto server to use to handle the WebSocket communication + // on this socket. + // + // Note: awaiting this won't succeed until the handshake response has been returned to the client, so this must be + // spawned on a separate task so as not to block that response being handed back. + let stream = hyper::upgrade::on(req).await?; + let stream = BufReader::new(BufWriter::new(stream.compat())); + + // Get back a reader and writer that we can use to send and receive websocket messages. + let (mut sender, mut receiver) = server.into_builder(stream).finish(); + + // Echo any received messages back to the client: + let mut message = Vec::new(); + loop { + message.clear(); + match receiver.receive_data(&mut message).await { + Ok(soketto::Data::Binary(n)) => { + assert_eq!(n, message.len()); + sender.send_binary_mut(&mut message).await?; + sender.flush().await? + } + Ok(soketto::Data::Text(n)) => { + assert_eq!(n, message.len()); + if let Ok(txt) = std::str::from_utf8(&message) { + sender.send_text(txt).await?; + sender.flush().await? + } else { + break; + } + } + Err(soketto::connection::Error::Closed) => break, + Err(e) => { + eprintln!("Websocket connection error: {}", e); + break; + } + } + } + + Ok(()) +} diff --git a/src/handshake.rs b/src/handshake.rs index 433430a2..f9d9415c 100644 --- a/src/handshake.rs +++ b/src/handshake.rs @@ -11,10 +11,13 @@ //! [handshake]: https://tools.ietf.org/html/rfc6455#section-4 pub mod client; +#[cfg(feature = "http")] +pub mod http; pub mod server; use crate::extension::{Extension, Param}; use bytes::BytesMut; +use sha1::{Digest, Sha1}; use std::{fmt, io, str}; pub use client::{Client, ServerResponse}; @@ -105,7 +108,15 @@ where bytes.extend_from_slice(b"\r\nSec-WebSocket-Extensions: ") } - while let Some(e) = iter.next() { + append_extension_header_value(iter, bytes) +} + +// Write the extension header value to the given buffer. +fn append_extension_header_value<'a, I>(mut extensions_iter: std::iter::Peekable, bytes: &mut BytesMut) +where + I: Iterator>, +{ + while let Some(e) = extensions_iter.next() { bytes.extend_from_slice(e.name().as_bytes()); for p in e.params() { bytes.extend_from_slice(b"; "); @@ -115,12 +126,33 @@ where bytes.extend_from_slice(v.as_bytes()) } } - if iter.peek().is_some() { + if extensions_iter.peek().is_some() { bytes.extend_from_slice(b", ") } } } +// This function takes a 16 byte key (base64 encoded, and so 24 bytes of input) that is expected via +// the `Sec-WebSocket-Key` header during a websocket handshake, and writes the response that's expected +// to be handed back in the response header `Sec-WebSocket-Accept`. +// +// The response is a base64 encoding of a 160bit hash. base64 encoding uses 1 ascii character per 6 bits, +// and 160 / 6 = 26.66 characters. The output is padded with '=' to the nearest 4 characters, so we need 28 +// bytes in total for all of the characters. +// +// See https://datatracker.ietf.org/doc/html/rfc6455#section-1.3 for more information on this. +fn generate_accept_key<'k>(key_base64: &WebSocketKey) -> [u8; 28] { + let mut digest = Sha1::new(); + digest.update(key_base64); + digest.update(KEY); + let d = digest.finalize(); + + let mut output_buf = [0; 28]; + let n = base64::encode_config_slice(&d, base64::STANDARD, &mut output_buf); + debug_assert_eq!(n, 28, "encoding to base64 should be exactly 28 bytes"); + output_buf +} + /// Enumeration of possible handshake errors. #[non_exhaustive] #[derive(Debug)] diff --git a/src/handshake/client.rs b/src/handshake/client.rs index a7b6cbd2..cfb9306f 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -21,6 +21,8 @@ use futures::prelude::*; use sha1::{Digest, Sha1}; use std::{mem, str}; +pub use httparse::Header; + const BLOCK_SIZE: usize = 8 * 1024; /// Websocket client handshake. @@ -32,8 +34,8 @@ pub struct Client<'a, T> { host: &'a str, /// The HTTP host ressource. resource: &'a str, - /// The HTTP origin header. - origin: Option<&'a str>, + /// The HTTP headers. + headers: &'a [Header<'a>], /// A buffer holding the base-64 encoded request nonce. nonce: WebSocketKey, /// The protocols to include in the handshake. @@ -51,7 +53,7 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> { socket, host, resource, - origin: None, + headers: &[], nonce: [0; 24], protocols: Vec::new(), extensions: Vec::new(), @@ -70,9 +72,11 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> { mem::take(&mut self.buffer) } - /// Set the handshake origin header. - pub fn set_origin(&mut self, o: &'a str) -> &mut Self { - self.origin = Some(o); + /// Set connection headers to a slice. These headers are not checked for validity, + /// the caller of this method is responsible for verification as well as avoiding + /// conflicts with internally set headers. + pub fn set_headers(&mut self, h: &'a [Header]) -> &mut Self { + self.headers = h; self } @@ -135,10 +139,12 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> { self.buffer.extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: Upgrade"); self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Key: "); self.buffer.extend_from_slice(&self.nonce); - if let Some(o) = &self.origin { - self.buffer.extend_from_slice(b"\r\nOrigin: "); - self.buffer.extend_from_slice(o.as_bytes()) - } + self.headers.iter().for_each(|h| { + self.buffer.extend_from_slice(b"\r\n"); + self.buffer.extend_from_slice(h.name.as_bytes()); + self.buffer.extend_from_slice(b": "); + self.buffer.extend_from_slice(h.value); + }); if let Some((last, prefix)) = self.protocols.split_last() { self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: "); for p in prefix { diff --git a/src/handshake/http.rs b/src/handshake/http.rs new file mode 100644 index 00000000..c429019d --- /dev/null +++ b/src/handshake/http.rs @@ -0,0 +1,158 @@ +// Copyright (c) 2021 Parity Technologies (UK) Ltd. +// +// Licensed under the Apache License, Version 2.0 +// or the MIT +// license , at your +// option. All files in the project carrying such notice may not be copied, +// modified, or distributed except according to those terms. + +/*! +This module somewhat mirrors [`crate::handshake::server`], except it's focus is on working +with [`http::Request`] and [`http::Response`] types, making it easier to integrate with +external web servers such as Hyper. + +See `examples/hyper_server.rs` from this crate's repository for example usage. +*/ + +use super::{WebSocketKey, SEC_WEBSOCKET_EXTENSIONS}; +use crate::connection::{self, Mode}; +use crate::extension::Extension; +use crate::handshake; +use bytes::BytesMut; +use futures::prelude::*; +use http::{header, HeaderMap, Response}; +use std::convert::TryInto; +use std::mem; + +/// A re-export of [`handshake::Error`]. +pub type Error = handshake::Error; + +/// Websocket handshake server. This is similar to [`handshake::Server`], but it is +/// focused on performing the WebSocket handshake using a provided [`http::Request`], as opposed +/// to decoding the request internally. +pub struct Server { + // Extensions the server supports. + extensions: Vec>, + // Encoding/decoding buffer. + buffer: BytesMut, +} + +impl Server { + /// Create a new server handshake. + pub fn new() -> Self { + Server { extensions: Vec::new(), buffer: BytesMut::new() } + } + + /// Override the buffer to use for request/response handling. + pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self { + self.buffer = b; + self + } + + /// Extract the buffer. + pub fn take_buffer(&mut self) -> BytesMut { + mem::take(&mut self.buffer) + } + + /// Add an extension the server supports. + pub fn add_extension(&mut self, e: Box) -> &mut Self { + self.extensions.push(e); + self + } + + /// Get back all extensions. + pub fn drain_extensions(&mut self) -> impl Iterator> + '_ { + self.extensions.drain(..) + } + + /// Attempt to interpret the provided [`http::Request`] as a WebSocket Upgrade request. If successful, this + /// returns an [`http::Response`] that should be returned to the client to complete the handshake. + pub fn receive_request(&mut self, req: &http::Request) -> Result, Error> { + if !is_upgrade_request(&req) { + return Err(Error::InvalidSecWebSocketAccept); + } + + let key = match req.headers().get("Sec-WebSocket-Key") { + Some(key) => key, + None => { + return Err(Error::HeaderNotFound("Sec-WebSocket-Key".into()).into()); + } + }; + + if req.headers().get("Sec-WebSocket-Version").map(|v| v.as_bytes()) != Some(b"13") { + return Err(Error::HeaderNotFound("Sec-WebSocket-Version".into()).into()); + } + + // Pull out the Sec-WebSocket-Key and generate the appropriate response to it. + let key: &WebSocketKey = match key.as_bytes().try_into() { + Ok(key) => key, + Err(_) => return Err(Error::InvalidSecWebSocketAccept), + }; + let accept_key = handshake::generate_accept_key(key); + + // Get extension information out of the request as we'll need this as well. + let extension_config = req + .headers() + .iter() + .filter(|&(name, _)| name.as_str().eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) + .map(|(_, value)| Ok(std::str::from_utf8(value.as_bytes())?.to_string())) + .collect::, Error>>()?; + + // Attempt to set the extension configuration params that the client requested. + for config_str in &extension_config { + handshake::configure_extensions(&mut self.extensions, &config_str)?; + } + + // Build a response that should be sent back to the client to acknowledge the upgrade. + let mut response = Response::builder() + .status(http::StatusCode::SWITCHING_PROTOCOLS) + .header(http::header::CONNECTION, "upgrade") + .header(http::header::UPGRADE, "websocket") + .header("Sec-WebSocket-Accept", &accept_key[..]); + + // Tell the client about the agreed-upon extension configuration. We reuse code to build up the + // extension header value, but that does make this a little more clunky. + if !self.extensions.is_empty() { + let mut buf = bytes::BytesMut::new(); + let enabled_extensions = self.extensions.iter().filter(|e| e.is_enabled()).peekable(); + handshake::append_extension_header_value(enabled_extensions, &mut buf); + response = response.header("Sec-WebSocket-Extensions", buf.as_ref()); + } + + let response = response.body(()).expect("bug: failed to build response"); + Ok(response) + } + + /// Turn this handshake into a [`connection::Builder`]. + pub fn into_builder(mut self, socket: T) -> connection::Builder { + let mut builder = connection::Builder::new(socket, Mode::Server); + builder.set_buffer(self.buffer); + builder.add_extensions(self.extensions.drain(..)); + builder + } +} + +/// Check if an [`http::Request`] looks like a valid websocket upgrade request. +pub fn is_upgrade_request(request: &http::Request) -> bool { + header_contains_value(request.headers(), header::CONNECTION, b"upgrade") + && header_contains_value(request.headers(), header::UPGRADE, b"websocket") +} + +// Check if there is a header of the given name containing the wanted value. +fn header_contains_value(headers: &HeaderMap, header: header::HeaderName, value: &[u8]) -> bool { + pub fn trim(x: &[u8]) -> &[u8] { + let from = match x.iter().position(|x| !x.is_ascii_whitespace()) { + Some(i) => i, + None => return &[], + }; + let to = x.iter().rposition(|x| !x.is_ascii_whitespace()).unwrap(); + &x[from..=to] + } + + for header in headers.get_all(header) { + if header.as_bytes().split(|&c| c == b',').any(|x| trim(x).eq_ignore_ascii_case(value)) { + return true; + } + } + false +} diff --git a/src/handshake/server.rs b/src/handshake/server.rs index a4ba6b13..99fd5f29 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -11,21 +11,20 @@ //! [handshake]: https://tools.ietf.org/html/rfc6455#section-4 use super::{ - append_extensions, configure_extensions, expect_ascii_header, with_first_header, Error, WebSocketKey, KEY, + append_extensions, configure_extensions, expect_ascii_header, with_first_header, Error, WebSocketKey, MAX_NUM_HEADERS, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_PROTOCOL, }; use crate::connection::{self, Mode}; use crate::extension::Extension; use bytes::BytesMut; use futures::prelude::*; -use sha1::{Digest, Sha1}; use std::{mem, str}; // Most HTTP servers default to 8KB limit on headers const MAX_HEADERS_SIZE: usize = 8 * 1024; const BLOCK_SIZE: usize = 8 * 1024; -/// Websocket handshake client. +/// Websocket handshake server. #[derive(Debug)] pub struct Server<'a, T> { socket: T, @@ -187,15 +186,7 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> { fn encode_response(&mut self, response: &Response<'_>) { match response { Response::Accept { key, protocol } => { - let mut key_buf = [0; 32]; - let accept_value = { - let mut digest = Sha1::new(); - digest.update(key); - digest.update(KEY); - let d = digest.finalize(); - let n = base64::encode_config_slice(&d, base64::STANDARD, &mut key_buf); - &key_buf[..n] - }; + let accept_value = super::generate_accept_key(&key); self.buffer.extend_from_slice( concat![ "HTTP/1.1 101 Switching Protocols", @@ -207,7 +198,7 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> { ] .as_bytes(), ); - self.buffer.extend_from_slice(accept_value); + self.buffer.extend_from_slice(&accept_value); if let Some(p) = protocol { self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: "); self.buffer.extend_from_slice(p.as_bytes()) diff --git a/src/lib.rs b/src/lib.rs index e1665b1c..1408742c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -102,6 +102,10 @@ //! # } //! //! ``` +//! +//! See `examples/hyper_server.rs` from this crate's repository for an example of +//! starting up a WebSocket server alongside an Hyper HTTP server. +//! //! [client]: handshake::Client //! [server]: handshake::Server //! [Sender]: connection::Sender