Skip to content

Commit

Permalink
Allow new-style self-types in classmethods (#17381)
Browse files Browse the repository at this point in the history
Fixes #16547
Fixes #16410
Fixes #5570

From the upvotes on the issue it looks like an important use case. From
what I see this is an omission in the original implementation, I don't
see any additional unsafety (except for the same that exists for
instance methods/variables). I also incorporate a small refactoring and
remove couple unused `get_proper_type()` calls.

The fix uncovered an unrelated issue with unions in descriptors, so I
fix that one as well.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ilevkivskyi and pre-commit-ci[bot] committed Jun 17, 2024
1 parent 06c7d26 commit ba5c279
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 31 deletions.
4 changes: 3 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3261,7 +3261,9 @@ def analyze_ordinary_member_access(self, e: MemberExpr, is_lvalue: bool) -> Type
if isinstance(base, RefExpr) and isinstance(base.node, MypyFile):
module_symbol_table = base.node.names
if isinstance(base, RefExpr) and isinstance(base.node, Var):
is_self = base.node.is_self
# This is needed to special case self-types, so we don't need to track
# these flags separately in checkmember.py.
is_self = base.node.is_self or base.node.is_cls
else:
is_self = False

Expand Down
78 changes: 49 additions & 29 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def analyze_descriptor_access(descriptor_type: Type, mx: MemberContext) -> Type:
Return:
The return type of the appropriate ``__get__`` overload for the descriptor.
"""
instance_type = get_proper_type(mx.original_type)
instance_type = get_proper_type(mx.self_type)
orig_descriptor_type = descriptor_type
descriptor_type = get_proper_type(descriptor_type)

Expand All @@ -647,16 +647,6 @@ def analyze_descriptor_access(descriptor_type: Type, mx: MemberContext) -> Type:
return make_simplified_union(
[analyze_descriptor_access(typ, mx) for typ in descriptor_type.items]
)
elif isinstance(instance_type, UnionType):
# map over the instance types
return make_simplified_union(
[
analyze_descriptor_access(
descriptor_type, mx.copy_modified(original_type=original_type)
)
for original_type in instance_type.relevant_items()
]
)
elif not isinstance(descriptor_type, Instance):
return orig_descriptor_type

Expand Down Expand Up @@ -777,23 +767,10 @@ def analyze_var(
if mx.is_lvalue and var.is_classvar:
mx.msg.cant_assign_to_classvar(name, mx.context)
t = freshen_all_functions_type_vars(typ)
if not (mx.is_self or mx.is_super) or supported_self_type(
get_proper_type(mx.original_type)
):
t = expand_self_type(var, t, mx.original_type)
elif (
mx.is_self
and original_itype.type != var.info
# If an attribute with Self-type was defined in a supertype, we need to
# rebind the Self type variable to Self type variable of current class...
and original_itype.type.self_type is not None
# ...unless `self` has an explicit non-trivial annotation.
and original_itype == mx.chk.scope.active_self_type()
):
t = expand_self_type(var, t, original_itype.type.self_type)
t = get_proper_type(expand_type_by_instance(t, itype))
t = expand_self_type_if_needed(t, mx, var, original_itype)
t = expand_type_by_instance(t, itype)
freeze_all_type_vars(t)
result: Type = t
result = t
typ = get_proper_type(typ)

call_type: ProperType | None = None
Expand Down Expand Up @@ -857,6 +834,50 @@ def analyze_var(
return result


def expand_self_type_if_needed(
t: Type, mx: MemberContext, var: Var, itype: Instance, is_class: bool = False
) -> Type:
"""Expand special Self type in a backwards compatible manner.
This should ensure that mixing old-style and new-style self-types work
seamlessly. Also, re-bind new style self-types in subclasses if needed.
"""
original = get_proper_type(mx.self_type)
if not (mx.is_self or mx.is_super):
repl = mx.self_type
if is_class:
if isinstance(original, TypeType):
repl = original.item
elif isinstance(original, CallableType):
# Problematic access errors should have been already reported.
repl = erase_typevars(original.ret_type)
else:
repl = itype
return expand_self_type(var, t, repl)
elif supported_self_type(
# Support compatibility with plain old style T -> T and Type[T] -> T only.
get_proper_type(mx.self_type),
allow_instances=False,
allow_callable=False,
):
repl = mx.self_type
if is_class and isinstance(original, TypeType):
repl = original.item
return expand_self_type(var, t, repl)
elif (
mx.is_self
and itype.type != var.info
# If an attribute with Self-type was defined in a supertype, we need to
# rebind the Self type variable to Self type variable of current class...
and itype.type.self_type is not None
# ...unless `self` has an explicit non-trivial annotation.
and itype == mx.chk.scope.active_self_type()
):
return expand_self_type(var, t, itype.type.self_type)
else:
return t


def freeze_all_type_vars(member_type: Type) -> None:
member_type.accept(FreezeTypeVarsVisitor())

Expand Down Expand Up @@ -1059,12 +1080,11 @@ def analyze_class_attribute_access(
else:
message = message_registry.GENERIC_INSTANCE_VAR_CLASS_ACCESS
mx.msg.fail(message, mx.context)

t = expand_self_type_if_needed(t, mx, node.node, itype, is_class=True)
# Erase non-mapped variables, but keep mapped ones, even if there is an error.
# In the above example this means that we infer following types:
# C.x -> Any
# C[int].x -> int
t = get_proper_type(expand_self_type(node.node, t, itype))
t = erase_typevars(expand_type_by_instance(t, isuper), {tv.id for tv in def_vars})

is_classmethod = (is_decorated and cast(Decorator, node.node).func.is_class) or (
Expand Down
35 changes: 35 additions & 0 deletions test-data/unit/check-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -1950,6 +1950,41 @@ class B:
def foo(x: Union[A, B]) -> None:
reveal_type(x.attr) # N: Revealed type is "builtins.str"

[case testDescriptorGetUnionRestricted]
from typing import Any, Union

class getter:
def __get__(self, instance: X1, owner: Any) -> str: ...

class X1:
prop = getter()

class X2:
prop: str

def foo(x: Union[X1, X2]) -> None:
reveal_type(x.prop) # N: Revealed type is "builtins.str"

[case testDescriptorGetUnionType]
from typing import Any, Union, Type, overload

class getter:
@overload
def __get__(self, instance: None, owner: Any) -> getter: ...
@overload
def __get__(self, instance: object, owner: Any) -> str: ...
def __get__(self, instance, owner):
...

class X1:
prop = getter()
class X2:
prop = getter()

def foo(x: Type[Union[X1, X2]]) -> None:
reveal_type(x.prop) # N: Revealed type is "__main__.getter"


-- _promote decorators
-- -------------------

Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-recursive-types.test
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ from typing import NamedTuple, TypeVar, Tuple
NT = NamedTuple("NT", [("x", NT), ("y", int)])
nt: NT
reveal_type(nt) # N: Revealed type is "Tuple[..., builtins.int, fallback=__main__.NT]"
reveal_type(nt.x) # N: Revealed type is "Tuple[Tuple[..., builtins.int, fallback=__main__.NT], builtins.int, fallback=__main__.NT]"
reveal_type(nt.x) # N: Revealed type is "Tuple[..., builtins.int, fallback=__main__.NT]"
reveal_type(nt[0]) # N: Revealed type is "Tuple[Tuple[..., builtins.int, fallback=__main__.NT], builtins.int, fallback=__main__.NT]"
y: str
if nt.x is not None:
Expand Down
61 changes: 61 additions & 0 deletions test-data/unit/check-selftype.test
Original file line number Diff line number Diff line change
Expand Up @@ -2071,3 +2071,64 @@ p: Partial
reveal_type(p()) # N: Revealed type is "Never"
p2: Partial2
reveal_type(p2(42)) # N: Revealed type is "builtins.int"

[case testAccessingSelfClassVarInClassMethod]
from typing import Self, ClassVar, Type, TypeVar

T = TypeVar("T", bound="Foo")

class Foo:
instance: ClassVar[Self]
@classmethod
def get_instance(cls) -> Self:
return reveal_type(cls.instance) # N: Revealed type is "Self`0"
@classmethod
def get_instance_old(cls: Type[T]) -> T:
return reveal_type(cls.instance) # N: Revealed type is "T`-1"

class Bar(Foo):
extra: int

@classmethod
def get_instance(cls) -> Self:
reveal_type(cls.instance.extra) # N: Revealed type is "builtins.int"
return cls.instance

@classmethod
def other(cls) -> None:
reveal_type(cls.instance) # N: Revealed type is "Self`0"
reveal_type(cls.instance.extra) # N: Revealed type is "builtins.int"

reveal_type(Bar.instance) # N: Revealed type is "__main__.Bar"
[builtins fixtures/classmethod.pyi]

[case testAccessingSelfClassVarInClassMethodTuple]
from typing import Self, ClassVar, Tuple

class C(Tuple[int, str]):
x: Self
y: ClassVar[Self]

@classmethod
def bar(cls) -> None:
reveal_type(cls.y) # N: Revealed type is "Self`0"
@classmethod
def bar_self(self) -> Self:
return reveal_type(self.y) # N: Revealed type is "Self`0"

c: C
reveal_type(c.x) # N: Revealed type is "Tuple[builtins.int, builtins.str, fallback=__main__.C]"
reveal_type(c.y) # N: Revealed type is "Tuple[builtins.int, builtins.str, fallback=__main__.C]"
reveal_type(C.y) # N: Revealed type is "Tuple[builtins.int, builtins.str, fallback=__main__.C]"
C.x # E: Access to generic instance variables via class is ambiguous
[builtins fixtures/classmethod.pyi]

[case testAccessingTypingSelfUnion]
from typing import Self, Union

class C:
x: Self
class D:
x: int
x: Union[C, D]
reveal_type(x.x) # N: Revealed type is "Union[__main__.C, builtins.int]"

0 comments on commit ba5c279

Please sign in to comment.