Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add a cache around server ACL checking (#16360)
Browse files Browse the repository at this point in the history
* Pre-compiles the server ACLs onto an object per room and
  invalidates them when new events come in.
* Converts the server ACL checking into Rust.
  • Loading branch information
clokep authored Sep 26, 2023
1 parent 17800a0 commit f84da3c
Show file tree
Hide file tree
Showing 11 changed files with 235 additions and 85 deletions.
1 change: 1 addition & 0 deletions changelog.d/16360.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cache server ACL checking.
102 changes: 102 additions & 0 deletions rust/src/acl/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! An implementation of Matrix server ACL rules.
use std::net::Ipv4Addr;
use std::str::FromStr;

use anyhow::Error;
use pyo3::prelude::*;
use regex::Regex;

use crate::push::utils::{glob_to_regex, GlobMatchType};

/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let child_module = PyModule::new(py, "acl")?;
child_module.add_class::<ServerAclEvaluator>()?;

m.add_submodule(child_module)?;

// We need to manually add the module to sys.modules to make `from
// synapse.synapse_rust import acl` work.
py.import("sys")?
.getattr("modules")?
.set_item("synapse.synapse_rust.acl", child_module)?;

Ok(())
}

#[derive(Debug, Clone)]
#[pyclass(frozen)]
pub struct ServerAclEvaluator {
allow_ip_literals: bool,
allow: Vec<Regex>,
deny: Vec<Regex>,
}

#[pymethods]
impl ServerAclEvaluator {
#[new]
pub fn py_new(
allow_ip_literals: bool,
allow: Vec<&str>,
deny: Vec<&str>,
) -> Result<Self, Error> {
let allow = allow
.iter()
.map(|s| glob_to_regex(s, GlobMatchType::Whole))
.collect::<Result<_, _>>()?;
let deny = deny
.iter()
.map(|s| glob_to_regex(s, GlobMatchType::Whole))
.collect::<Result<_, _>>()?;

Ok(ServerAclEvaluator {
allow_ip_literals,
allow,
deny,
})
}

pub fn server_matches_acl_event(&self, server_name: &str) -> bool {
// first of all, check if literal IPs are blocked, and if so, whether the
// server name is a literal IP
if !self.allow_ip_literals {
// check for ipv6 literals. These start with '['.
if server_name.starts_with('[') {
return false;
}

// check for ipv4 literals. We can just lift the routine from std::net.
if Ipv4Addr::from_str(server_name).is_ok() {
return false;
}
}

// next, check the deny list
if self.deny.iter().any(|e| e.is_match(server_name)) {
return false;
}

// then the allow list.
if self.allow.iter().any(|e| e.is_match(server_name)) {
return true;
}

// everything else should be rejected.
false
}
}
2 changes: 2 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use lazy_static::lazy_static;
use pyo3::prelude::*;
use pyo3_log::ResetHandle;

pub mod acl;
pub mod push;

lazy_static! {
Expand Down Expand Up @@ -38,6 +39,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?;
m.add_function(wrap_pyfunction!(reset_logging_config, m)?)?;

acl::register_module(py, m)?;
push::register_module(py, m)?;

Ok(())
Expand Down
21 changes: 21 additions & 0 deletions stubs/synapse/synapse_rust/acl.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

class ServerAclEvaluator:
def __init__(
self, allow_ip_literals: bool, allow: List[str], deny: List[str]
) -> None: ...
def server_matches_acl_event(self, server_name: str) -> bool: ...
7 changes: 5 additions & 2 deletions synapse/events/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
CANONICALJSON_MIN_INT,
validate_canonicaljson,
)
from synapse.federation.federation_server import server_matches_acl_event
from synapse.http.servlet import validate_json_object
from synapse.rest.models import RequestBodyModel
from synapse.storage.controllers.state import server_acl_evaluator_from_event
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID


Expand Down Expand Up @@ -106,7 +106,10 @@ def validate_new(self, event: EventBase, config: HomeServerConfig) -> None:
self._validate_retention(event)

elif event.type == EventTypes.ServerACL:
if not server_matches_acl_event(config.server.server_name, event):
server_acl_evaluator = server_acl_evaluator_from_event(event)
if not server_acl_evaluator.server_matches_acl_event(
config.server.server_name
):
raise SynapseError(
400, "Can't create an ACL event that denies the local server"
)
Expand Down
76 changes: 6 additions & 70 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@
Union,
)

from matrix_common.regex import glob_to_regex
from prometheus_client import Counter, Gauge, Histogram

from twisted.internet.abstract import isIPAddress
from twisted.python import failure

from synapse.api.constants import (
Expand Down Expand Up @@ -1324,75 +1322,13 @@ async def check_server_matches_acl(self, server_name: str, room_id: str) -> None
Raises:
AuthError if the server does not match the ACL
"""
acl_event = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.ServerACL, ""
server_acl_evaluator = (
await self._storage_controllers.state.get_server_acl_for_room(room_id)
)
if not acl_event or server_matches_acl_event(server_name, acl_event):
return

raise AuthError(code=403, msg="Server is banned from room")


def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
"""Check if the given server is allowed by the ACL event
Args:
server_name: name of server, without any port part
acl_event: m.room.server_acl event
Returns:
True if this server is allowed by the ACLs
"""
logger.debug("Checking %s against acl %s", server_name, acl_event.content)

# first of all, check if literal IPs are blocked, and if so, whether the
# server name is a literal IP
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
logger.warning("Ignoring non-bool allow_ip_literals flag")
allow_ip_literals = True
if not allow_ip_literals:
# check for ipv6 literals. These start with '['.
if server_name[0] == "[":
return False

# check for ipv4 literals. We can just lift the routine from twisted.
if isIPAddress(server_name):
return False

# next, check the deny list
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
logger.warning("Ignoring non-list deny ACL %s", deny)
deny = []
for e in deny:
if _acl_entry_matches(server_name, e):
# logger.info("%s matched deny rule %s", server_name, e)
return False

# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
logger.warning("Ignoring non-list allow ACL %s", allow)
allow = []
for e in allow:
if _acl_entry_matches(server_name, e):
# logger.info("%s matched allow rule %s", server_name, e)
return True

# everything else should be rejected.
# logger.info("%s fell through", server_name)
return False


def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool:
if not isinstance(acl_entry, str):
logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
)
return False
regex = glob_to_regex(acl_entry)
return bool(regex.match(server_name))
if server_acl_evaluator and not server_acl_evaluator.server_matches_acl_event(
server_name
):
raise AuthError(code=403, msg="Server is banned from room")


class FederationHandlerRegistry:
Expand Down
6 changes: 6 additions & 0 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -2342,6 +2342,12 @@ async def _notify_persisted_event(
# TODO retrieve the previous state, and exclude join -> join transitions
self._notifier.notify_user_joined_room(event.event_id, event.room_id)

# If this is a server ACL event, clear the cache in the storage controller.
if event.type == EventTypes.ServerACL:
self._state_storage_controller.get_server_acl_for_room.invalidate(
(event.room_id,)
)

def _sanity_check_event(self, ev: EventBase) -> None:
"""
Do some early sanity checks of a received event
Expand Down
5 changes: 5 additions & 0 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1730,6 +1730,11 @@ async def persist_and_notify_client_events(
event.event_id, event.room_id
)

if event.type == EventTypes.ServerACL:
self._storage_controllers.state.get_server_acl_for_room.invalidate(
(event.room_id,)
)

await self._maybe_kick_guest_users(event, context)

if event.type == EventTypes.CanonicalAlias:
Expand Down
6 changes: 6 additions & 0 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ async def on_rdata(
self.notifier.notify_user_joined_room(
row.data.event_id, row.data.room_id
)

# If this is a server ACL event, clear the cache in the storage controller.
if row.data.type == EventTypes.ServerACL:
self._state_storage_controller.get_server_acl_for_room.invalidate(
(row.data.room_id,)
)
elif stream_name == UnPartialStatedRoomStream.NAME:
for row in rows:
assert isinstance(row, UnPartialStatedRoomStreamRow)
Expand Down
59 changes: 59 additions & 0 deletions synapse/storage/controllers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
PartialCurrentStateTracker,
PartialStateEventsTracker,
)
from synapse.synapse_rust.acl import ServerAclEvaluator
from synapse.types import MutableStateMap, StateMap, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
Expand Down Expand Up @@ -501,6 +502,31 @@ async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:

return event.content.get("alias")

@cached()
async def get_server_acl_for_room(
self, room_id: str
) -> Optional[ServerAclEvaluator]:
"""Get the server ACL evaluator for room, if any
This does up-front parsing of the content to ignore bad data and pre-compile
regular expressions.
Args:
room_id: The room ID
Returns:
The server ACL evaluator, if any
"""

acl_event = await self.get_current_state_event(
room_id, EventTypes.ServerACL, ""
)

if not acl_event:
return None

return server_acl_evaluator_from_event(acl_event)

@trace
@tag_args
async def get_current_state_deltas(
Expand Down Expand Up @@ -760,3 +786,36 @@ async def _get_joined_hosts(
cache.state_group = object()

return frozenset(cache.hosts_to_joined_users)


def server_acl_evaluator_from_event(acl_event: EventBase) -> "ServerAclEvaluator":
"""
Create a ServerAclEvaluator from a m.room.server_acl event's content.
This does up-front parsing of the content to ignore bad data. It then creates
the ServerAclEvaluator which will pre-compile regular expressions from the globs.
"""

# first of all, parse if literal IPs are blocked.
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
logger.warning("Ignoring non-bool allow_ip_literals flag")
allow_ip_literals = True

# next, parse the deny list by ignoring any non-strings.
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
logger.warning("Ignoring non-list deny ACL %s", deny)
deny = []
else:
deny = [s for s in deny if isinstance(s, str)]

# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
logger.warning("Ignoring non-list allow ACL %s", allow)
allow = []
else:
allow = [s for s in allow if isinstance(s, str)]

return ServerAclEvaluator(allow_ip_literals, allow, deny)
Loading

0 comments on commit f84da3c

Please sign in to comment.