Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to tests.replication #14987

Merged
merged 6 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14987.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ disallow_untyped_defs = True
[mypy-tests.push.*]
disallow_untyped_defs = True

[mypy-tests.replication.*]
disallow_untyped_defs = True

[mypy-tests.rest.*]
disallow_untyped_defs = True

Expand Down
70 changes: 40 additions & 30 deletions tests/replication/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from typing import Any, Dict, List, Optional, Set, Tuple

from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol
from twisted.internet.protocol import Protocol, connectionDone
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource

from synapse.app.generic_worker import GenericWorkerServer
Expand All @@ -30,6 +32,7 @@
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from synapse.util import Clock

from tests import unittest
from tests.server import FakeTransport
Expand All @@ -51,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
Expand Down Expand Up @@ -92,8 +95,8 @@ def prepare(self, reactor, clock, hs):
repl_handler,
)

self._client_transport = None
self._server_transport = None
self._client_transport: Optional[FakeTransport] = None
self._server_transport: Optional[FakeTransport] = None

def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
Expand All @@ -107,10 +110,10 @@ def _get_worker_hs_config(self) -> dict:
config["worker_replication_http_port"] = "8765"
return config

def _build_replication_data_handler(self):
def _build_replication_data_handler(self) -> "TestReplicationDataHandler":
return TestReplicationDataHandler(self.worker_hs)

def reconnect(self):
def reconnect(self) -> None:
if self._client_transport:
self.client.close()

Expand All @@ -123,7 +126,7 @@ def reconnect(self):
self._server_transport = FakeTransport(self.client, self.reactor)
self.server.makeConnection(self._server_transport)

def disconnect(self):
def disconnect(self) -> None:
if self._client_transport:
self._client_transport = None
self.client.close()
Expand All @@ -132,7 +135,7 @@ def disconnect(self):
self._server_transport = None
self.server.close()

def replicate(self):
def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
Expand Down Expand Up @@ -168,7 +171,7 @@ def handle_http_replication_attempt(self) -> SynapseRequest:
requests: List[SynapseRequest] = []
real_request_factory = channel.requestFactory

def request_factory(*args, **kwargs):
def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest:
request = real_request_factory(*args, **kwargs)
requests.append(request)
return request
Expand Down Expand Up @@ -202,7 +205,7 @@ def request_factory(*args, **kwargs):

def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
):
) -> None:
"""Asserts that the given request is a HTTP replication request for
fetching updates for given stream.
"""
Expand Down Expand Up @@ -244,7 +247,7 @@ def default_config(self) -> Dict[str, Any]:
base["redis"] = {"enabled": True}
return base

def setUp(self):
def setUp(self) -> None:
super().setUp()

# build a replication server
Expand Down Expand Up @@ -287,7 +290,7 @@ def setUp(self):
lambda: self._handle_http_replication_attempt(self.hs, 8765),
)

def create_test_resource(self):
def create_test_resource(self) -> ReplicationRestResource:
"""Overrides `HomeserverTestCase.create_test_resource`."""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
Expand All @@ -301,7 +304,7 @@ def create_test_resource(self):
return resource

def make_worker_hs(
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation
stream to the master HS.
Expand Down Expand Up @@ -385,14 +388,14 @@ def _get_worker_hs_config(self) -> dict:
config["worker_replication_http_port"] = "8765"
return config

def replicate(self):
def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump()

def _handle_http_replication_attempt(self, hs, repl_port):
def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None:
"""Handles a connection attempt to the given HS replication HTTP
listener on the given port.
"""
Expand Down Expand Up @@ -429,7 +432,7 @@ def _handle_http_replication_attempt(self, hs, repl_port):
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.

def connect_any_redis_attempts(self):
def connect_any_redis_attempts(self) -> None:
"""If redis is enabled we need to deal with workers connecting to a
redis server. We don't want to use a real Redis server so we use a
fake one.
Expand All @@ -440,8 +443,11 @@ def connect_any_redis_attempts(self):
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)

client_protocol = client_factory.buildProtocol(None)
server_protocol = self._redis_server.buildProtocol(None)
client_address = IPv4Address("TCP", "127.0.0.1", 6379)
client_protocol = client_factory.buildProtocol(client_address)

server_address = IPv4Address("TCP", host, port)
server_protocol = self._redis_server.buildProtocol(server_address)
Comment on lines -443 to +450
Copy link
Contributor

@DMRobertson DMRobertson Feb 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went for a different approach in #14988 (comment) of passing in a dummy address. I should probably just go for something simpler like this.


client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
Expand All @@ -463,7 +469,9 @@ def __init__(self, hs: HomeServer):
# list of received (stream_name, token, row) tuples
self.received_rdata_rows: List[Tuple[str, int, Any]] = []

async def on_rdata(self, stream_name, instance_name, token, rows):
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
) -> None:
await super().on_rdata(stream_name, instance_name, token, rows)
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))
Expand All @@ -472,28 +480,30 @@ async def on_rdata(self, stream_name, instance_name, token, rows):
class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""

def __init__(self):
def __init__(self) -> None:
self._subscribers_by_channel: Dict[
bytes, Set["FakeRedisPubSubProtocol"]
] = defaultdict(set)

def add_subscriber(self, conn, channel: bytes):
def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
"""A connection has called SUBSCRIBE"""
self._subscribers_by_channel[channel].add(conn)

def remove_subscriber(self, conn):
def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None:
"""A connection has lost connection"""
for subscribers in self._subscribers_by_channel.values():
subscribers.discard(conn)

def publish(self, conn, channel: bytes, msg) -> int:
def publish(
self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object
) -> int:
"""A connection want to publish a message to subscribers."""
for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg])

return len(self._subscribers_by_channel)

def buildProtocol(self, addr):
def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol":
return FakeRedisPubSubProtocol(self)


Expand All @@ -506,7 +516,7 @@ def __init__(self, server: FakeRedisPubSubServer):
self._server = server
self._reader = hiredis.Reader()

def dataReceived(self, data):
def dataReceived(self, data: bytes) -> None:
self._reader.feed(data)

# We might get multiple messages in one packet.
Expand All @@ -523,7 +533,7 @@ def dataReceived(self, data):

self.handle_command(msg[0], *msg[1:])

def handle_command(self, command, *args):
def handle_command(self, command: bytes, *args: bytes) -> None:
"""Received a Redis command from the client."""

# We currently only support pub/sub.
Expand All @@ -548,9 +558,9 @@ def handle_command(self, command, *args):
self.send("PONG")

else:
raise Exception(f"Unknown command: {command}")
raise Exception(f"Unknown command: {command!r}")

def send(self, msg):
def send(self, msg: object) -> None:
"""Send a message back to the client."""
assert self.transport is not None

Expand All @@ -559,7 +569,7 @@ def send(self, msg):
self.transport.write(raw)
self.transport.flush()

def encode(self, obj):
def encode(self, obj: object) -> str:
"""Encode an object to its Redis format.

Supports: strings/bytes, integers and list/tuples.
Expand All @@ -581,5 +591,5 @@ def encode(self, obj):

raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)

def connectionLost(self, reason):
def connectionLost(self, reason: Failure = connectionDone) -> None:
self._server.remove_subscriber(self)
2 changes: 1 addition & 1 deletion tests/replication/http/test__base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def _handle_request( # type: ignore[override]
class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
"""Tests for `ReplicationEndpoint` cancellation."""

def create_test_resource(self):
def create_test_resource(self) -> JsonResource:
"""Overrides `HomeserverTestCase.create_test_resource`."""
resource = JsonResource(self.hs)

Expand Down
25 changes: 16 additions & 9 deletions tests/replication/slave/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Iterable, Optional
from unittest.mock import Mock

from tests.replication._base import BaseStreamTestCase
from twisted.test.proto_helpers import MemoryReactor

from synapse.server import HomeServer
from synapse.util import Clock

class BaseSlavedStoreTestCase(BaseStreamTestCase):
def make_homeserver(self, reactor, clock):
from tests.replication._base import BaseStreamTestCase

hs = self.setup_test_homeserver(federation_client=Mock())

return hs
class BaseSlavedStoreTestCase(BaseStreamTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=Mock())

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)

self.reconnect()

self.master_store = hs.get_datastores().main
self.slaved_store = self.worker_hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
persistence = hs.get_storage_controllers().persistence
assert persistence is not None
self.persistance = persistence

def replicate(self):
def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump(0.1)

def check(self, method, args, expected_result=None):
def check(
self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
) -> None:
master_result = self.get_success(getattr(self.master_store, method)(*args))
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
if expected_result is not None:
Expand Down
Loading