Skip to content

Commit

Permalink
Add support for MSC3202 in appservice module
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Jul 14, 2021
1 parent e2ce035 commit c41b515
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 19 deletions.
50 changes: 33 additions & 17 deletions mautrix/appservice/as_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# Partly based on github.com/Cadair/python-appservice-framework (MIT license)
from typing import Optional, Callable, Awaitable, List, Set
from typing import Optional, Callable, Awaitable, List, Set, Dict, Any
from json import JSONDecodeError
from aiohttp import web
import asyncio
import logging

from mautrix.types import JSON, UserID, RoomAlias, Event, EphemeralEvent, SerializerError
from mautrix.types import (JSON, UserID, RoomAlias, Event, EphemeralEvent, SerializerError,
DeviceOTKCount, DeviceLists)

QueryFunc = Callable[[web.Request], Awaitable[Optional[web.Response]]]
HandlerFunc = Callable[[Event], Awaitable]
Expand Down Expand Up @@ -102,6 +103,17 @@ async def _http_query_alias(self, request: web.Request) -> web.Response:
return web.json_response({}, status=404)
return web.json_response(response)

@staticmethod
def _get_with_fallback(json: Dict[str, Any], field: str, unstable_prefix: str,
default: Any = None) -> Any:
try:
return json.pop(field)
except KeyError:
try:
return json.pop(f"{unstable_prefix}.{field}")
except KeyError:
return default

async def _http_handle_transaction(self, request: web.Request) -> web.Response:
if not self._check_token(request):
return web.json_response({"error": "Invalid auth token"}, status=401)
Expand All @@ -116,29 +128,30 @@ async def _http_handle_transaction(self, request: web.Request) -> web.Response:
return web.json_response({"error": "Body is not JSON"}, status=400)

try:
events = json["events"]
events = json.pop("events")
except KeyError:
return web.json_response({"error": "Missing events object in body"}, status=400)

if self.ephemeral_events:
try:
ephemeral = json["ephemeral"]
except KeyError:
try:
ephemeral = json["de.sorunome.msc2409.ephemeral"]
except KeyError:
ephemeral = None
else:
ephemeral = None
ephemeral = (self._get_with_fallback(json, "ephemeral", "de.sorunome.msc2409")
if self.ephemeral_events else None)
device_lists = DeviceLists.deserialize(
self._get_with_fallback(json, "device_lists", "org.matrix.msc3202"))
otk_counts = {user_id: DeviceOTKCount.deserialize(count)
for user_id, count
in self._get_with_fallback(json, "device_one_time_keys_count",
"org.matrix.msc3202", default={}).items()}

try:
await self.handle_transaction(transaction_id, events=events, ephemeral=ephemeral)
output = await self.handle_transaction(transaction_id, events=events, extra_data=json,
ephemeral=ephemeral, device_lists=device_lists,
device_otk_count=otk_counts)
except Exception:
self.log.exception("Exception in transaction handler")
output = None

self.transactions.add(transaction_id)

return web.json_response({})
return web.json_response(output or {})

@staticmethod
def _fix_prev_content(raw_event: JSON) -> None:
Expand All @@ -150,8 +163,10 @@ def _fix_prev_content(raw_event: JSON) -> None:
except KeyError:
pass

async def handle_transaction(self, txn_id: str, events: List[JSON],
ephemeral: Optional[List[JSON]] = None) -> None:
async def handle_transaction(self, txn_id: str, *, events: List[JSON], extra_data: JSON,
ephemeral: Optional[List[JSON]] = None,
device_otk_count: Optional[Dict[UserID, DeviceOTKCount]] = None,
device_lists: Optional[DeviceLists] = None) -> Optional[JSON]:
for raw_edu in ephemeral or []:
try:
edu = EphemeralEvent.deserialize(raw_edu)
Expand All @@ -167,6 +182,7 @@ async def handle_transaction(self, txn_id: str, events: List[JSON],
self.log.exception("Failed to deserialize event %s", raw_event)
else:
self.handle_matrix_event(event)
return {}

def handle_matrix_event(self, event: Event) -> None:
if event.type.is_state and event.state_key is None:
Expand Down
13 changes: 11 additions & 2 deletions mautrix/types/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,17 @@
from .util import SerializableAttrs
from .event import Event

DeviceLists = NamedTuple("DeviceLists", changed=List[UserID], left=List[UserID])
DeviceOTKCount = NamedTuple("DeviceOTKCount", curve25519=int, signed_curve25519=int)

@dataclass
class DeviceLists(SerializableAttrs):
changed: List[UserID] = attr.ib(factory=lambda: [])
left: List[UserID] = attr.ib(factory=lambda: [])


@dataclass
class DeviceOTKCount(SerializableAttrs):
curve25519: int
signed_curve25519: int


class RoomCreatePreset(Enum):
Expand Down

0 comments on commit c41b515

Please sign in to comment.