Skip to content

Commit

Permalink
fix: mypy issues when following imports (#328)
Browse files Browse the repository at this point in the history
Fix a few issues issues with star imports along the way.
  • Loading branch information
browniebroke authored Feb 2, 2021
1 parent 67f09d7 commit 4fb45f8
Show file tree
Hide file tree
Showing 14 changed files with 205 additions and 109 deletions.
4 changes: 1 addition & 3 deletions django_codemod/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,7 @@ def call_command(command_instance: BaseCodemodCommand, files: List[Path]):
# Super simplified call
result = parallel_exec_transform_with_prettyprint(
command_instance,
files,
# Number of jobs to use when processing files. Defaults to number of cores
jobs=None,
files, # type: ignore
)
except KeyboardInterrupt:
raise click.Abort("Interrupted!")
Expand Down
3 changes: 3 additions & 0 deletions django_codemod/visitors/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ClassDef,
FunctionDef,
ImportFrom,
ImportStar,
Module,
Name,
Param,
Expand All @@ -32,6 +33,8 @@ class InlineHasAddPermissionsTransformer(BaseDjCodemodTransformer):
def leave_ImportFrom(
self, original_node: ImportFrom, updated_node: ImportFrom
) -> Union[BaseSmallStatement, RemovalSentinel]:
if isinstance(updated_node.names, ImportStar):
return super().leave_ImportFrom(original_node, updated_node)
base_cls_matcher = []
if m.matches(
updated_node,
Expand Down
12 changes: 6 additions & 6 deletions django_codemod/visitors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class BaseDjCodemodTransformer(ContextAwareTransformer, ABC):

def module_matcher(
import_parts: Sequence[str],
) -> Union[m.BaseMatcherNode, m.DoNotCare]:
) -> Union[m.Attribute, m.Name]:
"""Build matcher for a module given sequence of import parts."""
# If only one element, it is just a Name
if len(import_parts) == 1:
Expand Down Expand Up @@ -87,8 +87,8 @@ def leave_ImportFrom(
return updated_node.with_changes(names=cleaned_names)

def gen_new_imported_names(
self, old_names: Union[Sequence[ImportAlias], ImportStar]
) -> Generator:
self, old_names: Sequence[ImportAlias]
) -> Generator[ImportAlias, None, None]:
"""Update import if the entity we're interested in is imported."""
for import_alias in old_names:
if not self.old_name or import_alias.evaluated_name == self.old_name:
Expand All @@ -100,12 +100,12 @@ def gen_new_imported_names(
yield import_alias

def resolve_parent_node(self, node: CSTNode) -> CSTNode:
parent_nodes = self.context.wrapper.resolve(ParentNodeProvider)
parent_nodes = self.context.wrapper.resolve(ParentNodeProvider) # type: ignore
return parent_nodes[node]

def resolve_scope(self, node: CSTNode) -> Scope:
scopes_map = self.context.wrapper.resolve(ScopeProvider)
return scopes_map[node]
scopes_map = self.context.wrapper.resolve(ScopeProvider) # type: ignore
return scopes_map[node] # type: ignore

def save_import_scope(self, import_from: ImportFrom) -> None:
scope = self.resolve_scope(import_from)
Expand Down
6 changes: 5 additions & 1 deletion django_codemod/visitors/http.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from libcst import BaseExpression, Call
from libcst import Attribute, BaseExpression, Call
from libcst import matchers as m

from django_codemod.constants import DJANGO_2_0, DJANGO_2_1, DJANGO_3_0, DJANGO_4_0
Expand Down Expand Up @@ -75,6 +75,10 @@ class HttpRequestXReadLinesTransformer(BaseDjCodemodTransformer):
)

def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression:
if not isinstance(updated_node.func, Attribute):
# A bit redundant with matcher below,
# but this is to make type checker happy
return super().leave_Call(original_node, updated_node)
if m.matches(updated_node, self.matcher):
return updated_node.func.value
return super().leave_Call(original_node, updated_node)
Expand Down
18 changes: 10 additions & 8 deletions django_codemod/visitors/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Call,
FunctionDef,
ImportFrom,
ImportStar,
MaybeSentinel,
Name,
RemovalSentinel,
Expand All @@ -35,6 +36,8 @@ class ModelsPermalinkTransformer(BaseDjCodemodTransformer):
def leave_ImportFrom(
self, original_node: ImportFrom, updated_node: ImportFrom
) -> Union[BaseSmallStatement, RemovalSentinel]:
if isinstance(updated_node.names, ImportStar):
return super().leave_ImportFrom(original_node, updated_node)
if m.matches(
updated_node,
m.ImportFrom(module=module_matcher(["django", "db"])),
Expand All @@ -55,13 +58,11 @@ def leave_ImportFrom(
updated_names = []
for imported_name in updated_node.names:
if m.matches(imported_name, m.ImportAlias(name=m.Name("permalink"))):
decorator_name = (
imported_name.asname.name.value
if imported_name.asname
else "permalink"
decorator_name_str = (
imported_name.evaluated_alias or imported_name.evaluated_name
)
self.add_decorator_matcher(
m.Decorator(decorator=m.Name(decorator_name))
m.Decorator(decorator=m.Name(decorator_name_str))
)
else:
updated_names.append(imported_name)
Expand Down Expand Up @@ -122,9 +123,10 @@ def visiting_permalink_method(self):
def leave_Return(
self, original_node: Return, updated_node: Return
) -> Union[BaseSmallStatement, RemovalSentinel]:
if self.visiting_permalink_method and m.matches(updated_node.value, m.Tuple()):
elem_0 = updated_node.value.elements[0]
elem_1_3 = updated_node.value.elements[1:3]
if self.visiting_permalink_method and m.matches(
updated_node.value, m.Tuple() # type: ignore
):
elem_0, *elem_1_3 = updated_node.value.elements[:3] # type: ignore
args = (
Arg(elem_0.value),
Arg(Name("None")),
Expand Down
43 changes: 22 additions & 21 deletions django_codemod/visitors/signals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional

from libcst import BaseExpression, Call, ImportFrom, MaybeSentinel, Module
from libcst import BaseExpression, Call, ImportFrom, ImportStar, MaybeSentinel, Module
from libcst import matchers as m

from django_codemod.constants import DJANGO_1_9, DJANGO_2_0
Expand Down Expand Up @@ -46,26 +46,27 @@ def leave_Module(self, original_node: Module, updated_node: Module) -> Module:

def visit_ImportFrom(self, node: ImportFrom) -> Optional[bool]:
"""Set the `Call` matcher depending on which signals are imported.."""
if import_from_matches(node, ["django", "db", "models", "signals"]):
for import_alias in node.names:
if m.matches(import_alias, self.import_alias_matcher):
# We're visiting an import statement for a built-in signal
# Get the actual name it's imported as (in case of import alias)
imported_name = (
import_alias.asname
and import_alias.asname.name
or import_alias.name
if not import_from_matches(
node, ["django", "db", "models", "signals"]
) or isinstance(node.names, ImportStar):
return False
for import_alias in node.names:
if m.matches(import_alias, self.import_alias_matcher):
# We're visiting an import statement for a built-in signal
# Get the actual name it's imported as (in case of import alias)
imported_name_str = (
import_alias.evaluated_alias or import_alias.evaluated_name
)
# Add the call matcher for the current signal to the list
self.add_disconnect_call_matcher(
m.Call(
func=m.Attribute(
value=m.Name(imported_name_str),
attr=m.Name("disconnect"),
),
)
# Add the call matcher for the current signal to the list
self.add_disconnect_call_matcher(
m.Call(
func=m.Attribute(
value=m.Name(imported_name.value),
attr=m.Name("disconnect"),
),
)
)
return super().visit_ImportFrom(node)
)
return None

def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression:
"""
Expand All @@ -87,7 +88,7 @@ def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression:
should_change = True
else:
updated_args.append(arg)
last_comma = arg.comma
last_comma = arg.comma # type: ignore
if should_change:
# Make sure the end of line is formatted as initially
updated_args[-1] = updated_args[-1].with_changes(comma=last_comma)
Expand Down
81 changes: 46 additions & 35 deletions django_codemod/visitors/template_tags.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from typing import Optional, Union

from libcst import Assign, Decorator, ImportFrom, Module, Name, RemovalSentinel
from libcst import (
Assign,
Decorator,
ImportFrom,
ImportStar,
Module,
Name,
RemovalSentinel,
)
from libcst import matchers as m

from django_codemod.constants import DJANGO_1_9, DJANGO_2_0
Expand Down Expand Up @@ -33,46 +41,48 @@ def leave_Module(self, original_node: Module, updated_node: Module) -> Module:

def visit_ImportFrom(self, node: ImportFrom) -> Optional[bool]:
"""Record whether an interesting import is detected."""
return self._check_template_imported(node) or self._check_libary_imported(node)
return self._check_template_imported(node) or self._check_library_imported(node)

def _check_template_imported(self, node: ImportFrom) -> bool:
"""Record matcher if django.template is imported."""
if import_from_matches(node, ["django"]):
for import_alias in node.names:
if m.matches(import_alias, m.ImportAlias(name=m.Name("template"))):
# We're visiting the `from django import template` statement
# Get the actual name it's imported as (in case of import alias)
imported_name = (
import_alias.asname
and import_alias.asname.name
or import_alias.name
)
# Build the `Call` matcher to look out for, eg `template.Library()`
self.context.scratch[self.ctx_key_library_call_matcher] = m.Call(
func=m.Attribute(
attr=m.Name("Library"), value=m.Name(imported_name.value)
)
if not import_from_matches(node, ["django"]) or isinstance(
node.names, ImportStar
):
return False
for import_alias in node.names:
if m.matches(import_alias, m.ImportAlias(name=m.Name("template"))):
# We're visiting the `from django import template` statement
# Get the actual name it's imported as (in case of import alias)
imported_name_str = (
import_alias.evaluated_alias or import_alias.evaluated_name
)
# Build the `Call` matcher to look out for, eg `template.Library()`
self.context.scratch[self.ctx_key_library_call_matcher] = m.Call(
func=m.Attribute(
attr=m.Name("Library"), value=m.Name(value=imported_name_str)
)
return True
)
return True
return False

def _check_libary_imported(self, node: ImportFrom) -> bool:
def _check_library_imported(self, node: ImportFrom) -> bool:
"""Record matcher if django.template.Library is imported."""
if import_from_matches(node, ["django", "template"]):
for import_alias in node.names:
if m.matches(import_alias, m.ImportAlias(name=m.Name("Library"))):
# We're visiting the `from django.template import Library` statement
# Get the actual name it's imported as (in case of import alias)
imported_name = (
import_alias.asname
and import_alias.asname.name
or import_alias.name
)
# Build the `Call` matcher to look out for, eg `Library()`
self.context.scratch[self.ctx_key_library_call_matcher] = m.Call(
func=m.Name(imported_name.value)
)
return True
if not import_from_matches(node, ["django", "template"]) or isinstance(
node.names, ImportStar
):
return False
for import_alias in node.names:
if m.matches(import_alias, m.ImportAlias(name=m.Name("Library"))):
# We're visiting the `from django.template import Library` statement
# Get the actual name it's imported as (in case of import alias)
imported_name_str = (
import_alias.evaluated_alias or import_alias.evaluated_name
)
# Build the `Call` matcher to look out for, eg `Library()`
self.context.scratch[self.ctx_key_library_call_matcher] = m.Call(
func=m.Name(imported_name_str)
)
return True
return False

def visit_Assign(self, node: Assign) -> Optional[bool]:
Expand All @@ -84,7 +94,8 @@ def visit_Assign(self, node: Assign) -> Optional[bool]:
# Visiting a `register = template.Library()` statement
# Get all names on the left side of the assignment
target_names = (
assign_target.target.value for assign_target in node.targets
assign_target.target.value # type: ignore
for assign_target in node.targets
)
# Build the decorator matchers to look out for
target_matchers = (
Expand Down
12 changes: 4 additions & 8 deletions django_codemod/visitors/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ def update_call(self, updated_node: Call) -> BaseExpression:
self.add_new_import()
return super().update_call(updated_node)

def update_call_to_path(self, updated_node: Call):
def update_call_to_path(self, updated_node: Call) -> Call:
"""Update an URL pattern to `path()` in simple cases."""
first_arg, *other_args = updated_node.args
self.check_not_simple_string(first_arg)
if not isinstance(first_arg.value, SimpleString):
raise PatternNotSupported()
# Extract the URL pattern from the first argument
pattern = first_arg.value.evaluated_value
# If we reach this point, we might be able to use `path()`
Expand All @@ -61,12 +62,7 @@ def update_call_to_path(self, updated_node: Call):
)
return call

def check_not_simple_string(self, first_arg: Arg):
"""Translated patterns are not supported."""
if not m.matches(first_arg, m.Arg(value=m.SimpleString())):
raise PatternNotSupported()

def build_path_call(self, pattern, other_args):
def build_path_call(self, pattern: str, other_args: Sequence[Arg]) -> Call:
"""Build the `Call` node using Django 2.0's `path()` function."""
route = self.build_route(pattern)
updated_args = (Arg(value=SimpleString(f"'{route}'")), *other_args)
Expand Down
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# Config taken from flake8-mypy for sane defaults
[mypy]
# do not follow imports (except for ones found in typeshed)
follow_imports=skip

# since we're ignoring imports, writing .mypy_cache doesn't make any sense
cache_dir=/dev/null

Expand Down
Loading

0 comments on commit 4fb45f8

Please sign in to comment.