Skip to content

Commit

Permalink
fix: remove needless Semaphore::(u32::MAX) (#1051)
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasad1 authored Mar 21, 2023
1 parent 014f771 commit a0ce8d4
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 33 deletions.
22 changes: 4 additions & 18 deletions core/src/server/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore};

use super::rpc_module::{DisconnectError, SendTimeoutError, SubscriptionMessage, TrySendError};

/// Subscription permit.
pub type SubscriptionPermit = OwnedSemaphorePermit;

/// Bounded writer that allows writing at most `max_len` bytes.
///
/// ```
Expand Down Expand Up @@ -191,20 +194,6 @@ pub fn prepare_error(data: &[u8]) -> (Id<'_>, ErrorCode) {
}
}

/// A permitted subscription.
#[derive(Debug)]
pub struct SubscriptionPermit {
_permit: OwnedSemaphorePermit,
resource: Arc<Notify>,
}

impl SubscriptionPermit {
/// Get the handle to [`tokio::sync::Notify`].
pub fn handle(&self) -> Arc<Notify> {
self.resource.clone()
}
}

/// Wrapper over [`tokio::sync::Notify`] with bounds check.
#[derive(Debug, Clone)]
pub struct BoundedSubscriptions {
Expand All @@ -227,10 +216,7 @@ impl BoundedSubscriptions {
///
/// Fails if `max_subscriptions` have been exceeded.
pub fn acquire(&self) -> Option<SubscriptionPermit> {
Arc::clone(&self.guard)
.try_acquire_owned()
.ok()
.map(|p| SubscriptionPermit { _permit: p, resource: self.resource.clone() })
Arc::clone(&self.guard).try_acquire_owned().ok()
}

/// Get the maximum number of permitted subscriptions.
Expand Down
33 changes: 18 additions & 15 deletions core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ pub type MaxResponseSize = usize;
/// A 3-tuple containing:
/// - Call result as a `String`,
/// - a [`mpsc::UnboundedReceiver<String>`] to receive future subscription results
/// - a [`crate::server::helpers::SubscriptionPermit`] to allow subscribers to notify their [`SubscriptionSink`] when
/// they disconnect.
pub type RawRpcResponse = (MethodResponse, mpsc::Receiver<String>, SubscriptionPermit);
pub type RawRpcResponse = (MethodResponse, mpsc::Receiver<String>);

/// Error that may occur during [`SubscriptionSink::try_send`].
#[derive(Debug)]
Expand Down Expand Up @@ -408,7 +406,7 @@ impl Methods {
let params = params.to_rpc_params()?;
let req = Request::new(method.into(), params.as_ref().map(|p| p.as_ref()), Id::Number(0));
tracing::trace!("[Methods::call] Method: {:?}, params: {:?}", method, params);
let (resp, _, _) = self.inner_call(req, 1).await;
let (resp, _) = self.inner_call(req, 1, mock_subscription_permit()).await;

if resp.success {
serde_json::from_str::<Response<T>>(&resp.result).map(|r| r.result).map_err(Into::into)
Expand Down Expand Up @@ -456,27 +454,28 @@ impl Methods {
) -> Result<(MethodResponse, mpsc::Receiver<String>), Error> {
tracing::trace!("[Methods::raw_json_request] Request: {:?}", request);
let req: Request = serde_json::from_str(request)?;
let (resp, rx, _) = self.inner_call(req, buf_size).await;
let (resp, rx) = self.inner_call(req, buf_size, mock_subscription_permit()).await;

Ok((resp, rx))
}

/// Execute a callback.
async fn inner_call(&self, req: Request<'_>, buf_size: usize) -> RawRpcResponse {
async fn inner_call(
&self,
req: Request<'_>,
buf_size: usize,
subscription_permit: SubscriptionPermit,
) -> RawRpcResponse {
let (tx, mut rx) = mpsc::channel(buf_size);
let id = req.id.clone();
let params = Params::new(req.params.map(|params| params.get()));
let bounded_subs = BoundedSubscriptions::new(u32::MAX);
let p1 = bounded_subs.acquire().expect("u32::MAX permits is sufficient; qed");
let p2 = bounded_subs.acquire().expect("u32::MAX permits is sufficient; qed");

let response = match self.method(&req.method) {
None => MethodResponse::error(req.id, ErrorObject::from(ErrorCode::MethodNotFound)),
Some(MethodCallback::Sync(cb)) => (cb)(id, params, usize::MAX),
Some(MethodCallback::Async(cb)) => (cb)(id.into_owned(), params.into_owned(), 0, usize::MAX).await,
Some(MethodCallback::Subscription(cb)) => {
let conn_state =
ConnState { conn_id: 0, id_provider: &RandomIntegerIdProvider, subscription_permit: p1 };
let conn_state = ConnState { conn_id: 0, id_provider: &RandomIntegerIdProvider, subscription_permit };
let res = match (cb)(id, params, MethodSink::new(tx.clone()), conn_state).await {
Ok(rp) => rp,
Err(id) => MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError)),
Expand All @@ -495,7 +494,7 @@ impl Methods {

tracing::trace!("[Methods::inner_call] Method: {}, response: {:?}", req.method, response);

(response, rx, p2)
(response, rx)
}

/// Helper to create a subscription on the `RPC module` without having to spin up a server.
Expand Down Expand Up @@ -544,7 +543,7 @@ impl Methods {

tracing::trace!("[Methods::subscribe] Method: {}, params: {:?}", sub_method, params);

let (resp, rx, permit) = self.inner_call(req, buf_size).await;
let (resp, rx) = self.inner_call(req, buf_size, mock_subscription_permit()).await;

let subscription_response = match serde_json::from_str::<Response<RpcSubscriptionId>>(&resp.result) {
Ok(r) => r,
Expand All @@ -556,7 +555,7 @@ impl Methods {

let sub_id = subscription_response.result.into_owned();

Ok(Subscription { sub_id, rx, _permit: permit })
Ok(Subscription { sub_id, rx })
}

/// Returns an `Iterator` with all the method names registered on this server.
Expand Down Expand Up @@ -1127,7 +1126,6 @@ impl Drop for SubscriptionSink {
pub struct Subscription {
rx: mpsc::Receiver<String>,
sub_id: RpcSubscriptionId<'static>,
_permit: SubscriptionPermit,
}

impl Subscription {
Expand Down Expand Up @@ -1168,3 +1166,8 @@ impl Drop for Subscription {
self.close();
}
}

// Mock subscription permit to be able to make a call.
fn mock_subscription_permit() -> SubscriptionPermit {
BoundedSubscriptions::new(1).acquire().expect("1 permit should exist; qed")
}

0 comments on commit a0ce8d4

Please sign in to comment.