Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rpc-servers: Allow chainHead methods to be called from a single connection #3343

Closed
wants to merge 8 commits into from
20 changes: 18 additions & 2 deletions substrate/client/rpc-servers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@

pub mod middleware;

use std::{convert::Infallible, error::Error as StdError, net::SocketAddr, time::Duration};
use std::{
convert::Infallible,
error::Error as StdError,
net::SocketAddr,
sync::{Arc, Mutex},
time::Duration,
};

use http::header::HeaderValue;
use hyper::{
Expand All @@ -49,6 +55,8 @@ pub use jsonrpsee::core::{
};
pub use middleware::{MetricsLayer, RpcMetrics};

use crate::middleware::chain_head::{ChainHeadLayer, ConnectionData};

const MEGABYTE: u32 = 1024 * 1024;

/// Type alias for the JSON-RPC server.
Expand Down Expand Up @@ -142,6 +150,9 @@ pub async fn start_server<M: Send + Sync + 'static>(
let make_service = make_service_fn(move |_conn: &AddrStream| {
let cfg = cfg.clone();

// Chain head data is per connection.
let chain_head_data = Arc::new(Mutex::new(ConnectionData::default()));

async move {
let cfg = cfg.clone();

Expand All @@ -152,8 +163,13 @@ pub async fn start_server<M: Send + Sync + 'static>(
let is_websocket = ws::is_upgrade_request(&req);
let transport_label = if is_websocket { "ws" } else { "http" };

// Order of the requests matter here, the metrics layer should be the first to not
// miss metrics.
let metrics = metrics.map(|m| MetricsLayer::new(m, transport_label));
let rpc_middleware = RpcServiceBuilder::new().option_layer(metrics.clone());
let chain_head = ChainHeadLayer::new(chain_head_data.clone());

let rpc_middleware =
RpcServiceBuilder::new().option_layer(metrics.clone()).layer(chain_head);
let mut svc =
service_builder.set_rpc_middleware(rpc_middleware).build(methods, stop_handle);

Expand Down
237 changes: 237 additions & 0 deletions substrate/client/rpc-servers/src/middleware/chain_head.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
// This file is part of Substrate.

// Copyright (C) Parity Technologies (UK) Ltd.
// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0

// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

//! RPC middleware to ensure chainHead methods are called from a single connection.

use std::{
collections::HashSet,
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

use jsonrpsee::{
server::middleware::rpc::RpcServiceT,
types::{Params, Request},
MethodResponse,
};
use parking_lot::Mutex;
use pin_project::pin_project;

/// The per connectin data needed to manage chainHead subscriptions.
#[derive(Default)]
pub struct ConnectionData {
/// Active `chainHeda_follow` subscriptions for this connection.
lexnv marked this conversation as resolved.
Show resolved Hide resolved
subscriptions: HashSet<String>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably add a type alias SubscriptionId here for readability.

}

/// Layer to allow the `chainHead` RPC methods to be called from a single connection.
#[derive(Clone)]
pub struct ChainHeadLayer {
connection_data: Arc<Mutex<ConnectionData>>,
}

impl ChainHeadLayer {
/// Create a new [`ChainHeadLayer`].
pub fn new(connection_data: Arc<Mutex<ConnectionData>>) -> Self {
Self { connection_data }
}
}

impl<S> tower::Layer<S> for ChainHeadLayer {
type Service = ChainHeadMiddleware<S>;

fn layer(&self, inner: S) -> Self::Service {
ChainHeadMiddleware::new(inner, self.connection_data.clone())
}
}

/// Chain head middleware.
#[derive(Clone)]
pub struct ChainHeadMiddleware<S> {
service: S,
connection_data: Arc<Mutex<ConnectionData>>,
}

impl<S> ChainHeadMiddleware<S> {
/// Create a new chain head middleware.
pub fn new(service: S, connection_data: Arc<Mutex<ConnectionData>>) -> ChainHeadMiddleware<S> {
ChainHeadMiddleware { service, connection_data }
}
}

impl<'a, S> RpcServiceT<'a> for ChainHeadMiddleware<S>
where
S: Send + Sync + RpcServiceT<'a>,
{
type Future = ResponseFuture<S::Future>;

fn call(&self, req: Request<'a>) -> Self::Future {
const CHAIN_HEAD_FOLLOW: &str = "chainHead_unstable_follow";
const CHAIN_HEAD_CALL_METHODS: [&str; 8] = [
Copy link
Member

@niklasad1 niklasad1 Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A random thought is whether we can fetch these method names from the rpc module itself i.e, to keep in sync if a method is removed, modified and so on.

Before the server start you could do:

let chain_head_api = Vec<_> = rpc_api.method_names().filter(|m| m.starts_with("chainHead")).collect();
let chain_middleware = ChainHeadMiddleware::new(service, connection_data, chain_api);

Perhaps, we could add some integration tests for it at some point.

"chainHead_unstable_body",
"chainHead_unstable_header",
"chainHead_unstable_call",
"chainHead_unstable_unpin",
"chainHead_unstable_continue",
"chainHead_unstable_storage",
"chainHead_unstable_stopOperation",
"chainHead_unstable_unfollow",
];

let method_name = req.method_name();

// Intercept the subscription ID returned by the `chainHead_follow` method.
if method_name == CHAIN_HEAD_FOLLOW {
return ResponseFuture::Register {
fut: self.service.call(req.clone()),
connection_data: self.connection_data.clone(),
}
}

// Ensure the subscription ID of those methods corresponds to a subscription ID
// of this connection.
if CHAIN_HEAD_CALL_METHODS.contains(&method_name) {
let params = req.params();
let follow_subscription = get_subscription_id(params);

if let Some(follow_subscription) = follow_subscription {
if !self.connection_data.lock().subscriptions.contains(&follow_subscription) {
return ResponseFuture::Ready {
response: Some(MethodResponse::error(
req.id(),
jsonrpsee::types::error::ErrorObject::owned(
-32602,
"Invalid subscription ID",
Copy link
Member

@niklasad1 niklasad1 Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A JSON-RPC error with error code -32602 is generated if one of the parameters doesn't correspond to the expected type (similarly to a missing parameter or an invalid parameter type).

I would probably add another error for this, the subscription ID is decoded successfully but the subscription is not "active" on the connection....

None::<()>,
),
)),
};
}
}
}

ResponseFuture::Forward { fut: self.service.call(req.clone()) }
}
}

/// Extract the subscription ID from the provided parameters.
///
/// We make the assumption that all `chainHead` methods are given the
/// subscription ID as a first parameter.
///
/// This method handles positional and named `camelCase` parameters.
fn get_subscription_id<'a>(params: Params<'a>) -> Option<String> {
Copy link
Member

@niklasad1 niklasad1 Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

subscription_id_from_request

// Support positional parameters.
if let Ok(follow_subscription) = params.sequence().next::<String>() {
Copy link
Member

@niklasad1 niklasad1 Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can call params.is_object() here to do either parse it as a sequence or map.

however, you use params.one::<String>() instead sequence here since you only care about one item

return Some(follow_subscription)
}

// Support named parameters.
let Ok(value) = params.parse::<serde_json::Value>() else { return None };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again custom type that implement Deserialize here would be cleaner

    #[derive(serde::Deserialize)]
    struct FollowSubscriptionPayload {
        result: String,
    }


let serde_json::Value::Object(map) = value else { return None };
if let Some(serde_json::Value::String(subscription_id)) = map.get("followSubscription") {
return Some(subscription_id.clone())
}

None
}

/// Extract the result of a jsonrpc object.
///
/// The function extracts the `result` field from the JSON-RPC response.
///
/// In this example, the result is `tfMQUZekzJLorGlR`.
/// ```ignore
/// "{"jsonrpc":"2.0","result":"tfMQUZekzJLorGlR","id":0}"
/// ```
fn get_method_result(response: &MethodResponse) -> Option<String> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

subscription_id_from_response

if response.is_error() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if response.is_error() {
if response.is_error() || !response.is_subscription() {

return None
}

let result = response.as_result();
let Ok(value) = serde_json::from_str(result) else { return None };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this is awkward to decode the response from the final JSON string.

It would be cleaner and easier with a type that implements deserialize.

    #[derive(serde::Deserialize)]
    struct SubscriptionPayload {
        result: String,
    }

I guess the rpc v2 requires to subscription ID to be a string but jsonrpsee accepts integers as well.


let serde_json::Value::Object(map) = value else { return None };
let Some(serde_json::Value::String(res)) = map.get("result") else { return None };

Some(res.clone())
}

/// Response future for chainHead middleware.
#[pin_project(project = ResponseFutureProj)]
pub enum ResponseFuture<F> {
/// The response is propagated immediately without calling other layers.
///
/// This is used in case of an error.
Ready {
/// The response provided to the client directly.
///
/// This is `Option` to consume the value and return a `MethodResponse`
/// from the projected structure.
response: Option<MethodResponse>,
},

/// Forward the call to another layer.
Forward {
/// The future response value.
#[pin]
fut: F,
},

/// Forward the call to another layer and store the subscription ID of the response.
Register {
/// The future response value.
#[pin]
fut: F,
/// Connection data that captures the subscription ID.
connection_data: Arc<Mutex<ConnectionData>>,
},
}

impl<'a, F> std::fmt::Debug for ResponseFuture<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("ResponseFuture")
}
}

impl<F: Future<Output = MethodResponse>> Future for ResponseFuture<F> {
type Output = F::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure but may be worth boxing the future for readability

let this = self.project();

match this {
ResponseFutureProj::Ready { response } =>
Poll::Ready(response.take().expect("Value is set; qed")),
ResponseFutureProj::Forward { fut } => fut.poll(cx),
Copy link
Member

@niklasad1 niklasad1 Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you not removing subscriptions when chainHead_unstable_unfollow is called?

ResponseFutureProj::Register { fut, connection_data } => {
let res = fut.poll(cx);
if let Poll::Ready(response) = &res {
if let Some(subscription_id) = get_method_result(response) {
connection_data.lock().subscriptions.insert(subscription_id);
}
}
res
},
}
}
}
1 change: 1 addition & 0 deletions substrate/client/rpc-servers/src/middleware/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

//! JSON-RPC specific middleware.

pub mod chain_head;
pub mod metrics;

pub use metrics::*;
Loading