From a1109058938d71fbe8cd0fcf68e876ad9552e481 Mon Sep 17 00:00:00 2001 From: Jacob Walls Date: Tue, 12 Dec 2023 08:18:32 -0500 Subject: [PATCH] Add `__main__` as inferred value for `__name__` (#2345) --- ChangeLog | 5 +++++ astroid/nodes/scoped_nodes/scoped_nodes.py | 12 +++++++++++- tests/test_scoped_nodes.py | 4 +++- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/ChangeLog b/ChangeLog index 7bce59680b..c6d2b17404 100644 --- a/ChangeLog +++ b/ChangeLog @@ -12,6 +12,11 @@ Release date: TBA Refs pylint-dev/pylint#9193 +* Add ``__main__`` as a possible inferred value for ``__name__`` to improve + control flow inference around ``if __name__ == "__main__":`` guards. + + Closes #2071 + What's New in astroid 3.0.3? ============================ diff --git a/astroid/nodes/scoped_nodes/scoped_nodes.py b/astroid/nodes/scoped_nodes/scoped_nodes.py index e8b1aef4f1..2d492f1b64 100644 --- a/astroid/nodes/scoped_nodes/scoped_nodes.py +++ b/astroid/nodes/scoped_nodes/scoped_nodes.py @@ -41,7 +41,15 @@ from astroid.interpreter.dunder_lookup import lookup from astroid.interpreter.objectmodel import ClassModel, FunctionModel, ModuleModel from astroid.manager import AstroidManager -from astroid.nodes import Arguments, Const, NodeNG, Unknown, _base_nodes, node_classes +from astroid.nodes import ( + Arguments, + Const, + NodeNG, + Unknown, + _base_nodes, + const_factory, + node_classes, +) from astroid.nodes.scoped_nodes.mixin import ComprehensionScope, LocalsDictNodeNG from astroid.nodes.scoped_nodes.utils import builtin_lookup from astroid.nodes.utils import Position @@ -346,6 +354,8 @@ def getattr( if name in self.special_attributes and not ignore_locals and not name_in_locals: result = [self.special_attributes.lookup(name)] + if name == "__name__": + result.append(const_factory("__main__")) elif not ignore_locals and name_in_locals: result = self.locals[name] elif self.package: diff --git a/tests/test_scoped_nodes.py b/tests/test_scoped_nodes.py index 1bc5af78b6..995f0428d9 100644 --- a/tests/test_scoped_nodes.py +++ b/tests/test_scoped_nodes.py @@ -79,9 +79,11 @@ def setUp(self) -> None: class ModuleNodeTest(ModuleLoader, unittest.TestCase): def test_special_attributes(self) -> None: - self.assertEqual(len(self.module.getattr("__name__")), 1) + self.assertEqual(len(self.module.getattr("__name__")), 2) self.assertIsInstance(self.module.getattr("__name__")[0], nodes.Const) self.assertEqual(self.module.getattr("__name__")[0].value, "data.module") + self.assertIsInstance(self.module.getattr("__name__")[1], nodes.Const) + self.assertEqual(self.module.getattr("__name__")[1].value, "__main__") self.assertEqual(len(self.module.getattr("__doc__")), 1) self.assertIsInstance(self.module.getattr("__doc__")[0], nodes.Const) self.assertEqual(