diff --git a/http-client/Cargo.toml b/http-client/Cargo.toml index 87369fd228..e9b41e0c1a 100644 --- a/http-client/Cargo.toml +++ b/http-client/Cargo.toml @@ -13,7 +13,7 @@ documentation = "https://docs.rs/jsonrpsee-http-client" async-trait = "0.1" fnv = "1" hyper = { version = "0.14.10", features = ["client", "http1", "http2", "tcp"] } -hyper-rustls = { version = "0.23", features = ["webpki-tokio"] } +hyper-rustls = { version = "0.23", optional = true } jsonrpsee-types = { path = "../types", version = "0.6.0" } jsonrpsee-utils = { path = "../utils", version = "0.6.0", features = ["client", "http-helpers"] } serde = { version = "1.0", default-features = false, features = ["derive"] } @@ -21,8 +21,11 @@ serde_json = "1.0" thiserror = "1.0" tokio = { version = "1.8", features = ["time"] } tracing = "0.1" -url = "2.2" [dev-dependencies] jsonrpsee-test-utils = { path = "../test-utils" } tokio = { version = "1.8", features = ["net", "rt-multi-thread", "macros"] } + +[features] +default = ["tls"] +tls = ["hyper-rustls/webpki-tokio"] diff --git a/http-client/src/transport.rs b/http-client/src/transport.rs index b512e6749d..73daa3a703 100644 --- a/http-client/src/transport.rs +++ b/http-client/src/transport.rs @@ -8,20 +8,39 @@ use crate::types::error::GenericTransportError; use hyper::client::{Client, HttpConnector}; -use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; +use hyper::Uri; use jsonrpsee_types::CertificateStore; use jsonrpsee_utils::http_helpers; use thiserror::Error; const CONTENT_TYPE_JSON: &str = "application/json"; +#[derive(Debug, Clone)] +enum HyperClient { + /// Hyper client with https connector. + #[cfg(feature = "tls")] + Https(Client>), + /// Hyper client with http connector. + Http(Client), +} + +impl HyperClient { + fn request(&self, req: hyper::Request) -> hyper::client::ResponseFuture { + match self { + Self::Http(client) => client.request(req), + #[cfg(feature = "tls")] + Self::Https(client) => client.request(req), + } + } +} + /// HTTP Transport Client. #[derive(Debug, Clone)] pub(crate) struct HttpTransportClient { /// Target to connect to. - target: url::Url, + target: Uri, /// HTTP client - client: Client>, + client: HyperClient, /// Configurable max request body size max_request_body_size: u32, } @@ -33,22 +52,40 @@ impl HttpTransportClient { max_request_body_size: u32, cert_store: CertificateStore, ) -> Result { - let target = url::Url::parse(target.as_ref()).map_err(|e| Error::Url(format!("Invalid URL: {}", e)))?; - if target.scheme() == "http" || target.scheme() == "https" { - let connector = match cert_store { - CertificateStore::Native => { - HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1() - } - CertificateStore::WebPki => { - HttpsConnectorBuilder::new().with_webpki_roots().https_or_http().enable_http1() - } - _ => return Err(Error::InvalidCertficateStore), - }; - let client = Client::builder().build::<_, hyper::Body>(connector.build()); - Ok(HttpTransportClient { target, client, max_request_body_size }) - } else { - Err(Error::Url("URL scheme not supported, expects 'http' or 'https'".into())) + let target: Uri = target.as_ref().parse().map_err(|e| Error::Url(format!("Invalid URL: {}", e)))?; + if target.port_u16().is_none() { + return Err(Error::Url("Port number is missing in the URL".into())); } + + let client = match target.scheme_str() { + Some("http") => { + let connector = HttpConnector::new(); + let client = Client::builder().build::<_, hyper::Body>(connector); + HyperClient::Http(client) + } + #[cfg(feature = "tls")] + Some("https") => { + let connector = match cert_store { + CertificateStore::Native => { + hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1() + } + CertificateStore::WebPki => { + hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots().https_or_http().enable_http1() + } + _ => return Err(Error::InvalidCertficateStore), + }; + let client = Client::builder().build::<_, hyper::Body>(connector.build()); + HyperClient::Https(client) + } + _ => { + #[cfg(feature = "tls")] + let err = "URL scheme not supported, expects 'http' or 'https'"; + #[cfg(not(feature = "tls"))] + let err = "URL scheme not supported, expects 'http'"; + return Err(Error::Url(err.into())); + } + }; + Ok(Self { target, client, max_request_body_size }) } async fn inner_send(&self, body: String) -> Result, Error> { @@ -58,7 +95,9 @@ impl HttpTransportClient { return Err(Error::RequestTooLarge); } - let req = hyper::Request::post(self.target.as_str()) + // NOTE(niklasad1): this annoying we could just take `&str` here but more user-friendly to check + // that the URI is well-formed in the constructor. + let req = hyper::Request::post(self.target.clone()) .header(hyper::header::CONTENT_TYPE, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON)) .header(hyper::header::ACCEPT, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON)) .body(From::from(body)) @@ -135,12 +174,74 @@ where mod tests { use super::{CertificateStore, Error, HttpTransportClient}; + fn assert_target( + client: &HttpTransportClient, + host: &str, + scheme: &str, + path_and_query: &str, + port: u16, + max_request_size: u32, + ) { + assert_eq!(client.target.scheme_str(), Some(scheme)); + assert_eq!(client.target.path_and_query().map(|pq| pq.as_str()), Some(path_and_query)); + assert_eq!(client.target.host(), Some(host)); + assert_eq!(client.target.port_u16(), Some(port)); + assert_eq!(client.max_request_body_size, max_request_size); + } + #[test] fn invalid_http_url_rejected() { let err = HttpTransportClient::new("ws://localhost:9933", 80, CertificateStore::Native).unwrap_err(); assert!(matches!(err, Error::Url(_))); } + #[cfg(feature = "tls")] + #[test] + fn https_works() { + let client = HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native).unwrap(); + assert_target(&client, "localhost", "https", "/", 9933, 80); + } + + #[cfg(not(feature = "tls"))] + #[test] + fn https_fails_without_tls_feature() { + let err = HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native).unwrap_err(); + assert!(matches!(err, Error::Url(_))); + } + + #[test] + fn faulty_port() { + let err = HttpTransportClient::new("http://localhost:-43", 80, CertificateStore::Native).unwrap_err(); + assert!(matches!(err, Error::Url(_))); + let err = HttpTransportClient::new("http://localhost:-99999", 80, CertificateStore::Native).unwrap_err(); + assert!(matches!(err, Error::Url(_))); + } + + #[test] + fn url_with_path_works() { + let client = + HttpTransportClient::new("http://localhost:9944/my-special-path", 1337, CertificateStore::Native).unwrap(); + assert_target(&client, "localhost", "http", "/my-special-path", 9944, 1337); + } + + #[test] + fn url_with_query_works() { + let client = HttpTransportClient::new( + "http://127.0.0.1:9999/my?name1=value1&name2=value2", + u32::MAX, + CertificateStore::WebPki, + ) + .unwrap(); + assert_target(&client, "127.0.0.1", "http", "/my?name1=value1&name2=value2", 9999, u32::MAX); + } + + #[test] + fn url_with_fragment_is_ignored() { + let client = + HttpTransportClient::new("http://127.0.0.1:9944/my.htm#ignore", 999, CertificateStore::Native).unwrap(); + assert_target(&client, "127.0.0.1", "http", "/my.htm", 9944, 999); + } + #[tokio::test] async fn request_limit_works() { let eighty_bytes_limit = 80; diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 267a8205ca..8070b42ba5 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -258,8 +258,7 @@ async fn ws_with_non_ascii_url_doesnt_hang_or_panic() { #[tokio::test] async fn http_with_non_ascii_url_doesnt_hang_or_panic() { - let client = HttpClientBuilder::default().build("http://♥♥♥♥♥♥∀∂").unwrap(); - let err: Result<(), Error> = client.request("system_chain", None).await; + let err = HttpClientBuilder::default().build("http://♥♥♥♥♥♥∀∂"); assert!(matches!(err, Err(Error::Transport(_)))); } diff --git a/ws-client/Cargo.toml b/ws-client/Cargo.toml index 2b054d7aaa..7002950231 100644 --- a/ws-client/Cargo.toml +++ b/ws-client/Cargo.toml @@ -22,7 +22,7 @@ serde_json = "1" soketto = "0.7.1" thiserror = "1" tokio = { version = "1.8", features = ["net", "time", "rt-multi-thread", "macros"] } -tokio-rustls = "0.23" +tokio-rustls = { version = "0.23", optional = true } tokio-util = { version = "0.6", features = ["compat"] } tracing = "0.1" webpki-roots = "0.22.0" @@ -32,3 +32,7 @@ env_logger = "0.9" jsonrpsee-test-utils = { path = "../test-utils" } jsonrpsee-utils = { path = "../utils", features = ["client"] } tokio = { version = "1.8", features = ["macros"] } + +[features] +default = ["tls"] +tls = ["tokio-rustls"] diff --git a/ws-client/src/stream.rs b/ws-client/src/stream.rs index 6dee6d8a66..bb859e9c9c 100644 --- a/ws-client/src/stream.rs +++ b/ws-client/src/stream.rs @@ -32,23 +32,21 @@ use futures::{ }; use pin_project::pin_project; use std::{io::Error as IoError, pin::Pin, task::Context, task::Poll}; +use tokio::net::TcpStream; use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; /// Stream to represent either a unencrypted or encrypted socket stream. #[pin_project(project = EitherStreamProj)] -#[derive(Debug, Copy, Clone)] -pub enum EitherStream { +#[derive(Debug)] +pub enum EitherStream { /// Unencrypted socket stream. - Plain(#[pin] S), + Plain(#[pin] TcpStream), /// Encrypted socket stream. - Tls(#[pin] T), + #[cfg(feature = "tls")] + Tls(#[pin] tokio_rustls::client::TlsStream), } -impl AsyncRead for EitherStream -where - S: TokioAsyncReadCompatExt, - T: TokioAsyncReadCompatExt, -{ +impl AsyncRead for EitherStream { fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { match self.project() { EitherStreamProj::Plain(s) => { @@ -56,6 +54,7 @@ where futures::pin_mut!(compat); AsyncRead::poll_read(compat, cx, buf) } + #[cfg(feature = "tls")] EitherStreamProj::Tls(t) => { let compat = t.compat(); futures::pin_mut!(compat); @@ -75,6 +74,7 @@ where futures::pin_mut!(compat); AsyncRead::poll_read_vectored(compat, cx, bufs) } + #[cfg(feature = "tls")] EitherStreamProj::Tls(t) => { let compat = t.compat(); futures::pin_mut!(compat); @@ -84,11 +84,7 @@ where } } -impl AsyncWrite for EitherStream -where - S: TokioAsyncWriteCompatExt, - T: TokioAsyncWriteCompatExt, -{ +impl AsyncWrite for EitherStream { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { match self.project() { EitherStreamProj::Plain(s) => { @@ -96,6 +92,7 @@ where futures::pin_mut!(compat); AsyncWrite::poll_write(compat, cx, buf) } + #[cfg(feature = "tls")] EitherStreamProj::Tls(t) => { let compat = t.compat_write(); futures::pin_mut!(compat); @@ -111,6 +108,7 @@ where futures::pin_mut!(compat); AsyncWrite::poll_write_vectored(compat, cx, bufs) } + #[cfg(feature = "tls")] EitherStreamProj::Tls(t) => { let compat = t.compat_write(); futures::pin_mut!(compat); @@ -126,6 +124,7 @@ where futures::pin_mut!(compat); AsyncWrite::poll_flush(compat, cx) } + #[cfg(feature = "tls")] EitherStreamProj::Tls(t) => { let compat = t.compat_write(); futures::pin_mut!(compat); @@ -141,6 +140,7 @@ where futures::pin_mut!(compat); AsyncWrite::poll_close(compat, cx) } + #[cfg(feature = "tls")] EitherStreamProj::Tls(t) => { let compat = t.compat_write(); futures::pin_mut!(compat); diff --git a/ws-client/src/transport.rs b/ws-client/src/transport.rs index e16e94e90a..f33885422d 100644 --- a/ws-client/src/transport.rs +++ b/ws-client/src/transport.rs @@ -35,25 +35,21 @@ use std::{ convert::TryFrom, io, net::{SocketAddr, ToSocketAddrs}, - sync::Arc, time::Duration, }; use thiserror::Error; use tokio::net::TcpStream; -use tokio_rustls::{client::TlsStream, rustls, webpki::InvalidDnsNameError, TlsConnector}; - -type TlsOrPlain = EitherStream>; /// Sending end of WebSocket transport. #[derive(Debug)] pub struct Sender { - inner: connection::Sender>>, + inner: connection::Sender>>, } /// Receiving end of WebSocket transport. #[derive(Debug)] pub struct Receiver { - inner: connection::Receiver>>, + inner: connection::Receiver>>, } /// Builder for a WebSocket transport [`Sender`] and ['Receiver`] pair. @@ -106,8 +102,9 @@ pub enum WsHandshakeError { Transport(#[source] soketto::handshake::Error), /// Invalid DNS name error for TLS + #[cfg(feature = "tls")] #[error("Invalid DNS name: {0}")] - InvalidDnsName(#[source] InvalidDnsNameError), + InvalidDnsName(#[source] tokio_rustls::webpki::InvalidDnsNameError), /// Server rejected the handshake. #[error("Connection rejected with status code: {status_code}")] @@ -169,31 +166,28 @@ impl Receiver { impl<'a> WsTransportClientBuilder<'a> { /// Try to establish the connection. pub async fn build(self) -> Result<(Sender, Receiver), WsHandshakeError> { - let connector = match self.target.mode { - Mode::Tls => { - let tls_connector = build_tls_config(&self.certificate_store)?; - Some(tls_connector) - } - Mode::Plain => None, - }; - - self.try_connect(connector).await + self.try_connect().await } - async fn try_connect( - self, - mut tls_connector: Option, - ) -> Result<(Sender, Receiver), WsHandshakeError> { + async fn try_connect(self) -> Result<(Sender, Receiver), WsHandshakeError> { let mut target = self.target; let mut err = None; + // Only build TLS connector if `wss` in URL. + #[cfg(feature = "tls")] + let mut connector = match target.mode { + Mode::Tls => Some(build_tls_config(&self.certificate_store)?), + Mode::Plain => None, + }; + for _ in 0..self.max_redirections { tracing::debug!("Connecting to target: {:?}", target); // The sockaddrs might get reused if the server replies with a relative URI. let sockaddrs = std::mem::take(&mut target.sockaddrs); for sockaddr in &sockaddrs { - let tcp_stream = match connect(*sockaddr, self.timeout, &target.host, &tls_connector).await { + #[cfg(feature = "tls")] + let tcp_stream = match connect(*sockaddr, self.timeout, &target.host, connector.as_ref()).await { Ok(stream) => stream, Err(e) => { tracing::debug!("Failed to connect to sockaddr: {:?}", sockaddr); @@ -201,6 +195,17 @@ impl<'a> WsTransportClientBuilder<'a> { continue; } }; + + #[cfg(not(feature = "tls"))] + let tcp_stream = match connect(*sockaddr, self.timeout).await { + Ok(stream) => stream, + Err(e) => { + tracing::debug!("Failed to connect to sockaddr: {:?}", sockaddr); + err = Some(Err(e)); + continue; + } + }; + let mut client = WsHandshakeClient::new( BufReader::new(BufWriter::new(tcp_stream)), &target.host_header, @@ -231,13 +236,17 @@ impl<'a> WsTransportClientBuilder<'a> { // Absolute URI. if uri.scheme().is_some() { target = uri.try_into()?; + + // Only build TLS connector if `wss` in redirection URL. + #[cfg(feature = "tls")] match target.mode { - Mode::Tls if tls_connector.is_none() => { - tls_connector = Some(build_tls_config(&self.certificate_store)?); + Mode::Tls if connector.is_none() => { + connector = Some(build_tls_config(&self.certificate_store)?); } Mode::Tls => (), + // Drop connector if it was configured previously. Mode::Plain => { - tls_connector = None; + connector = None; } }; } @@ -282,12 +291,13 @@ impl<'a> WsTransportClientBuilder<'a> { } } +#[cfg(feature = "tls")] async fn connect( sockaddr: SocketAddr, timeout_dur: Duration, host: &str, - tls_connector: &Option, -) -> Result>, WsHandshakeError> { + tls_connector: Option<&tokio_rustls::TlsConnector>, +) -> Result { let socket = TcpStream::connect(sockaddr); let timeout = tokio::time::sleep(timeout_dur); tokio::select! { @@ -297,11 +307,11 @@ async fn connect( tracing::warn!("set nodelay failed: {:?}", err); } match tls_connector { - None => Ok(TlsOrPlain::Plain(socket)), + None => Ok(EitherStream::Plain(socket)), Some(connector) => { - let server_name: rustls::ServerName = host.try_into().map_err(|e| WsHandshakeError::Url(format!("Invalid host: {} {:?}", host, e).into()))?; + let server_name: tokio_rustls::rustls::ServerName = host.try_into().map_err(|e| WsHandshakeError::Url(format!("Invalid host: {} {:?}", host, e).into()))?; let tls_stream = connector.connect(server_name, socket).await?; - Ok(TlsOrPlain::Tls(tls_stream)) + Ok(EitherStream::Tls(tls_stream)) } } } @@ -309,14 +319,31 @@ async fn connect( } } +#[cfg(not(feature = "tls"))] +async fn connect(sockaddr: SocketAddr, timeout_dur: Duration) -> Result { + let socket = TcpStream::connect(sockaddr); + let timeout = tokio::time::sleep(timeout_dur); + tokio::select! { + socket = socket => { + let socket = socket?; + if let Err(err) = socket.set_nodelay(true) { + tracing::warn!("set nodelay failed: {:?}", err); + } + Ok(EitherStream::Plain(socket)) + } + _ = timeout => Err(WsHandshakeError::Timeout(timeout_dur)) + } +} + impl From for WsHandshakeError { fn from(err: io::Error) -> WsHandshakeError { WsHandshakeError::Io(err) } } -impl From for WsHandshakeError { - fn from(err: InvalidDnsNameError) -> WsHandshakeError { +#[cfg(feature = "tls")] +impl From for WsHandshakeError { + fn from(err: tokio_rustls::webpki::InvalidDnsNameError) -> WsHandshakeError { WsHandshakeError::InvalidDnsName(err) } } @@ -354,8 +381,15 @@ impl TryFrom for Target { fn try_from(uri: Uri) -> Result { let mode = match uri.scheme_str() { Some("ws") => Mode::Plain, + #[cfg(feature = "tls")] Some("wss") => Mode::Tls, - _ => return Err(WsHandshakeError::Url("URL scheme not supported, expects 'ws' or 'wss'".into())), + _ => { + #[cfg(feature = "tls")] + let err = "URL scheme not supported, expects 'ws' or 'wss'"; + #[cfg(not(feature = "tls"))] + let err = "URL scheme not supported, expects 'ws'"; + return Err(WsHandshakeError::Url(err.into())); + } }; let host = uri.host().map(ToOwned::to_owned).ok_or_else(|| WsHandshakeError::Url("No host in URL".into()))?; let port = uri @@ -370,8 +404,11 @@ impl TryFrom for Target { } // NOTE: this is slow and should be used sparingly. -fn build_tls_config(cert_store: &CertificateStore) -> Result { - let mut roots = tokio_rustls::rustls::RootCertStore::empty(); +#[cfg(feature = "tls")] +fn build_tls_config(cert_store: &CertificateStore) -> Result { + use tokio_rustls::rustls; + + let mut roots = rustls::RootCertStore::empty(); match cert_store { CertificateStore::Native => { @@ -403,7 +440,7 @@ fn build_tls_config(cert_store: &CertificateStore) -> Result