From 68507ea212303bdea27697ec1f44911298c2c888 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 23 Sep 2020 16:51:37 -0400 Subject: [PATCH 1/3] Add type hints to response cache. --- mypy.ini | 1 + synapse/handlers/initial_sync.py | 4 +-- synapse/handlers/room.py | 3 ++ synapse/util/caches/response_cache.py | 46 +++++++++++++++------------ 4 files changed, 31 insertions(+), 23 deletions(-) diff --git a/mypy.ini b/mypy.ini index a7ffb81ef133..2475b9c03e92 100644 --- a/mypy.ini +++ b/mypy.ini @@ -60,6 +60,7 @@ files = synapse/types.py, synapse/util/async_helpers.py, synapse/util/caches/descriptors.py, + synapse/util/caches/response_cache.py, synapse/util/caches/stream_change_cache.py, synapse/util/metrics.py, tests/replication, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 39a85801c1ad..17972e94e708 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -52,7 +52,7 @@ def __init__(self, hs: "HomeServer"): self.storage = hs.get_storage() self.state_store = self.storage.state - def snapshot_all_rooms( + async def snapshot_all_rooms( self, user_id: str, pagin_config: PaginationConfig, @@ -84,7 +84,7 @@ def snapshot_all_rooms( include_archived, ) - return self.snapshot_cache.wrap( + return await self.snapshot_cache.wrap( key, self._snapshot_all_rooms, user_id, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index d0530a446c83..4e87133ff961 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -147,6 +147,9 @@ async def upgrade_room( # Check if this room is already being upgraded by another person for key in self._upgrade_response_cache.pending_result_cache: + # The keys of pending_result_cache just need to be hashable, but in + # this case we know that it is a tuple. + assert isinstance(key, tuple) if key[0] == old_room_id and key[1] != user_id: # Two different people are trying to upgrade the same room. # Send the second an error. diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index df1a721adda8..55e7a83dd17f 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, Optional from twisted.internet import defer @@ -20,6 +21,9 @@ from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import register_cache +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -31,8 +35,9 @@ class ResponseCache: used rather than trying to compute a new response. """ - def __init__(self, hs, name, timeout_ms=0): - self.pending_result_cache = {} # Requests that haven't finished yet. + def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0): + # Requests that haven't finished yet. + self.pending_result_cache = {} # type: Dict[Hashable, ObservableDeferred] self.clock = hs.get_clock() self.timeout_sec = timeout_ms / 1000.0 @@ -40,13 +45,13 @@ def __init__(self, hs, name, timeout_ms=0): self._name = name self._metrics = register_cache("response_cache", name, self, resizable=False) - def size(self): + def size(self) -> int: return len(self.pending_result_cache) - def __len__(self): + def __len__(self) -> int: return self.size() - def get(self, key): + def get(self, key: Hashable) -> Optional[defer.Deferred]: """Look up the given key. Can return either a new Deferred (which also doesn't follow the synapse @@ -58,12 +63,11 @@ def get(self, key): from an absent cache entry. Args: - key (hashable): + key: key to get/set in the cache Returns: - twisted.internet.defer.Deferred|None|E: None if there is no entry - for this key; otherwise either a deferred result or the result - itself. + None if there is no entry for this key; otherwise a deferred which + resolves to the result. """ result = self.pending_result_cache.get(key) if result is not None: @@ -73,7 +77,7 @@ def get(self, key): self._metrics.inc_misses() return None - def set(self, key, deferred): + def set(self, key: Hashable, deferred: defer.Deferred) -> defer.Deferred: """Set the entry for the given key to the given deferred. *deferred* should run its callbacks in the sentinel logcontext (ie, @@ -85,12 +89,11 @@ def set(self, key, deferred): result. You will probably want to make_deferred_yieldable the result. Args: - key (hashable): - deferred (twisted.internet.defer.Deferred[T): + key: key to get/set in the cache + deferred: The deferred which resolves to the result. Returns: - twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual - result. + A new deferred which resolves to the actual result. """ result = ObservableDeferred(deferred, consumeErrors=True) self.pending_result_cache[key] = result @@ -107,7 +110,9 @@ def remove(r): result.addBoth(remove) return result.observe() - def wrap(self, key, callback, *args, **kwargs): + def wrap( + self, key: Hashable, callback: "Callable[..., Any]", *args: Any, **kwargs: Any + ) -> defer.Deferred: """Wrap together a *get* and *set* call, taking care of logcontexts First looks up the key in the cache, and if it is present makes it @@ -118,21 +123,20 @@ def wrap(self, key, callback, *args, **kwargs): Example usage: - @defer.inlineCallbacks - def handle_request(request): + async def handle_request(request): # etc return result - result = yield response_cache.wrap( + result = await response_cache.wrap( key, handle_request, request, ) Args: - key (hashable): key to get/set in the cache + key: key to get/set in the cache - callback (callable): function to call if the key is not found in + callback: function to call if the key is not found in the cache *args: positional parameters to pass to the callback, if it is used @@ -140,7 +144,7 @@ def handle_request(request): **kwargs: named parameters to pass to the callback, if it is used Returns: - twisted.internet.defer.Deferred: yieldable result + Deferred which resolves to the result """ result = self.get(key) if not result: From abe8bb786a418a024982a0d26a357e17fe7681da Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 8 Oct 2020 14:59:09 -0400 Subject: [PATCH 2/3] Add newsfragment. --- changelog.d/8507.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/8507.misc diff --git a/changelog.d/8507.misc b/changelog.d/8507.misc new file mode 100644 index 000000000000..724da8a9960e --- /dev/null +++ b/changelog.d/8507.misc @@ -0,0 +1 @@ + Add type hints to various parts of the code base. From dd1971a63019fddcbe8ed8f6d210efc073f95d50 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 9 Oct 2020 11:02:27 -0400 Subject: [PATCH 3/3] Make ResponseCache a generic. --- synapse/appservice/api.py | 4 ++-- synapse/federation/federation_server.py | 8 +++++--- synapse/handlers/initial_sync.py | 6 ++++-- synapse/handlers/room.py | 5 +---- synapse/handlers/sync.py | 4 +++- synapse/replication/http/_base.py | 2 +- synapse/util/caches/response_cache.py | 14 ++++++++------ 7 files changed, 24 insertions(+), 19 deletions(-) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index c526c28b9307..e8f07937952b 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Tuple from prometheus_client import Counter @@ -93,7 +93,7 @@ def __init__(self, hs): self.protocol_meta_cache = ResponseCache( hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS - ) + ) # type: ResponseCache[Tuple[str, str]] async def query_user(self, service, user_id): if service.url is None: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 02f11e120997..3fa520f5b6a4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -111,7 +111,7 @@ def __init__(self, hs): # We cache results for transaction with the same ID self._transaction_resp_cache = ResponseCache( hs, "fed_txn_handler", timeout_ms=30000 - ) + ) # type: ResponseCache[Tuple[str, str]] self.transaction_actions = TransactionActions(self.store) @@ -119,10 +119,12 @@ def __init__(self, hs): # We cache responses to state queries, as they take a while and often # come in waves. - self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) + self._state_resp_cache = ResponseCache( + hs, "state_resp", timeout_ms=30000 + ) # type: ResponseCache[Tuple[str, str]] self._state_ids_resp_cache = ResponseCache( hs, "state_ids_resp", timeout_ms=30000 - ) + ) # type: ResponseCache[Tuple[str, str]] self._federation_metrics_domains = ( hs.get_config().federation.federation_metrics_domains diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 17972e94e708..98075f48d2b3 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple from twisted.internet import defer @@ -47,7 +47,9 @@ def __init__(self, hs: "HomeServer"): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = ResponseCache(hs, "initial_sync_cache") + self.snapshot_cache = ResponseCache( + hs, "initial_sync_cache" + ) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]] self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 4e87133ff961..c2a2783cb7e1 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -120,7 +120,7 @@ def __init__(self, hs: "HomeServer"): # subsequent requests self._upgrade_response_cache = ResponseCache( hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS - ) + ) # type: ResponseCache[Tuple[str, str]] self._server_notices_mxid = hs.config.server_notices_mxid self.third_party_event_rules = hs.get_third_party_event_rules() @@ -147,9 +147,6 @@ async def upgrade_room( # Check if this room is already being upgraded by another person for key in self._upgrade_response_cache.pending_result_cache: - # The keys of pending_result_cache just need to be hashable, but in - # this case we know that it is a tuple. - assert isinstance(key, tuple) if key[0] == old_room_id and key[1] != user_id: # Two different people are trying to upgrade the same room. # Send the second an error. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6fb8332f9365..a3066310942e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -243,7 +243,9 @@ def __init__(self, hs: "HomeServer"): self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() - self.response_cache = ResponseCache(hs, "sync") + self.response_cache = ResponseCache( + hs, "sync" + ) # type: ResponseCache[Tuple[Any, ...]] self.state = hs.get_state_handler() self.auth = hs.get_auth() self.storage = hs.get_storage() diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 64edadb624c1..2b3972cb1418 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -92,7 +92,7 @@ def __init__(self, hs): if self.CACHE: self.response_cache = ResponseCache( hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 - ) + ) # type: ResponseCache[str] # We reserve `instance_name` as a parameter to sending requests, so we # assert here that sub classes don't try and use the name. diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 55e7a83dd17f..32228f42ee59 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar from twisted.internet import defer @@ -26,8 +26,10 @@ logger = logging.getLogger(__name__) +T = TypeVar("T") -class ResponseCache: + +class ResponseCache(Generic[T]): """ This caches a deferred response. Until the deferred completes it will be returned from the cache. This means that if the client retries the request @@ -37,7 +39,7 @@ class ResponseCache: def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0): # Requests that haven't finished yet. - self.pending_result_cache = {} # type: Dict[Hashable, ObservableDeferred] + self.pending_result_cache = {} # type: Dict[T, ObservableDeferred] self.clock = hs.get_clock() self.timeout_sec = timeout_ms / 1000.0 @@ -51,7 +53,7 @@ def size(self) -> int: def __len__(self) -> int: return self.size() - def get(self, key: Hashable) -> Optional[defer.Deferred]: + def get(self, key: T) -> Optional[defer.Deferred]: """Look up the given key. Can return either a new Deferred (which also doesn't follow the synapse @@ -77,7 +79,7 @@ def get(self, key: Hashable) -> Optional[defer.Deferred]: self._metrics.inc_misses() return None - def set(self, key: Hashable, deferred: defer.Deferred) -> defer.Deferred: + def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred: """Set the entry for the given key to the given deferred. *deferred* should run its callbacks in the sentinel logcontext (ie, @@ -111,7 +113,7 @@ def remove(r): return result.observe() def wrap( - self, key: Hashable, callback: "Callable[..., Any]", *args: Any, **kwargs: Any + self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any ) -> defer.Deferred: """Wrap together a *get* and *set* call, taking care of logcontexts