From e38c01351cd478b82f703041b7f9f7b7a8fa38c7 Mon Sep 17 00:00:00 2001 From: Eric Nielsen <4120606+ericbn@users.noreply.github.com> Date: Tue, 9 Apr 2024 21:39:45 -0500 Subject: [PATCH 1/4] Support traversal of TypeAlias for Python 3.12 Also update tox and GitHub Actions to use Python 3.12 by default. --- .github/workflows/ci.yaml | 38 +++++++++++++++++----------------- .github/workflows/release.yaml | 2 +- src/ssort/_ast.py | 28 +++++++++++++++++++++++++ tests/test_requirements.py | 23 ++++++++++++++++++++ tests/test_ssort.py | 34 ++++++++++++++++++++++++++++++ tox.ini | 20 +++++++++--------- 6 files changed, 115 insertions(+), 30 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 09d105e..207f33c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -10,7 +10,7 @@ jobs: name: "Unit Tests" strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] runs-on: ["ubuntu-22.04", "windows-2019", "macos-11"] runs-on: ${{ matrix.runs-on }} steps: @@ -23,7 +23,7 @@ jobs: run: | python -m pip install --upgrade pip pip install pytest - pip install pyyaml==6.0 + pip install pyyaml==6.0.1 pip install -e .[test] - name: Run tests run: | @@ -34,15 +34,15 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 - - name: Set up Python 3.11 + - name: Set up Python 3.12 uses: actions/setup-python@v2 with: - python-version: "3.11" + python-version: "3.12" - name: Install dependencies run: | python -m pip install --upgrade pip pip install pytest pytest-cov coveralls - pip install pyyaml==6.0 + pip install pyyaml==6.0.1 pip install -e .[test] - name: Run tests run: | @@ -58,10 +58,10 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 - - name: Set up Python 3.11 + - name: Set up Python 3.12 uses: actions/setup-python@v2 with: - python-version: "3.11" + python-version: "3.12" - name: Install dependencies run: | python -m pip install --upgrade pip @@ -75,10 +75,10 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 - - name: Set up Python 3.11 + - name: Set up Python 3.12 uses: actions/setup-python@v2 with: - python-version: "3.11" + python-version: "3.12" - name: Install dependencies run: | python -m pip install --upgrade pip @@ -92,10 +92,10 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 - - name: Set up Python 3.11 + - name: Set up Python 3.12 uses: actions/setup-python@v2 with: - python-version: "3.11" + python-version: "3.12" - name: Install dependencies run: | python -m pip install --upgrade pip @@ -109,10 +109,10 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 - - name: Set up Python 3.11 + - name: Set up Python 3.12 uses: actions/setup-python@v2 with: - python-version: "3.11" + python-version: "3.12" - name: Install dependencies run: | python -m pip install --upgrade pip @@ -126,16 +126,16 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 - - name: Set up Python 3.11 + - name: Set up Python 3.12 uses: actions/setup-python@v2 with: - python-version: "3.11" + python-version: "3.12" - name: Install dependencies run: | python -m pip install --upgrade pip pip install -e .[test] pip install pytest - pip install pyyaml==6.0 + pip install pyyaml==6.0.1 pip install pylint - name: Run pylint run: | @@ -146,16 +146,16 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 - - name: Set up Python 3.11 + - name: Set up Python 3.12 uses: actions/setup-python@v2 with: - python-version: "3.11" + python-version: "3.12" - name: Install dependencies run: | python -m pip install --upgrade pip pip install mypy pip install pytest - pip install pyyaml==6.0 + pip install pyyaml==6.0.1 pip install types-PyYAML pip install types-setuptools - name: Run mypy diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index b03ac75..ce9264c 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -15,7 +15,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v3 with: - python-version: 3.11 + python-version: 3.12 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/src/ssort/_ast.py b/src/ssort/_ast.py index 4987045..7047c08 100644 --- a/src/ssort/_ast.py +++ b/src/ssort/_ast.py @@ -496,3 +496,31 @@ def _iter_child_nodes_of_type_ignore( node: ast.TypeIgnore, ) -> Iterable[ast.AST]: return () + + +if sys.version_info >= (3, 12): + + @iter_child_nodes.register(ast.TypeAlias) + def _iter_child_nodes_of_type_alias( + node: ast.TypeAlias, + ) -> Iterable[ast.AST]: + yield node.name + yield from node.type_params + yield node.value + + @iter_child_nodes.register(ast.TypeVar) + def _iter_child_nodes_of_type_var(node: ast.TypeVar) -> Iterable[ast.AST]: + if node.bound is not None: + yield node.bound + + @iter_child_nodes.register(ast.ParamSpec) + def _iter_child_nodes_of_param_spec( + node: ast.ParamSpec, + ) -> Iterable[ast.AST]: + return () + + @iter_child_nodes.register(ast.TypeVarTuple) + def _iter_child_nodes_of_type_var_tuple( + node: ast.TypeVarTuple, + ) -> Iterable[ast.AST]: + return () diff --git a/tests/test_requirements.py b/tests/test_requirements.py index 29bb044..9c949e5 100644 --- a/tests/test_requirements.py +++ b/tests/test_requirements.py @@ -16,6 +16,11 @@ reason="exception groups were introduced in python 3.11", ) +type_parameter_syntax = pytest.mark.skipif( + sys.version_info < (3, 12), + reason="type parameter syntax was introduced in python 3.12", +) + def _parse(source): source = textwrap.dedent(source) @@ -1195,3 +1200,21 @@ def test_try_star_requirements(): "otherwise", "finish", ] + + +@type_parameter_syntax +def test_type_var_requirements(): + node = _parse("type Alias[T: (str, bytes)] = list[T]") + assert _dep_names(node) == ["str", "bytes", "list", "T"] + + +@type_parameter_syntax +def test_param_spec_requirements(): + node = _parse("type Alias[**P] = Callable[P, int]") + assert _dep_names(node) == ["Callable", "P", "int"] + + +@type_parameter_syntax +def test_type_var_tuple_requirements(): + node = _parse("type Alias[*Ts] = tuple[*Ts]") + assert _dep_names(node) == ["tuple", "Ts"] diff --git a/tests/test_ssort.py b/tests/test_ssort.py index 87da854..2447b5f 100644 --- a/tests/test_ssort.py +++ b/tests/test_ssort.py @@ -1,7 +1,15 @@ +import sys import textwrap +import pytest + from ssort import ssort +type_parameter_syntax = pytest.mark.skipif( + sys.version_info < (3, 12), + reason="type parameter syntax was introduced in python 3.12", +) + def _clean(text): return textwrap.dedent(text).strip() + "\n" @@ -694,3 +702,29 @@ def test_single_line_dummy_class(): actual = ssort(original) assert actual == expected + + +@type_parameter_syntax +def test_ssort_type_alias(): + original = _clean( + """ + from decimal import Decimal + + def roundint(n: N) -> int: + return int(round(n)) + + type N = Decimal | float + """ + ) + expected = _clean( + """ + from decimal import Decimal + + type N = Decimal | float + + def roundint(n: N) -> int: + return int(round(n)) + """ + ) + actual = ssort(original) + assert actual == expected diff --git a/tox.ini b/tox.ini index 609a4a3..4daf9a1 100644 --- a/tox.ini +++ b/tox.ini @@ -1,16 +1,16 @@ [tox] -envlist = py38,py39,py310,py311,black,isort,ssort,pyflakes,pylint,mypy +envlist = py38,py39,py310,py311,py312,black,isort,ssort,pyflakes,pylint,mypy isolated_build = true [testenv] deps = pytest - pyyaml==6.0 + pyyaml==6.0.1 commands = pytest -vv tests/ [testenv:black] -basepython = py311 +basepython = py312 deps = black skip_install = True @@ -18,7 +18,7 @@ commands = black --check --diff . [testenv:isort] -basepython = py311 +basepython = py312 deps = isort skip_install = True @@ -26,12 +26,12 @@ commands = isort --check-only --diff . [testenv:ssort] -basepython = py311 +basepython = py312 commands = ssort --check --diff src/ tests/ [testenv:pyflakes] -basepython = py311 +basepython = py312 deps = pyflakes skip_install = True @@ -39,10 +39,10 @@ commands = pyflakes src/ tests/ [testenv:pylint] -basepython = py311 +basepython = py312 deps = pytest - pyyaml==6.0 + pyyaml==6.0.1 pylint extras= test @@ -50,11 +50,11 @@ commands = pylint -E src/ tests/ [testenv:mypy] -basepython = py311 +basepython = py312 deps = mypy pytest - pyyaml==6.0 + pyyaml==6.0.1 types-PyYAML types-setuptools skip_install = True From 6e765f45d2691a6a23b6d217ac5f06871a4b3f8c Mon Sep 17 00:00:00 2001 From: Eric Nielsen <4120606+ericbn@users.noreply.github.com> Date: Fri, 12 Apr 2024 21:21:39 -0500 Subject: [PATCH 2/4] Add requirements processing for type parameter syntax --- src/ssort/_ast.py | 4 +++ src/ssort/_parsing.py | 10 ++++++ src/ssort/_requirements.py | 47 ++++++++++++++++++++++++++--- tests/test_bindings.py | 52 ++++++++++++++++++++++++++++++++ tests/test_requirements.py | 62 +++++++++++++++++++++++++++++++++++--- tests/test_ssort.py | 48 +++++++++++++++++++++++++++++ 6 files changed, 214 insertions(+), 9 deletions(-) diff --git a/src/ssort/_ast.py b/src/ssort/_ast.py index 7047c08..2cab1a6 100644 --- a/src/ssort/_ast.py +++ b/src/ssort/_ast.py @@ -50,6 +50,8 @@ def _iter_child_nodes_of_function_def( if node.returns is not None: yield node.returns yield from node.body + if sys.version_info >= (3, 12): + yield from node.type_params @iter_child_nodes.register(ast.ClassDef) @@ -58,6 +60,8 @@ def _iter_child_nodes_of_class_def(node: ast.ClassDef) -> Iterable[ast.AST]: yield from node.bases yield from node.keywords yield from node.body + if sys.version_info >= (3, 12): + yield from node.type_params @iter_child_nodes.register(ast.Return) diff --git a/src/ssort/_parsing.py b/src/ssort/_parsing.py index 960e6cd..6798140 100644 --- a/src/ssort/_parsing.py +++ b/src/ssort/_parsing.py @@ -128,6 +128,16 @@ def split_class(statement): assert token.type == NAME token = next(tokens) + if token.string == "[": + token = next(tokens) + depth = 1 + while depth: + if token.string == "[": + depth += 1 + if token.string == "]": + depth -= 1 + token = next(tokens) + if token.string == "(": token = next(tokens) depth = 1 diff --git a/src/ssort/_requirements.py b/src/ssort/_requirements.py index 4d960cf..a1792da 100644 --- a/src/ssort/_requirements.py +++ b/src/ssort/_requirements.py @@ -3,6 +3,7 @@ import ast import dataclasses import enum +import sys from typing import Iterable from ssort._ast import iter_child_nodes @@ -52,12 +53,22 @@ def _get_requirements_for_function_def( for decorator in node.decorator_list: yield from get_requirements(decorator) - yield from get_requirements(node.args) + scope: set[str] = set() + if sys.version_info >= (3, 12): + for type_param in node.type_params: + yield from get_requirements(type_param) + scope.update(type_param.name for type_param in node.type_params) # type: ignore[attr-defined] + + for requirement in get_requirements(node.args): + if requirement.name not in scope: + yield requirement if node.returns is not None: - yield from get_requirements(node.returns) + for requirement in get_requirements(node.returns): + if requirement.name not in scope: + yield requirement - scope = _get_scope_from_arguments(node.args) + scope.update(_get_scope_from_arguments(node.args)) requirements = [] for statement in node.body: @@ -83,10 +94,18 @@ def _get_requirements_for_class_def( for decorator in node.decorator_list: yield from get_requirements(decorator) + scope: set[str] = set() + if sys.version_info >= (3, 12): + for type_param in node.type_params: + yield from get_requirements(type_param) + scope.update(type_param.name for type_param in node.type_params) # type: ignore[attr-defined] + for base in node.bases: - yield from get_requirements(base) + for requirement in get_requirements(base): + if requirement.name not in scope: + yield requirement - scope = set(CLASS_BUILTINS) + scope.update(CLASS_BUILTINS) for statement in node.body: for stmt_dep in get_requirements(statement): @@ -191,3 +210,21 @@ def _get_requirements_for_name(node: ast.Name) -> Iterable[Requirement]: yield Requirement( name=node.id, lineno=node.lineno, col_offset=node.col_offset ) + + +if sys.version_info >= (3, 12): + + @get_requirements.register(ast.TypeAlias) + def _get_requirements_for_type_alias( + node: ast.TypeAlias, + ) -> Iterable[Requirement]: + for type_param in node.type_params: + yield from get_requirements(type_param) + scope: set[str] = set() + scope.update(type_param.name for type_param in node.type_params) # type: ignore[attr-defined] + scope.add(node.name.id) + for requirement in get_requirements(node.value): + if not requirement.deferred: + requirement = dataclasses.replace(requirement, deferred=True) + if requirement.name not in scope: + yield requirement diff --git a/tests/test_bindings.py b/tests/test_bindings.py index 2bfc3d5..2bffd69 100644 --- a/tests/test_bindings.py +++ b/tests/test_bindings.py @@ -24,6 +24,11 @@ reason="exception groups were introduced in python 3.11", ) +type_parameter_syntax = pytest.mark.skipif( + sys.version_info < (3, 12), + reason="type parameter syntax was introduced in python 3.12", +) + def _parse(source): source = textwrap.dedent(source) @@ -1632,3 +1637,50 @@ def test_try_star_binding_walrus_exeption_type(): """ ) assert list(get_bindings(node)) == ["a", "grp_type", "grp", "b", "c", "d"] + + +@type_parameter_syntax +def test_type_var_bindings(): + node = _parse("type RecursiveList[T] = T | list[RecursiveList[T]]") + assert list(get_bindings(node)) == ["RecursiveList"] + + +@type_parameter_syntax +def test_param_spec_bindings(): + node = _parse("type Alias[**P] = Callable[P, int]") + assert list(get_bindings(node)) == ["Alias"] + + +@type_parameter_syntax +def test_type_var_tuple_bindings(): + node = _parse("type Alias[*Ts] = tuple[*Ts]") + assert list(get_bindings(node)) == ["Alias"] + + +@type_parameter_syntax +def test_generic_type_alias_bindings(): + node = _parse("type ListOrSet[T] = list[T] | set[T]") + assert list(get_bindings(node)) == ["ListOrSet"] + + +@type_parameter_syntax +def test_generic_function_bindings(): + node = _parse( + """ + def func[T](a: T, b: T) -> T: + pass + """ + ) + assert list(get_bindings(node)) == ["func"] + + +@type_parameter_syntax +def test_generic_class_bindings(): + node = _parse( + """ + class ClassA[AnyStr: (str, bytes)]: + def method1(self) -> AnyStr: + pass + """ + ) + assert list(get_bindings(node)) == ["ClassA"] diff --git a/tests/test_requirements.py b/tests/test_requirements.py index 9c949e5..af3f0d1 100644 --- a/tests/test_requirements.py +++ b/tests/test_requirements.py @@ -1204,17 +1204,71 @@ def test_try_star_requirements(): @type_parameter_syntax def test_type_var_requirements(): - node = _parse("type Alias[T: (str, bytes)] = list[T]") - assert _dep_names(node) == ["str", "bytes", "list", "T"] + node = _parse('type Alias[T: ("ForwardReference", bytes)] = list[T]') + assert _dep_names(node) == ["bytes", "list"] + + +@type_parameter_syntax +def test_type_var_requirements_recursive(): + node = _parse("type RecursiveList[T] = T | list[RecursiveList[T]]") + assert _dep_names(node) == ["list"] @type_parameter_syntax def test_param_spec_requirements(): node = _parse("type Alias[**P] = Callable[P, int]") - assert _dep_names(node) == ["Callable", "P", "int"] + assert _dep_names(node) == ["Callable", "int"] @type_parameter_syntax def test_type_var_tuple_requirements(): node = _parse("type Alias[*Ts] = tuple[*Ts]") - assert _dep_names(node) == ["tuple", "Ts"] + assert _dep_names(node) == ["tuple"] + + +@type_parameter_syntax +def test_generic_type_alias_requirements(): + node = _parse("type ListOrSet[T] = list[T] | set[T]") + assert _dep_names(node) == ["list", "set"] + + +@type_parameter_syntax +def test_generic_function_requirements(): + node = _parse( + """ + def func[T1, T2](a: T1, b: T2) -> tuple[T1, T2]: + return a, b + """ + ) + assert _dep_names(node) == ["tuple"] + + +@type_parameter_syntax +def test_generic_class_requirements(): + node = _parse( + """ + class ClassA[T]: + attr1: T + def method1(self) -> T: + return self.attr1 + """ + ) + assert _dep_names(node) == [] + + +@type_parameter_syntax +def test_generic_inner_class_requirements(): + node = _parse( + """ + class Outer: + class Private: + pass + + class Inner[T](Private, Sequence[T]): + pass + + def method1[T](self, a: Inner[T]) -> Inner[T]: + return a + """ + ) + assert _dep_names(node) == ["Sequence"] diff --git a/tests/test_ssort.py b/tests/test_ssort.py index 2447b5f..c1adaee 100644 --- a/tests/test_ssort.py +++ b/tests/test_ssort.py @@ -710,6 +710,8 @@ def test_ssort_type_alias(): """ from decimal import Decimal + roundint(3.14) + def roundint(n: N) -> int: return int(round(n)) @@ -724,6 +726,52 @@ def roundint(n: N) -> int: def roundint(n: N) -> int: return int(round(n)) + + roundint(3.14) + """ + ) + actual = ssort(original) + assert actual == expected + + +@type_parameter_syntax +def test_ssort_generic_function(): + original = _clean( + """ + func(4) + def func[T](a: T) -> T: + return a + """ + ) + expected = _clean( + """ + def func[T](a: T) -> T: + return a + func(4) + """ + ) + actual = ssort(original) + assert actual == expected + + +@type_parameter_syntax +def test_ssort_generic_class(): + original = _clean( + """ + obj = ClassA[str]() + class ClassA[T: (str, bytes)](BaseClass[T]): + attr1: T + class BaseClass[T]: + pass + """ + ) + expected = _clean( + """ + class BaseClass[T]: + pass + class ClassA[T: (str, bytes)](BaseClass[T]): + attr1: T + obj = ClassA[str]() """ ) actual = ssort(original) From ac899c9a8d3837c631cb06ec69ff9951f15cbcdf Mon Sep 17 00:00:00 2001 From: Eric Nielsen <4120606+ericbn@users.noreply.github.com> Date: Sat, 13 Apr 2024 21:18:29 -0500 Subject: [PATCH 3/4] Fix type parameter scopes Also extract _get_requirements_from_nodes, _get_scope_from_type_params and _optional_nested_brackets functions. --- src/ssort/_parsing.py | 30 +++++------- src/ssort/_requirements.py | 98 +++++++++++++++++++++----------------- tests/test_bindings.py | 2 +- tests/test_requirements.py | 30 +++++++++++- 4 files changed, 96 insertions(+), 64 deletions(-) diff --git a/src/ssort/_parsing.py b/src/ssort/_parsing.py index 6798140..7afc3e9 100644 --- a/src/ssort/_parsing.py +++ b/src/ssort/_parsing.py @@ -128,25 +128,21 @@ def split_class(statement): assert token.type == NAME token = next(tokens) - if token.string == "[": - token = next(tokens) - depth = 1 - while depth: - if token.string == "[": - depth += 1 - if token.string == "]": - depth -= 1 - token = next(tokens) - if token.string == "(": - token = next(tokens) - depth = 1 - while depth: - if token.string == "(": - depth += 1 - if token.string == ")": - depth -= 1 + def _optional_nested_brackets(open: str, close: str): + nonlocal token + if token.string == open: token = next(tokens) + depth = 1 + while depth: + if token.string == open: + depth += 1 + if token.string == close: + depth -= 1 + token = next(tokens) + + _optional_nested_brackets("[", "]") + _optional_nested_brackets("(", ")") assert token.string == ":" diff --git a/src/ssort/_requirements.py b/src/ssort/_requirements.py index a1792da..b8201da 100644 --- a/src/ssort/_requirements.py +++ b/src/ssort/_requirements.py @@ -33,6 +33,13 @@ def get_requirements(node: ast.AST) -> Iterable[Requirement]: yield from get_requirements(child) +def _get_requirements_from_nodes( + nodes: Iterable[ast.AST], +) -> Iterable[Requirement]: + for node in nodes: + yield from get_requirements(node) + + def _get_scope_from_arguments(args: ast.arguments) -> set[str]: scope: set[str] = set() scope.update(arg.arg for arg in args.posonlyargs) @@ -45,19 +52,43 @@ def _get_scope_from_arguments(args: ast.arguments) -> set[str]: return scope +if sys.version_info >= (3, 12): + + def _get_scope_from_type_params( + type_params: list[ast.type_param], + ) -> set[str]: + return set(type_param.name for type_param in type_params) # type: ignore[attr-defined] + + @get_requirements.register(ast.TypeAlias) + def _get_requirements_for_type_alias( + node: ast.TypeAlias, + ) -> Iterable[Requirement]: + scope = _get_scope_from_type_params(node.type_params) + for requirement in _get_requirements_from_nodes(node.type_params): + if requirement.name not in scope: + yield requirement + + scope.add(node.name.id) + for requirement in get_requirements(node.value): + if not requirement.deferred: + requirement = dataclasses.replace(requirement, deferred=True) + if requirement.name not in scope: + yield requirement + + @get_requirements.register(ast.FunctionDef) @get_requirements.register(ast.AsyncFunctionDef) def _get_requirements_for_function_def( node: ast.FunctionDef | ast.AsyncFunctionDef, ) -> Iterable[Requirement]: - for decorator in node.decorator_list: - yield from get_requirements(decorator) + yield from _get_requirements_from_nodes(node.decorator_list) scope: set[str] = set() if sys.version_info >= (3, 12): - for type_param in node.type_params: - yield from get_requirements(type_param) - scope.update(type_param.name for type_param in node.type_params) # type: ignore[attr-defined] + scope.update(_get_scope_from_type_params(node.type_params)) + for requirement in _get_requirements_from_nodes(node.type_params): + if requirement.name not in scope: + yield requirement for requirement in get_requirements(node.args): if requirement.name not in scope: @@ -91,20 +122,19 @@ def _get_requirements_for_function_def( def _get_requirements_for_class_def( node: ast.ClassDef, ) -> Iterable[Requirement]: - for decorator in node.decorator_list: - yield from get_requirements(decorator) + yield from _get_requirements_from_nodes(node.decorator_list) scope: set[str] = set() if sys.version_info >= (3, 12): - for type_param in node.type_params: - yield from get_requirements(type_param) - scope.update(type_param.name for type_param in node.type_params) # type: ignore[attr-defined] - - for base in node.bases: - for requirement in get_requirements(base): + scope.update(_get_scope_from_type_params(node.type_params)) + for requirement in _get_requirements_from_nodes(node.type_params): if requirement.name not in scope: yield requirement + for requirement in _get_requirements_from_nodes(node.bases): + if requirement.name not in scope: + yield requirement + scope.update(CLASS_BUILTINS) for statement in node.body: @@ -125,15 +155,13 @@ def _get_requirements_for_for( yield from get_requirements(node.target) yield from get_requirements(node.iter) - for stmt in node.body: - for requirement in get_requirements(stmt): - if requirement.name not in bindings: - yield requirement + for requirement in _get_requirements_from_nodes(node.body): + if requirement.name not in bindings: + yield requirement - for stmt in node.orelse: - for requirement in get_requirements(stmt): - if requirement.name not in bindings: - yield requirement + for requirement in _get_requirements_from_nodes(node.orelse): + if requirement.name not in bindings: + yield requirement @get_requirements.register(ast.With) @@ -143,13 +171,11 @@ def _get_requirements_for_with( ) -> Iterable[Requirement]: bindings = set(get_bindings(node)) - for item in node.items: - yield from get_requirements(item) + yield from _get_requirements_from_nodes(node.items) - for stmt in node.body: - for requirement in get_requirements(stmt): - if requirement.name not in bindings: - yield requirement + for requirement in _get_requirements_from_nodes(node.body): + if requirement.name not in bindings: + yield requirement @get_requirements.register(ast.Global) @@ -210,21 +236,3 @@ def _get_requirements_for_name(node: ast.Name) -> Iterable[Requirement]: yield Requirement( name=node.id, lineno=node.lineno, col_offset=node.col_offset ) - - -if sys.version_info >= (3, 12): - - @get_requirements.register(ast.TypeAlias) - def _get_requirements_for_type_alias( - node: ast.TypeAlias, - ) -> Iterable[Requirement]: - for type_param in node.type_params: - yield from get_requirements(type_param) - scope: set[str] = set() - scope.update(type_param.name for type_param in node.type_params) # type: ignore[attr-defined] - scope.add(node.name.id) - for requirement in get_requirements(node.value): - if not requirement.deferred: - requirement = dataclasses.replace(requirement, deferred=True) - if requirement.name not in scope: - yield requirement diff --git a/tests/test_bindings.py b/tests/test_bindings.py index 2bffd69..289801a 100644 --- a/tests/test_bindings.py +++ b/tests/test_bindings.py @@ -1640,7 +1640,7 @@ def test_try_star_binding_walrus_exeption_type(): @type_parameter_syntax -def test_type_var_bindings(): +def test_type_var_recursive_bindings(): node = _parse("type RecursiveList[T] = T | list[RecursiveList[T]]") assert list(get_bindings(node)) == ["RecursiveList"] diff --git a/tests/test_requirements.py b/tests/test_requirements.py index af3f0d1..941dc09 100644 --- a/tests/test_requirements.py +++ b/tests/test_requirements.py @@ -1209,11 +1209,17 @@ def test_type_var_requirements(): @type_parameter_syntax -def test_type_var_requirements_recursive(): +def test_type_var_recursive_requirements(): node = _parse("type RecursiveList[T] = T | list[RecursiveList[T]]") assert _dep_names(node) == ["list"] +@type_parameter_syntax +def test_type_var_scope_requirements(): + node = _parse("type Alias[S, T: Sequence[S]] = dict[S, T]") + assert _dep_names(node) == ["Sequence", "dict"] + + @type_parameter_syntax def test_param_spec_requirements(): node = _parse("type Alias[**P] = Callable[P, int]") @@ -1243,6 +1249,17 @@ def func[T1, T2](a: T1, b: T2) -> tuple[T1, T2]: assert _dep_names(node) == ["tuple"] +@type_parameter_syntax +def test_generic_function_scope_requirements(): + node = _parse( + """ + def func[S, T: Sequence[S]](a: S) -> T: + pass + """ + ) + assert _dep_names(node) == ["Sequence"] + + @type_parameter_syntax def test_generic_class_requirements(): node = _parse( @@ -1256,6 +1273,17 @@ def method1(self) -> T: assert _dep_names(node) == [] +@type_parameter_syntax +def test_generic_class_scope_requirements(): + node = _parse( + """ + class ClassA[S, T: Sequence[S]]: + pass + """ + ) + assert _dep_names(node) == ["Sequence"] + + @type_parameter_syntax def test_generic_inner_class_requirements(): node = _parse( From 20b3f63525926c2df1ae43577150b335727a2fdc Mon Sep 17 00:00:00 2001 From: Eric Nielsen <4120606+ericbn@users.noreply.github.com> Date: Thu, 18 Apr 2024 16:58:53 -0500 Subject: [PATCH 4/4] Revert refactoring of _optional_nested_brackets --- src/ssort/_parsing.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/ssort/_parsing.py b/src/ssort/_parsing.py index 7afc3e9..24d3e5f 100644 --- a/src/ssort/_parsing.py +++ b/src/ssort/_parsing.py @@ -128,21 +128,24 @@ def split_class(statement): assert token.type == NAME token = next(tokens) - - def _optional_nested_brackets(open: str, close: str): - nonlocal token - if token.string == open: + if token.string == "[": + token = next(tokens) + depth = 1 + while depth: + if token.string == "[": + depth += 1 + if token.string == "]": + depth -= 1 + token = next(tokens) + if token.string == "(": + token = next(tokens) + depth = 1 + while depth: + if token.string == "(": + depth += 1 + if token.string == ")": + depth -= 1 token = next(tokens) - depth = 1 - while depth: - if token.string == open: - depth += 1 - if token.string == close: - depth -= 1 - token = next(tokens) - - _optional_nested_brackets("[", "]") - _optional_nested_brackets("(", ")") assert token.string == ":"