Skip to content

Commit

Permalink
Custom errors (#977)
Browse files Browse the repository at this point in the history
* Implement support for custom errors

* Remove unneded for<'a> from E bound

* Fix doctest

* Handle the case where there are not exactly two arguments

* Support for other Result paths

* Rewrite with a more explicit rewriting logic

* Back to rewriting the error argument

* Add UI error for non-result

* Apply suggestions from code review

Co-authored-by: Niklas Adolfsson <[email protected]>

* Fix a typo

* Fix errors in the rest of the targets

Co-authored-by: Niklas Adolfsson <[email protected]>
  • Loading branch information
MOZGIII and niklasad1 authored Jan 24, 2023
1 parent ed9a3ee commit dab1bfc
Show file tree
Hide file tree
Showing 14 changed files with 179 additions and 25 deletions.
12 changes: 9 additions & 3 deletions benches/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,17 @@ fn gen_rpc_module() -> jsonrpsee::RpcModule<()> {
let mut module = jsonrpsee::RpcModule::new(());

module.register_method(SYNC_FAST_CALL, |_, _| Ok("lo")).unwrap();
module.register_async_method(ASYNC_FAST_CALL, |_, _| async { Ok("lo") }).unwrap();
module
.register_async_method(ASYNC_FAST_CALL, |_, _| async { Result::<_, jsonrpsee::core::Error>::Ok("lo") })
.unwrap();

module.register_method(SYNC_MEM_CALL, |_, _| Ok("A".repeat(MIB))).unwrap();

module.register_async_method(ASYNC_MEM_CALL, |_, _| async move { Ok("A".repeat(MIB)) }).unwrap();
module
.register_async_method(ASYNC_MEM_CALL, |_, _| async move {
Result::<_, jsonrpsee::core::Error>::Ok("A".repeat(MIB))
})
.unwrap();

module
.register_method(SYNC_SLOW_CALL, |_, _| {
Expand All @@ -179,7 +185,7 @@ fn gen_rpc_module() -> jsonrpsee::RpcModule<()> {
module
.register_async_method(ASYNC_SLOW_CALL, |_, _| async move {
tokio::time::sleep(SLOW_CALL).await;
Ok("slow call async")
Result::<_, jsonrpsee::core::Error>::Ok("slow call async")
})
.unwrap();

Expand Down
2 changes: 1 addition & 1 deletion core/src/server/resource_limiting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
//! module
//! .register_async_method("my_expensive_method", |_, _| async move {
//! // Do work
//! Ok("hello")
//! Result::<_, jsonrpsee::core::Error>::Ok("hello")
//! })?
//! .resource("cpu", 5)?
//! .resource("mem", 2)?;
Expand Down
14 changes: 8 additions & 6 deletions core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -569,14 +569,15 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
}

/// Register a new asynchronous RPC method, which computes the response with the given callback.
pub fn register_async_method<R, Fun, Fut>(
pub fn register_async_method<R, E, Fun, Fut>(
&mut self,
method_name: &'static str,
callback: Fun,
) -> Result<MethodResourcesBuilder, Error>
where
R: Serialize + Send + Sync + 'static,
Fut: Future<Output = Result<R, Error>> + Send,
E: Into<Error>,
Fut: Future<Output = Result<R, E>> + Send,
Fun: (Fn(Params<'static>, Arc<Context>) -> Fut) + Clone + Send + Sync + 'static,
{
let ctx = self.ctx.clone();
Expand All @@ -589,7 +590,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let future = async move {
let result = match callback(params, ctx).await {
Ok(res) => MethodResponse::response(id, res, max_response_size),
Err(err) => MethodResponse::error(id, err),
Err(err) => MethodResponse::error(id, err.into()),
};

// Release claimed resources
Expand All @@ -606,15 +607,16 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {

/// Register a new **blocking** synchronous RPC method, which computes the response with the given callback.
/// Unlike the regular [`register_method`](RpcModule::register_method), this method can block its thread and perform expensive computations.
pub fn register_blocking_method<R, F>(
pub fn register_blocking_method<R, E, F>(
&mut self,
method_name: &'static str,
callback: F,
) -> Result<MethodResourcesBuilder, Error>
where
Context: Send + Sync + 'static,
R: Serialize,
F: Fn(Params, Arc<Context>) -> Result<R, Error> + Clone + Send + Sync + 'static,
E: Into<Error>,
F: Fn(Params, Arc<Context>) -> Result<R, E> + Clone + Send + Sync + 'static,
{
let ctx = self.ctx.clone();
let callback = self.methods.verify_and_insert(
Expand All @@ -626,7 +628,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
tokio::task::spawn_blocking(move || {
let result = match callback(params, ctx) {
Ok(result) => MethodResponse::response(id, result, max_response_size),
Err(err) => MethodResponse::error(id, err),
Err(err) => MethodResponse::error(id, err.into()),
};

// Release claimed resources
Expand Down
3 changes: 2 additions & 1 deletion examples/examples/tokio_console.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
use std::net::SocketAddr;

use jsonrpsee::core::Error;
use jsonrpsee::server::ServerBuilder;
use jsonrpsee::RpcModule;

Expand All @@ -55,7 +56,7 @@ async fn run_server() -> anyhow::Result<SocketAddr> {
module.register_method("memory_call", |_, _| Ok("A".repeat(1024 * 1024)))?;
module.register_async_method("sleep", |_, _| async {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
Ok("lo")
Result::<_, Error>::Ok("lo")
})?;

let addr = server.local_addr()?;
Expand Down
1 change: 1 addition & 0 deletions proc-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ trybuild = "1.0"
tokio = { version = "1.16", features = ["rt", "macros"] }
futures-channel = { version = "0.3.14", default-features = false }
futures-util = { version = "0.3.14", default-features = false }
serde_json = "1"
46 changes: 44 additions & 2 deletions proc-macros/src/render_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ use crate::attributes::ParamKind;
use crate::helpers::generate_where_clause;
use crate::rpc_macro::{RpcDescription, RpcMethod, RpcSubscription};
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{FnArg, Pat, PatIdent, PatType, TypeParam};
use quote::{quote, quote_spanned};
use syn::spanned::Spanned;
use syn::{AngleBracketedGenericArguments, FnArg, Pat, PatIdent, PatType, PathArguments, TypeParam};

impl RpcDescription {
pub(super) fn render_client(&self) -> Result<TokenStream2, syn::Error> {
Expand Down Expand Up @@ -68,6 +69,46 @@ impl RpcDescription {
Ok(trait_impl)
}

/// Verify and rewrite the return type (for methods).
fn return_result_type(&self, mut ty: syn::Type) -> TokenStream2 {
// We expect a valid type path.
let syn::Type::Path(ref mut type_path) = ty else {
return quote_spanned!(ty.span() => compile_error!("Expecting something like 'Result<Foo, Err>' here. (1)"));
};

// The path (eg std::result::Result) should have a final segment like 'Result'.
let Some(type_name) = type_path.path.segments.last_mut() else {
return quote_spanned!(ty.span() => compile_error!("Expecting this path to end in something like 'Result<Foo, Err>'"));
};

// Get the generic args eg the <T, E> in Result<T, E>.
let PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) = &mut type_name.arguments else {
return quote_spanned!(ty.span() => compile_error!("Expecting something like 'Result<Foo, Err>' here, but got no generic args (eg no '<Foo,Err>')."));
};

if type_name.ident == "Result" {
// Result<T, E> should have 2 generic args.
if args.len() != 2 {
return quote_spanned!(args.span() => compile_error!("Result must be have two arguments"));
}

// Force the last argument to be `jsonrpsee::core::Error`:
let error_arg = args.last_mut().unwrap();
*error_arg = syn::GenericArgument::Type(syn::Type::Verbatim(self.jrps_client_item(quote! { core::Error })));

quote!(#ty)
} else if type_name.ident == "RpcResult" {
// RpcResult<T> (an alias we export) should have 1 generic arg.
if args.len() != 1 {
return quote_spanned!(args.span() => compile_error!("RpcResult must have one argument"));
}
quote!(#ty)
} else {
// Any other type name isn't allowed.
quote_spanned!(type_name.span() => compile_error!("The return type must be Result or RpcResult"))
}
}

fn render_method(&self, method: &RpcMethod) -> Result<TokenStream2, syn::Error> {
// `jsonrpsee::Error`
let jrps_error = self.jrps_client_item(quote! { core::Error });
Expand All @@ -83,6 +124,7 @@ impl RpcDescription {
// `returns` represent the return type of the *rust method* (`Result< <..>, jsonrpsee::core::Error`).
let (called_method, returns) = if let Some(returns) = &method.returns {
let called_method = quote::format_ident!("request");
let returns = self.return_result_type(returns.clone());
let returns = quote! { #returns };

(called_method, returns)
Expand Down
88 changes: 88 additions & 0 deletions proc-macros/tests/ui/correct/errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//! Example of using custom errors.
use std::net::SocketAddr;

use jsonrpsee::core::async_trait;
use jsonrpsee::proc_macros::rpc;
use jsonrpsee::server::ServerBuilder;
use jsonrpsee::ws_client::*;

pub enum CustomError {
One,
Two { custom_data: u32 },
}

impl From<CustomError> for jsonrpsee::core::Error {
fn from(err: CustomError) -> Self {
let code = match &err {
CustomError::One => 101,
CustomError::Two { .. } => 102,
};
let data = match &err {
CustomError::One => None,
CustomError::Two { custom_data } => Some(serde_json::json!({ "customData": custom_data })),
};

let data = data.map(|val| serde_json::value::to_raw_value(&val).unwrap());

let error_object = jsonrpsee::types::ErrorObjectOwned::owned(code, "custom_error", data);

Self::Call(jsonrpsee::types::error::CallError::Custom(error_object))
}
}

#[rpc(client, server, namespace = "foo")]
pub trait Rpc {
#[method(name = "method1")]
async fn method1(&self) -> Result<u16, CustomError>;

#[method(name = "method2")]
async fn method2(&self) -> Result<u16, CustomError>;
}

pub struct RpcServerImpl;

#[async_trait]
impl RpcServer for RpcServerImpl {
async fn method1(&self) -> Result<u16, CustomError> {
Err(CustomError::One)
}

async fn method2(&self) -> Result<u16, CustomError> {
Err(CustomError::Two { custom_data: 123 })
}
}

pub async fn server() -> SocketAddr {
let server = ServerBuilder::default().build("127.0.0.1:0").await.unwrap();
let addr = server.local_addr().unwrap();
let server_handle = server.start(RpcServerImpl.into_rpc()).unwrap();

tokio::spawn(server_handle.stopped());

addr
}

#[tokio::main]
async fn main() {
let server_addr = server().await;
let server_url = format!("ws://{}", server_addr);
let client = WsClientBuilder::default().build(&server_url).await.unwrap();

let get_error_object = |err| match err {
jsonrpsee::core::Error::Call(jsonrpsee::types::error::CallError::Custom(object)) => object,
_ => panic!("wrong error kind: {:?}", err),
};

let error = client.method1().await.unwrap_err();
let error_object = get_error_object(error);
assert_eq!(error_object.code(), 101);
assert_eq!(error_object.message(), "custom_error");
assert!(error_object.data().is_none());

let error = client.method2().await.unwrap_err();
let error_object = get_error_object(error);
assert_eq!(error_object.code(), 102);
assert_eq!(error_object.message(), "custom_error");
assert_eq!(error_object.data().unwrap().get(), r#"{"customData":123}"#);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use jsonrpsee::proc_macros::rpc;

#[rpc(client)]
pub trait NonResultReturnType {
#[method(name = "a")]
async fn a(&self) -> u16;
}

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: Expecting something like 'Result<Foo, Err>' here, but got no generic args (eg no '<Foo,Err>').
--> tests/ui/incorrect/method/method_non_result_return_type.rs:6:23
|
6 | async fn a(&self) -> u16;
| ^^^
10 changes: 5 additions & 5 deletions server/src/tests/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ pub(crate) async fn server_with_handles() -> (SocketAddr, ServerHandle) {
tracing::debug!("server respond to hello");
// Call some async function inside.
futures_util::future::ready(()).await;
Ok("hello")
Result::<_, Error>::Ok("hello")
}
})
.unwrap();
module
.register_async_method("add_async", |params, _| async move {
let params: Vec<u64> = params.parse()?;
let sum: u64 = params.into_iter().sum();
Ok(sum)
Result::<_, Error>::Ok(sum)
})
.unwrap();
module
Expand Down Expand Up @@ -111,7 +111,7 @@ pub(crate) async fn server_with_handles() -> (SocketAddr, ServerHandle) {
module
.register_async_method("should_ok_async", |_p, ctx| async move {
ctx.ok().map_err(CallError::Failed)?;
Ok("ok")
Result::<_, Error>::Ok("ok")
})
.unwrap();

Expand Down Expand Up @@ -146,15 +146,15 @@ pub(crate) async fn server_with_context() -> SocketAddr {
.register_async_method("should_ok_async", |_p, ctx| async move {
ctx.ok().map_err(CallError::Failed)?;
// Call some async function inside.
Ok(futures_util::future::ready("ok!").await)
Result::<_, Error>::Ok(futures_util::future::ready("ok!").await)
})
.unwrap();

rpc_module
.register_async_method("err_async", |_p, ctx| async move {
ctx.ok().map_err(CallError::Failed)?;
// Async work that returns an error
futures_util::future::err::<(), _>(anyhow!("nah").into()).await
futures_util::future::err::<(), Error>(anyhow!("nah").into()).await
})
.unwrap();

Expand Down
4 changes: 2 additions & 2 deletions server/src/tests/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async fn server() -> (SocketAddr, ServerHandle) {
let mut module = RpcModule::new(ctx);
let addr = server.local_addr().unwrap();
module.register_method("say_hello", |_, _| Ok("lo")).unwrap();
module.register_async_method("say_hello_async", |_, _| async move { Ok("lo") }).unwrap();
module.register_async_method("say_hello_async", |_, _| async move { Result::<_, Error>::Ok("lo") }).unwrap();
module
.register_method("add", |params, _| {
let params: Vec<u64> = params.parse()?;
Expand Down Expand Up @@ -78,7 +78,7 @@ async fn server() -> (SocketAddr, ServerHandle) {
module
.register_async_method("should_ok_async", |_p, ctx| async move {
ctx.ok().map_err(CallError::Failed)?;
Ok("ok")
Result::<_, Error>::Ok("ok")
})
.unwrap();

Expand Down
2 changes: 1 addition & 1 deletion tests/tests/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ pub async fn server() -> SocketAddr {
module
.register_async_method("slow_hello", |_, _| async {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
Ok("hello")
Result::<_, Error>::Ok("hello")
})
.unwrap();

Expand Down
6 changes: 3 additions & 3 deletions tests/tests/resource_limiting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,20 @@ fn module_manual() -> Result<RpcModule<()>, Error> {

module.register_async_method("say_hello", |_, _| async move {
sleep(Duration::from_millis(50)).await;
Ok("hello")
Result::<_, Error>::Ok("hello")
})?;

module
.register_async_method("expensive_call", |_, _| async move {
sleep(Duration::from_millis(50)).await;
Ok("hello expensive call")
Result::<_, Error>::Ok("hello expensive call")
})?
.resource("CPU", 3)?;

module
.register_async_method("memory_hog", |_, _| async move {
sleep(Duration::from_millis(50)).await;
Ok("hello memory hog")
Result::<_, Error>::Ok("hello memory hog")
})?
.resource("CPU", 0)?
.resource("MEM", 8)?;
Expand Down
2 changes: 1 addition & 1 deletion tests/tests/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async fn calling_method_without_server() {
module
.register_async_method("roo", |params, ctx| {
let ns: Vec<u8> = params.parse().expect("valid params please");
async move { Ok(ctx.roo(ns)) }
async move { Result::<_, Error>::Ok(ctx.roo(ns)) }
})
.unwrap();
let res: u64 = module.call("roo", [12, 13]).await.unwrap();
Expand Down

0 comments on commit dab1bfc

Please sign in to comment.