From bbd479c314d0584fe76136b317030e7062499595 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 17 Apr 2020 18:23:14 +0100 Subject: [PATCH 1/5] Add some type annotations to StreamChangeCache including dragging in some stubs for SortedDict --- stubs/sortedcontainers/__init__.pyi | 13 +++ stubs/sortedcontainers/sorteddict.pyi | 124 +++++++++++++++++++++ synapse/util/caches/stream_change_cache.py | 37 ++++-- tox.ini | 4 +- 4 files changed, 167 insertions(+), 11 deletions(-) create mode 100644 stubs/sortedcontainers/__init__.pyi create mode 100644 stubs/sortedcontainers/sorteddict.pyi diff --git a/stubs/sortedcontainers/__init__.pyi b/stubs/sortedcontainers/__init__.pyi new file mode 100644 index 000000000000..073b806d3c98 --- /dev/null +++ b/stubs/sortedcontainers/__init__.pyi @@ -0,0 +1,13 @@ +from .sorteddict import ( + SortedDict, + SortedKeysView, + SortedItemsView, + SortedValuesView, +) + +__all__ = [ + "SortedDict", + "SortedKeysView", + "SortedItemsView", + "SortedValuesView", +] diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi new file mode 100644 index 000000000000..68779f968ed6 --- /dev/null +++ b/stubs/sortedcontainers/sorteddict.pyi @@ -0,0 +1,124 @@ +# stub for SortedDict. This is a lightly edited copy of +# https://github.com/grantjenks/python-sortedcontainers/blob/eea42df1f7bad2792e8da77335ff888f04b9e5ae/sortedcontainers/sorteddict.pyi +# (from https://github.com/grantjenks/python-sortedcontainers/pull/107) + +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterator, + Iterable, + ItemsView, + KeysView, + List, + Mapping, + Optional, + Sequence, + Type, + TypeVar, + Tuple, + Union, + ValuesView, + overload, +) + +_T = TypeVar("_T") +_S = TypeVar("_S") +_T_h = TypeVar("_T_h", bound=Hashable) +_KT = TypeVar("_KT", bound=Hashable) # Key type. +_VT = TypeVar("_VT") # Value type. +_KT_co = TypeVar("_KT_co", covariant=True, bound=Hashable) +_VT_co = TypeVar("_VT_co", covariant=True) +_SD = TypeVar("_SD", bound=SortedDict) +_Key = Callable[[_T], Any] + +class SortedDict(Dict[_KT, _VT]): + @overload + def __init__(self, **kwargs: _VT) -> None: ... + @overload + def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ... + @overload + def __init__( + self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT + ) -> None: ... + @overload + def __init__(self, __key: _Key[_KT], **kwargs: _VT) -> None: ... + @overload + def __init__( + self, __key: _Key[_KT], __map: Mapping[_KT, _VT], **kwargs: _VT + ) -> None: ... + @overload + def __init__( + self, __key: _Key[_KT], __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT + ) -> None: ... + @property + def key(self) -> Optional[_Key[_KT]]: ... + @property + def iloc(self) -> SortedKeysView[_KT]: ... + def clear(self) -> None: ... + def __delitem__(self, key: _KT) -> None: ... + def __iter__(self) -> Iterator[_KT]: ... + def __reversed__(self) -> Iterator[_KT]: ... + def __setitem__(self, key: _KT, value: _VT) -> None: ... + def _setitem(self, key: _KT, value: _VT) -> None: ... + def copy(self: _SD) -> _SD: ... + def __copy__(self: _SD) -> _SD: ... + @classmethod + @overload + def fromkeys(cls, seq: Iterable[_T_h]) -> SortedDict[_T_h, None]: ... + @classmethod + @overload + def fromkeys(cls, seq: Iterable[_T_h], value: _S) -> SortedDict[_T_h, _S]: ... + def keys(self) -> SortedKeysView[_KT]: ... + def items(self) -> SortedItemsView[_KT, _VT]: ... + def values(self) -> SortedValuesView[_VT]: ... + @overload + def pop(self, key: _KT) -> _VT: ... + @overload + def pop(self, key: _KT, default: _T = ...) -> Union[_VT, _T]: ... + def popitem(self, index: int = ...) -> Tuple[_KT, _VT]: ... + def peekitem(self, index: int = ...) -> Tuple[_KT, _VT]: ... + def setdefault(self, key: _KT, default: Optional[_VT] = ...) -> _VT: ... + @overload + def update(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ... + @overload + def update(self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT) -> None: ... + @overload + def update(self, **kwargs: _VT) -> None: ... + def __reduce__( + self, + ) -> Tuple[ + Type[SortedDict[_KT, _VT]], Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]], + ]: ... + def __repr__(self) -> str: ... + def _check(self) -> None: ... + def islice( + self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool, + ) -> Iterator[_KT]: ... + def bisect_left(self, value: _KT) -> int: ... + def bisect_right(self, value: _KT) -> int: ... + +class SortedKeysView(KeysView[_KT_co], Sequence[_KT_co]): + @overload + def __getitem__(self, index: int) -> _KT_co: ... + @overload + def __getitem__(self, index: slice) -> List[_KT_co]: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... + +class SortedItemsView( # type: ignore + ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]] +): + def __iter__(self) -> Iterator[Tuple[_KT_co, _VT_co]]: ... + @overload + def __getitem__(self, index: int) -> Tuple[_KT_co, _VT_co]: ... + @overload + def __getitem__(self, index: slice) -> List[Tuple[_KT_co, _VT_co]]: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... + +class SortedValuesView(ValuesView[_VT_co], Sequence[_VT_co]): + @overload + def __getitem__(self, index: int) -> _VT_co: ... + @overload + def __getitem__(self, index: slice) -> List[_VT_co]: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 235f64049c95..a1c98a506366 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Collection, Dict, Iterable, Mapping, Optional, Set from six import integer_types @@ -23,8 +24,11 @@ logger = logging.getLogger(__name__) +# for now, assume all entities in the cache are strings +EntityType = str -class StreamChangeCache(object): + +class StreamChangeCache: """Keeps track of the stream positions of the latest change in a set of entities. Typically the entity will be a room or user id. @@ -34,10 +38,19 @@ class StreamChangeCache(object): old then the cache will simply return all given entities. """ - def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache=None): + def __init__( + self, + name: str, + current_stream_pos: int, + max_size=10000, + prefilled_cache: Optional[Mapping[EntityType, int]] = None, + ): self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR) - self._entity_to_key = {} - self._cache = SortedDict() + self._entity_to_key = {} # type: Dict[EntityType, int] + + # map from stream id to the entity which changed at that stream id. + self._cache = SortedDict() # type: SortedDict[int, EntityType] + self._earliest_known_stream_pos = current_stream_pos self.name = name self.metrics = caches.register_cache("cache", self.name, self._cache) @@ -46,7 +59,7 @@ def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache=Non for entity, stream_pos in prefilled_cache.items(): self.entity_has_changed(entity, stream_pos) - def has_entity_changed(self, entity, stream_pos): + def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: """Returns True if the entity may have been updated since stream_pos """ assert type(stream_pos) in integer_types @@ -67,7 +80,9 @@ def has_entity_changed(self, entity, stream_pos): self.metrics.inc_hits() return False - def get_entities_changed(self, entities, stream_pos): + def get_entities_changed( + self, entities: Iterable[EntityType], stream_pos: int + ) -> Set[EntityType]: """ Returns subset of entities that have had new things since the given position. Entities unknown to the cache will be returned. If the @@ -90,7 +105,7 @@ def get_entities_changed(self, entities, stream_pos): return result - def has_any_entity_changed(self, stream_pos): + def has_any_entity_changed(self, stream_pos: int) -> bool: """Returns if any entity has changed """ assert type(stream_pos) is int @@ -106,7 +121,9 @@ def has_any_entity_changed(self, stream_pos): self.metrics.inc_misses() return True - def get_all_entities_changed(self, stream_pos): + def get_all_entities_changed( + self, stream_pos: int + ) -> Optional[Collection[EntityType]]: """Returns all entites that have had new things since the given position. If the position is too old it will return None. """ @@ -120,7 +137,7 @@ def get_all_entities_changed(self, stream_pos): else: return None - def entity_has_changed(self, entity, stream_pos): + def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None: """Informs the cache that the entity has been changed at the given position. """ @@ -141,7 +158,7 @@ def entity_has_changed(self, entity, stream_pos): ) self._entity_to_key.pop(r, None) - def get_max_pos_of_last_change(self, entity): + def get_max_pos_of_last_change(self, entity: EntityType) -> int: """Returns an upper bound of the stream id of the last change to an entity. """ diff --git a/tox.ini b/tox.ini index 42b2d7489132..31011d7436f5 100644 --- a/tox.ini +++ b/tox.ini @@ -202,7 +202,9 @@ commands = mypy \ synapse/spam_checker_api \ synapse/storage/engines \ synapse/storage/database.py \ - synapse/streams + synapse/streams \ + synapse/util/caches/stream_change_cache.py \ + tests/util/test_stream_change_cache.py # To find all folders that pass mypy you run: # From 5a2efc02415f7e8bed3a75bccf77b2705e259911 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 17 Apr 2020 20:46:54 +0100 Subject: [PATCH 2/5] Add some more tests for cache expiry --- synapse/util/caches/stream_change_cache.py | 4 ++++ tests/util/test_stream_change_cache.py | 11 ++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index a1c98a506366..867d3e1c65ad 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -51,6 +51,10 @@ def __init__( # map from stream id to the entity which changed at that stream id. self._cache = SortedDict() # type: SortedDict[int, EntityType] + # the earliest stream_pos for which we can reliably answer + # get_all_entities_changed. In other words, one less than the earliest + # stream_pos for which we know _cache is valid. + # self._earliest_known_stream_pos = current_stream_pos self.name = name self.metrics = caches.register_cache("cache", self.name, self._cache) diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 72a9de5370ec..bb732c2f9cf4 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -47,7 +47,7 @@ def test_has_entity_changed(self): self.assertTrue(cache.has_entity_changed("not@here.website", 0)) @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0) - def test_has_entity_changed_pops_off_start(self): + def test_entity_has_changed_pops_off_start(self): """ StreamChangeCache.entity_has_changed will respect the max size and purge the oldest items upon reaching that max size. @@ -64,11 +64,20 @@ def test_has_entity_changed_pops_off_start(self): # The oldest item has been popped off self.assertTrue("user@foo.com" not in cache._entity_to_key) + self.assertEqual( + cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"], + ) + self.assertIsNone(cache.get_all_entities_changed(1)) + # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("bar@baz.net", 5) self.assertEqual( {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) + self.assertEqual( + cache.get_all_entities_changed(2), ["user@elsewhere.org", "bar@baz.net",], + ) + self.assertIsNone(cache.get_all_entities_changed(1)) def test_get_all_entities_changed(self): """ From fa7bb6b911d496a461cf03a9e106f830018d6081 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 17 Apr 2020 14:24:59 +0100 Subject: [PATCH 3/5] Add support for multiple entities per stream id --- synapse/util/caches/stream_change_cache.py | 82 ++++++++++++---------- tests/util/test_stream_change_cache.py | 58 ++++++++++++--- 2 files changed, 94 insertions(+), 46 deletions(-) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 867d3e1c65ad..65d69ca918f5 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Collection, Dict, Iterable, Mapping, Optional, Set +from typing import Dict, Iterable, List, Mapping, Optional, Set from six import integer_types @@ -48,8 +48,8 @@ def __init__( self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR) self._entity_to_key = {} # type: Dict[EntityType, int] - # map from stream id to the entity which changed at that stream id. - self._cache = SortedDict() # type: SortedDict[int, EntityType] + # map from stream id to the a set of entities which changed at that stream id. + self._cache = SortedDict() # type: SortedDict[int, Set[EntityType]] # the earliest stream_pos for which we can reliably answer # get_all_entities_changed. In other words, one less than the earliest @@ -92,16 +92,9 @@ def get_entities_changed( position. Entities unknown to the cache will be returned. If the position is too old it will just return the given list. """ - assert type(stream_pos) is int - - if stream_pos >= self._earliest_known_stream_pos: - changed_entities = { - self._cache[k] - for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) - } - - result = changed_entities.intersection(entities) - + changed_entities = self.get_all_entities_changed(stream_pos) + if changed_entities is not None: + result = set(changed_entities).intersection(entities) self.metrics.inc_hits() else: result = set(entities) @@ -115,7 +108,7 @@ def has_any_entity_changed(self, stream_pos: int) -> bool: assert type(stream_pos) is int if not self._cache: - # If we have no cache, nothing can have changed. + # If the cache is empty, nothing can have changed. return False if stream_pos >= self._earliest_known_stream_pos: @@ -125,42 +118,55 @@ def has_any_entity_changed(self, stream_pos: int) -> bool: self.metrics.inc_misses() return True - def get_all_entities_changed( - self, stream_pos: int - ) -> Optional[Collection[EntityType]]: - """Returns all entites that have had new things since the given + def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]: + """Returns all entities that have had new things since the given position. If the position is too old it will return None. + + Returns the entities in the order that they were changed. """ assert type(stream_pos) is int - if stream_pos >= self._earliest_known_stream_pos: - return [ - self._cache[k] - for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) - ] - else: + if stream_pos < self._earliest_known_stream_pos: return None + changed_entities = [] # type: List[EntityType] + + for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)): + changed_entities.extend(self._cache[k]) + return changed_entities + def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None: """Informs the cache that the entity has been changed at the given position. """ assert type(stream_pos) is int - if stream_pos > self._earliest_known_stream_pos: - old_pos = self._entity_to_key.get(entity, None) - if old_pos is not None: - stream_pos = max(stream_pos, old_pos) - self._cache.pop(old_pos, None) - self._cache[stream_pos] = entity - self._entity_to_key[entity] = stream_pos - - while len(self._cache) > self._max_size: - k, r = self._cache.popitem(0) - self._earliest_known_stream_pos = max( - k, self._earliest_known_stream_pos - ) - self._entity_to_key.pop(r, None) + if stream_pos <= self._earliest_known_stream_pos: + return + + old_pos = self._entity_to_key.get(entity, None) + if old_pos is not None: + if old_pos >= stream_pos: + # nothing to do + return + e = self._cache[old_pos] + e.remove(entity) + if not e: + # cache at this point is now empty + del self._cache[old_pos] + + e1 = self._cache.get(stream_pos) + if e1 is None: + e1 = self._cache[stream_pos] = set() + e1.add(entity) + self._entity_to_key[entity] = stream_pos + + # if the cache is too big, remove entries + while len(self._cache) > self._max_size: + k, r = self._cache.popitem(0) + self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) + for entity in r: + del self._entity_to_key[entity] def get_max_pos_of_last_change(self, entity: EntityType) -> int: """Returns an upper bound of the stream id of the last change to an diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index bb732c2f9cf4..9a2e4a77b1f2 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -28,18 +28,26 @@ def test_has_entity_changed(self): cache.entity_has_changed("user@foo.com", 6) cache.entity_has_changed("bar@baz.net", 7) + # also test multiple things changing on the same stream ID + cache.entity_has_changed("user2@foo.com", 8) + cache.entity_has_changed("bar2@baz.net", 8) + # If it's been changed after that stream position, return True self.assertTrue(cache.has_entity_changed("user@foo.com", 4)) self.assertTrue(cache.has_entity_changed("bar@baz.net", 4)) + self.assertTrue(cache.has_entity_changed("bar2@baz.net", 4)) + self.assertTrue(cache.has_entity_changed("user2@foo.com", 4)) # If it's been changed at that stream position, return False self.assertFalse(cache.has_entity_changed("user@foo.com", 6)) + self.assertFalse(cache.has_entity_changed("user2@foo.com", 8)) # If there's no changes after that stream position, return False self.assertFalse(cache.has_entity_changed("user@foo.com", 7)) + self.assertFalse(cache.has_entity_changed("user2@foo.com", 9)) # If the entity does not exist, return False. - self.assertFalse(cache.has_entity_changed("not@here.website", 7)) + self.assertFalse(cache.has_entity_changed("not@here.website", 9)) # If we request before the stream cache's earliest known position, # return True, whether it's a known entity or not. @@ -89,18 +97,52 @@ def test_get_all_entities_changed(self): cache.entity_has_changed("user@foo.com", 2) cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("anotheruser@foo.com", 3) cache.entity_has_changed("user@elsewhere.org", 4) - self.assertEqual( - cache.get_all_entities_changed(1), - ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], - ) - self.assertEqual( - cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"] - ) + r = cache.get_all_entities_changed(1) + + # either of these are valid + ok1 = [ + "user@foo.com", + "bar@baz.net", + "anotheruser@foo.com", + "user@elsewhere.org", + ] + ok2 = [ + "user@foo.com", + "anotheruser@foo.com", + "bar@baz.net", + "user@elsewhere.org", + ] + self.assertTrue(r == ok1 or r == ok2) + + r = cache.get_all_entities_changed(2) + self.assertTrue(r == ok1[1:] or r == ok2[1:]) + self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) self.assertEqual(cache.get_all_entities_changed(0), None) + # ... later, things gest more updates + cache.entity_has_changed("user@foo.com", 5) + cache.entity_has_changed("bar@baz.net", 5) + cache.entity_has_changed("anotheruser@foo.com", 6) + + ok1 = [ + "user@elsewhere.org", + "user@foo.com", + "bar@baz.net", + "anotheruser@foo.com", + ] + ok2 = [ + "user@elsewhere.org", + "bar@baz.net", + "user@foo.com", + "anotheruser@foo.com", + ] + r = cache.get_all_entities_changed(3) + self.assertTrue(r == ok1 or r == ok2) + def test_has_any_entity_changed(self): """ StreamChangeCache.has_any_entity_changed will return True if any From d85be5d2515b35c03bb67af417daf7bd22448e43 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 17 Apr 2020 21:00:27 +0100 Subject: [PATCH 4/5] changelog --- changelog.d/7303.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/7303.misc diff --git a/changelog.d/7303.misc b/changelog.d/7303.misc new file mode 100644 index 000000000000..aa89c2b25444 --- /dev/null +++ b/changelog.d/7303.misc @@ -0,0 +1 @@ +Fix StreamChangeCache to work with multiple entities changing on the same stream id. From 7cb6db354e726f5350f656e811ab99099734ec57 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 17 Apr 2020 23:24:21 +0100 Subject: [PATCH 5/5] fix lint --- tests/util/test_stream_change_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 9a2e4a77b1f2..6857933540d6 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -83,7 +83,7 @@ def test_entity_has_changed_pops_off_start(self): {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) self.assertEqual( - cache.get_all_entities_changed(2), ["user@elsewhere.org", "bar@baz.net",], + cache.get_all_entities_changed(2), ["user@elsewhere.org", "bar@baz.net"], ) self.assertIsNone(cache.get_all_entities_changed(1))