From a0c3729b9df8c535d02cdc3d43b50ee5d686268b Mon Sep 17 00:00:00 2001 From: David Liu Date: Tue, 20 Jul 2021 16:36:21 -0400 Subject: [PATCH] Improve variable lookup to handle function parameters being overwritten --- ChangeLog | 4 ++ astroid/node_classes.py | 17 ++++-- tests/unittest_inference.py | 18 ++++++ tests/unittest_lookup.py | 108 +++++++++++++++++++++++++++++++++++- 4 files changed, 142 insertions(+), 5 deletions(-) diff --git a/ChangeLog b/ChangeLog index 8faed02351..92919c4d88 100644 --- a/ChangeLog +++ b/ChangeLog @@ -17,6 +17,10 @@ Release date: TBA Closes PyCQA/pylint#3711 +* Fix variable lookup's handling of function parameters + + Closes PyCQA/astroid#180 + What's New in astroid 2.6.5? ============================ diff --git a/astroid/node_classes.py b/astroid/node_classes.py index 0226b096de..042bbb8cf3 100644 --- a/astroid/node_classes.py +++ b/astroid/node_classes.py @@ -1216,10 +1216,10 @@ def _filter_stmts(self, stmts, frame, offset): if are_exclusive(self, node): continue + # An AssignName node overrides previous assignments if: + # 1. node's statement always assigns + # 2. node and self are in the same block (i.e., has the same parent as self) if isinstance(node, AssignName): - # Remove all previously stored assignments if: - # 1. node's statement always assigns - # 2. node has the same parent as self (i.e., they're in the same block) if not optional_assign and stmt.parent is mystmt.parent: _stmts = [] _stmt_parents = [] @@ -1230,7 +1230,16 @@ def _filter_stmts(self, stmts, frame, offset): continue # Add the new assignment _stmts.append(node) - _stmt_parents.append(stmt.parent) + if isinstance(node, Arguments) or isinstance(node.parent, Arguments): + # Special case for _stmt_parents when node is a function parameter; + # in this case, stmt is the enclosing FunctionDef, which is what we + # want to add to _stmt_parents, not stmt.parent. This case occurs when + # node is an Arguments node (representing varargs or kwargs parameter), + # and when node.parent is an Arguments node (other parameters). + # See issue #180. + _stmt_parents.append(stmt) + else: + _stmt_parents.append(stmt.parent) return _stmts diff --git a/tests/unittest_inference.py b/tests/unittest_inference.py index 591696f9e8..5c09fbda12 100644 --- a/tests/unittest_inference.py +++ b/tests/unittest_inference.py @@ -4805,6 +4805,24 @@ def test(*args): return args inferred = next(node.infer()) self.assertEqual(inferred, util.Uninferable) + def test_args_overwritten(self): + # https://github.com/PyCQA/astroid/issues/180 + node = extract_node( + """ + next = 42 + def wrapper(next=next): + next = 24 + def test(): + return next + return test + wrapper()() #@ + """ + ) + inferred = node.inferred() + self.assertEqual(len(inferred), 1) + self.assertIsInstance(inferred[0], nodes.Const, inferred[0]) + self.assertEqual(inferred[0].value, 24) + class SliceTest(unittest.TestCase): def test_slice(self): diff --git a/tests/unittest_lookup.py b/tests/unittest_lookup.py index 1d6e77b598..86f9e8cc95 100644 --- a/tests/unittest_lookup.py +++ b/tests/unittest_lookup.py @@ -18,7 +18,7 @@ import functools import unittest -from astroid import builder, nodes, scoped_nodes +from astroid import builder, nodes, scoped_nodes, test_utils from astroid.exceptions import ( AttributeInferenceError, InferenceError, @@ -743,6 +743,112 @@ def f(b): self.assertEqual(len(stmts), 1) self.assertEqual(stmts[0].lineno, 3) + def test_assign_after_param(self): + """When an assignment statement overwrites a function parameter, only the + assignment is returned, even when the variable and assignment do not have + the same parent. + """ + code = """ + def f1(x): + x = 100 + print(x) + + def f2(x): + x = 100 + if True: + print(x) + """ + astroid = builder.parse(code) + x_name1, x_name2 = ( + n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x" + ) + _, stmts1 = x_name1.lookup("x") + self.assertEqual(len(stmts1), 1) + self.assertEqual(stmts1[0].lineno, 3) + + _, stmts2 = x_name2.lookup("x") + self.assertEqual(len(stmts2), 1) + self.assertEqual(stmts2[0].lineno, 7) + + def test_assign_after_kwonly_param(self): + """When an assignment statement overwrites a function keyword-only parameter, + only the assignment is returned, even when the variable and assignment do + not have the same parent. + """ + code = """ + def f1(*, x): + x = 100 + print(x) + + def f2(*, x): + x = 100 + if True: + print(x) + """ + astroid = builder.parse(code) + x_name1, x_name2 = ( + n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x" + ) + _, stmts1 = x_name1.lookup("x") + self.assertEqual(len(stmts1), 1) + self.assertEqual(stmts1[0].lineno, 3) + + _, stmts2 = x_name2.lookup("x") + self.assertEqual(len(stmts2), 1) + self.assertEqual(stmts2[0].lineno, 7) + + @test_utils.require_version(minver="3.8") + def test_assign_after_posonly_param(self): + """When an assignment statement overwrites a function positional-only parameter, + only the assignment is returned, even when the variable and assignment do + not have the same parent. + """ + code = """ + def f1(x, /): + x = 100 + print(x) + + def f2(x, /): + x = 100 + if True: + print(x) + """ + astroid = builder.parse(code) + x_name1, x_name2 = ( + n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x" + ) + _, stmts1 = x_name1.lookup("x") + self.assertEqual(len(stmts1), 1) + self.assertEqual(stmts1[0].lineno, 3) + + _, stmts2 = x_name2.lookup("x") + self.assertEqual(len(stmts2), 1) + self.assertEqual(stmts2[0].lineno, 7) + + def test_assign_after_args_param(self): + """When an assignment statement overwrites a function parameter, only the + assignment is returned. + """ + code = """ + def f(*args, **kwargs): + args = [100] + kwargs = {} + if True: + print(args, kwargs) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "args"][0] + _, stmts1 = x_name.lookup("args") + self.assertEqual(len(stmts1), 1) + self.assertEqual(stmts1[0].lineno, 3) + + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "kwargs"][ + 0 + ] + _, stmts2 = x_name.lookup("kwargs") + self.assertEqual(len(stmts2), 1) + self.assertEqual(stmts2[0].lineno, 4) + if __name__ == "__main__": unittest.main()