Skip to content

Commit

Permalink
Add support for generic attrs converters
Browse files Browse the repository at this point in the history
  • Loading branch information
chadrik committed Mar 16, 2023
1 parent a6b5b1e commit f33f58c
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 3 deletions.
20 changes: 17 additions & 3 deletions mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import mypy.plugin # To avoid circular imports.
from mypy.checker import TypeChecker
from mypy.errorcodes import LITERAL_REQ
from mypy.expandtype import expand_type
from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type
from mypy.messages import format_type_bare
from mypy.nodes import (
Expand Down Expand Up @@ -49,7 +50,7 @@
deserialize_and_fixup_type,
)
from mypy.server.trigger import make_wildcard_trigger
from mypy.typeops import make_simplified_union, map_type_from_supertype
from mypy.typeops import get_type_vars, make_simplified_union, map_type_from_supertype
from mypy.types import (
AnyType,
CallableType,
Expand All @@ -61,6 +62,7 @@
TupleType,
Type,
TypeOfAny,
TypeVarId,
TypeVarType,
UnionType,
get_proper_type,
Expand All @@ -85,8 +87,9 @@
class Converter:
"""Holds information about a `converter=` argument"""

def __init__(self, init_type: Type | None = None) -> None:
def __init__(self, init_type: Type | None = None, ret_type: Type | None = None) -> None:
self.init_type = init_type
self.ret_type = ret_type


class Attribute:
Expand Down Expand Up @@ -115,11 +118,20 @@ def __init__(
def argument(self, ctx: mypy.plugin.ClassDefContext) -> Argument:
"""Return this attribute as an argument to __init__."""
assert self.init

init_type: Type | None = None
if self.converter:
if self.converter.init_type:
init_type = self.converter.init_type
if init_type and self.converter.ret_type:
# The converter return type should be the same type as the attribute type.
# Copy type vars from attr type to converter.
converter_vars = get_type_vars(self.converter.ret_type)
init_vars = get_type_vars(self.init_type)
if converter_vars and len(converter_vars) == len(init_vars):
variables = {
binder.id: arg for binder, arg in zip(converter_vars, init_vars)
}
init_type = expand_type(init_type, variables)
else:
ctx.api.fail("Cannot determine __init__ type from converter", self.context)
init_type = AnyType(TypeOfAny.from_error)
Expand Down Expand Up @@ -671,6 +683,8 @@ def _parse_converter(
converter_type = get_proper_type(converter_type)
if isinstance(converter_type, CallableType) and converter_type.arg_types:
converter_info.init_type = converter_type.arg_types[0]
if not is_attr_converters_optional:
converter_info.ret_type = converter_type.ret_type
elif isinstance(converter_type, Overloaded):
types: list[Type] = []
for item in converter_type.items:
Expand Down
42 changes: 42 additions & 0 deletions test-data/unit/check-attr.test
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,48 @@ A([1], '2') # E: Cannot infer type argument 1 of "A"

[builtins fixtures/list.pyi]

[case testAttrsGenericWithConverter]
from typing import TypeVar, Generic, List, Iterable, Iterator
import attr
T = TypeVar('T')

def int_gen() -> Iterator[int]:
yield 1

def list_converter(x: Iterable[T]) -> List[T]:
return list(x)

@attr.s(auto_attribs=True)
class A(Generic[T]):
x: List[T] = attr.ib(converter=list_converter)
y: T = attr.ib()
def foo(self) -> List[T]:
return [self.y]
def bar(self) -> T:
return self.x[0]
def problem(self) -> T:
return self.x # E: Incompatible return value type (got "List[T]", expected "T")
reveal_type(A) # N: Revealed type is "def [T] (x: typing.Iterable[T`1], y: T`1) -> __main__.A[T`1]"
a1 = A([1], 2)
reveal_type(a1) # N: Revealed type is "__main__.A[builtins.int]"
reveal_type(a1.x) # N: Revealed type is "builtins.list[builtins.int]"
reveal_type(a1.y) # N: Revealed type is "builtins.int"

a2 = A(int_gen(), 2)
reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"
reveal_type(a2.x) # N: Revealed type is "builtins.list[builtins.int]"
reveal_type(a2.y) # N: Revealed type is "builtins.int"

# Leaving this as a sanity check
class B(Generic[T]):
def __init__(self, x: Iterable[T], y: T) -> None:
pass

B(['str'], 7)
B([1], '2')

[builtins fixtures/list.pyi]


[case testAttrsUntypedGenericInheritance]
from typing import Generic, TypeVar
Expand Down

0 comments on commit f33f58c

Please sign in to comment.