Skip to content

Commit

Permalink
rpc module: report error on invalid subscription (#561)
Browse files Browse the repository at this point in the history
* rpc module: report error on invalid subscription

* fix tests

* remove some boiler plate

* remove unused code
  • Loading branch information
niklasad1 authored Nov 18, 2021
1 parent 6af6db2 commit fff8460
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 5 deletions.
10 changes: 10 additions & 0 deletions test-utils/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
use crate::mocks::{Body, HttpResponse, Id, Uri};
use hyper::service::{make_service_fn, service_fn};
use hyper::{Request, Response, Server};
use serde::Serialize;
use serde_json::Value;
use std::convert::Infallible;
use std::net::SocketAddr;
Expand Down Expand Up @@ -95,6 +96,15 @@ pub fn invalid_params(id: Id) -> String {
)
}

pub fn call<T: Serialize>(method: &str, params: Vec<T>, id: Id) -> String {
format!(
r#"{{"jsonrpc":"2.0","method":{},"params":{},"id":{}}}"#,
serde_json::to_string(method).unwrap(),
serde_json::to_string(&params).unwrap(),
serde_json::to_string(&id).unwrap()
)
}

pub fn call_execution_failed(msg: &str, id: Id) -> String {
format!(
r#"{{"jsonrpc":"2.0","error":{{"code":-32000,"message":"{}"}},"id":{}}}"#,
Expand Down
23 changes: 22 additions & 1 deletion types/src/v2/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ pub struct RpcError<'a> {
pub id: Id<'a>,
}

impl<'a> RpcError<'a> {
/// Create a new `RpcError`.
pub fn new(error: ErrorObject<'a>, id: Id<'a>) -> Self {
Self { jsonrpc: TwoPointZero, error, id }
}
}

impl<'a> fmt::Display for RpcError<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", serde_json::to_string(&self).expect("infallible; qed"))
Expand All @@ -64,6 +71,13 @@ pub struct ErrorObject<'a> {
pub data: Option<&'a RawValue>,
}

impl<'a> ErrorObject<'a> {
/// Create a new `ErrorObject` with optional data.
pub fn new(code: ErrorCode, data: Option<&'a RawValue>) -> ErrorObject<'a> {
Self { code, message: code.message(), data }
}
}

impl<'a> From<ErrorCode> for ErrorObject<'a> {
fn from(code: ErrorCode) -> Self {
Self { code, message: code.message(), data: None }
Expand All @@ -73,7 +87,7 @@ impl<'a> From<ErrorCode> for ErrorObject<'a> {
impl<'a> PartialEq for ErrorObject<'a> {
fn eq(&self, other: &Self) -> bool {
let this_raw = self.data.map(|r| r.get());
let other_raw = self.data.map(|r| r.get());
let other_raw = other.data.map(|r| r.get());
self.code == other.code && self.message == other.message && this_raw == other_raw
}
}
Expand All @@ -98,6 +112,8 @@ pub const SERVER_IS_BUSY_CODE: i32 = -32604;
pub const CALL_EXECUTION_FAILED_CODE: i32 = -32000;
/// Unknown error.
pub const UNKNOWN_ERROR_CODE: i32 = -32001;
/// Invalid subscription error code.
pub const INVALID_SUBSCRIPTION_CODE: i32 = -32002;

/// Parse error message
pub const PARSE_ERROR_MSG: &str = "Parse error";
Expand Down Expand Up @@ -212,6 +228,11 @@ impl serde::Serialize for ErrorCode {
}
}

/// Create a invalid subscription ID error.
pub fn invalid_subscription_err(data: Option<&RawValue>) -> ErrorObject {
ErrorObject::new(ErrorCode::ServerError(INVALID_SUBSCRIPTION_CODE), data)
}

#[cfg(test)]
mod tests {
use super::{ErrorCode, ErrorObject, Id, RpcError, TwoPointZero};
Expand Down
17 changes: 13 additions & 4 deletions utils/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ use crate::server::resource_limiting::{ResourceGuard, ResourceTable, ResourceVec
use beef::Cow;
use futures_channel::{mpsc, oneshot};
use futures_util::{future::BoxFuture, FutureExt, StreamExt};
use jsonrpsee_types::to_json_raw_value;
use jsonrpsee_types::v2::error::{invalid_subscription_err, CALL_EXECUTION_FAILED_CODE};
use jsonrpsee_types::{
error::{Error, SubscriptionClosedError},
traits::ToRpcParams,
Expand Down Expand Up @@ -587,7 +589,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
err,
id
);
send_error(id, method_sink, ErrorCode::ServerError(-1).into());
send_error(id, method_sink, ErrorCode::ServerError(CALL_EXECUTION_FAILED_CODE).into());
}
})),
);
Expand All @@ -605,12 +607,18 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
unsubscribe_method_name,
id
);
send_error(id, tx, ErrorCode::ServerError(-1).into());
let err = to_json_raw_value(&"Invalid subscription ID type, must be integer").ok();
send_error(id, tx, invalid_subscription_err(err.as_deref()));
return;
}
};
subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id });
send_response(id, tx, "Unsubscribed", max_response_size);

if subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id }).is_some() {
send_response(id, tx, "Unsubscribed", max_response_size);
} else {
let err = to_json_raw_value(&format!("Invalid subscription ID={}", sub_id)).ok();
send_error(id, tx, invalid_subscription_err(err.as_deref()))
}
})),
);
}
Expand Down Expand Up @@ -698,6 +706,7 @@ impl SubscriptionSink {
fn inner_close(&mut self, err: &SubscriptionClosedError) {
self.is_connected.take();
if let Some((sink, _)) = self.subscribers.lock().remove(&self.uniq_sub) {
tracing::debug!("Closing subscription: {:?}", self.uniq_sub.sub_id);
let msg = self.build_message(err).expect("valid json infallible; qed");
let _ = sink.unbounded_send(msg);
}
Expand Down
53 changes: 53 additions & 0 deletions ws-server/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,16 @@
#![cfg(test)]

use crate::types::error::{CallError, Error};
use crate::types::v2::{self, Response, RpcError};
use crate::types::DeserializeOwned;
use crate::{future::ServerHandle, RpcModule, WsServerBuilder};
use anyhow::anyhow;
use futures_util::future::join;
use jsonrpsee_test_utils::helpers::*;
use jsonrpsee_test_utils::mocks::{Id, TestContext, WebSocketTestClient, WebSocketTestError};
use jsonrpsee_test_utils::TimeoutFutureExt;
use jsonrpsee_types::to_json_raw_value;
use jsonrpsee_types::v2::error::invalid_subscription_err;
use serde_json::Value as JsonValue;
use std::{fmt, net::SocketAddr, time::Duration};
use tracing_subscriber::{EnvFilter, FmtSubscriber};
Expand All @@ -41,6 +45,11 @@ fn init_logger() {
let _ = FmtSubscriber::builder().with_env_filter(EnvFilter::from_default_env()).try_init();
}

fn deser_call<T: DeserializeOwned>(raw: String) -> T {
let out: Response<T> = serde_json::from_str(&raw).unwrap();
out.result
}

/// Applications can/should provide their own error.
#[derive(Debug)]
struct MyAppError;
Expand Down Expand Up @@ -107,6 +116,15 @@ async fn server_with_handles() -> (SocketAddr, ServerHandle) {
Ok("Yawn!")
})
.unwrap();
module
.register_subscription("subscribe_hello", "unsubscribe_hello", |_, sink, _| {
std::thread::spawn(move || loop {
let _ = sink;
std::thread::sleep(std::time::Duration::from_secs(30));
});
Ok(())
})
.unwrap();

let addr = server.local_addr().unwrap();

Expand Down Expand Up @@ -569,3 +587,38 @@ async fn run_forever() {
// Send the shutdown request from one handle and await the server on the second one.
join(server_handle.clone().stop().unwrap(), server_handle).with_timeout(TIMEOUT).await.unwrap();
}

#[tokio::test]
async fn unsubscribe_twice_should_indicate_error() {
init_logger();
let addr = server().await;
let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap();

let sub_call = call("subscribe_hello", Vec::<()>::new(), Id::Num(0));
let sub_id: u64 = deser_call(client.send_request_text(sub_call).await.unwrap());

let unsub_call = call("unsubscribe_hello", vec![sub_id], Id::Num(1));
let unsub_1: String = deser_call(client.send_request_text(unsub_call).await.unwrap());
assert_eq!(&unsub_1, "Unsubscribed");

let unsub_call = call("unsubscribe_hello", vec![sub_id], Id::Num(2));
let unsub_2 = client.send_request_text(unsub_call).await.unwrap();
let unsub_2_err: RpcError = serde_json::from_str(&unsub_2).unwrap();
let sub_id = to_json_raw_value(&sub_id).unwrap();

let err = Some(to_json_raw_value(&format!("Invalid subscription ID={}", sub_id)).unwrap());
assert_eq!(unsub_2_err, RpcError::new(invalid_subscription_err(err.as_deref()), v2::Id::Number(2)));
}

#[tokio::test]
async fn unsubscribe_wrong_sub_id_type() {
init_logger();
let addr = server().await;
let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap();

let unsub =
client.send_request_text(call("unsubscribe_hello", vec!["string_is_not_supported"], Id::Num(0))).await.unwrap();
let unsub_2_err: RpcError = serde_json::from_str(&unsub).unwrap();
let err = Some(to_json_raw_value(&"Invalid subscription ID type, must be integer").unwrap());
assert_eq!(unsub_2_err, RpcError::new(invalid_subscription_err(err.as_deref()), v2::Id::Number(0)));
}

0 comments on commit fff8460

Please sign in to comment.