From 050dfacca32211e5a019d017a20699bfd53af0f8 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Sat, 30 Oct 2021 22:11:02 +0200 Subject: [PATCH 1/5] clients: introduce tls feature flag --- http-client/Cargo.toml | 6 +- http-client/src/transport.rs | 50 ++++++++++--- tests/tests/integration_tests.rs | 2 +- ws-client/Cargo.toml | 9 ++- ws-client/src/stream.rs | 28 +++---- ws-client/src/transport.rs | 122 +++++++++++++++++-------------- 6 files changed, 134 insertions(+), 83 deletions(-) diff --git a/http-client/Cargo.toml b/http-client/Cargo.toml index 5f29bcf01d..0206a28358 100644 --- a/http-client/Cargo.toml +++ b/http-client/Cargo.toml @@ -11,7 +11,7 @@ documentation = "https://docs.rs/jsonrpsee-http-client" [dependencies] async-trait = "0.1" -hyper-rustls = "0.22" +hyper-rustls = { version = "0.22", optional = true } hyper = { version = "0.14.10", features = ["client", "http1", "http2", "tcp"] } jsonrpsee-types = { path = "../types", version = "0.4.1" } jsonrpsee-utils = { path = "../utils", version = "0.4.1", features = ["http-helpers"] } @@ -26,3 +26,7 @@ fnv = "1" [dev-dependencies] jsonrpsee-test-utils = { path = "../test-utils" } tokio = { package = "tokio", version = "1", features = ["net", "rt-multi-thread", "macros"] } + +[features] +default = ["tls"] +tls = ["hyper-rustls"] \ No newline at end of file diff --git a/http-client/src/transport.rs b/http-client/src/transport.rs index 1fecdd0f18..16115b4f0c 100644 --- a/http-client/src/transport.rs +++ b/http-client/src/transport.rs @@ -8,19 +8,37 @@ use crate::types::error::GenericTransportError; use hyper::client::{Client, HttpConnector}; -use hyper_rustls::HttpsConnector; use jsonrpsee_utils::http_helpers; use thiserror::Error; const CONTENT_TYPE_JSON: &str = "application/json"; +#[derive(Debug, Clone)] +enum HyperClient { + /// Hyper client with https connector. + #[cfg(feature = "tls")] + Https(Client>), + /// Hyper client with http connector. + Http(Client), +} + +impl HyperClient { + fn request(&self, req: hyper::Request) -> hyper::client::ResponseFuture { + match self { + Self::Http(client) => client.request(req), + #[cfg(feature = "tls")] + Self::Https(client) => client.request(req), + } + } +} + /// HTTP Transport Client. #[derive(Debug, Clone)] pub(crate) struct HttpTransportClient { /// Target to connect to. target: url::Url, /// HTTP client - client: Client>, + client: HyperClient, /// Configurable max request body size max_request_body_size: u32, } @@ -29,13 +47,27 @@ impl HttpTransportClient { /// Initializes a new HTTP client. pub(crate) fn new(target: impl AsRef, max_request_body_size: u32) -> Result { let target = url::Url::parse(target.as_ref()).map_err(|e| Error::Url(format!("Invalid URL: {}", e)))?; - if target.scheme() == "http" || target.scheme() == "https" { - let connector = HttpsConnector::with_native_roots(); - let client = Client::builder().build::<_, hyper::Body>(connector); - Ok(HttpTransportClient { target, client, max_request_body_size }) - } else { - Err(Error::Url("URL scheme not supported, expects 'http' or 'https'".into())) - } + let client = match target.scheme() { + "http" => { + let connector = HttpConnector::new(); + let client = Client::builder().build::<_, hyper::Body>(connector); + HyperClient::Http(client) + } + #[cfg(feature = "tls")] + "https" => { + let connector = hyper_rustls::HttpsConnector::with_native_roots(); + let client = Client::builder().build::<_, hyper::Body>(connector); + HyperClient::Https(client) + } + _ => { + #[cfg(feature = "tls")] + let err = "URL scheme not supported, expects 'http' or 'https'"; + #[cfg(not(feature = "tls"))] + let err = "URL scheme not supported, expects 'http'"; + return Err(Error::Url(err.into())); + } + }; + Ok(Self { target, client, max_request_body_size }) } async fn inner_send(&self, body: String) -> Result, Error> { diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 2cc7570bd1..d9b7e6544c 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -245,7 +245,7 @@ async fn https_works() { #[tokio::test] #[ignore] async fn wss_works() { - let client = WsClientBuilder::default().build("wss://kusama-rpc.polkadot.io").await.unwrap(); + let client = WsClientBuilder::default().build("wss://kusama-rpc.polkadot.io:443").await.unwrap(); let response: String = client.request("system_chain", None).await.unwrap(); assert_eq!(&response, "Kusama"); } diff --git a/ws-client/Cargo.toml b/ws-client/Cargo.toml index 127676fd38..2236648bfd 100644 --- a/ws-client/Cargo.toml +++ b/ws-client/Cargo.toml @@ -11,7 +11,7 @@ documentation = "https://docs.rs/jsonrpsee-ws-client" [dependencies] tokio = { version = "1", features = ["net", "time", "rt-multi-thread", "macros"] } -tokio-rustls = "0.22" +tokio-rustls = { version = "0.22", optional = true } tokio-util = { version = "0.6", features = ["compat"] } async-trait = "0.1" @@ -26,9 +26,14 @@ serde_json = "1" soketto = "0.7" thiserror = "1" tracing = "0.1" +webpki-roots = "0.21" [dev-dependencies] env_logger = "0.9" jsonrpsee-test-utils = { path = "../test-utils" } jsonrpsee-utils = { path = "../utils" } -tokio = { version = "1", features = ["macros"] } \ No newline at end of file +tokio = { version = "1", features = ["macros"] } + +[features] +default = ["tls"] +tls = ["tokio-rustls"] diff --git a/ws-client/src/stream.rs b/ws-client/src/stream.rs index 6dee6d8a66..bb859e9c9c 100644 --- a/ws-client/src/stream.rs +++ b/ws-client/src/stream.rs @@ -32,23 +32,21 @@ use futures::{ }; use pin_project::pin_project; use std::{io::Error as IoError, pin::Pin, task::Context, task::Poll}; +use tokio::net::TcpStream; use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; /// Stream to represent either a unencrypted or encrypted socket stream. #[pin_project(project = EitherStreamProj)] -#[derive(Debug, Copy, Clone)] -pub enum EitherStream { +#[derive(Debug)] +pub enum EitherStream { /// Unencrypted socket stream. - Plain(#[pin] S), + Plain(#[pin] TcpStream), /// Encrypted socket stream. - Tls(#[pin] T), + #[cfg(feature = "tls")] + Tls(#[pin] tokio_rustls::client::TlsStream), } -impl AsyncRead for EitherStream -where - S: TokioAsyncReadCompatExt, - T: TokioAsyncReadCompatExt, -{ +impl AsyncRead for EitherStream { fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { match self.project() { EitherStreamProj::Plain(s) => { @@ -56,6 +54,7 @@ where futures::pin_mut!(compat); AsyncRead::poll_read(compat, cx, buf) } + #[cfg(feature = "tls")] EitherStreamProj::Tls(t) => { let compat = t.compat(); futures::pin_mut!(compat); @@ -75,6 +74,7 @@ where futures::pin_mut!(compat); AsyncRead::poll_read_vectored(compat, cx, bufs) } + #[cfg(feature = "tls")] EitherStreamProj::Tls(t) => { let compat = t.compat(); futures::pin_mut!(compat); @@ -84,11 +84,7 @@ where } } -impl AsyncWrite for EitherStream -where - S: TokioAsyncWriteCompatExt, - T: TokioAsyncWriteCompatExt, -{ +impl AsyncWrite for EitherStream { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { match self.project() { EitherStreamProj::Plain(s) => { @@ -96,6 +92,7 @@ where futures::pin_mut!(compat); AsyncWrite::poll_write(compat, cx, buf) } + #[cfg(feature = "tls")] EitherStreamProj::Tls(t) => { let compat = t.compat_write(); futures::pin_mut!(compat); @@ -111,6 +108,7 @@ where futures::pin_mut!(compat); AsyncWrite::poll_write_vectored(compat, cx, bufs) } + #[cfg(feature = "tls")] EitherStreamProj::Tls(t) => { let compat = t.compat_write(); futures::pin_mut!(compat); @@ -126,6 +124,7 @@ where futures::pin_mut!(compat); AsyncWrite::poll_flush(compat, cx) } + #[cfg(feature = "tls")] EitherStreamProj::Tls(t) => { let compat = t.compat_write(); futures::pin_mut!(compat); @@ -141,6 +140,7 @@ where futures::pin_mut!(compat); AsyncWrite::poll_close(compat, cx) } + #[cfg(feature = "tls")] EitherStreamProj::Tls(t) => { let compat = t.compat_write(); futures::pin_mut!(compat); diff --git a/ws-client/src/transport.rs b/ws-client/src/transport.rs index 987d84d1f0..5a0a8744c8 100644 --- a/ws-client/src/transport.rs +++ b/ws-client/src/transport.rs @@ -40,25 +40,17 @@ use std::{ }; use thiserror::Error; use tokio::net::TcpStream; -use tokio_rustls::{ - client::TlsStream, - rustls::ClientConfig, - webpki::{DNSNameRef, InvalidDNSNameError}, - TlsConnector, -}; - -type TlsOrPlain = EitherStream>; /// Sending end of WebSocket transport. #[derive(Debug)] pub struct Sender { - inner: connection::Sender>>, + inner: connection::Sender>>, } /// Receiving end of WebSocket transport. #[derive(Debug)] pub struct Receiver { - inner: connection::Receiver>>, + inner: connection::Receiver>>, } /// Builder for a WebSocket transport [`Sender`] and ['Receiver`] pair. @@ -121,8 +113,9 @@ pub enum WsHandshakeError { Transport(#[source] soketto::handshake::Error), /// Invalid DNS name error for TLS + #[cfg(feature = "tls")] #[error("Invalid DNS name: {0}")] - InvalidDnsName(#[source] InvalidDNSNameError), + InvalidDnsName(#[source] tokio_rustls::webpki::InvalidDNSNameError), /// Server rejected the handshake. #[error("Connection rejected with status code: {status_code}")] @@ -184,25 +177,10 @@ impl Receiver { impl<'a> WsTransportClientBuilder<'a> { /// Try to establish the connection. pub async fn build(self) -> Result<(Sender, Receiver), WsHandshakeError> { - let connector = match self.target.mode { - Mode::Tls => { - let mut client_config = ClientConfig::default(); - if let CertificateStore::Native = self.certificate_store { - client_config.root_store = rustls_native_certs::load_native_certs() - .map_err(|(_, e)| WsHandshakeError::CertificateStore(e))?; - } - Some(Arc::new(client_config).into()) - } - Mode::Plain => None, - }; - - self.try_connect(connector).await + self.try_connect().await } - async fn try_connect( - self, - mut tls_connector: Option, - ) -> Result<(Sender, Receiver), WsHandshakeError> { + async fn try_connect(self) -> Result<(Sender, Receiver), WsHandshakeError> { let mut target = self.target; let mut err = None; @@ -212,14 +190,15 @@ impl<'a> WsTransportClientBuilder<'a> { // The sockaddrs might get reused if the server replies with a relative URI. let sockaddrs = std::mem::take(&mut target.sockaddrs); for sockaddr in &sockaddrs { - let tcp_stream = match connect(*sockaddr, self.timeout, &target.host, &tls_connector).await { - Ok(stream) => stream, - Err(e) => { - tracing::debug!("Failed to connect to sockaddr: {:?}", sockaddr); - err = Some(Err(e)); - continue; - } - }; + let tcp_stream = + match connect(*sockaddr, self.timeout, &target.host, &self.certificate_store, &target.mode).await { + Ok(stream) => stream, + Err(e) => { + tracing::debug!("Failed to connect to sockaddr: {:?}", sockaddr); + err = Some(Err(e)); + continue; + } + }; let mut client = WsHandshakeClient::new( BufReader::new(BufWriter::new(tcp_stream)), &target.host_header, @@ -250,17 +229,6 @@ impl<'a> WsTransportClientBuilder<'a> { // Absolute URI. if uri.scheme().is_some() { target = uri.try_into()?; - tls_connector = match target.mode { - Mode::Tls => { - let mut client_config = ClientConfig::default(); - if let CertificateStore::Native = self.certificate_store { - client_config.root_store = rustls_native_certs::load_native_certs() - .map_err(|(_, e)| WsHandshakeError::CertificateStore(e))?; - } - Some(Arc::new(client_config).into()) - } - Mode::Plain => None, - }; } // Relative URI. else { @@ -303,12 +271,14 @@ impl<'a> WsTransportClientBuilder<'a> { } } +#[cfg(feature = "tls")] async fn connect( sockaddr: SocketAddr, timeout_dur: Duration, host: &str, - tls_connector: &Option, -) -> Result>, WsHandshakeError> { + cert_store: &CertificateStore, + mode: &Mode, +) -> Result { let socket = TcpStream::connect(sockaddr); let timeout = tokio::time::sleep(timeout_dur); tokio::select! { @@ -317,12 +287,14 @@ async fn connect( if let Err(err) = socket.set_nodelay(true) { tracing::warn!("set nodelay failed: {:?}", err); } - match tls_connector { - None => Ok(TlsOrPlain::Plain(socket)), - Some(connector) => { - let dns_name = DNSNameRef::try_from_ascii_str(host)?; + match mode { + Mode::Plain => Ok(EitherStream::Plain(socket)), + Mode::Tls => { + // TODO(niklasad1): cache this. + let connector = build_tls_config(cert_store)?; + let dns_name = tokio_rustls::webpki::DNSNameRef::try_from_ascii_str(host)?; let tls_stream = connector.connect(dns_name, socket).await?; - Ok(TlsOrPlain::Tls(tls_stream)) + Ok(EitherStream::Tls(tls_stream)) } } } @@ -330,14 +302,36 @@ async fn connect( } } +#[cfg(not(feature = "tls"))] +async fn connect( + sockaddr: SocketAddr, + timeout_dur: Duration, + host: &str, + cert_store: &CertificateStore, +) -> Result { + let socket = TcpStream::connect(sockaddr); + let timeout = tokio::time::sleep(timeout_dur); + tokio::select! { + socket = socket => { + let socket = socket?; + if let Err(err) = socket.set_nodelay(true) { + tracing::warn!("set nodelay failed: {:?}", err); + } + Ok(EitherStream::Plain(socket)) + } + _ = timeout => Err(WsHandshakeError::Timeout(timeout_dur)) + } +} + impl From for WsHandshakeError { fn from(err: io::Error) -> WsHandshakeError { WsHandshakeError::Io(err) } } -impl From for WsHandshakeError { - fn from(err: InvalidDNSNameError) -> WsHandshakeError { +#[cfg(feature = "tls")] +impl From for WsHandshakeError { + fn from(err: tokio_rustls::webpki::InvalidDNSNameError) -> WsHandshakeError { WsHandshakeError::InvalidDnsName(err) } } @@ -375,6 +369,7 @@ impl TryFrom for Target { fn try_from(uri: Uri) -> Result { let mode = match uri.scheme_str() { Some("ws") => Mode::Plain, + #[cfg(feature = "tls")] Some("wss") => Mode::Tls, _ => return Err(WsHandshakeError::Url("URL scheme not supported, expects 'ws' or 'wss'".into())), }; @@ -390,6 +385,21 @@ impl TryFrom for Target { } } +#[cfg(feature = "tls")] +fn build_tls_config(cert_store: &CertificateStore) -> Result { + let mut client_config = tokio_rustls::rustls::ClientConfig::default(); + match cert_store { + CertificateStore::Native => { + client_config.root_store = + rustls_native_certs::load_native_certs().map_err(|(_, e)| WsHandshakeError::CertificateStore(e))?; + } + CertificateStore::WebPki => { + client_config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + } + }; + Ok(Arc::new(client_config).into()) +} + #[cfg(test)] mod tests { use super::{Mode, Target, Uri, WsHandshakeError}; From fab8e246a7ae7397f1c3e9a9d0b160ce89d8d31c Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Mon, 1 Nov 2021 16:11:00 +0100 Subject: [PATCH 2/5] Update tests/tests/integration_tests.rs --- tests/tests/integration_tests.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index d9b7e6544c..2cc7570bd1 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -245,7 +245,7 @@ async fn https_works() { #[tokio::test] #[ignore] async fn wss_works() { - let client = WsClientBuilder::default().build("wss://kusama-rpc.polkadot.io:443").await.unwrap(); + let client = WsClientBuilder::default().build("wss://kusama-rpc.polkadot.io").await.unwrap(); let response: String = client.request("system_chain", None).await.unwrap(); assert_eq!(&response, "Kusama"); } From 2d9020a2b4fe7df6e993ff619018b6d9b4b5f81c Mon Sep 17 00:00:00 2001 From: Niklas Date: Fri, 3 Dec 2021 19:20:42 +0100 Subject: [PATCH 3/5] fix: don't rebuild tls connector of every connect --- ws-client/src/transport.rs | 72 +++++++++++++++++++++++++------------- 1 file changed, 47 insertions(+), 25 deletions(-) diff --git a/ws-client/src/transport.rs b/ws-client/src/transport.rs index 41a5debd02..483e26b064 100644 --- a/ws-client/src/transport.rs +++ b/ws-client/src/transport.rs @@ -35,7 +35,6 @@ use std::{ convert::TryFrom, io, net::{SocketAddr, ToSocketAddrs}, - sync::Arc, time::Duration, }; use thiserror::Error; @@ -174,21 +173,39 @@ impl<'a> WsTransportClientBuilder<'a> { let mut target = self.target; let mut err = None; + // Only build TLS connector if `wss` in URL. + #[cfg(feature = "tls")] + let mut connector = match target.mode { + Mode::Tls => Some(build_tls_config(&self.certificate_store)?), + Mode::Plain => None, + }; + for _ in 0..self.max_redirections { tracing::debug!("Connecting to target: {:?}", target); // The sockaddrs might get reused if the server replies with a relative URI. let sockaddrs = std::mem::take(&mut target.sockaddrs); for sockaddr in &sockaddrs { - let tcp_stream = - match connect(*sockaddr, self.timeout, &target.host, &self.certificate_store, &target.mode).await { - Ok(stream) => stream, - Err(e) => { - tracing::debug!("Failed to connect to sockaddr: {:?}", sockaddr); - err = Some(Err(e)); - continue; - } - }; + #[cfg(feature = "tls")] + let tcp_stream = match connect(*sockaddr, self.timeout, &target.host, connector.as_ref()).await { + Ok(stream) => stream, + Err(e) => { + tracing::debug!("Failed to connect to sockaddr: {:?}", sockaddr); + err = Some(Err(e)); + continue; + } + }; + + #[cfg(not(feature = "tls"))] + let tcp_stream = match connect(*sockaddr, self.timeout).await { + Ok(stream) => stream, + Err(e) => { + tracing::debug!("Failed to connect to sockaddr: {:?}", sockaddr); + err = Some(Err(e)); + continue; + } + }; + let mut client = WsHandshakeClient::new( BufReader::new(BufWriter::new(tcp_stream)), &target.host_header, @@ -219,6 +236,19 @@ impl<'a> WsTransportClientBuilder<'a> { // Absolute URI. if uri.scheme().is_some() { target = uri.try_into()?; + + // Only build TLS connector if `wss` in redirection URL. + #[cfg(feature = "tls")] + match target.mode { + Mode::Tls if connector.is_none() => { + connector = Some(build_tls_config(&self.certificate_store)?); + } + Mode::Tls => (), + // Drop connector if it was configured previously. + Mode::Plain => { + connector = None; + } + }; } // Relative URI. else { @@ -266,8 +296,7 @@ async fn connect( sockaddr: SocketAddr, timeout_dur: Duration, host: &str, - cert_store: &CertificateStore, - mode: &Mode, + tls_connector: Option<&tokio_rustls::TlsConnector>, ) -> Result { let socket = TcpStream::connect(sockaddr); let timeout = tokio::time::sleep(timeout_dur); @@ -277,11 +306,9 @@ async fn connect( if let Err(err) = socket.set_nodelay(true) { tracing::warn!("set nodelay failed: {:?}", err); } - match mode { - Mode::Plain => Ok(EitherStream::Plain(socket)), - Mode::Tls => { - // TODO(niklasad1): cache this. - let connector = build_tls_config(cert_store)?; + match tls_connector { + None => Ok(EitherStream::Plain(socket)), + Some(connector) => { let server_name: tokio_rustls::rustls::ServerName = host.try_into().map_err(|e| WsHandshakeError::Url(format!("Invalid host: {} {:?}", host, e).into()))?; let tls_stream = connector.connect(server_name, socket).await?; Ok(EitherStream::Tls(tls_stream)) @@ -293,12 +320,7 @@ async fn connect( } #[cfg(not(feature = "tls"))] -async fn connect( - sockaddr: SocketAddr, - timeout_dur: Duration, - host: &str, - cert_store: &CertificateStore, -) -> Result { +async fn connect(sockaddr: SocketAddr, timeout_dur: Duration) -> Result { let socket = TcpStream::connect(sockaddr); let timeout = tokio::time::sleep(timeout_dur); tokio::select! { @@ -378,7 +400,7 @@ impl TryFrom for Target { // NOTE: this is slow and should be used sparingly. #[cfg(feature = "tls")] fn build_tls_config(cert_store: &CertificateStore) -> Result { - use tokio_rustls::rustls as rustls; + use tokio_rustls::rustls; let mut roots = rustls::RootCertStore::empty(); @@ -412,7 +434,7 @@ fn build_tls_config(cert_store: &CertificateStore) -> Result Date: Mon, 6 Dec 2021 13:56:50 +0100 Subject: [PATCH 4/5] fix tests + remove url dep --- http-client/Cargo.toml | 1 - http-client/src/transport.rs | 81 +++++++++++++++++++++++++++++++++--- ws-client/src/transport.rs | 28 +++++++++---- 3 files changed, 96 insertions(+), 14 deletions(-) diff --git a/http-client/Cargo.toml b/http-client/Cargo.toml index 1cbaaa3341..e9b41e0c1a 100644 --- a/http-client/Cargo.toml +++ b/http-client/Cargo.toml @@ -21,7 +21,6 @@ serde_json = "1.0" thiserror = "1.0" tokio = { version = "1.8", features = ["time"] } tracing = "0.1" -url = "2.2" [dev-dependencies] jsonrpsee-test-utils = { path = "../test-utils" } diff --git a/http-client/src/transport.rs b/http-client/src/transport.rs index 75a474aa6b..7137c6ba11 100644 --- a/http-client/src/transport.rs +++ b/http-client/src/transport.rs @@ -8,6 +8,7 @@ use crate::types::error::GenericTransportError; use hyper::client::{Client, HttpConnector}; +use hyper::Uri; use jsonrpsee_types::CertificateStore; use jsonrpsee_utils::http_helpers; use thiserror::Error; @@ -37,7 +38,7 @@ impl HyperClient { #[derive(Debug, Clone)] pub(crate) struct HttpTransportClient { /// Target to connect to. - target: url::Url, + target: Uri, /// HTTP client client: HyperClient, /// Configurable max request body size @@ -51,15 +52,19 @@ impl HttpTransportClient { max_request_body_size: u32, cert_store: CertificateStore, ) -> Result { - let target = url::Url::parse(target.as_ref()).map_err(|e| Error::Url(format!("Invalid URL: {}", e)))?; - let client = match target.scheme() { - "http" => { + let target: Uri = target.as_ref().parse().map_err(|e| Error::Url(format!("Invalid URL: {}", e)))?; + if target.port_u16().is_none() { + return Err(Error::Url("Port number is missing in the URL".into())); + } + + let client = match target.scheme_str() { + Some("http") => { let connector = HttpConnector::new(); let client = Client::builder().build::<_, hyper::Body>(connector); HyperClient::Http(client) } #[cfg(feature = "tls")] - "https" => { + Some("https") => { let connector = match cert_store { CertificateStore::Native => { hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1() @@ -90,7 +95,9 @@ impl HttpTransportClient { return Err(Error::RequestTooLarge); } - let req = hyper::Request::post(self.target.as_str()) + // TODO(niklasad1): this annoying we could just take `&str` here but more user-friendly to check + // that the uri is well-formed in the constructor. + let req = hyper::Request::post(self.target.clone()) .header(hyper::header::CONTENT_TYPE, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON)) .header(hyper::header::ACCEPT, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON)) .body(From::from(body)) @@ -167,12 +174,74 @@ where mod tests { use super::{CertificateStore, Error, HttpTransportClient}; + fn assert_target( + client: &HttpTransportClient, + host: &str, + scheme: &str, + path_and_query: &str, + port: u16, + max_request_size: u32, + ) { + assert_eq!(client.target.scheme_str(), Some(scheme)); + assert_eq!(client.target.path_and_query().map(|pq| pq.as_str()), Some(path_and_query)); + assert_eq!(client.target.host(), Some(host)); + assert_eq!(client.target.port_u16(), Some(port)); + assert_eq!(client.max_request_body_size, max_request_size); + } + #[test] fn invalid_http_url_rejected() { let err = HttpTransportClient::new("ws://localhost:9933", 80, CertificateStore::Native).unwrap_err(); assert!(matches!(err, Error::Url(_))); } + #[cfg(feature = "tls")] + #[test] + fn https_works() { + let client = HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native).unwrap(); + assert_target(&client, "localhost", "https", "/", 9933, 80); + } + + #[cfg(not(feature = "tls"))] + #[test] + fn https_fails_without_tls_feature() { + let err = HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native).unwrap_err(); + assert!(matches!(err, Error::Url(_))); + } + + #[test] + fn faulty_port() { + let err = HttpTransportClient::new("http://localhost:-43", 80, CertificateStore::Native).unwrap_err(); + assert!(matches!(err, Error::Url(_))); + let err = HttpTransportClient::new("http://localhost:-99999", 80, CertificateStore::Native).unwrap_err(); + assert!(matches!(err, Error::Url(_))); + } + + #[test] + fn url_with_path_works() { + let client = + HttpTransportClient::new("http://localhost:9944/my-special-path", 1337, CertificateStore::Native).unwrap(); + assert_target(&client, "localhost", "http", "/my-special-path", 9944, 1337); + } + + #[test] + fn url_with_query_works() { + let client = HttpTransportClient::new( + "http://127.0.0.1:9999/my?name1=value1&name2=value2", + u32::MAX, + CertificateStore::WebPki, + ) + .unwrap(); + assert_target(&client, "127.0.0.1", "http", "/my?name1=value1&name2=value2", 9999, u32::MAX); + } + + #[test] + fn url_with_fragment_is_ignored() { + let client = + HttpTransportClient::new("http://127.0.0.1:9944/my.htm#ignore", 999, CertificateStore::Native).unwrap(); + assert_target(&client, "127.0.0.1", "http", "/my.htm", 9944, 999); + } + #[tokio::test] async fn request_limit_works() { let eighty_bytes_limit = 80; diff --git a/ws-client/src/transport.rs b/ws-client/src/transport.rs index 483e26b064..f33885422d 100644 --- a/ws-client/src/transport.rs +++ b/ws-client/src/transport.rs @@ -383,7 +383,13 @@ impl TryFrom for Target { Some("ws") => Mode::Plain, #[cfg(feature = "tls")] Some("wss") => Mode::Tls, - _ => return Err(WsHandshakeError::Url("URL scheme not supported, expects 'ws' or 'wss'".into())), + _ => { + #[cfg(feature = "tls")] + let err = "URL scheme not supported, expects 'ws' or 'wss'"; + #[cfg(not(feature = "tls"))] + let err = "URL scheme not supported, expects 'ws'"; + return Err(WsHandshakeError::Url(err.into())); + } }; let host = uri.host().map(ToOwned::to_owned).ok_or_else(|| WsHandshakeError::Url("No host in URL".into()))?; let port = uri @@ -460,12 +466,20 @@ mod tests { assert_ws_target(target, "127.0.0.1", "127.0.0.1:9933", Mode::Plain, "/"); } + #[cfg(feature = "tls")] #[test] fn wss_works() { let target = parse_target("wss://kusama-rpc.polkadot.io:443").unwrap(); assert_ws_target(target, "kusama-rpc.polkadot.io", "kusama-rpc.polkadot.io:443", Mode::Tls, "/"); } + #[cfg(not(feature = "tls"))] + #[test] + fn wss_fails_with_tls_feature() { + let err = parse_target("wss://kusama-rpc.polkadot.io:443").unwrap_err(); + assert!(matches!(err, WsHandshakeError::Url(_))); + } + #[test] fn faulty_url_scheme() { let err = parse_target("http://kusama-rpc.polkadot.io:443").unwrap_err(); @@ -482,19 +496,19 @@ mod tests { #[test] fn url_with_path_works() { - let target = parse_target("wss://127.0.0.1:443/my-special-path").unwrap(); - assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Tls, "/my-special-path"); + let target = parse_target("ws://127.0.0.1:443/my-special-path").unwrap(); + assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Plain, "/my-special-path"); } #[test] fn url_with_query_works() { - let target = parse_target("wss://127.0.0.1:443/my?name1=value1&name2=value2").unwrap(); - assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Tls, "/my?name1=value1&name2=value2"); + let target = parse_target("ws://127.0.0.1:443/my?name1=value1&name2=value2").unwrap(); + assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Plain, "/my?name1=value1&name2=value2"); } #[test] fn url_with_fragment_is_ignored() { - let target = parse_target("wss://127.0.0.1:443/my.htm#ignore").unwrap(); - assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Tls, "/my.htm"); + let target = parse_target("ws://127.0.0.1:443/my.htm#ignore").unwrap(); + assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Plain, "/my.htm"); } } From 6b10b435453d88b0944bd0b28ebfea59febbfd15 Mon Sep 17 00:00:00 2001 From: Niklas Date: Mon, 6 Dec 2021 15:12:19 +0100 Subject: [PATCH 5/5] fix tests again --- http-client/src/transport.rs | 4 ++-- tests/tests/integration_tests.rs | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/http-client/src/transport.rs b/http-client/src/transport.rs index 7137c6ba11..73daa3a703 100644 --- a/http-client/src/transport.rs +++ b/http-client/src/transport.rs @@ -95,8 +95,8 @@ impl HttpTransportClient { return Err(Error::RequestTooLarge); } - // TODO(niklasad1): this annoying we could just take `&str` here but more user-friendly to check - // that the uri is well-formed in the constructor. + // NOTE(niklasad1): this annoying we could just take `&str` here but more user-friendly to check + // that the URI is well-formed in the constructor. let req = hyper::Request::post(self.target.clone()) .header(hyper::header::CONTENT_TYPE, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON)) .header(hyper::header::ACCEPT, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON)) diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 267a8205ca..8070b42ba5 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -258,8 +258,7 @@ async fn ws_with_non_ascii_url_doesnt_hang_or_panic() { #[tokio::test] async fn http_with_non_ascii_url_doesnt_hang_or_panic() { - let client = HttpClientBuilder::default().build("http://♥♥♥♥♥♥∀∂").unwrap(); - let err: Result<(), Error> = client.request("system_chain", None).await; + let err = HttpClientBuilder::default().build("http://♥♥♥♥♥♥∀∂"); assert!(matches!(err, Err(Error::Transport(_)))); }