Skip to content

Commit

Permalink
Better story for import redefinitions (#13969)
Browse files Browse the repository at this point in the history
This changes our importing logic to be more consistent and to treat
import statements more like assignments.

Fixes #13803, fixes #13914, fixes half of #12965, probably fixes #12574

The primary motivation for this is when typing modules as protocols, as
in #13803. But it turns out we already allowed redefinition with "from"
imports, so this also seems like a nice consistency win.

We move shared logic from visit_import_all and visit_import_from (via
process_imported_symbol) into add_imported_symbol. We then reuse it in
visit_import.

To simplify stuff, we inline the code from add_module_symbol into
visit_import. Then we copy over logic from add_symbol, because MypyFile
is not a SymbolTableNode, but this isn't the worst thing ever.

Finally, we now need to check non-from import statements like
assignments, which was a thing we weren't doing earlier.
  • Loading branch information
hauntsaninja committed Nov 3, 2022
1 parent 0457d33 commit a4da89e
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 54 deletions.
4 changes: 2 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2527,8 +2527,8 @@ def visit_import_from(self, node: ImportFrom) -> None:
def visit_import_all(self, node: ImportAll) -> None:
self.check_import(node)

def visit_import(self, s: Import) -> None:
pass
def visit_import(self, node: Import) -> None:
self.check_import(node)

def check_import(self, node: ImportBase) -> None:
for assign in node.assignments:
Expand Down
84 changes: 39 additions & 45 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2235,13 +2235,33 @@ def visit_import(self, i: Import) -> None:
base_id = id.split(".")[0]
imported_id = base_id
module_public = use_implicit_reexport
self.add_module_symbol(
base_id,
imported_id,
context=i,
module_public=module_public,
module_hidden=not module_public,
)

if base_id in self.modules:
node = self.modules[base_id]
if self.is_func_scope():
kind = LDEF
elif self.type is not None:
kind = MDEF
else:
kind = GDEF
symbol = SymbolTableNode(
kind, node, module_public=module_public, module_hidden=not module_public
)
self.add_imported_symbol(
imported_id,
symbol,
context=i,
module_public=module_public,
module_hidden=not module_public,
)
else:
self.add_unknown_imported_symbol(
imported_id,
context=i,
target_name=base_id,
module_public=module_public,
module_hidden=not module_public,
)

def visit_import_from(self, imp: ImportFrom) -> None:
self.statement = imp
Expand Down Expand Up @@ -2377,19 +2397,6 @@ def process_imported_symbol(
module_hidden=module_hidden,
becomes_typeinfo=True,
)
existing_symbol = self.globals.get(imported_id)
if (
existing_symbol
and not isinstance(existing_symbol.node, PlaceholderNode)
and not isinstance(node.node, PlaceholderNode)
):
# Import can redefine a variable. They get special treatment.
if self.process_import_over_existing_name(imported_id, existing_symbol, node, context):
return
if existing_symbol and isinstance(node.node, PlaceholderNode):
# Imports are special, some redefinitions are allowed, so wait until
# we know what is the new symbol node.
return
# NOTE: we take the original node even for final `Var`s. This is to support
# a common pattern when constants are re-exported (same applies to import *).
self.add_imported_symbol(
Expand Down Expand Up @@ -2507,14 +2514,9 @@ def visit_import_all(self, i: ImportAll) -> None:
if isinstance(node.node, MypyFile):
# Star import of submodule from a package, add it as a dependency.
self.imports.add(node.node.fullname)
existing_symbol = self.lookup_current_scope(name)
if existing_symbol and not isinstance(node.node, PlaceholderNode):
# Import can redefine a variable. They get special treatment.
if self.process_import_over_existing_name(name, existing_symbol, node, i):
continue
# `from x import *` always reexports symbols
self.add_imported_symbol(
name, node, i, module_public=True, module_hidden=False
name, node, context=i, module_public=True, module_hidden=False
)

else:
Expand Down Expand Up @@ -5589,24 +5591,6 @@ def add_local(self, node: Var | FuncDef | OverloadedFuncDef, context: Context) -
node._fullname = name
self.add_symbol(name, node, context)

def add_module_symbol(
self, id: str, as_id: str, context: Context, module_public: bool, module_hidden: bool
) -> None:
"""Add symbol that is a reference to a module object."""
if id in self.modules:
node = self.modules[id]
self.add_symbol(
as_id, node, context, module_public=module_public, module_hidden=module_hidden
)
else:
self.add_unknown_imported_symbol(
as_id,
context,
target_name=id,
module_public=module_public,
module_hidden=module_hidden,
)

def _get_node_for_class_scoped_import(
self, name: str, symbol_node: SymbolNode | None, context: Context
) -> SymbolNode | None:
Expand Down Expand Up @@ -5653,13 +5637,23 @@ def add_imported_symbol(
self,
name: str,
node: SymbolTableNode,
context: Context,
context: ImportBase,
module_public: bool,
module_hidden: bool,
) -> None:
"""Add an alias to an existing symbol through import."""
assert not module_hidden or not module_public

existing_symbol = self.lookup_current_scope(name)
if (
existing_symbol
and not isinstance(existing_symbol.node, PlaceholderNode)
and not isinstance(node.node, PlaceholderNode)
):
# Import can redefine a variable. They get special treatment.
if self.process_import_over_existing_name(name, existing_symbol, node, context):
return

symbol_node: SymbolNode | None = node.node

if self.is_class_scope():
Expand Down
3 changes: 1 addition & 2 deletions test-data/unit/check-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -7414,8 +7414,7 @@ class Foo:
def meth1(self, a: str) -> str: ... # E: Name "meth1" already defined on line 5

def meth2(self, a: str) -> str: ...
from mod1 import meth2 # E: Unsupported class scoped import \
# E: Name "meth2" already defined on line 8
from mod1 import meth2 # E: Incompatible import of "meth2" (imported name has type "Callable[[int], int]", local name has type "Callable[[Foo, str], str]")

class Bar:
from mod1 import foo # E: Unsupported class scoped import
Expand Down
5 changes: 1 addition & 4 deletions test-data/unit/check-incremental.test
Original file line number Diff line number Diff line change
Expand Up @@ -1025,10 +1025,7 @@ import a.b

[file a/b.py]

[rechecked b]
[stale]
[out2]
tmp/b.py:4: error: Name "a" already defined on line 3
[stale b]

[case testIncrementalSilentImportsAndImportsInClass]
# flags: --ignore-missing-imports
Expand Down
19 changes: 19 additions & 0 deletions test-data/unit/check-modules.test
Original file line number Diff line number Diff line change
Expand Up @@ -651,10 +651,29 @@ try:
from m import f, g # E: Incompatible import of "g" (imported name has type "Callable[[Any, Any], Any]", local name has type "Callable[[Any], Any]")
except:
pass

import m as f # E: Incompatible import of "f" (imported name has type "object", local name has type "Callable[[Any], Any]")

[file m.py]
def f(x): pass
def g(x, y): pass

[case testRedefineTypeViaImport]
from typing import Type
import mod

X: Type[mod.A]
Y: Type[mod.B]
from mod import B as X
from mod import A as Y # E: Incompatible import of "Y" (imported name has type "Type[A]", local name has type "Type[B]")

import mod as X # E: Incompatible import of "X" (imported name has type "object", local name has type "Type[A]")

[file mod.py]
class A: ...
class B(A): ...


[case testImportVariableAndAssignNone]
try:
from m import x
Expand Down
56 changes: 56 additions & 0 deletions test-data/unit/check-protocols.test
Original file line number Diff line number Diff line change
Expand Up @@ -3787,3 +3787,59 @@ from typing_extensions import Final

a: Final = 1
[builtins fixtures/module.pyi]


[case testModuleAsProtocolRedefinitionTopLevel]
from typing import Protocol

class P(Protocol):
def f(self) -> str: ...

cond: bool
t: P
if cond:
import mod1 as t
else:
import mod2 as t

import badmod as t # E: Incompatible import of "t" (imported name has type Module, local name has type "P")

[file mod1.py]
def f() -> str: ...

[file mod2.py]
def f() -> str: ...

[file badmod.py]
def nothing() -> int: ...
[builtins fixtures/module.pyi]

[case testModuleAsProtocolRedefinitionImportFrom]
from typing import Protocol

class P(Protocol):
def f(self) -> str: ...

cond: bool
t: P
if cond:
from package import mod1 as t
else:
from package import mod2 as t

from package import badmod as t # E: Incompatible import of "t" (imported name has type Module, local name has type "P")

package: int = 10

import package.mod1 as t
import package.mod1 # E: Incompatible import of "package" (imported name has type Module, local name has type "int")

[file package/mod1.py]
def f() -> str: ...

[file package/mod2.py]
def f() -> str: ...

[file package/badmod.py]
def nothing() -> int: ...
[builtins fixtures/module.pyi]
2 changes: 1 addition & 1 deletion test-data/unit/check-redefine.test
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def f() -> None:
import typing as m
m = 1 # E: Incompatible types in assignment (expression has type "int", variable has type Module)
n = 1
import typing as n # E: Name "n" already defined on line 5
import typing as n # E: Incompatible import of "n" (imported name has type Module, local name has type "int")
[builtins fixtures/module.pyi]

[case testRedefineLocalWithTypeAnnotation]
Expand Down

0 comments on commit a4da89e

Please sign in to comment.