Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to some tests files #12240

Merged
merged 2 commits into from
Mar 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12240.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to tests files.
4 changes: 0 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ exclude = (?x)
|tests/federation/test_federation_server.py
|tests/federation/transport/test_knocking.py
|tests/federation/transport/test_server.py
|tests/handlers/test_cas.py
|tests/handlers/test_federation.py
|tests/handlers/test_presence.py
|tests/handlers/test_typing.py
|tests/http/federation/test_matrix_federation_agent.py
|tests/http/federation/test_srv_resolver.py
Expand All @@ -80,7 +77,6 @@ exclude = (?x)
|tests/logging/test_terse_json.py
|tests/module_api/test_api.py
|tests/push/test_email.py
|tests/push/test_http.py
|tests/push/test_presentable_names.py
|tests/push/test_push_rule_evaluator.py
|tests/rest/client/test_transactions.py
Expand Down
19 changes: 12 additions & 7 deletions tests/handlers/test_cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
from unittest.mock import Mock

from twisted.test.proto_helpers import MemoryReactor

from synapse.handlers.cas import CasResponse
from synapse.server import HomeServer
from synapse.util import Clock

from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
Expand All @@ -24,7 +29,7 @@


class CasHandlerTestCase(HomeserverTestCase):
def default_config(self):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
cas_config = {
Expand All @@ -40,7 +45,7 @@ def default_config(self):

return config

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver()

self.handler = hs.get_cas_handler()
Expand All @@ -51,7 +56,7 @@ def make_homeserver(self, reactor, clock):

return hs

def test_map_cas_user_to_user(self):
def test_map_cas_user_to_user(self) -> None:
"""Ensure that mapping the CAS user returned from a provider to an MXID works properly."""

# stub out the auth handler
Expand All @@ -75,7 +80,7 @@ def test_map_cas_user_to_user(self):
auth_provider_session_id=None,
)

def test_map_cas_user_to_existing_user(self):
def test_map_cas_user_to_existing_user(self) -> None:
"""Existing users can log in with CAS account."""
store = self.hs.get_datastores().main
self.get_success(
Expand Down Expand Up @@ -119,7 +124,7 @@ def test_map_cas_user_to_existing_user(self):
auth_provider_session_id=None,
)

def test_map_cas_user_to_invalid_localpart(self):
def test_map_cas_user_to_invalid_localpart(self) -> None:
"""CAS automaps invalid characters to base-64 encoding."""

# stub out the auth handler
Expand Down Expand Up @@ -150,7 +155,7 @@ def test_map_cas_user_to_invalid_localpart(self):
}
}
)
def test_required_attributes(self):
def test_required_attributes(self) -> None:
"""The required attributes must be met from the CAS response."""

# stub out the auth handler
Expand All @@ -166,7 +171,7 @@ def test_required_attributes(self):
auth_handler.complete_sso_login.assert_not_called()

# The response doesn't have any department.
cas_response = CasResponse("test_user", {"userGroup": "staff"})
cas_response = CasResponse("test_user", {"userGroup": ["staff"]})
request.reset_mock()
self.get_success(
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
Expand Down
36 changes: 21 additions & 15 deletions tests/handlers/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List
from typing import List, cast
from unittest import TestCase

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions
Expand All @@ -23,7 +25,9 @@
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import create_requester
from synapse.util import Clock
from synapse.util.stringutils import random_string

from tests import unittest
Expand All @@ -42,15 +46,15 @@ class FederationTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main
self.state_store = hs.get_storage().state
self._event_auth_handler = hs.get_event_auth_handler()
return hs

def test_exchange_revoked_invite(self):
def test_exchange_revoked_invite(self) -> None:
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")

Expand Down Expand Up @@ -96,7 +100,7 @@ def test_exchange_revoked_invite(self):
self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
self.assertEqual(failure.msg, "You are not invited to this room.")

def test_rejected_message_event_state(self):
def test_rejected_message_event_state(self) -> None:
"""
Check that we store the state group correctly for rejected non-state events.

Expand Down Expand Up @@ -126,7 +130,7 @@ def test_rejected_message_event_state(self):
"content": {},
"room_id": room_id,
"sender": "@yetanotheruser:" + OTHER_SERVER,
"depth": join_event["depth"] + 1,
"depth": cast(int, join_event["depth"]) + 1,
"prev_events": [join_event.event_id],
"auth_events": [],
"origin_server_ts": self.clock.time_msec(),
Expand All @@ -149,7 +153,7 @@ def test_rejected_message_event_state(self):

self.assertEqual(sg, sg2)

def test_rejected_state_event_state(self):
def test_rejected_state_event_state(self) -> None:
"""
Check that we store the state group correctly for rejected state events.

Expand Down Expand Up @@ -180,7 +184,7 @@ def test_rejected_state_event_state(self):
"content": {},
"room_id": room_id,
"sender": "@yetanotheruser:" + OTHER_SERVER,
"depth": join_event["depth"] + 1,
"depth": cast(int, join_event["depth"]) + 1,
"prev_events": [join_event.event_id],
"auth_events": [],
"origin_server_ts": self.clock.time_msec(),
Expand All @@ -203,7 +207,7 @@ def test_rejected_state_event_state(self):

self.assertEqual(sg, sg2)

def test_backfill_with_many_backward_extremities(self):
def test_backfill_with_many_backward_extremities(self) -> None:
"""
Check that we can backfill with many backward extremities.
The goal is to make sure that when we only use a portion
Expand Down Expand Up @@ -262,7 +266,7 @@ def test_backfill_with_many_backward_extremities(self):
)
self.get_success(d)

def test_backfill_floating_outlier_membership_auth(self):
def test_backfill_floating_outlier_membership_auth(self) -> None:
"""
As the local homeserver, check that we can properly process a federated
event from the OTHER_SERVER with auth_events that include a floating
Expand Down Expand Up @@ -377,7 +381,7 @@ async def get_event_auth(
for ae in auth_events
]

self.handler.federation_client.get_event_auth = get_event_auth
self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment]

with LoggingContext("receive_pdu"):
# Fake the OTHER_SERVER federating the message event over to our local homeserver
Expand All @@ -397,7 +401,7 @@ async def get_event_auth(
@unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
)
def test_invite_by_user_ratelimit(self):
def test_invite_by_user_ratelimit(self) -> None:
"""Tests that invites from federation to a particular user are
actually rate-limited.
"""
Expand Down Expand Up @@ -446,7 +450,9 @@ def create_invite():
exc=LimitExceededError,
)

def _build_and_send_join_event(self, other_server, other_user, room_id):
def _build_and_send_join_event(
self, other_server: str, other_user: str, room_id: str
) -> EventBase:
join_event = self.get_success(
self.handler.on_make_join_request(other_server, room_id, other_user)
)
Expand All @@ -469,7 +475,7 @@ def _build_and_send_join_event(self, other_server, other_user, room_id):


class EventFromPduTestCase(TestCase):
def test_valid_json(self):
def test_valid_json(self) -> None:
"""Valid JSON should be turned into an event."""
ev = event_from_pdu_json(
{
Expand All @@ -487,7 +493,7 @@ def test_valid_json(self):

self.assertIsInstance(ev, EventBase)

def test_invalid_numbers(self):
def test_invalid_numbers(self) -> None:
"""Invalid values for an integer should be rejected, all floats should be rejected."""
for value in [
-(2 ** 53),
Expand All @@ -512,7 +518,7 @@ def test_invalid_numbers(self):
RoomVersions.V6,
)

def test_invalid_nested(self):
def test_invalid_nested(self) -> None:
"""List and dictionaries are recursively searched."""
with self.assertRaises(SynapseError):
event_from_pdu_json(
Expand Down
13 changes: 9 additions & 4 deletions tests/handlers/test_presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,11 @@ def test_persisting_presence_updates(self):

# Extract presence update user ID and state information into lists of tuples
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
presence_states = [(ps.user_id, ps.state) for ps in presence_states]
presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]

# Compare what we put into the storage with what we got out.
# They should be identical.
self.assertEqual(presence_states, db_presence_states)
self.assertEqual(presence_states_compare, db_presence_states)


class PresenceTimeoutTestCase(unittest.TestCase):
Expand All @@ -357,6 +357,7 @@ def test_idle_timer(self):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)

self.assertIsNotNone(new_state)
assert new_state is not None
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
self.assertEqual(new_state.status_msg, status_msg)

Expand All @@ -380,6 +381,7 @@ def test_busy_no_idle(self):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)

self.assertIsNotNone(new_state)
assert new_state is not None
self.assertEqual(new_state.state, PresenceState.BUSY)
self.assertEqual(new_state.status_msg, status_msg)

Expand All @@ -399,6 +401,7 @@ def test_sync_timeout(self):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)

self.assertIsNotNone(new_state)
assert new_state is not None
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)

Expand All @@ -420,6 +423,7 @@ def test_sync_online(self):
)

self.assertIsNotNone(new_state)
assert new_state is not None
self.assertEqual(new_state.state, PresenceState.ONLINE)
self.assertEqual(new_state.status_msg, status_msg)

Expand Down Expand Up @@ -477,6 +481,7 @@ def test_federation_timeout(self):
)

self.assertIsNotNone(new_state)
assert new_state is not None
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)

Expand Down Expand Up @@ -653,13 +658,13 @@ def test_set_presence_with_status_msg_none(self):
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)

def _set_presencestate_with_status_msg(
self, user_id: str, state: PresenceState, status_msg: Optional[str]
self, user_id: str, state: str, status_msg: Optional[str]
):
"""Set a PresenceState and status_msg and check the result.

Args:
user_id: User for that the status is to be set.
PresenceState: The new PresenceState.
state: The new PresenceState.
status_msg: Status message that is to be set.
"""
self.get_success(
Expand Down
Loading