Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add an admin endpoint to allow authorizing server to signal token revocations #16125

Merged
merged 19 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16125.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add an admin endpoint to allow authorizing server to signal token revocations.
12 changes: 12 additions & 0 deletions synapse/api/auth/msc3861_delegated.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,15 @@ async def get_user_by_access_token(
scope=scope,
is_guest=(has_guest_scope and not has_user_scope),
)

def invalidate_cached_tokens(self, keys: List[str]) -> None:
"""
Invalidate the entry(s) in the introspection token cache corresponding to the given key
"""
self._token_cache.invalidate(keys)

def invalidate_token_cache(self) -> None:
"""
Invalidate the entire token cache.
"""
self._token_cache.invalidate_all()
12 changes: 12 additions & 0 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import (
AccountDataStream,
CachesStream,
DeviceListsStream,
PushersStream,
PushRulesStream,
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(self, hs: "HomeServer"):
self._instance_name = hs.get_instance_name()
self._typing_handler = hs.get_typing_handler()
self._state_storage_controller = hs.get_storage_controllers().state
self.auth = hs.get_auth()

self._notify_pushers = hs.config.worker.start_pushers
self._pusher_pool = hs.get_pusherpool()
Expand Down Expand Up @@ -218,6 +220,16 @@ async def on_rdata(
self._state_storage_controller.notify_event_un_partial_stated(
row.event_id
)
# invalidate the introspection token cache
elif stream_name == CachesStream.NAME:
for row in rows:
if row.cache_func == "introspection_token_invalidation":
if row.keys[0] is None:
# invalidate the whole cache
# mypy ignore - the token cache is defined on MSC3861DelegatedAuth
self.auth.invalidate_token_cache() # type: ignore[attr-defined]
else:
self.auth.invalidate_cached_tokens(row.keys) # type: ignore[attr-defined]

await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows
Expand Down
3 changes: 3 additions & 0 deletions synapse/rest/admin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ListDestinationsRestServlet,
)
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.oidc import OIDCTokenRevocationRestServlet
from synapse.rest.admin.registration_tokens import (
ListRegistrationTokensRestServlet,
NewRegistrationTokenRestServlet,
Expand Down Expand Up @@ -297,6 +298,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
BackgroundUpdateRestServlet(hs).register(http_server)
BackgroundUpdateStartJobRestServlet(hs).register(http_server)
ExperimentalFeaturesRestServlet(hs).register(http_server)
if hs.config.experimental.msc3861.enabled:
OIDCTokenRevocationRestServlet(hs).register(http_server)


def register_servlets_for_client_rest_resource(
Expand Down
50 changes: 50 additions & 0 deletions synapse/rest/admin/oidc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2023 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Tuple

from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin

if TYPE_CHECKING:
from synapse.server import HomeServer


class OIDCTokenRevocationRestServlet(RestServlet):
"""
Delete a given token introspection response - identified by the `jti` field - from the
introspection token cache when a token is revoked at the authorizing server
"""

PATTERNS = admin_patterns("/OIDC_token_revocation/(?P<token_id>[^/]*)")

def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main

async def on_DELETE(
self, request: SynapseRequest, token_id: str
) -> Tuple[HTTPStatus, Dict]:
await assert_requester_is_admin(self.auth, request)

# mypy ignore - this attribute is defined on MSC3861DelegatedAuth, which is loaded via a config flag
# this endpoint will only be loaded if the same config flag is present
self.auth._token_cache.invalidate([token_id]) # type: ignore[attr-defined]

# make sure we invalidate the cache on any workers
await self.store.stream_introspection_token_invalidation((token_id,))

return HTTPStatus.OK, {}
13 changes: 13 additions & 0 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,19 @@ def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
else:
return 0

async def stream_introspection_token_invalidation(
self, key: Tuple[Optional[str]]
) -> None:
"""
Stream an invalidation request for the introspection token cache to workers

Args:
key: token_id of the introspection token to remove from the cache
"""
await self.send_invalidation_to_replication(
"introspection_token_invalidation", key
)

@wrap_as_background_process("clean_up_old_cache_invalidations")
async def _clean_up_cache_invalidation_wrapper(self) -> None:
"""
Expand Down
9 changes: 9 additions & 0 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from synapse.api.constants import EduTypes
from synapse.api.errors import Codes, StoreError
from synapse.config.homeserver import HomeServerConfig
from synapse.logging.opentracing import (
get_active_span_text_map,
set_tag,
Expand Down Expand Up @@ -1663,6 +1664,7 @@ def __init__(
self.device_id_exists_cache: LruCache[
Tuple[str, str], Literal[True]
] = LruCache(cache_name="device_id_exists", max_size=10000)
self.config: HomeServerConfig = hs.config

async def store_device(
self,
Expand Down Expand Up @@ -1784,6 +1786,13 @@ def _delete_devices_txn(txn: LoggingTransaction) -> None:
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))

# TODO: don't nuke the entire cache once there is a way to associate
# device_id -> introspection_token
if self.config.experimental.msc3861.enabled:
# mypy ignore - the token cache is defined on MSC3861DelegatedAuth
self.auth._token_cache.invalidate_all() # type: ignore[attr-defined]
await self.stream_introspection_token_invalidation((None,))

async def update_device(
self, user_id: str, device_id: str, new_display_name: Optional[str] = None
) -> None:
Expand Down
25 changes: 24 additions & 1 deletion synapse/util/caches/expiringcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from collections import OrderedDict
from typing import Any, Generic, Optional, TypeVar, Union, overload
from typing import Any, Generic, List, Optional, TypeVar, Union, overload

import attr
from typing_extensions import Literal
Expand Down Expand Up @@ -140,6 +140,21 @@ def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:

return value.value

def invalidate(self, keys: List[KT]) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

invalidate on the other caches just invalidates a single key. Let's do the same thing here to avoid confusion

"""
Remove the given key(s) from the cache.
"""

for key in keys:
value = self._cache.pop(key, None)
if value:
if self.iterable:
self.metrics.inc_evictions(
EvictionReason.invalidation, len(value.value)
)
else:
self.metrics.inc_evictions(EvictionReason.invalidation)

def __contains__(self, key: KT) -> bool:
return key in self._cache

Expand Down Expand Up @@ -193,6 +208,14 @@ async def _prune_cache(self) -> None:
len(self),
)

def invalidate_all(self) -> None:
"""
Remove all items from the cache.
"""
keys = set(self._cache.keys())
for key in keys:
self._cache.pop(key)

def __len__(self) -> int:
if self.iterable:
return sum(len(entry.value) for entry in self._cache.values())
Expand Down
34 changes: 33 additions & 1 deletion tests/handlers/test_oauth_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from http import HTTPStatus
from typing import Any, Dict, Union
from unittest.mock import ANY, Mock
from unittest.mock import ANY, AsyncMock, Mock
from urllib.parse import parse_qs

from signedjson.key import (
Expand Down Expand Up @@ -588,6 +588,38 @@ def test_introspection_token_cache(self) -> None:
)
self.assertEqual(self.http_client.request.call_count, 2)

def test_revocation_endpoint(self) -> None:
# mock introspection response and then admin verification response
self.http_client.request = AsyncMock(
side_effect=[
FakeResponse.json(
code=200, payload={"active": True, "jti": "open_sesame"}
),
FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
"username": USERNAME,
},
),
]
)

# cache a token to delete
introspection_token = self.get_success(
self.auth._introspect_token("open_sesame") # type: ignore[attr-defined]
)
self.assertEqual(self.auth._token_cache.get("open_sesame"), introspection_token) # type: ignore[attr-defined]

# delete the revoked token
introspection_token_id = "open_sesame"
url = f"/_synapse/admin/v1/OIDC_token_revocation/{introspection_token_id}"
channel = self.make_request("DELETE", url, access_token="mockAccessToken")
self.assertEqual(channel.code, 200)
self.assertEqual(self.auth._token_cache.get("open_sesame"), None) # type: ignore[attr-defined]

def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
# We only generate a master key to simplify the test.
master_signing_key = generate_signing_key(device_id)
Expand Down
48 changes: 48 additions & 0 deletions tests/replication/test_intro_token_invalidation.py
clokep marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Any, Dict

import synapse.rest.admin._base

from tests.replication._base import BaseMultiWorkerStreamTestCase


class IntrospectionTokenCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
servlets = [synapse.rest.admin.register_servlets]

def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["disable_registration"] = True
config["experimental_features"] = {
"msc3861": {
"enabled": True,
"issuer": "some_dude",
"client_id": "ID",
"client_auth_method": "client_secret_post",
"client_secret": "secret",
}
}
return config

def test_stream_introspection_token_invalidation(self) -> None:
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
auth = worker_hs.get_auth()
store = self.hs.get_datastores().main

# add a token to the cache on the worker
auth._token_cache["open_sesame"] = "intro_token" # type: ignore[attr-defined]

# stream the invalidation from the master
self.get_success(
store.stream_introspection_token_invalidation(("open_sesame",))
)

# check that the cache on the worker was invalidated
self.assertEqual(auth._token_cache.get("open_sesame"), None) # type: ignore[attr-defined]

# test invalidating whole cache
for i in range(0, 5):
auth._token_cache[f"open_sesame_{i}"] = f"intro_token_{i}" # type: ignore[attr-defined]
self.assertEqual(len(auth._token_cache), 5) # type: ignore[attr-defined]

self.get_success(store.stream_introspection_token_invalidation((None,)))

self.assertEqual(len(auth._token_cache), 0) # type: ignore[attr-defined]