Skip to content

Commit

Permalink
fix(ws server): batch wait until all methods has been executed. (#542)
Browse files Browse the repository at this point in the history
* reproduce Kian's issue

* fix ws server wait until batches has completed

* fix nit

* clippify

* enable benches for ws batch requests

* use stream instead of futures::join_all

* clippify

* address grumbles: better assert
  • Loading branch information
niklasad1 authored Nov 1, 2021
1 parent 6815422 commit 092081a
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 31 deletions.
6 changes: 2 additions & 4 deletions benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,14 @@ criterion_group!(
SyncBencher::http_requests,
SyncBencher::batched_http_requests,
SyncBencher::websocket_requests,
// TODO: https://github.com/paritytech/jsonrpsee/issues/528
// SyncBencher::batched_ws_requests,
SyncBencher::batched_ws_requests,
);
criterion_group!(
async_benches,
AsyncBencher::http_requests,
AsyncBencher::batched_http_requests,
AsyncBencher::websocket_requests,
// TODO: https://github.com/paritytech/jsonrpsee/issues/528
// AsyncBencher::batched_ws_requests
AsyncBencher::batched_ws_requests
);
criterion_group!(subscriptions, AsyncBencher::subscriptions);
criterion_main!(types_benches, sync_benches, async_benches, subscriptions);
Expand Down
3 changes: 2 additions & 1 deletion tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ license = "MIT"
publish = false

[dev-dependencies]
env_logger = "0.8"
beef = { version = "0.5.1", features = ["impl_serde"] }
futures = { version = "0.3.14", default-features = false, features = ["std"] }
jsonrpsee = { path = "../jsonrpsee", features = ["full"] }
tokio = { version = "1", features = ["full"] }
serde_json = "1"
tracing = "0.1"
serde_json = "1"
7 changes: 7 additions & 0 deletions tests/tests/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ pub async fn websocket_server() -> SocketAddr {
let mut module = RpcModule::new(());
module.register_method("say_hello", |_, _| Ok("hello")).unwrap();

module
.register_async_method("slow_hello", |_, _| async {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
Ok("hello")
})
.unwrap();

let addr = server.local_addr().unwrap();

server.start(module).unwrap();
Expand Down
15 changes: 15 additions & 0 deletions tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,18 @@ async fn ws_server_should_stop_subscription_after_client_drop() {
// assert that the server received `SubscriptionClosed` after the client was dropped.
assert!(matches!(rx.next().await.unwrap(), SubscriptionClosedError { .. }));
}

#[tokio::test]
async fn ws_batch_works() {
let server_addr = websocket_server().await;
let server_url = format!("ws://{}", server_addr);
let client = WsClientBuilder::default().build(&server_url).await.unwrap();

let mut batch = Vec::new();

batch.push(("say_hello", rpc_params![]));
batch.push(("slow_hello", rpc_params![]));

let responses: Vec<String> = client.batch_request(batch).await.unwrap();
assert_eq!(responses, vec!["hello".to_string(), "hello".to_string()]);
}
64 changes: 38 additions & 26 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ use crate::types::{
TEN_MB_SIZE_BYTES,
};
use futures_channel::mpsc;
use futures_util::future::FutureExt;
use futures_util::io::{BufReader, BufWriter};
use futures_util::stream::StreamExt;
use futures_util::stream::{self, StreamExt};
use soketto::handshake::{server::Response, Server as SokettoServer};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
Expand Down Expand Up @@ -296,34 +297,45 @@ async fn background_task(
}
}
Some(b'[') => {
if let Ok(batch) = serde_json::from_slice::<Vec<Request>>(&data) {
if !batch.is_empty() {
// Batch responses must be sent back as a single message so we read the results from each
// request in the batch and read the results off of a new channel, `rx_batch`, and then send the
// complete batch response back to the client over `tx`.
let (tx_batch, mut rx_batch) = mpsc::unbounded::<String>();

for fut in batch
.into_iter()
.filter_map(|req| methods.execute_with_resources(&tx_batch, req, conn_id, &resources))
{
method_executors.add(fut);
}

// Closes the receiving half of a channel without dropping it. This prevents any further
// messages from being sent on the channel.
rx_batch.close();
let results = collect_batch_response(rx_batch).await;
if let Err(err) = tx.unbounded_send(results) {
tracing::error!("Error sending batch response to the client: {:?}", err)
// Make sure the following variables are not moved into async closure below.
let d = std::mem::take(&mut data);
let resources = &resources;
let methods = &methods;
let tx2 = tx.clone();

let fut = async move {
// Batch responses must be sent back as a single message so we read the results from each
// request in the batch and read the results off of a new channel, `rx_batch`, and then send the
// complete batch response back to the client over `tx`.
let (tx_batch, mut rx_batch) = mpsc::unbounded();
if let Ok(batch) = serde_json::from_slice::<Vec<Request>>(&d) {
if !batch.is_empty() {
let methods_stream =
stream::iter(batch.into_iter().filter_map(|req| {
methods.execute_with_resources(&tx_batch, req, conn_id, resources)
}));

let results = methods_stream
.for_each_concurrent(None, |item| item)
.then(|_| {
rx_batch.close();
collect_batch_response(rx_batch)
})
.await;

if let Err(err) = tx2.unbounded_send(results) {
tracing::error!("Error sending batch response to the client: {:?}", err)
}
} else {
send_error(Id::Null, &tx2, ErrorCode::InvalidRequest.into());
}
} else {
send_error(Id::Null, &tx, ErrorCode::InvalidRequest.into());
let (id, code) = prepare_error(&d);
send_error(id, &tx2, code.into());
}
} else {
let (id, code) = prepare_error(&data);
send_error(id, &tx, code.into());
}
};

method_executors.add(Box::pin(fut));
}
_ => send_error(Id::Null, &tx, ErrorCode::ParseError.into()),
}
Expand Down

0 comments on commit 092081a

Please sign in to comment.