Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Make synapse._scripts pass typechecks (#12421)
Browse files Browse the repository at this point in the history
  • Loading branch information
David Robertson authored Apr 8, 2022
1 parent dd5cc37 commit 0cd182f
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 43 deletions.
1 change: 1 addition & 0 deletions changelog.d/12421.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make `synapse._scripts` pass type checks.
5 changes: 0 additions & 5 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ exclude = (?x)
|scripts-dev/federation_client.py
|scripts-dev/release.py

|synapse/_scripts/export_signing_key.py
|synapse/_scripts/move_remote_media_to_new_store.py
|synapse/_scripts/synapse_port_db.py
|synapse/_scripts/update_synapse_database.py

|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
Expand Down
6 changes: 3 additions & 3 deletions synapse/_scripts/export_signing_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import time
from typing import Optional

import nacl.signing
from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
from signedjson.types import VerifyKey


def exit(status: int = 0, message: Optional[str] = None):
Expand All @@ -27,7 +27,7 @@ def exit(status: int = 0, message: Optional[str] = None):
sys.exit(status)


def format_plain(public_key: nacl.signing.VerifyKey):
def format_plain(public_key: VerifyKey):
print(
"%s:%s %s"
% (
Expand All @@ -38,7 +38,7 @@ def format_plain(public_key: nacl.signing.VerifyKey):
)


def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int):
def format_for_config(public_key: VerifyKey, expiry_ts: int):
print(
' "%s:%s": { key: "%s", expired_ts: %i }'
% (
Expand Down
9 changes: 4 additions & 5 deletions synapse/_scripts/move_remote_media_to_new_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,9 @@ def mkdir_and_move(original_file, dest_file):
parser.add_argument("dest_repo", help="Path to source content repo")
args = parser.parse_args()

logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
}
logging.basicConfig(**logging_config)
logging.basicConfig(
level=logging.DEBUG if args.v else logging.INFO,
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
)

main(args.src_repo, args.dest_repo)
53 changes: 31 additions & 22 deletions synapse/_scripts/synapse_port_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
import sys
import time
import traceback
from typing import Dict, Iterable, Optional, Set
from types import TracebackType
from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast

import yaml
from matrix_common.versionstring import get_distribution_version_string

from twisted.internet import defer, reactor
from twisted.internet import defer, reactor as reactor_

from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
Expand Down Expand Up @@ -66,8 +67,12 @@
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.types import ISynapseReactor
from synapse.util import Clock

# Cast safety: Twisted does some naughty magic which replaces the
# twisted.internet.reactor module with a Reactor instance at runtime.
reactor = cast(ISynapseReactor, reactor_)
logger = logging.getLogger("synapse_port_db")


Expand Down Expand Up @@ -159,12 +164,14 @@

# Error returned by the run function. Used at the top-level part of the script to
# handle errors and return codes.
end_error = None # type: Optional[str]
end_error: Optional[str] = None
# The exec_info for the error, if any. If error is defined but not exec_info the script
# will show only the error message without the stacktrace, if exec_info is defined but
# not the error then the script will show nothing outside of what's printed in the run
# function. If both are defined, the script will print both the error and the stacktrace.
end_error_exec_info = None
end_error_exec_info: Optional[
Tuple[Type[BaseException], BaseException, TracebackType]
] = None


class Store(
Expand Down Expand Up @@ -236,9 +243,12 @@ def get_instance_name(self):
return "master"


class Porter(object):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
class Porter:
def __init__(self, sqlite_config, progress, batch_size, hs_config):
self.sqlite_config = sqlite_config
self.progress = progress
self.batch_size = batch_size
self.hs_config = hs_config

async def setup_table(self, table):
if table in APPEND_ONLY_TABLES:
Expand Down Expand Up @@ -323,7 +333,7 @@ def _get_constraints(txn):
"""
txn.execute(sql)

results = {}
results: Dict[str, Set[str]] = {}
for table, foreign_table in txn:
results.setdefault(table, set()).add(foreign_table)
return results
Expand Down Expand Up @@ -540,7 +550,8 @@ def build_db_store(
db_conn, allow_outdated_version=allow_outdated_version
)
prepare_database(db_conn, engine, config=self.hs_config)
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs)
# Type safety: ignore that we're using Mock homeservers here.
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) # type: ignore[arg-type]
db_conn.commit()

return store
Expand Down Expand Up @@ -724,7 +735,9 @@ def alter_table(txn):
except Exception as e:
global end_error_exec_info
end_error = str(e)
end_error_exec_info = sys.exc_info()
# Type safety: we're in an exception handler, so the exc_info() tuple
# will not be (None, None, None).
end_error_exec_info = sys.exc_info() # type: ignore[assignment]
logger.exception("")
finally:
reactor.stop()
Expand Down Expand Up @@ -1023,7 +1036,7 @@ def __init__(self, stdscr):
curses.init_pair(1, curses.COLOR_RED, -1)
curses.init_pair(2, curses.COLOR_GREEN, -1)

self.last_update = 0
self.last_update = 0.0

self.finished = False

Expand Down Expand Up @@ -1082,8 +1095,7 @@ def render(self, force=False):
left_margin = 5
middle_space = 1

items = self.tables.items()
items = sorted(items, key=lambda i: (i[1]["perc"], i[0]))
items = sorted(self.tables.items(), key=lambda i: (i[1]["perc"], i[0]))

for i, (table, data) in enumerate(items):
if i + 2 >= rows:
Expand Down Expand Up @@ -1179,15 +1191,11 @@ def main():

args = parser.parse_args()

logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
}

if args.curses:
logging_config["filename"] = "port-synapse.log"

logging.basicConfig(**logging_config)
logging.basicConfig(
level=logging.DEBUG if args.v else logging.INFO,
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
filename="port-synapse.log" if args.curses else None,
)

sqlite_config = {
"name": "sqlite3",
Expand Down Expand Up @@ -1218,6 +1226,7 @@ def main():
config.parse_config_dict(hs_config, "", "")

def start(stdscr=None):
progress: Progress
if stdscr:
progress = CursesProgress(stdscr)
else:
Expand Down
19 changes: 11 additions & 8 deletions synapse/_scripts/update_synapse_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,27 @@
import argparse
import logging
import sys
from typing import cast

import yaml
from matrix_common.versionstring import get_distribution_version_string

from twisted.internet import defer, reactor
from twisted.internet import defer, reactor as reactor_

from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.types import ISynapseReactor

# Cast safety: Twisted does some naughty magic which replaces the
# twisted.internet.reactor module with a Reactor instance at runtime.
reactor = cast(ISynapseReactor, reactor_)
logger = logging.getLogger("update_database")


class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore
DATASTORE_CLASS = DataStore # type: ignore [assignment]

def __init__(self, config, **kwargs):
super(MockHomeserver, self).__init__(
Expand Down Expand Up @@ -85,12 +90,10 @@ def main():

args = parser.parse_args()

logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
}

logging.basicConfig(**logging_config)
logging.basicConfig(
level=logging.DEBUG if args.v else logging.INFO,
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
)

# Load, process and sanity-check the config.
hs_config = yaml.safe_load(args.database_config)
Expand Down

0 comments on commit 0cd182f

Please sign in to comment.