Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple type hints for aerich/ #341

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions aerich/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import List
from typing import TYPE_CHECKING, List, Optional, Type

from tortoise import Tortoise, generate_schema_for_client
from tortoise.exceptions import OperationalError
Expand All @@ -20,23 +20,26 @@
import_py_file,
)

if TYPE_CHECKING:
from aerich.inspectdb import Inspect # noqa:F401


class Command:
def __init__(
self,
tortoise_config: dict,
app: str = "models",
location: str = "./migrations",
):
) -> None:
self.tortoise_config = tortoise_config
self.app = app
self.location = location
Migrate.app = app

async def init(self):
async def init(self) -> None:
await Migrate.init(self.tortoise_config, self.app, self.location)

async def _upgrade(self, conn, version_file):
async def _upgrade(self, conn, version_file) -> None:
file_path = Path(Migrate.migrate_location, version_file)
m = import_py_file(file_path)
upgrade = getattr(m, "upgrade")
Expand All @@ -47,7 +50,7 @@ async def _upgrade(self, conn, version_file):
content=get_models_describe(self.app),
)

async def upgrade(self, run_in_transaction: bool = True):
async def upgrade(self, run_in_transaction: bool = True) -> List[str]:
migrated = []
for version_file in Migrate.get_all_version_files():
try:
Expand All @@ -65,8 +68,8 @@ async def upgrade(self, run_in_transaction: bool = True):
migrated.append(version_file)
return migrated

async def downgrade(self, version: int, delete: bool):
ret = []
async def downgrade(self, version: int, delete: bool) -> List[str]:
ret: List[str] = []
if version == -1:
specified_version = await Migrate.get_last_version()
else:
Expand All @@ -79,8 +82,8 @@ async def downgrade(self, version: int, delete: bool):
versions = [specified_version]
else:
versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk)
for version in versions:
file = version.version
for version_obj in versions:
file = version_obj.version
async with in_transaction(
get_app_connection_name(self.tortoise_config, self.app)
) as conn:
Expand All @@ -91,29 +94,29 @@ async def downgrade(self, version: int, delete: bool):
if not downgrade_sql.strip():
raise DowngradeError("No downgrade items found")
await conn.execute_script(downgrade_sql)
await version.delete()
await version_obj.delete()
if delete:
os.unlink(file_path)
ret.append(file)
return ret

async def heads(self):
async def heads(self) -> List[str]:
ret = []
versions = Migrate.get_all_version_files()
for version in versions:
if not await Aerich.exists(version=version, app=self.app):
ret.append(version)
return ret

async def history(self):
async def history(self) -> List[str]:
versions = Migrate.get_all_version_files()
return [version for version in versions]

async def inspectdb(self, tables: List[str] = None) -> str:
async def inspectdb(self, tables: Optional[List[str]] = None) -> str:
connection = get_app_connection(self.tortoise_config, self.app)
dialect = connection.schema_generator.DIALECT
if dialect == "mysql":
cls = InspectMySQL
cls: Type["Inspect"] = InspectMySQL
elif dialect == "postgres":
cls = InspectPostgres
elif dialect == "sqlite":
Expand All @@ -126,7 +129,7 @@ async def inspectdb(self, tables: List[str] = None) -> str:
async def migrate(self, name: str = "update", empty: bool = False) -> str:
return await Migrate.migrate(name, empty)

async def init_db(self, safe: bool):
async def init_db(self, safe: bool) -> None:
location = self.location
app = self.app
dirname = Path(location, app)
Expand Down
36 changes: 19 additions & 17 deletions aerich/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from functools import wraps
from pathlib import Path
from typing import List
from typing import Dict, List, cast

import click
import tomlkit
Expand All @@ -23,7 +23,7 @@

def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs) -> None:
loop = asyncio.get_event_loop()

# Close db connections at the end of all but the cli group function
Expand All @@ -48,7 +48,7 @@ def wrapper(*args, **kwargs):
@click.option("--app", required=False, help="Tortoise-ORM app name.")
@click.pass_context
@coro
async def cli(ctx: Context, config, app):
async def cli(ctx: Context, config, app) -> None:
ctx.ensure_object(dict)
ctx.obj["config_file"] = config

Expand All @@ -58,17 +58,19 @@ async def cli(ctx: Context, config, app):
if not config_path.exists():
raise UsageError("You must exec init first", ctx=ctx)
content = config_path.read_text()
doc = tomlkit.parse(content)
doc: dict = tomlkit.parse(content)
try:
tool = doc["tool"]["aerich"]
tool = cast(Dict[str, str], doc["tool"]["aerich"])
location = tool["location"]
tortoise_orm = tool["tortoise_orm"]
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
except NonExistentKey:
raise UsageError("You need run aerich init again when upgrade to 0.6.0+")
add_src_path(src_folder)
tortoise_config = get_tortoise_config(ctx, tortoise_orm)
app = app or list(tortoise_config.get("apps").keys())[0]
if not app:
apps_config = cast(dict, tortoise_config.get("apps"))
app = list(apps_config.keys())[0]
command = Command(tortoise_config=tortoise_config, app=app, location=location)
ctx.obj["command"] = command
if invoked_subcommand != "init-db":
Expand All @@ -82,7 +84,7 @@ async def cli(ctx: Context, config, app):
@click.option("--empty", default=False, is_flag=True, help="Generate empty migration file.")
@click.pass_context
@coro
async def migrate(ctx: Context, name):
async def migrate(ctx: Context, name) -> None:
command = ctx.obj["command"]
ret = await command.migrate(name)
if not ret:
Expand All @@ -100,7 +102,7 @@ async def migrate(ctx: Context, name):
)
@click.pass_context
@coro
async def upgrade(ctx: Context, in_transaction: bool):
async def upgrade(ctx: Context, in_transaction: bool) -> None:
command = ctx.obj["command"]
migrated = await command.upgrade(run_in_transaction=in_transaction)
if not migrated:
Expand Down Expand Up @@ -132,7 +134,7 @@ async def upgrade(ctx: Context, in_transaction: bool):
prompt="Downgrade is dangerous, which maybe lose your data, are you sure?",
)
@coro
async def downgrade(ctx: Context, version: int, delete: bool):
async def downgrade(ctx: Context, version: int, delete: bool) -> None:
command = ctx.obj["command"]
try:
files = await command.downgrade(version, delete)
Expand All @@ -145,7 +147,7 @@ async def downgrade(ctx: Context, version: int, delete: bool):
@cli.command(help="Show current available heads in migrate location.")
@click.pass_context
@coro
async def heads(ctx: Context):
async def heads(ctx: Context) -> None:
command = ctx.obj["command"]
head_list = await command.heads()
if not head_list:
Expand All @@ -157,7 +159,7 @@ async def heads(ctx: Context):
@cli.command(help="List all migrate items.")
@click.pass_context
@coro
async def history(ctx: Context):
async def history(ctx: Context) -> None:
command = ctx.obj["command"]
versions = await command.history()
if not versions:
Expand Down Expand Up @@ -188,7 +190,7 @@ async def history(ctx: Context):
)
@click.pass_context
@coro
async def init(ctx: Context, tortoise_orm, location, src_folder):
async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
config_file = ctx.obj["config_file"]

if os.path.isabs(src_folder):
Expand All @@ -203,9 +205,9 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
config_path = Path(config_file)
if config_path.exists():
content = config_path.read_text()
doc = tomlkit.parse(content)
else:
doc = tomlkit.parse("[tool.aerich]")
content = "[tool.aerich]"
doc: dict = tomlkit.parse(content)
table = tomlkit.table()
table["tortoise_orm"] = tortoise_orm
table["location"] = location
Expand All @@ -232,7 +234,7 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
)
@click.pass_context
@coro
async def init_db(ctx: Context, safe: bool):
async def init_db(ctx: Context, safe: bool) -> None:
command = ctx.obj["command"]
app = command.app
dirname = Path(command.location, app)
Expand All @@ -256,13 +258,13 @@ async def init_db(ctx: Context, safe: bool):
)
@click.pass_context
@coro
async def inspectdb(ctx: Context, table: List[str]):
async def inspectdb(ctx: Context, table: List[str]) -> None:
command = ctx.obj["command"]
ret = await command.inspectdb(table)
click.secho(ret)


def main():
def main() -> None:
cli()


Expand Down
9 changes: 5 additions & 4 deletions aerich/coder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import base64
import json
import pickle # nosec: B301,B403
from typing import Any, Union

from tortoise.indexes import Index


class JsonEncoder(json.JSONEncoder):
def default(self, obj):
def default(self, obj) -> Any:
if isinstance(obj, Index):
return {
"type": "index",
Expand All @@ -16,16 +17,16 @@ def default(self, obj):
return super().default(obj)


def object_hook(obj):
def object_hook(obj) -> Any:
_type = obj.get("type")
if not _type:
return obj
return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301


def encoder(obj: dict):
def encoder(obj: dict) -> str:
return json.dumps(obj, cls=JsonEncoder)


def decoder(obj: str):
def decoder(obj: Union[str, bytes]) -> Any:
return json.loads(obj, object_hook=object_hook)
13 changes: 7 additions & 6 deletions aerich/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import re
import sys
from pathlib import Path
from typing import Dict
from types import ModuleType
from typing import Dict, Optional

from click import BadOptionUsage, ClickException, Context
from tortoise import BaseDBAsyncClient, Tortoise
Expand Down Expand Up @@ -84,19 +85,19 @@ def get_models_describe(app: str) -> Dict:
:return:
"""
ret = {}
for model in Tortoise.apps.get(app).values():
for model in Tortoise.apps[app].values():
describe = model.describe()
ret[describe.get("name")] = describe
return ret


def is_default_function(string: str):
def is_default_function(string: str) -> Optional[re.Match]:
return re.match(r"^<function.+>$", str(string or ""))


def import_py_file(file: Path):
def import_py_file(file: Path) -> ModuleType:
module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
spec = importlib.util.spec_from_file_location(module_name, file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
module = importlib.util.module_from_spec(spec) # type:ignore[arg-type]
spec.loader.exec_module(module) # type:ignore[union-attr]
return module
Loading