Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use TypeVar defaults instead of Any when fixing instance types (PEP 696) #16812

Merged
merged 1 commit into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3017,12 +3017,15 @@ def for_function(callee: CallableType) -> str:
return ""


def wrong_type_arg_count(n: int, act: str, name: str) -> str:
s = f"{n} type arguments"
if n == 0:
s = "no type arguments"
elif n == 1:
s = "1 type argument"
def wrong_type_arg_count(low: int, high: int, act: str, name: str) -> str:
if low == high:
s = f"{low} type arguments"
if low == 0:
s = "no type arguments"
elif low == 1:
s = "1 type argument"
else:
s = f"between {low} and {high} type arguments"
if act == "0":
act = "none"
return f'"{name}" expects {s}, but {act} given'
Expand Down
87 changes: 57 additions & 30 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from mypy import errorcodes as codes, message_registry, nodes
from mypy.errorcodes import ErrorCode
from mypy.expandtype import expand_type
from mypy.messages import MessageBuilder, format_type_bare, quote_type_string, wrong_type_arg_count
from mypy.nodes import (
ARG_NAMED,
Expand Down Expand Up @@ -75,6 +76,7 @@
TypeOfAny,
TypeQuery,
TypeType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
Expand Down Expand Up @@ -1834,14 +1836,14 @@ def get_omitted_any(
return any_type


def fix_type_var_tuple_argument(any_type: Type, t: Instance) -> None:
def fix_type_var_tuple_argument(t: Instance) -> None:
if t.type.has_type_var_tuple_type:
args = list(t.args)
assert t.type.type_var_tuple_prefix is not None
tvt = t.type.defn.type_vars[t.type.type_var_tuple_prefix]
assert isinstance(tvt, TypeVarTupleType)
args[t.type.type_var_tuple_prefix] = UnpackType(
Instance(tvt.tuple_fallback.type, [any_type])
Instance(tvt.tuple_fallback.type, [args[t.type.type_var_tuple_prefix]])
)
t.args = tuple(args)

Expand All @@ -1855,26 +1857,42 @@ def fix_instance(
use_generic_error: bool = False,
unexpanded_type: Type | None = None,
) -> None:
"""Fix a malformed instance by replacing all type arguments with Any.
"""Fix a malformed instance by replacing all type arguments with TypeVar default or Any.

Also emit a suitable error if this is not due to implicit Any's.
"""
if len(t.args) == 0:
if use_generic_error:
fullname: str | None = None
else:
fullname = t.type.fullname
any_type = get_omitted_any(disallow_any, fail, note, t, options, fullname, unexpanded_type)
t.args = (any_type,) * len(t.type.type_vars)
fix_type_var_tuple_argument(any_type, t)
return
# Construct the correct number of type arguments, as
# otherwise the type checker may crash as it expects
# things to be right.
any_type = AnyType(TypeOfAny.from_error)
t.args = tuple(any_type for _ in t.type.type_vars)
fix_type_var_tuple_argument(any_type, t)
t.invalid = True
arg_count = len(t.args)
min_tv_count = sum(not tv.has_default() for tv in t.type.defn.type_vars)
max_tv_count = len(t.type.type_vars)
if arg_count < min_tv_count or arg_count > max_tv_count:
# Don't use existing args if arg_count doesn't match
t.args = ()

args: list[Type] = [*(t.args[:max_tv_count])]
any_type: AnyType | None = None
env: dict[TypeVarId, Type] = {}

for tv, arg in itertools.zip_longest(t.type.defn.type_vars, t.args, fillvalue=None):
if tv is None:
continue
if arg is None:
if tv.has_default():
arg = tv.default
else:
if any_type is None:
fullname = None if use_generic_error else t.type.fullname
any_type = get_omitted_any(
disallow_any, fail, note, t, options, fullname, unexpanded_type
)
arg = any_type
args.append(arg)
env[tv.id] = arg
t.args = tuple(args)
fix_type_var_tuple_argument(t)
if not t.type.has_type_var_tuple_type:
fixed = expand_type(t, env)
assert isinstance(fixed, Instance)
t.args = fixed.args


def instantiate_type_alias(
Expand Down Expand Up @@ -1963,7 +1981,7 @@ def instantiate_type_alias(
if use_standard_error:
# This is used if type alias is an internal representation of another type,
# for example a generic TypedDict or NamedTuple.
msg = wrong_type_arg_count(exp_len, str(act_len), node.name)
msg = wrong_type_arg_count(exp_len, exp_len, str(act_len), node.name)
else:
if node.tvar_tuple_index is not None:
exp_len_str = f"at least {exp_len - 1}"
Expand Down Expand Up @@ -2217,24 +2235,27 @@ def validate_instance(t: Instance, fail: MsgCallback, empty_tuple_index: bool) -
# TODO: is it OK to fill with TypeOfAny.from_error instead of special form?
return False
if t.type.has_type_var_tuple_type:
correct = len(t.args) >= len(t.type.type_vars) - 1
min_tv_count = sum(
not tv.has_default() and not isinstance(tv, TypeVarTupleType)
for tv in t.type.defn.type_vars
)
correct = len(t.args) >= min_tv_count
if any(
isinstance(a, UnpackType) and isinstance(get_proper_type(a.type), Instance)
for a in t.args
):
correct = True
if not correct:
exp_len = f"at least {len(t.type.type_vars) - 1}"
if not t.args:
if not (empty_tuple_index and len(t.type.type_vars) == 1):
# The Any arguments should be set by the caller.
return False
elif not correct:
fail(
f"Bad number of arguments, expected: {exp_len}, given: {len(t.args)}",
f"Bad number of arguments, expected: at least {min_tv_count}, given: {len(t.args)}",
t,
code=codes.TYPE_ARG,
)
return False
elif not t.args:
if not (empty_tuple_index and len(t.type.type_vars) == 1):
# The Any arguments should be set by the caller.
return False
else:
# We also need to check if we are not performing a type variable tuple split.
unpack = find_unpack_in_list(t.args)
Expand All @@ -2254,15 +2275,21 @@ def validate_instance(t: Instance, fail: MsgCallback, empty_tuple_index: bool) -
elif any(isinstance(a, UnpackType) for a in t.args):
# A variadic unpack in fixed size instance (fixed unpacks must be flattened by the caller)
fail(message_registry.INVALID_UNPACK_POSITION, t, code=codes.VALID_TYPE)
t.args = ()
return False
elif len(t.args) != len(t.type.type_vars):
# Invalid number of type parameters.
if t.args:
arg_count = len(t.args)
min_tv_count = sum(not tv.has_default() for tv in t.type.defn.type_vars)
max_tv_count = len(t.type.type_vars)
if arg_count and (arg_count < min_tv_count or arg_count > max_tv_count):
fail(
wrong_type_arg_count(len(t.type.type_vars), str(len(t.args)), t.type.name),
wrong_type_arg_count(min_tv_count, max_tv_count, str(arg_count), t.type.name),
t,
code=codes.TYPE_ARG,
)
t.args = ()
t.invalid = True
return False
return True

Expand Down
123 changes: 123 additions & 0 deletions test-data/unit/check-typevar-defaults.test
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,126 @@ def func_c1(x: Union[int, Callable[[Unpack[Ts1]], None]]) -> Tuple[Unpack[Ts1]]:
# reveal_type(func_c1(callback1)) # Revealed type is "builtins.tuple[str]" # TODO
# reveal_type(func_c1(2)) # Revealed type is "builtins.tuple[builtins.int, builtins.str]" # TODO
[builtins fixtures/tuple.pyi]

[case testTypeVarDefaultsClass1]
from typing import Generic, TypeVar

T1 = TypeVar("T1")
T2 = TypeVar("T2", default=int)
T3 = TypeVar("T3", default=str)

class ClassA1(Generic[T2, T3]): ...

def func_a1(
a: ClassA1,
b: ClassA1[float],
c: ClassA1[float, float],
d: ClassA1[float, float, float], # E: "ClassA1" expects between 0 and 2 type arguments, but 3 given
) -> None:
reveal_type(a) # N: Revealed type is "__main__.ClassA1[builtins.int, builtins.str]"
reveal_type(b) # N: Revealed type is "__main__.ClassA1[builtins.float, builtins.str]"
reveal_type(c) # N: Revealed type is "__main__.ClassA1[builtins.float, builtins.float]"
reveal_type(d) # N: Revealed type is "__main__.ClassA1[builtins.int, builtins.str]"

class ClassA2(Generic[T1, T2, T3]): ...

def func_a2(
a: ClassA2,
b: ClassA2[float],
c: ClassA2[float, float],
d: ClassA2[float, float, float],
e: ClassA2[float, float, float, float], # E: "ClassA2" expects between 1 and 3 type arguments, but 4 given
) -> None:
reveal_type(a) # N: Revealed type is "__main__.ClassA2[Any, builtins.int, builtins.str]"
reveal_type(b) # N: Revealed type is "__main__.ClassA2[builtins.float, builtins.int, builtins.str]"
reveal_type(c) # N: Revealed type is "__main__.ClassA2[builtins.float, builtins.float, builtins.str]"
reveal_type(d) # N: Revealed type is "__main__.ClassA2[builtins.float, builtins.float, builtins.float]"
reveal_type(e) # N: Revealed type is "__main__.ClassA2[Any, builtins.int, builtins.str]"

[case testTypeVarDefaultsClass2]
from typing import Generic, ParamSpec

P1 = ParamSpec("P1")
P2 = ParamSpec("P2", default=[int, str])
P3 = ParamSpec("P3", default=...)

class ClassB1(Generic[P2, P3]): ...

def func_b1(
a: ClassB1,
b: ClassB1[[float]],
c: ClassB1[[float], [float]],
d: ClassB1[[float], [float], [float]], # E: "ClassB1" expects between 0 and 2 type arguments, but 3 given
) -> None:
reveal_type(a) # N: Revealed type is "__main__.ClassB1[[builtins.int, builtins.str], ...]"
reveal_type(b) # N: Revealed type is "__main__.ClassB1[[builtins.float], ...]"
reveal_type(c) # N: Revealed type is "__main__.ClassB1[[builtins.float], [builtins.float]]"
reveal_type(d) # N: Revealed type is "__main__.ClassB1[[builtins.int, builtins.str], ...]"

class ClassB2(Generic[P1, P2]): ...

def func_b2(
a: ClassB2,
b: ClassB2[[float]],
c: ClassB2[[float], [float]],
d: ClassB2[[float], [float], [float]], # E: "ClassB2" expects between 1 and 2 type arguments, but 3 given
) -> None:
reveal_type(a) # N: Revealed type is "__main__.ClassB2[Any, [builtins.int, builtins.str]]"
reveal_type(b) # N: Revealed type is "__main__.ClassB2[[builtins.float], [builtins.int, builtins.str]]"
reveal_type(c) # N: Revealed type is "__main__.ClassB2[[builtins.float], [builtins.float]]"
reveal_type(d) # N: Revealed type is "__main__.ClassB2[Any, [builtins.int, builtins.str]]"

[case testTypeVarDefaultsClass3]
from typing import Generic, Tuple, TypeVar
from typing_extensions import TypeVarTuple, Unpack

T1 = TypeVar("T1")
T3 = TypeVar("T3", default=str)

Ts1 = TypeVarTuple("Ts1")
Ts2 = TypeVarTuple("Ts2", default=Unpack[Tuple[int, str]])
Ts3 = TypeVarTuple("Ts3", default=Unpack[Tuple[float, ...]])
Ts4 = TypeVarTuple("Ts4", default=Unpack[Tuple[()]])

class ClassC1(Generic[Unpack[Ts2]]): ...

def func_c1(
a: ClassC1,
b: ClassC1[float],
) -> None:
# reveal_type(a) # Revealed type is "__main__.ClassC1[builtins.int, builtins.str]" # TODO
reveal_type(b) # N: Revealed type is "__main__.ClassC1[builtins.float]"

class ClassC2(Generic[T3, Unpack[Ts3]]): ...

def func_c2(
a: ClassC2,
b: ClassC2[int],
c: ClassC2[int, Unpack[Tuple[()]]],
) -> None:
reveal_type(a) # N: Revealed type is "__main__.ClassC2[builtins.str, Unpack[builtins.tuple[builtins.float, ...]]]"
# reveal_type(b) # Revealed type is "__main__.ClassC2[builtins.int, Unpack[builtins.tuple[builtins.float, ...]]]" # TODO
reveal_type(c) # N: Revealed type is "__main__.ClassC2[builtins.int]"

class ClassC3(Generic[T3, Unpack[Ts4]]): ...

def func_c3(
a: ClassC3,
b: ClassC3[int],
c: ClassC3[int, Unpack[Tuple[float]]]
) -> None:
# reveal_type(a) # Revealed type is "__main__.ClassC3[builtins.str]" # TODO
reveal_type(b) # N: Revealed type is "__main__.ClassC3[builtins.int]"
reveal_type(c) # N: Revealed type is "__main__.ClassC3[builtins.int, builtins.float]"

class ClassC4(Generic[T1, Unpack[Ts1], T3]): ...

def func_c4(
a: ClassC4,
b: ClassC4[int],
c: ClassC4[int, float],
) -> None:
reveal_type(a) # N: Revealed type is "__main__.ClassC4[Any, Unpack[builtins.tuple[Any, ...]], builtins.str]"
# reveal_type(b) # Revealed type is "__main__.ClassC4[builtins.int, builtins.str]" # TODO
reveal_type(c) # N: Revealed type is "__main__.ClassC4[builtins.int, builtins.float]"
[builtins fixtures/tuple.pyi]