Skip to content

Commit

Permalink
server: make it possible to enable ws/http only (#939)
Browse files Browse the repository at this point in the history
* server: make it possible to enable ws/http only

* Update server/src/server.rs

* Update server/src/server.rs

* Update server/src/server.rs

Co-authored-by: Alexandru Vasile <[email protected]>

Co-authored-by: Alexandru Vasile <[email protected]>
  • Loading branch information
niklasad1 and lexnv authored Nov 18, 2022
1 parent 63e7124 commit 0a22c00
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 3 deletions.
46 changes: 44 additions & 2 deletions server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ where
conn_id: id,
logger: logger.clone(),
max_connections: self.cfg.max_connections,
enable_http: self.cfg.enable_http,
enable_ws: self.cfg.enable_ws,
};
process_connection(&self.service_builder, &connection_guard, data, socket, &mut connections);
id = id.wrapping_add(1);
Expand Down Expand Up @@ -193,6 +195,10 @@ struct Settings {
tokio_runtime: Option<tokio::runtime::Handle>,
/// The interval at which `Ping` frames are submitted.
ping_interval: Duration,
/// Enable HTTP.
enable_http: bool,
/// Enable WS.
enable_ws: bool,
}

impl Default for Settings {
Expand All @@ -207,6 +213,8 @@ impl Default for Settings {
allow_hosts: AllowHosts::Any,
tokio_runtime: None,
ping_interval: Duration::from_secs(60),
enable_http: true,
enable_ws: true,
}
}
}
Expand Down Expand Up @@ -425,6 +433,26 @@ impl<B, L> Builder<B, L> {
}
}

/// Configure the server to only serve JSON-RPC HTTP requests.
///
/// Default: both http and ws are enabled.
pub fn http_only(mut self) -> Self {
self.settings.enable_http = true;
self.settings.enable_ws = false;
self
}

/// Configure the server to only serve JSON-RPC WebSocket requests.
///
/// That implies that server just denies HTTP requests which isn't a WebSocket upgrade request
///
/// Default: both http and ws are enabled.
pub fn ws_only(mut self) -> Self {
self.settings.enable_http = false;
self.settings.enable_ws = true;
self
}

/// Finalize the configuration of the server. Consumes the [`Builder`].
///
/// ```rust
Expand Down Expand Up @@ -547,6 +575,10 @@ pub(crate) struct ServiceData<L: Logger> {
pub(crate) logger: L,
/// Handle to hold a `connection permit`.
pub(crate) conn: Arc<OwnedSemaphorePermit>,
/// Enable HTTP.
pub(crate) enable_http: bool,
/// Enable WS.
pub(crate) enable_ws: bool,
}

/// JsonRPSee service compatible with `tower`.
Expand Down Expand Up @@ -589,7 +621,9 @@ impl<L: Logger> hyper::service::Service<hyper::Request<hyper::Body>> for TowerSe
return async { Ok(http::response::host_not_allowed()) }.boxed();
}

if is_upgrade_request(&request) {
let is_upgrade_request = is_upgrade_request(&request);

if self.inner.enable_ws && is_upgrade_request {
let mut server = soketto::handshake::http::Server::new();

let response = match server.receive_request(&request) {
Expand Down Expand Up @@ -626,7 +660,7 @@ impl<L: Logger> hyper::service::Service<hyper::Request<hyper::Body>> for TowerSe
};

async { Ok(response) }.boxed()
} else {
} else if self.inner.enable_http && !is_upgrade_request {
// The request wasn't an upgrade request; let's treat it as a standard HTTP request:
let data = http::HandleRequest {
methods: self.inner.methods.clone(),
Expand All @@ -643,6 +677,8 @@ impl<L: Logger> hyper::service::Service<hyper::Request<hyper::Body>> for TowerSe
self.inner.logger.on_connect(self.inner.remote_addr, &request, TransportProtocol::Http);

Box::pin(http::handle_request(request, data).map(Ok))
} else {
Box::pin(async { http::response::denied() }.map(Ok))
}
}
}
Expand Down Expand Up @@ -730,6 +766,10 @@ struct ProcessConnection<L> {
conn_id: u32,
/// Logger.
logger: L,
/// Allow JSON-RPC HTTP requests.
enable_http: bool,
/// Allow JSON-RPC WS request and WS upgrade requests.
enable_ws: bool,
}

#[instrument(name = "connection", skip_all, fields(remote_addr = %cfg.remote_addr, conn_id = %cfg.conn_id), level = "INFO")]
Expand Down Expand Up @@ -787,6 +827,8 @@ fn process_connection<'a, L: Logger, B, U>(
conn_id: cfg.conn_id,
logger: cfg.logger,
conn: Arc::new(conn),
enable_http: cfg.enable_http,
enable_ws: cfg.enable_ws,
},
};

Expand Down
58 changes: 57 additions & 1 deletion server/src/tests/shared.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
use crate::tests::helpers::{init_logger, server_with_handles};
use http::StatusCode;
use jsonrpsee_core::Error;
use jsonrpsee_test_utils::TimeoutFutureExt;
use jsonrpsee_test_utils::{
helpers::{http_request, ok_response, to_http_uri},
mocks::{Id, WebSocketTestClient, WebSocketTestError},
TimeoutFutureExt,
};
use std::time::Duration;

#[tokio::test]
Expand Down Expand Up @@ -35,3 +40,54 @@ async fn run_forever() {
// Send the shutdown request from one handle and await the server on the second one.
server_handle.stopped().with_timeout(TIMEOUT).await.unwrap();
}

#[tokio::test]
async fn http_only_works() {
use crate::{RpcModule, ServerBuilder};

let server =
ServerBuilder::default().http_only().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap();
let mut module = RpcModule::new(());
module
.register_method("say_hello", |_, _| {
tracing::debug!("server respond to hello");
Ok("hello")
})
.unwrap();

let addr = server.local_addr().unwrap();
let _server_handle = server.start(module).unwrap();

let req = r#"{"jsonrpc":"2.0","method":"say_hello","id":1}"#;
let response = http_request(req.into(), to_http_uri(addr)).with_default_timeout().await.unwrap().unwrap();
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, ok_response("hello".to_string().into(), Id::Num(1)));

let err = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap_err();
assert!(matches!(err, WebSocketTestError::RejectedWithStatusCode(code) if code == 403));
}

#[tokio::test]
async fn ws_only_works() {
use crate::{RpcModule, ServerBuilder};

let server = ServerBuilder::default().ws_only().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap();
let mut module = RpcModule::new(());
module
.register_method("say_hello", |_, _| {
tracing::debug!("server respond to hello");
Ok("hello")
})
.unwrap();

let addr = server.local_addr().unwrap();
let _server_handle = server.start(module).unwrap();

let req = r#"{"jsonrpc":"2.0","method":"say_hello","id":1}"#;
let response = http_request(req.into(), to_http_uri(addr)).with_default_timeout().await.unwrap().unwrap();
assert_eq!(response.status, StatusCode::FORBIDDEN);

let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap();
let response = client.send_request_text(req.to_string()).await.unwrap();
assert_eq!(response, ok_response("hello".to_string().into(), Id::Num(1)));
}
5 changes: 5 additions & 0 deletions server/src/transport/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,4 +426,9 @@ pub(crate) mod response {
TEXT,
)
}

/// Create a response for when the server denied the request.
pub(crate) fn denied() -> hyper::Response<hyper::Body> {
from_template(hyper::StatusCode::FORBIDDEN, "".to_owned(), TEXT)
}
}

0 comments on commit 0a22c00

Please sign in to comment.