diff --git a/server/src/server.rs b/server/src/server.rs index 1d11d60394..f2b0fdb277 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -155,6 +155,8 @@ where conn_id: id, logger: logger.clone(), max_connections: self.cfg.max_connections, + enable_http: self.cfg.enable_http, + enable_ws: self.cfg.enable_ws, }; process_connection(&self.service_builder, &connection_guard, data, socket, &mut connections); id = id.wrapping_add(1); @@ -193,6 +195,10 @@ struct Settings { tokio_runtime: Option, /// The interval at which `Ping` frames are submitted. ping_interval: Duration, + /// Enable HTTP. + enable_http: bool, + /// Enable WS. + enable_ws: bool, } impl Default for Settings { @@ -207,6 +213,8 @@ impl Default for Settings { allow_hosts: AllowHosts::Any, tokio_runtime: None, ping_interval: Duration::from_secs(60), + enable_http: true, + enable_ws: true, } } } @@ -425,6 +433,26 @@ impl Builder { } } + /// Configure the server to only serve JSON-RPC HTTP requests. + /// + /// Default: both http and ws are enabled. + pub fn http_only(mut self) -> Self { + self.settings.enable_http = true; + self.settings.enable_ws = false; + self + } + + /// Configure the server to only serve JSON-RPC WebSocket requests. + /// + /// That implies that server just denies HTTP requests which isn't a WebSocket upgrade request + /// + /// Default: both http and ws are enabled. + pub fn ws_only(mut self) -> Self { + self.settings.enable_http = false; + self.settings.enable_ws = true; + self + } + /// Finalize the configuration of the server. Consumes the [`Builder`]. /// /// ```rust @@ -547,6 +575,10 @@ pub(crate) struct ServiceData { pub(crate) logger: L, /// Handle to hold a `connection permit`. pub(crate) conn: Arc, + /// Enable HTTP. + pub(crate) enable_http: bool, + /// Enable WS. + pub(crate) enable_ws: bool, } /// JsonRPSee service compatible with `tower`. @@ -589,7 +621,9 @@ impl hyper::service::Service> for TowerSe return async { Ok(http::response::host_not_allowed()) }.boxed(); } - if is_upgrade_request(&request) { + let is_upgrade_request = is_upgrade_request(&request); + + if self.inner.enable_ws && is_upgrade_request { let mut server = soketto::handshake::http::Server::new(); let response = match server.receive_request(&request) { @@ -626,7 +660,7 @@ impl hyper::service::Service> for TowerSe }; async { Ok(response) }.boxed() - } else { + } else if self.inner.enable_http && !is_upgrade_request { // The request wasn't an upgrade request; let's treat it as a standard HTTP request: let data = http::HandleRequest { methods: self.inner.methods.clone(), @@ -643,6 +677,8 @@ impl hyper::service::Service> for TowerSe self.inner.logger.on_connect(self.inner.remote_addr, &request, TransportProtocol::Http); Box::pin(http::handle_request(request, data).map(Ok)) + } else { + Box::pin(async { http::response::denied() }.map(Ok)) } } } @@ -730,6 +766,10 @@ struct ProcessConnection { conn_id: u32, /// Logger. logger: L, + /// Allow JSON-RPC HTTP requests. + enable_http: bool, + /// Allow JSON-RPC WS request and WS upgrade requests. + enable_ws: bool, } #[instrument(name = "connection", skip_all, fields(remote_addr = %cfg.remote_addr, conn_id = %cfg.conn_id), level = "INFO")] @@ -787,6 +827,8 @@ fn process_connection<'a, L: Logger, B, U>( conn_id: cfg.conn_id, logger: cfg.logger, conn: Arc::new(conn), + enable_http: cfg.enable_http, + enable_ws: cfg.enable_ws, }, }; diff --git a/server/src/tests/shared.rs b/server/src/tests/shared.rs index d756cb15e5..bb25395a13 100644 --- a/server/src/tests/shared.rs +++ b/server/src/tests/shared.rs @@ -1,6 +1,11 @@ use crate::tests::helpers::{init_logger, server_with_handles}; +use http::StatusCode; use jsonrpsee_core::Error; -use jsonrpsee_test_utils::TimeoutFutureExt; +use jsonrpsee_test_utils::{ + helpers::{http_request, ok_response, to_http_uri}, + mocks::{Id, WebSocketTestClient, WebSocketTestError}, + TimeoutFutureExt, +}; use std::time::Duration; #[tokio::test] @@ -35,3 +40,54 @@ async fn run_forever() { // Send the shutdown request from one handle and await the server on the second one. server_handle.stopped().with_timeout(TIMEOUT).await.unwrap(); } + +#[tokio::test] +async fn http_only_works() { + use crate::{RpcModule, ServerBuilder}; + + let server = + ServerBuilder::default().http_only().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap(); + let mut module = RpcModule::new(()); + module + .register_method("say_hello", |_, _| { + tracing::debug!("server respond to hello"); + Ok("hello") + }) + .unwrap(); + + let addr = server.local_addr().unwrap(); + let _server_handle = server.start(module).unwrap(); + + let req = r#"{"jsonrpc":"2.0","method":"say_hello","id":1}"#; + let response = http_request(req.into(), to_http_uri(addr)).with_default_timeout().await.unwrap().unwrap(); + assert_eq!(response.status, StatusCode::OK); + assert_eq!(response.body, ok_response("hello".to_string().into(), Id::Num(1))); + + let err = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap_err(); + assert!(matches!(err, WebSocketTestError::RejectedWithStatusCode(code) if code == 403)); +} + +#[tokio::test] +async fn ws_only_works() { + use crate::{RpcModule, ServerBuilder}; + + let server = ServerBuilder::default().ws_only().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap(); + let mut module = RpcModule::new(()); + module + .register_method("say_hello", |_, _| { + tracing::debug!("server respond to hello"); + Ok("hello") + }) + .unwrap(); + + let addr = server.local_addr().unwrap(); + let _server_handle = server.start(module).unwrap(); + + let req = r#"{"jsonrpc":"2.0","method":"say_hello","id":1}"#; + let response = http_request(req.into(), to_http_uri(addr)).with_default_timeout().await.unwrap().unwrap(); + assert_eq!(response.status, StatusCode::FORBIDDEN); + + let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap(); + let response = client.send_request_text(req.to_string()).await.unwrap(); + assert_eq!(response, ok_response("hello".to_string().into(), Id::Num(1))); +} diff --git a/server/src/transport/http.rs b/server/src/transport/http.rs index 8ddeb74d87..ad693bd708 100644 --- a/server/src/transport/http.rs +++ b/server/src/transport/http.rs @@ -426,4 +426,9 @@ pub(crate) mod response { TEXT, ) } + + /// Create a response for when the server denied the request. + pub(crate) fn denied() -> hyper::Response { + from_template(hyper::StatusCode::FORBIDDEN, "".to_owned(), TEXT) + } }