Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add StreamStore to mypy #8232

Merged
merged 3 commits into from
Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8232.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `StreamStore`.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ files =
synapse/server_notices,
synapse/spam_checker_api,
synapse/state,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py,
synapse/storage/engines,
Expand Down
4 changes: 2 additions & 2 deletions synapse/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import abc
import os
from distutils.util import strtobool
from typing import Dict, Optional, Type
from typing import Dict, Optional, Tuple, Type

from unpaddedbase64 import encode_base64

Expand Down Expand Up @@ -120,7 +120,7 @@ def __init__(self, internal_metadata_dict: JsonDict):
# be here
before = DictProperty("before") # type: str
after = DictProperty("after") # type: str
order = DictProperty("order") # type: int
order = DictProperty("order") # type: Tuple[int, int]

def get_dict(self) -> JsonDict:
return dict(self._dict)
Expand Down
34 changes: 34 additions & 0 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,18 @@ def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
results = [dict(zip(col_headers, row)) for row in cursor]
return results

@overload
async def execute(
self, desc: str, decoder: Literal[None], query: str, *args: Any
) -> List[Tuple[Any, ...]]:
...

@overload
async def execute(
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
) -> R:
...

async def execute(
self,
desc: str,
Expand Down Expand Up @@ -1088,6 +1100,28 @@ async def simple_select_one(
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)

@overload
async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one_onecol",
) -> Any:
...

@overload
async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one_onecol",
) -> Optional[Any]:
...

async def simple_select_one_onecol(
self,
table: str,
Expand Down
52 changes: 33 additions & 19 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,27 @@
import abc
import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple

from twisted.internet import defer

from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.types import RoomStreamToken
from synapse.types import Collection, RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -202,7 +209,7 @@ def _make_generic_sql_bound(
)


def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise.
Expand Down Expand Up @@ -260,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):

__metaclass__ = abc.ABCMeta

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super(StreamWorkerStore, self).__init__(database, db_conn, hs)

self._instance_name = hs.get_instance_name()
Expand Down Expand Up @@ -293,16 +300,16 @@ def __init__(self, database: DatabasePool, db_conn, hs):
self._stream_order_on_start = self.get_room_max_stream_ordering()

@abc.abstractmethod
def get_room_max_stream_ordering(self):
def get_room_max_stream_ordering(self) -> int:
raise NotImplementedError()

@abc.abstractmethod
def get_room_min_stream_ordering(self):
def get_room_min_stream_ordering(self) -> int:
raise NotImplementedError()

async def get_room_events_stream_for_rooms(
self,
room_ids: Iterable[str],
room_ids: Collection[str],
from_key: str,
to_key: str,
limit: int = 0,
Expand Down Expand Up @@ -356,19 +363,21 @@ async def get_room_events_stream_for_rooms(

return results

def get_rooms_that_changed(self, room_ids, from_key):
def get_rooms_that_changed(
self, room_ids: Collection[str], from_key: str
) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.

Args:
room_ids (list)
from_key (str): The room_key portion of a StreamToken
room_ids
from_key: The room_key portion of a StreamToken
"""
from_key = RoomStreamToken.parse_stream_token(from_key).stream
from_id = RoomStreamToken.parse_stream_token(from_key).stream
return {
room_id
for room_id in room_ids
if self._events_stream_cache.has_entity_changed(room_id, from_key)
if self._events_stream_cache.has_entity_changed(room_id, from_id)
}

async def get_room_events_stream_for_room(
Expand Down Expand Up @@ -440,7 +449,9 @@ def f(txn):

return ret, key

async def get_membership_changes_for_user(self, user_id, from_key, to_key):
async def get_membership_changes_for_user(
self, user_id: str, from_key: str, to_key: str
) -> List[EventBase]:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream

Expand Down Expand Up @@ -646,7 +657,7 @@ async def get_max_topological_token(self, room_id: str, stream_key: int) -> int:
)
return row[0][0] if row else 0

def _get_max_topological_txn(self, txn, room_id):
def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
txn.execute(
"SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
(room_id,),
Expand Down Expand Up @@ -719,7 +730,7 @@ async def get_events_around(

def _get_events_around_txn(
self,
txn,
txn: LoggingTransaction,
room_id: str,
event_id: str,
before_limit: int,
Expand Down Expand Up @@ -747,6 +758,9 @@ def _get_events_around_txn(
retcols=["stream_ordering", "topological_ordering"],
)

# This cannot happen as `allow_none=False`.
assert results is not None

# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
Expand Down Expand Up @@ -856,7 +870,7 @@ async def update_federation_out_pos(self, typ: str, stream_id: int) -> None:
desc="update_federation_out_pos",
)

def _reset_federation_positions_txn(self, txn) -> None:
def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
"""Fiddles with the `federation_stream_position` table to make it match
the configured federation sender instances during start up.
"""
Expand Down Expand Up @@ -895,7 +909,7 @@ def _reset_federation_positions_txn(self, txn) -> None:
GROUP BY type
"""
txn.execute(sql)
min_positions = dict(txn) # Map from type -> min position
min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this to just let mypy figure out the type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is because mypy complains that txn implements Iterable[Tuple[Any, ...]] rather than Iterable[Tuple[Any, Any]], whereas with this form we're (implicitly) explicitly asserting that the tuple length is two. I think. But yes its to appease mypy


# Ensure we do actually have some values here
assert set(min_positions) == {"federation", "events"}
Expand All @@ -922,7 +936,7 @@ def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:

def _paginate_room_events_txn(
self,
txn,
txn: LoggingTransaction,
room_id: str,
from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None,
Expand Down