diff --git a/src/ssort/_parsing.py b/src/ssort/_parsing.py index 24d3e5f..969d81f 100644 --- a/src/ssort/_parsing.py +++ b/src/ssort/_parsing.py @@ -8,6 +8,19 @@ from ssort._statements import Statement +def _build_row_lengths_offsets(text): + # Build an index of row lengths and start offsets to enable fast string + # indexing using ast row/column coordinates. + row_lengths = [] + row_offsets = [0] + for offset, char in enumerate(text): + if char == "\n": + row_lengths.append(offset - row_offsets[-1]) + row_offsets.append(offset + 1) + row_lengths.append(len(text) - row_offsets[-1]) + return row_lengths, row_offsets + + def _find_start(node): if ( isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) @@ -29,17 +42,8 @@ def split( nodes, next_row=0, next_col=0, - indent=0, ): - # Build an index of row lengths and start offsets to enable fast string - # indexing using ast row/column coordinates. - row_lengths = [] - row_offsets = [0] - for offset, char in enumerate(root_text): - if char == "\n": - row_lengths.append(offset - row_offsets[-1]) - row_offsets.append(offset + 1) - row_lengths.append(len(root_text) - row_offsets[-1]) + row_lengths, row_offsets = _build_row_lengths_offsets(root_text) nodes = iter(nodes) @@ -106,15 +110,7 @@ def split_class(statement): text = statement.text text_padded = statement.text_padded() - # Build an index of row lengths and start offsets to enable fast string - # indexing using ast row/column coordinates. - row_lengths = [] - row_offsets = [0] - for offset, char in enumerate(text_padded): - if char == "\n": - row_lengths.append(offset - row_offsets[-1]) - row_offsets.append(offset + 1) - row_lengths.append(len(text_padded) - row_offsets[-1]) + _, row_offsets = _build_row_lengths_offsets(text_padded) tokens = iter(generate_tokens(StringIO(text_padded).readline)) diff --git a/tests/test_ast.py b/tests/test_ast.py index 77c445f..5f400c0 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -30,6 +30,10 @@ def _nodes_types( node_type: type[ast.AST] = ast.AST, ) -> Iterable[type[ast.AST]]: + # coverage package adds a coverage.parser.NodeList subclass + if node_type.__module__ != "ast": + return + # Skip deprecated AST nodes. if issubclass(node_type, _deprecated_node_types): return diff --git a/tests/test_split.py b/tests/test_split.py index 754c99c..8a81773 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -1,5 +1,14 @@ +import sys + +import pytest + from ssort._parsing import parse, split_class +type_parameter_syntax = pytest.mark.skipif( + sys.version_info < (3, 12), + reason="type parameter syntax was introduced in python 3.12", +) + def _split_text(source): return [ @@ -117,13 +126,38 @@ def test_split_class_decorators(): assert actual == expected +def test_split_class_decorator_single_line(): + actual = _split_class("@decorator()\nclass A: key: int") + expected = "@decorator()\nclass A:", [" key: int"] + assert actual == expected + + def test_split_class_leading_comment(): actual = _split_class("# Comment.\nclass A:\n pass") expected = "# Comment.\nclass A:", [" pass"] assert actual == expected +def test_split_class_leading_comment_single_line(): + actual = _split_class("# Comment.\nclass A: pass") + expected = "# Comment.\nclass A:", [" pass"] + assert actual == expected + + def test_split_class_multiple(): actual = _split_class("def a():\n pass\n\nclass A:\n pass") expected = "\nclass A:", [" pass"] assert actual == expected + + +def test_split_class_with_bases(): + actual = _split_class("class A(\n B,\n): pass") + expected = "class A(\n B,\n):", [" pass"] + assert actual == expected + + +@type_parameter_syntax +def test_split_class_with_type_params(): + actual = _split_class("class A[\n B,\n]: pass") + expected = "class A[\n B,\n]:", [" pass"] + assert actual == expected