Skip to content

Commit

Permalink
Merge pull request #98 from jgberry/fix-list-comp-bindings
Browse files Browse the repository at this point in the history
Fix comprehension bindings and requirements when target conflicts with outer scope
  • Loading branch information
jgberry authored Jun 22, 2023
2 parents f986986 + 0701313 commit 3bd130d
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 20 deletions.
8 changes: 8 additions & 0 deletions src/ssort/_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ def _get_bindings_for_name(node: ast.Name) -> Iterable[str]:
yield node.id


@get_bindings.register(ast.comprehension)
def _get_bindings_for_comprehension(node: ast.comprehension) -> Iterable[str]:
# Ignore target, it can never produce bindings
yield from get_bindings(node.iter)
for condition in node.ifs:
yield from get_bindings(condition)


@get_bindings.register(ast.ExceptHandler)
def _get_bindings_for_except_handler(node: ast.ExceptHandler) -> Iterable[str]:
if node.type:
Expand Down
10 changes: 8 additions & 2 deletions src/ssort/_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,14 @@ def _get_requirements_for_lambda(node: ast.Lambda) -> Iterable[Requirement]:
@get_requirements.register(ast.SetComp)
@get_requirements.register(ast.DictComp)
@get_requirements.register(ast.GeneratorExp)
def _get_requirements_for_comp(node: ast.AST) -> Iterable[Requirement]:
bindings = set(get_bindings(node))
def _get_requirements_for_comp(
node: ast.ListComp | ast.SetComp | ast.DictComp | ast.GeneratorExp,
) -> Iterable[Requirement]:
bindings = {
binding
for generator in node.generators
for binding in get_bindings(generator.target)
}
for child in iter_child_nodes(node):
for requirement in get_requirements(child):
if requirement.name not in bindings:
Expand Down
36 changes: 18 additions & 18 deletions tests/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,22 +1023,22 @@ def test_list_comp_bindings():
ListComp(expr elt, comprehension* generators)
"""
node = _parse("[item for item in iterator if condition(item)]")
assert list(get_bindings(node)) == ["item"]
assert list(get_bindings(node)) == []


def test_list_comp_bindings_walrus_target():
node = _parse("[( a:= item) for item in iterator if condition(item)]")
assert list(get_bindings(node)) == ["a", "item"]
assert list(get_bindings(node)) == ["a"]


def test_list_comp_bindings_walrus_iter():
node = _parse("[item for item in (it := iterator) if condition(item)]")
assert list(get_bindings(node)) == ["item", "it"]
assert list(get_bindings(node)) == ["it"]


def test_list_comp_bindings_walrus_condition():
node = _parse("[item for item in iterator if (c := condition(item))]")
assert list(get_bindings(node)) == ["item", "c"]
assert list(get_bindings(node)) == ["c"]


def test_set_comp_bindings():
Expand All @@ -1049,22 +1049,22 @@ def test_set_comp_bindings():
SetComp(expr elt, comprehension* generators)
"""
node = _parse("{item for item in iterator if condition(item)}")
assert list(get_bindings(node)) == ["item"]
assert list(get_bindings(node)) == []


def test_set_comp_bindings_walrus_target():
node = _parse("{( a:= item) for item in iterator if condition(item)}")
assert list(get_bindings(node)) == ["a", "item"]
assert list(get_bindings(node)) == ["a"]


def test_set_comp_bindings_walrus_iter():
node = _parse("{item for item in (it := iterator) if condition(item)}")
assert list(get_bindings(node)) == ["item", "it"]
assert list(get_bindings(node)) == ["it"]


def test_set_comp_bindings_walrus_condition():
node = _parse("{item for item in iterator if (c := condition(item))}")
assert list(get_bindings(node)) == ["item", "c"]
assert list(get_bindings(node)) == ["c"]


def test_dict_comp_bindings():
Expand All @@ -1074,40 +1074,40 @@ def test_dict_comp_bindings():
DictComp(expr key, expr value, comprehension* generators)
"""
node = _parse("{item[0]: item[1] for item in iterator if check(item)}")
assert list(get_bindings(node)) == ["item"]
assert list(get_bindings(node)) == []


def test_dict_comp_bindings_unpack():
node = _parse("{key: value for key, value in iterator}")
assert list(get_bindings(node)) == ["key", "value"]
assert list(get_bindings(node)) == []


def test_dict_comp_bindings_walrus_key():
node = _parse(
"{(key := item[0]): item[1] for item in iterator if check(item)}"
)
assert list(get_bindings(node)) == ["key", "item"]
assert list(get_bindings(node)) == ["key"]


def test_dict_comp_bindings_walrus_value():
node = _parse(
"{item[0]: (value := item[1]) for item in iterator if check(item)}"
)
assert list(get_bindings(node)) == ["value", "item"]
assert list(get_bindings(node)) == ["value"]


def test_dict_comp_bindings_walrus_iter():
node = _parse(
"{item[0]: item[1] for item in (it := iterator) if check(item)}"
)
assert list(get_bindings(node)) == ["item", "it"]
assert list(get_bindings(node)) == ["it"]


def test_dict_comp_bindings_walrus_condition():
node = _parse(
"{item[0]: item[1] for item in iterator if (c := check(item))}"
)
assert list(get_bindings(node)) == ["item", "c"]
assert list(get_bindings(node)) == ["c"]


def test_generator_exp_bindings():
Expand All @@ -1117,22 +1117,22 @@ def test_generator_exp_bindings():
GeneratorExp(expr elt, comprehension* generators)
"""
node = _parse("(item for item in iterator if condition(item))")
assert list(get_bindings(node)) == ["item"]
assert list(get_bindings(node)) == []


def test_generator_exp_bindings_walrus_target():
node = _parse("(( a:= item) for item in iterator if condition(item))")
assert list(get_bindings(node)) == ["a", "item"]
assert list(get_bindings(node)) == ["a"]


def test_generator_exp_bindings_walrus_iter():
node = _parse("(item for item in (it := iterator) if condition(item))")
assert list(get_bindings(node)) == ["item", "it"]
assert list(get_bindings(node)) == ["it"]


def test_generator_exp_bindings_walrus_condition():
node = _parse("(item for item in iterator if (c := condition(item)))")
assert list(get_bindings(node)) == ["item", "c"]
assert list(get_bindings(node)) == ["c"]


def test_await_bindings():
Expand Down
96 changes: 96 additions & 0 deletions tests/test_ssort.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,99 @@ def test_ssort_preserve_crlf_endlines_str():

actual = ssort(original)
assert actual == expected


def test_ssort_list_comp_conflicts_with_global_scope():
original = _clean(
"""
def f():
g()
return [g for g in range(10)]
def g():
pass
"""
)
expected = _clean(
"""
def g():
pass
def f():
g()
return [g for g in range(10)]
"""
)

actual = ssort(original)
assert actual == expected


def test_ssort_set_comp_conflicts_with_global_scope():
original = _clean(
"""
def f():
g()
return {g for g in range(10)}
def g():
pass
"""
)
expected = _clean(
"""
def g():
pass
def f():
g()
return {g for g in range(10)}
"""
)

actual = ssort(original)
assert actual == expected


def test_ssort_dict_comp_conflicts_with_global_scope():
original = _clean(
"""
def f():
g()
return {g: 1 for g in range(10)}
def g():
pass
"""
)
expected = _clean(
"""
def g():
pass
def f():
g()
return {g: 1 for g in range(10)}
"""
)

actual = ssort(original)
assert actual == expected


def test_ssort_generator_exp_conflicts_with_global_scope():
original = _clean(
"""
def f():
g()
return (g for g in range(10))
def g():
pass
"""
)
expected = _clean(
"""
def g():
pass
def f():
g()
return (g for g in range(10))
"""
)

actual = ssort(original)
assert actual == expected

0 comments on commit 3bd130d

Please sign in to comment.