Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add missing type hints to synapse.appservice #11360

Merged
merged 11 commits into from
Dec 14, 2021
1 change: 1 addition & 0 deletions changelog.d/11360.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `synapse.appservice`.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ disallow_untyped_defs = True
[mypy-synapse.app.*]
disallow_untyped_defs = True

[mypy-synapse.appservice.*]
disallow_untyped_defs = True

[mypy-synapse.config._base]
disallow_untyped_defs = True

Expand Down
101 changes: 61 additions & 40 deletions synapse/appservice/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import re
from enum import Enum
from typing import TYPE_CHECKING, Iterable, List, Match, Optional
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern

import attr
from netaddr import IPSet

from synapse.api.constants import EventTypes
from synapse.events import EventBase
Expand All @@ -33,6 +37,13 @@ class ApplicationServiceState(Enum):
UP = "up"


@attr.s(slots=True, frozen=True, auto_attribs=True)
class Namespace:
exclusive: bool
group_id: Optional[str]
regex: Pattern


class ApplicationService:
"""Defines an application service. This definition is mostly what is
provided to the /register AS API.
Expand All @@ -50,17 +61,17 @@ class ApplicationService:

def __init__(
self,
token,
hostname,
id,
sender,
url=None,
namespaces=None,
hs_token=None,
protocols=None,
rate_limited=True,
ip_range_whitelist=None,
supports_ephemeral=False,
token: str,
hostname: str,
id: str,
sender: str,
url: Optional[str] = None,
namespaces: Optional[JsonDict] = None,
hs_token: Optional[str] = None,
protocols: Optional[Iterable[str]] = None,
rate_limited: bool = True,
ip_range_whitelist: Optional[IPSet] = None,
supports_ephemeral: bool = False,
):
self.token = token
self.url = (
Expand All @@ -85,27 +96,33 @@ def __init__(

self.rate_limited = rate_limited

def _check_namespaces(self, namespaces):
def _check_namespaces(
self, namespaces: Optional[JsonDict]
) -> Dict[str, List[Namespace]]:
# Sanity check that it is of the form:
# {
# users: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# }
if not namespaces:
if namespaces is None:
namespaces = {}

result: Dict[str, List[Namespace]] = {}

for ns in ApplicationService.NS_LIST:
result[ns] = []

if ns not in namespaces:
namespaces[ns] = []
continue

if type(namespaces[ns]) != list:
if not isinstance(namespaces[ns], list):
raise ValueError("Bad namespace value for '%s'" % ns)
for regex_obj in namespaces[ns]:
if not isinstance(regex_obj, dict):
raise ValueError("Expected dict regex for ns '%s'" % ns)
if not isinstance(regex_obj.get("exclusive"), bool):
exclusive = regex_obj.get("exclusive")
if not isinstance(exclusive, bool):
raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns)
group_id = regex_obj.get("group_id")
if group_id:
Expand All @@ -126,22 +143,26 @@ def _check_namespaces(self, namespaces):
)

regex = regex_obj.get("regex")
if isinstance(regex, str):
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
else:
if not isinstance(regex, str):
raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
return namespaces

def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]:
for regex_obj in self.namespaces[namespace_key]:
if regex_obj["regex"].match(test_string):
return regex_obj
# Pre-compile regex.
result[ns].append(Namespace(exclusive, group_id, re.compile(regex)))

return result

def _matches_regex(
self, namespace_key: str, test_string: str
) -> Optional[Namespace]:
for namespace in self.namespaces[namespace_key]:
if namespace.regex.match(test_string):
return namespace
return None

def _is_exclusive(self, ns_key: str, test_string: str) -> bool:
regex_obj = self._matches_regex(test_string, ns_key)
if regex_obj:
return regex_obj["exclusive"]
def _is_exclusive(self, namespace_key: str, test_string: str) -> bool:
namespace = self._matches_regex(namespace_key, test_string)
if namespace:
return namespace.exclusive
return False

async def _matches_user(
Expand Down Expand Up @@ -260,15 +281,15 @@ async def is_interested_in_presence(

def is_interested_in_user(self, user_id: str) -> bool:
return (
bool(self._matches_regex(user_id, ApplicationService.NS_USERS))
bool(self._matches_regex(ApplicationService.NS_USERS, user_id))
or user_id == self.sender
)

def is_interested_in_alias(self, alias: str) -> bool:
return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias))

def is_interested_in_room(self, room_id: str) -> bool:
return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id))

def is_exclusive_user(self, user_id: str) -> bool:
return (
Expand All @@ -285,14 +306,14 @@ def is_exclusive_alias(self, alias: str) -> bool:
def is_exclusive_room(self, room_id: str) -> bool:
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)

def get_exclusive_user_regexes(self):
def get_exclusive_user_regexes(self) -> List[Pattern]:
clokep marked this conversation as resolved.
Show resolved Hide resolved
"""Get the list of regexes used to determine if a user is exclusively
registered by the AS
"""
return [
regex_obj["regex"]
for regex_obj in self.namespaces[ApplicationService.NS_USERS]
if regex_obj["exclusive"]
namespace.regex
for namespace in self.namespaces[ApplicationService.NS_USERS]
if namespace.exclusive
]

def get_groups_for_user(self, user_id: str) -> Iterable[str]:
Expand All @@ -305,15 +326,15 @@ def get_groups_for_user(self, user_id: str) -> Iterable[str]:
An iterable that yields group_id strings.
"""
return (
regex_obj["group_id"]
for regex_obj in self.namespaces[ApplicationService.NS_USERS]
if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
namespace.group_id
for namespace in self.namespaces[ApplicationService.NS_USERS]
if namespace.group_id and namespace.regex.match(user_id)
)

def is_rate_limited(self) -> bool:
return self.rate_limited

def __str__(self):
def __str__(self) -> str:
# copy dictionary and redact token fields so they don't get logged
dict_copy = self.__dict__.copy()
dict_copy["token"] = "<redacted>"
Expand Down
48 changes: 35 additions & 13 deletions synapse/appservice/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
from typing import TYPE_CHECKING, List, Optional, Tuple
import urllib.parse
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple

from prometheus_client import Counter

Expand Down Expand Up @@ -53,15 +53,15 @@
APP_SERVICE_PREFIX = "/_matrix/app/unstable"


def _is_valid_3pe_metadata(info):
def _is_valid_3pe_metadata(info: JsonDict) -> bool:
if "instances" not in info:
return False
if not isinstance(info["instances"], list):
return False
return True


def _is_valid_3pe_result(r, field):
def _is_valid_3pe_result(r: JsonDict, field: str) -> bool:
if not isinstance(r, dict):
return False

Expand Down Expand Up @@ -93,9 +93,13 @@ def __init__(self, hs: "HomeServer"):
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
)

async def query_user(self, service, user_id):
async def query_user(self, service: "ApplicationService", user_id: str) -> bool:
if service.url is None:
return False

# This is required by the configuration.
assert service.hs_token is not None

uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
try:
response = await self.get_json(uri, {"access_token": service.hs_token})
Expand All @@ -109,9 +113,13 @@ async def query_user(self, service, user_id):
logger.warning("query_user to %s threw exception %s", uri, ex)
return False

async def query_alias(self, service, alias):
async def query_alias(self, service: "ApplicationService", alias: str) -> bool:
if service.url is None:
return False

# This is required by the configuration.
assert service.hs_token is not None

uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
try:
response = await self.get_json(uri, {"access_token": service.hs_token})
Expand All @@ -125,7 +133,13 @@ async def query_alias(self, service, alias):
logger.warning("query_alias to %s threw exception %s", uri, ex)
return False

async def query_3pe(self, service, kind, protocol, fields):
async def query_3pe(
self,
service: "ApplicationService",
kind: str,
protocol: str,
fields: Dict[bytes, List[bytes]],
) -> List[JsonDict]:
if kind == ThirdPartyEntityKind.USER:
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
Expand Down Expand Up @@ -205,11 +219,14 @@ async def push_bulk(
events: List[EventBase],
ephemeral: List[JsonDict],
txn_id: Optional[int] = None,
):
) -> bool:
if service.url is None:
return True

events = self._serialize(service, events)
# This is required by the configuration.
assert service.hs_token is not None

serialized_events = self._serialize(service, events)

if txn_id is None:
logger.warning(
Expand All @@ -221,9 +238,12 @@ async def push_bulk(

# Never send ephemeral events to appservices that do not support it
if service.supports_ephemeral:
body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral}
body = {
"events": serialized_events,
"de.sorunome.msc2409.ephemeral": ephemeral,
}
else:
body = {"events": events}
body = {"events": serialized_events}

try:
await self.put_json(
Expand All @@ -238,7 +258,7 @@ async def push_bulk(
[event.get("event_id") for event in events],
)
sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(events))
sent_events_counter.labels(service.id).inc(len(serialized_events))
return True
except CodeMessageException as e:
logger.warning(
Expand All @@ -260,7 +280,9 @@ async def push_bulk(
failed_transactions_counter.labels(service.id).inc()
return False

def _serialize(self, service, events):
def _serialize(
self, service: "ApplicationService", events: Iterable[EventBase]
) -> List[JsonDict]:
time_now = self.clock.time_msec()
return [
serialize_event(
Expand Down
Loading