Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamline some elements of variadic types support #15924

Merged
merged 8 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4665,10 +4665,7 @@ def analyze_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]:
isinstance(iterable, TupleType)
and iterable.partial_fallback.type.fullname == "builtins.tuple"
):
joined: Type = UninhabitedType()
for item in iterable.items:
joined = join_types(joined, item)
return iterator, joined
return iterator, tuple_fallback(iterable).args[0]
else:
# Non-tuple iterable.
return iterator, echk.check_method_call_by_name("__next__", iterator, [], [], expr)[0]
Expand Down
11 changes: 7 additions & 4 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@
UninhabitedType,
UnionType,
UnpackType,
flatten_nested_tuples,
find_unpack_in_list,
flatten_nested_unions,
get_proper_type,
get_proper_types,
Expand All @@ -185,7 +185,6 @@
)
from mypy.typestate import type_state
from mypy.typevars import fill_typevars
from mypy.typevartuples import find_unpack_in_list
from mypy.util import split_module_names
from mypy.visitor import ExpressionVisitor

Expand Down Expand Up @@ -1600,7 +1599,7 @@ def check_callable_call(
See the docstring of check_call for more information.
"""
# Always unpack **kwargs before checking a call.
callee = callee.with_unpacked_kwargs()
callee = callee.with_unpacked_kwargs().with_normalized_var_args()
if callable_name is None and callee.name:
callable_name = callee.name
ret_type = get_proper_type(callee.ret_type)
Expand Down Expand Up @@ -2409,7 +2408,12 @@ def check_argument_types(
+ unpacked_type.items[inner_unpack_index + 1 :]
)
callee_arg_kinds = [ARG_POS] * len(actuals)
elif isinstance(unpacked_type, TypeVarTupleType):
callee_arg_types = [orig_callee_arg_type]
callee_arg_kinds = [ARG_STAR]
else:
# TODO: Any and <nothing> can appear in Unpack (as a result of user error),
# fail gracefully here and elsewhere (and/or normalize them away).
assert isinstance(unpacked_type, Instance)
assert unpacked_type.type.fullname == "builtins.tuple"
callee_arg_types = [unpacked_type.args[0]] * len(actuals)
Expand Down Expand Up @@ -4451,7 +4455,6 @@ class C(Generic[T, Unpack[Ts]]): ...

prefix = next(i for (i, v) in enumerate(vars) if isinstance(v, TypeVarTupleType))
suffix = len(vars) - prefix - 1
args = flatten_nested_tuples(args)
if len(args) < len(vars) - 1:
self.msg.incompatible_type_application(len(vars), len(args), ctx)
return [AnyType(TypeOfAny.from_error)] * len(vars)
Expand Down
46 changes: 37 additions & 9 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
UninhabitedType,
UnionType,
UnpackType,
find_unpack_in_list,
get_proper_type,
has_recursive_types,
has_type_vars,
Expand All @@ -57,7 +58,7 @@
)
from mypy.types_utils import is_union_with_any
from mypy.typestate import type_state
from mypy.typevartuples import extract_unpack, find_unpack_in_list, split_with_mapped_and_template
from mypy.typevartuples import extract_unpack, split_with_mapped_and_template

if TYPE_CHECKING:
from mypy.infer import ArgumentInferContext
Expand Down Expand Up @@ -155,16 +156,33 @@ def infer_constraints_for_callable(
# not to hold we can always handle the prefixes too.
inner_unpack = unpacked_type.items[0]
assert isinstance(inner_unpack, UnpackType)
inner_unpacked_type = inner_unpack.type
assert isinstance(inner_unpacked_type, TypeVarTupleType)
inner_unpacked_type = get_proper_type(inner_unpack.type)
suffix_len = len(unpacked_type.items) - 1
constraints.append(
Constraint(
inner_unpacked_type,
SUPERTYPE_OF,
TupleType(actual_types[:-suffix_len], inner_unpacked_type.tuple_fallback),
if isinstance(inner_unpacked_type, TypeVarTupleType):
# Variadic item can be either *Ts...
constraints.append(
Constraint(
inner_unpacked_type,
SUPERTYPE_OF,
TupleType(
actual_types[:-suffix_len], inner_unpacked_type.tuple_fallback
),
)
)
)
else:
# ...or it can be a homogeneous tuple.
assert (
isinstance(inner_unpacked_type, Instance)
and inner_unpacked_type.type.fullname == "builtins.tuple"
)
for at in actual_types[:-suffix_len]:
constraints.extend(
infer_constraints(inner_unpacked_type.args[0], at, SUPERTYPE_OF)
)
# Now handle the suffix (if any).
if suffix_len:
for tt, at in zip(unpacked_type.items[1:], actual_types[-suffix_len:]):
constraints.extend(infer_constraints(tt, at, SUPERTYPE_OF))
else:
assert False, "mypy bug: unhandled constraint inference case"
else:
Expand Down Expand Up @@ -863,6 +881,16 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
and self.direction == SUPERTYPE_OF
):
for item in actual.items:
if isinstance(item, UnpackType):
unpacked = get_proper_type(item.type)
if isinstance(unpacked, TypeVarType):
# Cannot infer anything for T from [T, ...] <: *Ts
continue
assert (
isinstance(unpacked, Instance)
and unpacked.type.fullname == "builtins.tuple"
)
item = unpacked.args[0]
cb = infer_constraints(template.args[0], item, SUPERTYPE_OF)
res.extend(cb)
return res
Expand Down
111 changes: 18 additions & 93 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Final, Iterable, Mapping, Sequence, TypeVar, cast, overload

from mypy.nodes import ARG_POS, ARG_STAR, ArgKind, Var
from mypy.nodes import ARG_STAR, Var
from mypy.state import state
from mypy.types import (
ANY_STRATEGY,
Expand Down Expand Up @@ -35,12 +35,11 @@
UninhabitedType,
UnionType,
UnpackType,
flatten_nested_tuples,
flatten_nested_unions,
get_proper_type,
split_with_prefix_and_suffix,
)
from mypy.typevartuples import find_unpack_in_list, split_with_instance
from mypy.typevartuples import split_with_instance

# Solving the import cycle:
import mypy.type_visitor # ruff: isort: skip
Expand Down Expand Up @@ -294,101 +293,30 @@ def expand_unpack(self, t: UnpackType) -> list[Type] | AnyType | UninhabitedType
def visit_parameters(self, t: Parameters) -> Type:
return t.copy_modified(arg_types=self.expand_types(t.arg_types))

# TODO: can we simplify this method? It is too long.
def interpolate_args_for_unpack(
self, t: CallableType, var_arg: UnpackType
) -> tuple[list[str | None], list[ArgKind], list[Type]]:
def interpolate_args_for_unpack(self, t: CallableType, var_arg: UnpackType) -> list[Type]:
star_index = t.arg_kinds.index(ARG_STAR)
prefix = self.expand_types(t.arg_types[:star_index])
suffix = self.expand_types(t.arg_types[star_index + 1 :])

var_arg_type = get_proper_type(var_arg.type)
# We have something like Unpack[Tuple[Unpack[Ts], X1, X2]]
if isinstance(var_arg_type, TupleType):
expanded_tuple = var_arg_type.accept(self)
assert isinstance(expanded_tuple, ProperType) and isinstance(expanded_tuple, TupleType)
expanded_items = expanded_tuple.items
fallback = var_arg_type.partial_fallback
else:
# We have plain Unpack[Ts]
assert isinstance(var_arg_type, TypeVarTupleType)
fallback = var_arg_type.tuple_fallback
expanded_items_res = self.expand_unpack(var_arg)
if isinstance(expanded_items_res, list):
expanded_items = expanded_items_res
else:
# We got Any or <nothing>
arg_types = (
t.arg_types[:star_index] + [expanded_items_res] + t.arg_types[star_index + 1 :]
)
return t.arg_names, t.arg_kinds, arg_types

expanded_unpack_index = find_unpack_in_list(expanded_items)
# This is the case where we just have Unpack[Tuple[X1, X2, X3]]
# (for example if either the tuple had no unpacks, or the unpack in the
# tuple got fully expanded to something with fixed length)
if expanded_unpack_index is None:
arg_names = (
t.arg_names[:star_index]
+ [None] * len(expanded_items)
+ t.arg_names[star_index + 1 :]
)
arg_kinds = (
t.arg_kinds[:star_index]
+ [ARG_POS] * len(expanded_items)
+ t.arg_kinds[star_index + 1 :]
)
arg_types = (
self.expand_types(t.arg_types[:star_index])
+ expanded_items
+ self.expand_types(t.arg_types[star_index + 1 :])
)
else:
# If Unpack[Ts] simplest form still has an unpack or is a
# homogenous tuple, then only the prefix can be represented as
# positional arguments, and we pass Tuple[Unpack[Ts-1], Y1, Y2]
# as the star arg, for example.
expanded_unpack = expanded_items[expanded_unpack_index]
assert isinstance(expanded_unpack, UnpackType)

# Extract the TypeVarTuple, so we can get a tuple fallback from it.
expanded_unpacked_tvt = expanded_unpack.type
if isinstance(expanded_unpacked_tvt, TypeVarTupleType):
fallback = expanded_unpacked_tvt.tuple_fallback
else:
# This can happen when tuple[Any, ...] is used to "patch" a variadic
# generic type without type arguments provided, or when substitution is
# homogeneous tuple.
assert isinstance(expanded_unpacked_tvt, ProperType)
assert isinstance(expanded_unpacked_tvt, Instance)
assert expanded_unpacked_tvt.type.fullname == "builtins.tuple"
fallback = expanded_unpacked_tvt

prefix_len = expanded_unpack_index
arg_names = t.arg_names[:star_index] + [None] * prefix_len + t.arg_names[star_index:]
arg_kinds = (
t.arg_kinds[:star_index] + [ARG_POS] * prefix_len + t.arg_kinds[star_index:]
)
if (
len(expanded_items) == 1
and isinstance(expanded_unpack.type, ProperType)
and isinstance(expanded_unpack.type, Instance)
):
assert expanded_unpack.type.type.fullname == "builtins.tuple"
# Normalize *args: *tuple[X, ...] -> *args: X
arg_types = (
self.expand_types(t.arg_types[:star_index])
+ [expanded_unpack.type.args[0]]
+ self.expand_types(t.arg_types[star_index + 1 :])
)
else:
arg_types = (
self.expand_types(t.arg_types[:star_index])
+ expanded_items[:prefix_len]
# Constructing the Unpack containing the tuple without the prefix.
+ [
UnpackType(TupleType(expanded_items[prefix_len:], fallback))
if len(expanded_items) - prefix_len > 1
else expanded_items[prefix_len]
]
+ self.expand_types(t.arg_types[star_index + 1 :])
)
return arg_names, arg_kinds, arg_types
return prefix + [expanded_items_res] + suffix
new_unpack = UnpackType(TupleType(expanded_items, fallback))
return prefix + [new_unpack] + suffix

def visit_callable_type(self, t: CallableType) -> CallableType:
param_spec = t.param_spec()
Expand Down Expand Up @@ -427,20 +355,20 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
)

var_arg = t.var_arg()
needs_normalization = False
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
arg_names, arg_kinds, arg_types = self.interpolate_args_for_unpack(t, var_arg.typ)
needs_normalization = True
arg_types = self.interpolate_args_for_unpack(t, var_arg.typ)
else:
arg_names = t.arg_names
arg_kinds = t.arg_kinds
arg_types = self.expand_types(t.arg_types)

return t.copy_modified(
expanded = t.copy_modified(
arg_types=arg_types,
arg_names=arg_names,
arg_kinds=arg_kinds,
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
)
if needs_normalization:
return expanded.with_normalized_var_args()
return expanded

def visit_overloaded(self, t: Overloaded) -> Type:
items: list[CallableType] = []
Expand All @@ -460,9 +388,6 @@ def expand_types_with_unpack(
indicates use of Any or some error occurred earlier. In this case callers should
simply propagate the resulting type.
"""
# TODO: this will cause a crash on aliases like A = Tuple[int, Unpack[A]].
# Although it is unlikely anyone will write this, we should fail gracefully.
typs = flatten_nested_tuples(typs)
items: list[Type] = []
for item in typs:
if isinstance(item, UnpackType) and isinstance(item.type, TypeVarTupleType):
Expand Down
3 changes: 2 additions & 1 deletion mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
IMPLICIT_GENERIC_ANY_BUILTIN: Final = (
'Implicit generic "Any". Use "{}" and specify generic parameters'
)
INVALID_UNPACK = "{} cannot be unpacked (must be tuple or TypeVarTuple)"
INVALID_UNPACK: Final = "{} cannot be unpacked (must be tuple or TypeVarTuple)"
INVALID_UNPACK_POSITION: Final = "Unpack is only valid in a variadic position"

# TypeVar
INCOMPATIBLE_TYPEVAR_VALUE: Final = 'Value of type variable "{}" of {} cannot be {}'
Expand Down
2 changes: 1 addition & 1 deletion mypy/mixedtraverser.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def visit_class_def(self, o: ClassDef) -> None:
def visit_type_alias_expr(self, o: TypeAliasExpr) -> None:
super().visit_type_alias_expr(o)
self.in_type_alias_expr = True
o.type.accept(self)
o.node.target.accept(self)
self.in_type_alias_expr = False

def visit_type_var_expr(self, o: TypeVarExpr) -> None:
Expand Down
17 changes: 2 additions & 15 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2625,27 +2625,14 @@ def deserialize(cls, data: JsonDict) -> TypeVarTupleExpr:
class TypeAliasExpr(Expression):
"""Type alias expression (rvalue)."""

__slots__ = ("type", "tvars", "no_args", "node")
__slots__ = ("node",)

__match_args__ = ("type", "tvars", "no_args", "node")
__match_args__ = ("node",)

# The target type.
type: mypy.types.Type
# Names of type variables used to define the alias
tvars: list[str]
# Whether this alias was defined in bare form. Used to distinguish
# between
# A = List
# and
# A = List[Any]
no_args: bool
node: TypeAlias

def __init__(self, node: TypeAlias) -> None:
super().__init__()
self.type = node.target
self.tvars = [v.name for v in node.alias_tvars]
self.no_args = node.no_args
self.node = node

def accept(self, visitor: ExpressionVisitor[T]) -> T:
Expand Down
Loading