Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to tests.replication #14987

Merged
merged 6 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/replication/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
repl_handler,
)

self._client_transport = None
self._server_transport = None
self._client_transport: Optional[FakeTransport] = None
self._server_transport: Optional[FakeTransport] = None

def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
Expand Down
2 changes: 1 addition & 1 deletion tests/replication/http/test__base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def _handle_request( # type: ignore[override]
class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
"""Tests for `ReplicationEndpoint` cancellation."""

def create_test_resource(self):
def create_test_resource(self) -> JsonResource:
"""Overrides `HomeserverTestCase.create_test_resource`."""
resource = JsonResource(self.hs)

Expand Down
23 changes: 15 additions & 8 deletions tests/replication/slave/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Iterable, Optional
from unittest.mock import Mock

from tests.replication._base import BaseStreamTestCase
from twisted.test.proto_helpers import MemoryReactor

from synapse.server import HomeServer
from synapse.util import Clock

class BaseSlavedStoreTestCase(BaseStreamTestCase):
def make_homeserver(self, reactor, clock):
from tests.replication._base import BaseStreamTestCase

hs = self.setup_test_homeserver(federation_client=Mock())

return hs
class BaseSlavedStoreTestCase(BaseStreamTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=Mock())

def prepare(self, reactor, clock, hs) -> None:
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)

self.reconnect()

self.master_store = hs.get_datastores().main
self.slaved_store = self.worker_hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
persistence = hs.get_storage_controllers().persistence
assert persistence is not None
self.persistance = persistence

def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
Expand All @@ -41,7 +46,9 @@ def replicate(self) -> None:
self.streamer.on_notifier_poke()
self.pump(0.1)

def check(self, method, args, expected_result=None) -> None:
def check(
self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
) -> None:
master_result = self.get_success(getattr(self.master_store, method)(*args))
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
if expected_result is not None:
Expand Down
69 changes: 34 additions & 35 deletions tests/replication/slave/storage/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Iterable, Optional
from typing import Any, Callable, Iterable, List, Optional, Tuple

from canonicaljson import encode_canonical_json
from parameterized import parameterized

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import ReceiptTypes
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.handlers.room import RoomEventSource
from synapse.server import HomeServer
from synapse.storage.databases.main.event_push_actions import (
NotifCounts,
RoomNotifCounts,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
from synapse.types import PersistedEventPosition
from synapse.util import Clock

from tests.server import FakeTransport

Expand All @@ -41,19 +46,19 @@
logger = logging.getLogger(__name__)


def dict_equals(self, other):
def dict_equals(self: EventBase, other: EventBase) -> bool:
me = encode_canonical_json(self.get_pdu_json())
them = encode_canonical_json(other.get_pdu_json())
return me == them


def patch__eq__(cls):
def patch__eq__(cls: object) -> Callable[[], None]:
eq = getattr(cls, "__eq__", None)
cls.__eq__ = dict_equals
cls.__eq__ = dict_equals # type: ignore[assignment]

def unpatch() -> None:
if eq is not None:
cls.__eq__ = eq
cls.__eq__ = eq # type: ignore[assignment]

return unpatch

Expand All @@ -62,14 +67,14 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):

STORE_TYPE = EventsWorkerStore

def setUp(self):
def setUp(self) -> None:
# Patch up the equality operator for events so that we can check
# whether lists of events match using assertEqual
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
return super().setUp()
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
super().setUp()

def prepare(self, *args, **kwargs):
super().prepare(*args, **kwargs)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)

self.get_success(
self.master_store.store_room(
Expand Down Expand Up @@ -163,7 +168,7 @@ def test_invites(self) -> None:
)

@parameterized.expand([(True,), (False,)])
def test_push_actions_for_user(self, send_receipt: bool):
def test_push_actions_for_user(self, send_receipt: bool) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
self.persist(
Expand Down Expand Up @@ -243,7 +248,9 @@ def test_get_rooms_for_user_with_stream_ordering(self) -> None:
{GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
)

def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self) -> None:
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(
self,
) -> None:
"""Check that current_state invalidation happens correctly with multiple events
in the persistence batch.

Expand Down Expand Up @@ -283,11 +290,7 @@ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self)
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
msg, msgctx = self.build_event()
self.get_success(
self._storage_controllers.persistence.persist_events(
[(j2, j2ctx), (msg, msgctx)]
)
)
self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)]))
self.replicate()
assert j2.internal_metadata.stream_ordering is not None

Expand Down Expand Up @@ -339,7 +342,7 @@ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self)

event_id = 0

def persist(self, backfill=False, **kwargs) -> FrozenEvent:
def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase:
"""
Returns:
The event that was persisted.
Expand All @@ -348,32 +351,28 @@ def persist(self, backfill=False, **kwargs) -> FrozenEvent:

if backfill:
self.get_success(
self._storage_controllers.persistence.persist_events(
[(event, context)], backfilled=True
)
self.persistance.persist_events([(event, context)], backfilled=True)
)
else:
self.get_success(
self._storage_controllers.persistence.persist_event(event, context)
)
self.get_success(self.persistance.persist_event(event, context))

return event

def build_event(
self,
sender=USER_ID,
room_id=ROOM_ID,
type="m.room.message",
key=None,
sender: str = USER_ID,
room_id: str = ROOM_ID,
type: str = "m.room.message",
key: Optional[str] = None,
internal: Optional[dict] = None,
depth=None,
depth: Optional[int] = None,
prev_events: Optional[list] = None,
clokep marked this conversation as resolved.
Show resolved Hide resolved
auth_events: Optional[list] = None,
prev_state: Optional[list] = None,
redacts=None,
auth_events: Optional[List[str]] = None,
prev_state: Optional[List[str]] = None,
redacts: Optional[str] = None,
push_actions: Iterable = frozenset(),
**content,
):
**content: object,
) -> Tuple[EventBase, EventContext]:
prev_events = prev_events or []
auth_events = auth_events or []
prev_state = prev_state or []
Expand Down
18 changes: 11 additions & 7 deletions tests/replication/tcp/streams/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional
from typing import Any, List, Optional

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
Expand All @@ -25,6 +27,8 @@
)
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock

from tests.replication._base import BaseStreamTestCase
from tests.test_utils.event_injection import inject_event, inject_member_event
Expand All @@ -37,7 +41,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
room.register_servlets,
]

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
Expand All @@ -47,7 +51,7 @@ def prepare(self, reactor, clock, hs):
self.room_id = self.helper.create_room_as(tok=self.user_tok)
self.test_handler.received_rdata_rows.clear()

def test_update_function_event_row_limit(self):
def test_update_function_event_row_limit(self) -> None:
"""Test replication with many non-state events

Checks that all events are correctly replicated when there are lots of
Expand Down Expand Up @@ -102,7 +106,7 @@ def test_update_function_event_row_limit(self):

self.assertEqual([], received_rows)

def test_update_function_huge_state_change(self):
def test_update_function_huge_state_change(self) -> None:
"""Test replication with many state events

Ensures that all events are correctly replicated when there are lots of
Expand Down Expand Up @@ -256,7 +260,7 @@ def test_update_function_huge_state_change(self):
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)

def test_update_function_state_row_limit(self):
def test_update_function_state_row_limit(self) -> None:
"""Test replication with many state events over several stream ids."""

# we want to generate lots of state changes, but for this test, we want to
Expand Down Expand Up @@ -376,7 +380,7 @@ def test_update_function_state_row_limit(self):

self.assertEqual([], received_rows)

def test_backwards_stream_id(self):
def test_backwards_stream_id(self) -> None:
"""
Test that RDATA that comes after the current position should be discarded.
"""
Expand Down Expand Up @@ -437,7 +441,7 @@ def test_backwards_stream_id(self):
event_count = 0

def _inject_test_event(
self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs: Any
) -> EventBase:
if sender is None:
sender = self.user_id
Expand Down
2 changes: 1 addition & 1 deletion tests/replication/tcp/streams/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


class TypingStreamTestCase(BaseStreamTestCase):
def _build_replication_data_handler(self):
def _build_replication_data_handler(self) -> Mock:
return Mock(wraps=super()._build_replication_data_handler())

def test_typing(self) -> None:
Expand Down
6 changes: 4 additions & 2 deletions tests/replication/tcp/test_remote_server_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@

from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import IProtocol
from twisted.test.proto_helpers import StringTransport
from twisted.test.proto_helpers import MemoryReactor, StringTransport

from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from synapse.util import Clock

from tests.unittest import HomeserverTestCase


class RemoteServerUpTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.factory = ReplicationStreamProtocolFactory(hs)

def _make_client(self) -> Tuple[IProtocol, StringTransport]:
Expand Down
6 changes: 5 additions & 1 deletion tests/replication/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
# limitations under the License.
import logging

from twisted.test.proto_helpers import MemoryReactor

from synapse.rest.client import register
from synapse.server import HomeServer
from synapse.util import Clock

from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, make_request
Expand All @@ -27,7 +31,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):

servlets = [register.register_servlets]

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
# This isn't a real configuration option but is used to provide the main
# homeserver and worker homeserver different options.
Expand Down
10 changes: 6 additions & 4 deletions tests/replication/test_federation_ack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@

from unittest import mock

from twisted.test.proto_helpers import MemoryReactor

from synapse.app.generic_worker import GenericWorkerServer
from synapse.replication.tcp.commands import FederationAckCommand
from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams.federation import FederationStream
from synapse.server import HomeServer
from synapse.util import Clock

from tests.unittest import HomeserverTestCase

Expand All @@ -30,10 +34,8 @@ def default_config(self) -> dict:
config["federation_sender_instances"] = ["federation_sender1"]
return config

def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)

return hs
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)

def test_federation_ack_sent(self) -> None:
"""A FEDERATION_ACK should be sent back after each RDATA federation
Expand Down
4 changes: 3 additions & 1 deletion tests/replication/test_federation_sender_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ def test_send_typing_sharded(self) -> None:
self.assertTrue(sent_on_1)
self.assertTrue(sent_on_2)

def create_room_with_remote_server(self, user, token, remote_server="other_server"):
def create_room_with_remote_server(
self, user: str, token: str, remote_server: str = "other_server"
) -> str:
room = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastores().main
federation = self.hs.get_federation_event_handler()
Expand Down
Loading