Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Another batch of type annotations #12726

Merged
merged 11 commits into from
May 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12726.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type annotations to increase the number of modules passing `disallow-untyped-defs`.
21 changes: 21 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,30 @@ disallow_untyped_defs = True
[mypy-synapse.http.federation.*]
disallow_untyped_defs = True

[mypy-synapse.http.connectproxyclient]
disallow_untyped_defs = True

[mypy-synapse.http.proxyagent]
disallow_untyped_defs = True

[mypy-synapse.http.request_metrics]
disallow_untyped_defs = True

[mypy-synapse.http.server]
disallow_untyped_defs = True

[mypy-synapse.logging._remote]
disallow_untyped_defs = True

[mypy-synapse.logging.context]
disallow_untyped_defs = True

[mypy-synapse.logging.formatter]
disallow_untyped_defs = True

[mypy-synapse.logging.handlers]
disallow_untyped_defs = True

[mypy-synapse.metrics.*]
disallow_untyped_defs = True

Expand Down Expand Up @@ -166,6 +181,9 @@ disallow_untyped_defs = True
[mypy-synapse.state.*]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.background_updates]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.account_data]
disallow_untyped_defs = True

Expand Down Expand Up @@ -232,6 +250,9 @@ disallow_untyped_defs = True
[mypy-synapse.streams.*]
disallow_untyped_defs = True

[mypy-synapse.types]
disallow_untyped_defs = True

[mypy-synapse.util.*]
disallow_untyped_defs = True

Expand Down
29 changes: 11 additions & 18 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple

import attr
from canonicaljson import encode_canonical_json
Expand Down Expand Up @@ -1105,22 +1105,19 @@ async def _get_e2e_cross_signing_verify_key(
# can request over federation
raise NotFoundError("No %s key found for %s" % (key_type, user_id))

(
key,
key_id,
verify_key,
) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type)

if key is None:
cross_signing_keys = await self._retrieve_cross_signing_keys_for_remote_user(
user, key_type
)
if cross_signing_keys is None:
raise NotFoundError("No %s key found for %s" % (key_type, user_id))

return key, key_id, verify_key
return cross_signing_keys

async def _retrieve_cross_signing_keys_for_remote_user(
self,
user: UserID,
desired_key_type: str,
) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]:
"""Queries cross-signing keys for a remote user and saves them to the database
Only the key specified by `key_type` will be returned, while all retrieved keys
Expand All @@ -1146,12 +1143,10 @@ async def _retrieve_cross_signing_keys_for_remote_user(
type(e),
e,
)
return None, None, None
return None

# Process each of the retrieved cross-signing keys
desired_key = None
desired_key_id = None
desired_verify_key = None
desired_key_data = None
retrieved_device_ids = []
for key_type in ["master", "self_signing"]:
key_content = remote_result.get(key_type + "_key")
Expand Down Expand Up @@ -1196,9 +1191,7 @@ async def _retrieve_cross_signing_keys_for_remote_user(

# If this is the desired key type, save it and its ID/VerifyKey
if key_type == desired_key_type:
desired_key = key_content
desired_verify_key = verify_key
desired_key_id = key_id
desired_key_data = key_content, key_id, verify_key

# At the same time, store this key in the db for subsequent queries
await self.store.set_e2e_cross_signing_key(
Expand All @@ -1212,7 +1205,7 @@ async def _retrieve_cross_signing_keys_for_remote_user(
user.to_string(), retrieved_device_ids
)

return desired_key, desired_key_id, desired_verify_key
return desired_key_data


def _check_cross_signing_key(
Expand Down
39 changes: 23 additions & 16 deletions synapse/http/connectproxyclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,22 @@

import base64
import logging
from typing import Optional
from typing import Optional, Union

import attr
from zope.interface import implementer

from twisted.internet import defer, protocol
from twisted.internet.error import ConnectError
from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
from twisted.internet.interfaces import (
IAddress,
IConnector,
IProtocol,
IReactorCore,
IStreamClientEndpoint,
)
from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
from twisted.python.failure import Failure
from twisted.web import http

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -81,14 +88,14 @@ def __init__(
self._port = port
self._proxy_creds = proxy_creds

def __repr__(self):
def __repr__(self) -> str:
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)

# Mypy encounters a false positive here: it complains that ClientFactory
# is incompatible with IProtocolFactory. But ClientFactory inherits from
# Factory, which implements IProtocolFactory. So I think this is a bug
# in mypy-zope.
def connect(self, protocolFactory: ClientFactory): # type: ignore[override]
def connect(self, protocolFactory: ClientFactory) -> "defer.Deferred[IProtocol]": # type: ignore[override]
Comment on lines 94 to +98
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f = HTTPProxiedClientFactory(
self._host, self._port, protocolFactory, self._proxy_creds
)
Expand Down Expand Up @@ -125,10 +132,10 @@ def __init__(
self.proxy_creds = proxy_creds
self.on_connection: "defer.Deferred[None]" = defer.Deferred()

def startedConnecting(self, connector):
def startedConnecting(self, connector: IConnector) -> None:
return self.wrapped_factory.startedConnecting(connector)

def buildProtocol(self, addr):
def buildProtocol(self, addr: IAddress) -> "HTTPConnectProtocol":
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
if wrapped_protocol is None:
raise TypeError("buildProtocol produced None instead of a Protocol")
Expand All @@ -141,13 +148,13 @@ def buildProtocol(self, addr):
self.proxy_creds,
)

def clientConnectionFailed(self, connector, reason):
def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
logger.debug("Connection to proxy failed: %s", reason)
if not self.on_connection.called:
self.on_connection.errback(reason)
return self.wrapped_factory.clientConnectionFailed(connector, reason)

def clientConnectionLost(self, connector, reason):
def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
logger.debug("Connection to proxy lost: %s", reason)
if not self.on_connection.called:
self.on_connection.errback(reason)
Expand Down Expand Up @@ -191,10 +198,10 @@ def __init__(
)
self.http_setup_client.on_connected.addCallback(self.proxyConnected)

def connectionMade(self):
def connectionMade(self) -> None:
self.http_setup_client.makeConnection(self.transport)

def connectionLost(self, reason=connectionDone):
def connectionLost(self, reason: Failure = connectionDone) -> None:
if self.wrapped_protocol.connected:
self.wrapped_protocol.connectionLost(reason)

Expand All @@ -203,7 +210,7 @@ def connectionLost(self, reason=connectionDone):
if not self.connected_deferred.called:
self.connected_deferred.errback(reason)

def proxyConnected(self, _):
def proxyConnected(self, _: Union[None, "defer.Deferred[None]"]) -> None:
self.wrapped_protocol.makeConnection(self.transport)

self.connected_deferred.callback(self.wrapped_protocol)
Expand All @@ -213,7 +220,7 @@ def proxyConnected(self, _):
if buf:
self.wrapped_protocol.dataReceived(buf)

def dataReceived(self, data: bytes):
def dataReceived(self, data: bytes) -> None:
# if we've set up the HTTP protocol, we can send the data there
if self.wrapped_protocol.connected:
return self.wrapped_protocol.dataReceived(data)
Expand Down Expand Up @@ -243,7 +250,7 @@ def __init__(
self.proxy_creds = proxy_creds
self.on_connected: "defer.Deferred[None]" = defer.Deferred()

def connectionMade(self):
def connectionMade(self) -> None:
logger.debug("Connected to proxy, sending CONNECT")
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))

Expand All @@ -257,14 +264,14 @@ def connectionMade(self):

self.endHeaders()

def handleStatus(self, version: bytes, status: bytes, message: bytes):
def handleStatus(self, version: bytes, status: bytes, message: bytes) -> None:
logger.debug("Got Status: %s %s %s", status, message, version)
if status != b"200":
raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}")

def handleEndHeaders(self):
def handleEndHeaders(self) -> None:
logger.debug("End Headers")
self.on_connected.callback(None)

def handleResponse(self, body):
def handleResponse(self, body: bytes) -> None:
pass
2 changes: 1 addition & 1 deletion synapse/http/proxyagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def http_proxy_endpoint(
proxy: Optional[bytes],
reactor: IReactorCore,
tls_options_factory: Optional[IPolicyForHTTPS],
**kwargs,
**kwargs: object,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bit of a cop-out.

) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
"""Parses an http proxy setting and returns an endpoint for the proxy
Expand Down
20 changes: 12 additions & 8 deletions synapse/logging/_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
from twisted.internet.interfaces import (
IPushProducer,
IReactorTCP,
IStreamClientEndpoint,
)
from twisted.internet.protocol import Factory, Protocol
from twisted.internet.tcp import Connection
from twisted.python.failure import Failure
Expand Down Expand Up @@ -59,14 +63,14 @@ class LogProducer:
_buffer: Deque[logging.LogRecord]
_paused: bool = attr.ib(default=False, init=False)

def pauseProducing(self):
def pauseProducing(self) -> None:
self._paused = True

def stopProducing(self):
def stopProducing(self) -> None:
self._paused = True
self._buffer = deque()

def resumeProducing(self):
def resumeProducing(self) -> None:
# If we're already producing, nothing to do.
self._paused = False

Expand Down Expand Up @@ -102,8 +106,8 @@ def __init__(
host: str,
port: int,
maximum_buffer: int = 1000,
level=logging.NOTSET,
_reactor=None,
level: int = logging.NOTSET,
_reactor: Optional[IReactorTCP] = None,
):
super().__init__(level=level)
self.host = host
Expand All @@ -118,7 +122,7 @@ def __init__(
if _reactor is None:
from twisted.internet import reactor

_reactor = reactor
_reactor = reactor # type: ignore[assignment]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

twisted.internet.reactor is a module that, upon import, removes itself from sys.modules and calls twisted.internet.default.install(). This eventually reinserts twisted.internet.reactor back into sys.modules, but now it points to a reactor instance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No wonder mypy is confused.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other spots we do:

from twisted.internet import reactor as _reactor

reactor = cast(ISynapseReactor, _reactor)

🤷 (Could probably just put a cast here instead of an ignore to be consistent.)


try:
ip = ip_address(self.host)
Expand All @@ -139,7 +143,7 @@ def __init__(
self._stopping = False
self._connect()

def close(self):
def close(self) -> None:
self._stopping = True
self._service.stopService()

Expand Down
14 changes: 10 additions & 4 deletions synapse/logging/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import logging
import traceback
from io import StringIO
from types import TracebackType
from typing import Optional, Tuple, Type


class LogFormatter(logging.Formatter):
Expand All @@ -28,10 +30,14 @@ class LogFormatter(logging.Formatter):
where it was caught are logged).
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

def formatException(self, ei):
def formatException(
self,
ei: Tuple[
Optional[Type[BaseException]],
Optional[BaseException],
Optional[TracebackType],
],
) -> str:
sio = StringIO()
(typ, val, tb) = ei

Expand Down
4 changes: 2 additions & 2 deletions synapse/logging/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
)
self._flushing_thread.start()

def on_reactor_running():
def on_reactor_running() -> None:
self._reactor_started = True

reactor_to_use: IReactorCore
Expand All @@ -74,7 +74,7 @@ def shouldFlush(self, record: LogRecord) -> bool:
else:
return True

def _flush_periodically(self):
def _flush_periodically(self) -> None:
"""
Whilst this handler is active, flush the handler periodically.
"""
Expand Down
Loading