Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interactions between Literal and Final #6081

Merged
7 changes: 5 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1810,8 +1810,11 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)

if inferred:
self.infer_variable_type(inferred, lvalue, self.expr_checker.accept(rvalue),
rvalue)
rvalue_type = self.expr_checker.accept(
rvalue,
in_final_declaration=inferred.is_final,
)
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)

def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[Type],
rvalue: Expression) -> bool:
Expand Down
95 changes: 69 additions & 26 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from mypy.types import (
Type, AnyType, CallableType, Overloaded, NoneTyp, TypeVarDef,
TupleType, TypedDictType, Instance, TypeVarType, ErasedType, UnionType,
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType,
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue,
true_only, false_only, is_named_instance, function_type, callable_type, FunctionLike,
StarType, is_optional, remove_optional, is_generic_instance
)
Expand Down Expand Up @@ -139,6 +139,16 @@ def __init__(self,
self.msg = msg
self.plugin = plugin
self.type_context = [None]

# Set to 'True' whenever we are checking the expression in some 'Final' declaration.
# For example, if we're checking the "3" in a statement like "var: Final = 3".
#
# This flag changes the type that eventually gets inferred for "var". Instead of
# inferring *just* a 'builtins.int' instance, we infer an instance that keeps track
# of the underlying literal value. See the comments in Instance's constructors for
# more details.
self.in_final_declaration = False
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved

# Temporary overrides for expression types. This is currently
# used by the union math in overloads.
# TODO: refactor this to use a pattern similar to one in
Expand Down Expand Up @@ -210,10 +220,12 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:

def analyze_var_ref(self, var: Var, context: Context) -> Type:
if var.type:
if is_literal_type_like(self.type_context[-1]) and var.name() in {'True', 'False'}:
return LiteralType(var.name() == 'True', self.named_type('builtins.bool'))
else:
return var.type
if isinstance(var.type, Instance):
if self.is_literal_context() and var.type.final_value is not None:
return var.type.final_value
if var.name() in {'True', 'False'}:
return self.infer_literal_expr_type(var.name() == 'True', 'builtins.bool')
return var.type
else:
if not var.is_ready and self.chk.in_checked_function():
self.chk.handle_cannot_determine_type(var.name(), context)
Expand Down Expand Up @@ -691,7 +703,8 @@ def check_call(self,
elif isinstance(callee, Instance):
call_function = analyze_member_access('__call__', callee, context,
False, False, False, self.msg,
original_type=callee, chk=self.chk)
original_type=callee, chk=self.chk,
in_literal_context=self.is_literal_context())
return self.check_call(call_function, args, arg_kinds, context, arg_names,
callable_node, arg_messages)
elif isinstance(callee, TypeVarType):
Expand Down Expand Up @@ -1755,7 +1768,8 @@ def analyze_ordinary_member_access(self, e: MemberExpr,
original_type = self.accept(e.expr)
member_type = analyze_member_access(
e.name, original_type, e, is_lvalue, False, False,
self.msg, original_type=original_type, chk=self.chk)
self.msg, original_type=original_type, chk=self.chk,
in_literal_context=self.is_literal_context())
return member_type

def analyze_external_member_access(self, member: str, base_type: Type,
Expand All @@ -1765,35 +1779,57 @@ def analyze_external_member_access(self, member: str, base_type: Type,
"""
# TODO remove; no private definitions in mypy
return analyze_member_access(member, base_type, context, False, False, False,
self.msg, original_type=base_type, chk=self.chk)
self.msg, original_type=base_type, chk=self.chk,
in_literal_context=self.is_literal_context())

def is_literal_context(self) -> bool:
return is_literal_type_like(self.type_context[-1])
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved

def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Type:
"""Analyzes the given literal expression and determines if we should be
inferring an Instance type, a Literal[...] type, or an Instance that
remembers the original literal. We...

1. ...Infer a normal Instance in most circumstances.

2. ...Infer a Literal[...] if we're in a literal context. For example, if we
were analyzing the "3" in "foo(3)" where "foo" has a signature of
"def foo(Literal[3]) -> None", we'd want to infer that the "3" has a
type of Literal[3] instead of Instance.

3. ...Infer an Instance that remembers the original Literal if we're declaring
a Final variable with an inferred type -- for example, "bar" in "bar: Final = 3"
would be assigned an Instance that remembers it originated from a '3'. See
the comments in Instance's constructor for more details.
"""
typ = self.named_type(fallback_name)
if self.is_literal_context():
return LiteralType(value=value, fallback=typ)
elif self.in_final_declaration:
return typ.copy_modified(final_value=LiteralType(
value=value,
fallback=typ,
line=typ.line,
column=typ.column,
))
else:
return typ

def visit_int_expr(self, e: IntExpr) -> Type:
"""Type check an integer literal (trivial)."""
typ = self.named_type('builtins.int')
if is_literal_type_like(self.type_context[-1]):
return LiteralType(value=e.value, fallback=typ)
return typ
return self.infer_literal_expr_type(e.value, 'builtins.int')

def visit_str_expr(self, e: StrExpr) -> Type:
"""Type check a string literal (trivial)."""
typ = self.named_type('builtins.str')
if is_literal_type_like(self.type_context[-1]):
return LiteralType(value=e.value, fallback=typ)
return typ
return self.infer_literal_expr_type(e.value, 'builtins.str')

def visit_bytes_expr(self, e: BytesExpr) -> Type:
"""Type check a bytes literal (trivial)."""
typ = self.named_type('builtins.bytes')
if is_literal_type_like(self.type_context[-1]):
return LiteralType(value=e.value, fallback=typ)
return typ
return self.infer_literal_expr_type(e.value, 'builtins.bytes')

def visit_unicode_expr(self, e: UnicodeExpr) -> Type:
"""Type check a unicode literal (trivial)."""
typ = self.named_type('builtins.unicode')
if is_literal_type_like(self.type_context[-1]):
return LiteralType(value=e.value, fallback=typ)
return typ
return self.infer_literal_expr_type(e.value, 'builtins.unicode')

def visit_float_expr(self, e: FloatExpr) -> Type:
"""Type check a float literal (trivial)."""
Expand Down Expand Up @@ -1930,7 +1966,8 @@ def check_method_call_by_name(self,
"""
local_errors = local_errors or self.msg
method_type = analyze_member_access(method, base_type, context, False, False, True,
local_errors, original_type=base_type, chk=self.chk)
local_errors, original_type=base_type, chk=self.chk,
in_literal_context=self.is_literal_context())
return self.check_method_call(
method, base_type, method_type, args, arg_kinds, context, local_errors)

Expand Down Expand Up @@ -1994,6 +2031,7 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]:
context=context,
msg=local_errors,
chk=self.chk,
in_literal_context=self.is_literal_context()
)
if local_errors.is_errors():
return None
Expand Down Expand Up @@ -2950,7 +2988,8 @@ def analyze_super(self, e: SuperExpr, is_lvalue: bool) -> Type:
override_info=base,
context=e,
msg=self.msg,
chk=self.chk)
chk=self.chk,
in_literal_context=self.is_literal_context())
assert False, 'unreachable'
else:
# Invalid super. This has been reported by the semantic analyzer.
Expand Down Expand Up @@ -3117,13 +3156,16 @@ def accept(self,
type_context: Optional[Type] = None,
allow_none_return: bool = False,
always_allow_any: bool = False,
in_final_declaration: bool = False,
) -> Type:
"""Type check a node in the given type context. If allow_none_return
is True and this expression is a call, allow it to return None. This
applies only to this expression and not any subexpressions.
"""
if node in self.type_overrides:
return self.type_overrides[node]
old_in_final_declaration = self.in_final_declaration
self.in_final_declaration = in_final_declaration
self.type_context.append(type_context)
try:
if allow_none_return and isinstance(node, CallExpr):
Expand All @@ -3136,6 +3178,7 @@ def accept(self,
report_internal_error(err, self.chk.errors.file,
node.line, self.chk.errors, self.chk.options)
self.type_context.pop()
self.in_final_declaration = old_in_final_declaration
assert typ is not None
self.chk.store_type(node, typ)

Expand Down
9 changes: 7 additions & 2 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def analyze_member_access(name: str,
msg: MessageBuilder, *,
original_type: Type,
chk: 'mypy.checker.TypeChecker',
override_info: Optional[TypeInfo] = None) -> Type:
override_info: Optional[TypeInfo] = None,
in_literal_context: bool = False) -> Type:
"""Return the type of attribute 'name' of 'typ'.

The actual implementation is in '_analyze_member_access' and this docstring
Expand All @@ -96,7 +97,11 @@ def analyze_member_access(name: str,
context,
msg,
chk=chk)
return _analyze_member_access(name, typ, mx, override_info)
result = _analyze_member_access(name, typ, mx, override_info)
if in_literal_context and isinstance(result, Instance) and result.final_value is not None:
return result.final_value
else:
return result


def _analyze_member_access(name: str,
Expand Down
2 changes: 2 additions & 0 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def visit_instance(self, inst: Instance) -> None:
base.accept(self)
for a in inst.args:
a.accept(self)
if inst.final_value is not None:
inst.final_value.accept(self)

def visit_any(self, o: Any) -> None:
pass # Nothing to descend into.
Expand Down
3 changes: 2 additions & 1 deletion mypy/sametypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def visit_deleted_type(self, left: DeletedType) -> bool:
def visit_instance(self, left: Instance) -> bool:
return (isinstance(self.right, Instance) and
left.type == self.right.type and
is_same_types(left.args, self.right.args))
is_same_types(left.args, self.right.args) and
left.final_value == self.right.final_value)

def visit_type_var(self, left: TypeVarType) -> bool:
return (isinstance(self.right, TypeVarType) and
Expand Down
37 changes: 27 additions & 10 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from mypy.messages import CANNOT_ASSIGN_TO_TYPE, MessageBuilder
from mypy.types import (
FunctionLike, UnboundType, TypeVarDef, TupleType, UnionType, StarType, function_type,
CallableType, Overloaded, Instance, Type, AnyType,
CallableType, Overloaded, Instance, Type, AnyType, LiteralType, LiteralValue,
TypeTranslator, TypeOfAny, TypeType, NoneTyp,
)
from mypy.nodes import implicit_module_attrs
Expand Down Expand Up @@ -1756,9 +1756,9 @@ def final_cb(keep_final: bool) -> None:
self.type and self.type.is_protocol and not self.is_func_scope()):
self.fail('All protocol members must have explicitly declared types', s)
# Set the type if the rvalue is a simple literal (even if the above error occurred).
if len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr):
if len(s.lvalues) == 1 and isinstance(s.lvalues[0], RefExpr):
if s.lvalues[0].is_inferred_def:
s.type = self.analyze_simple_literal_type(s.rvalue)
s.type = self.analyze_simple_literal_type(s.rvalue, s.is_final_def)
if s.type:
# Store type into nodes.
for lvalue in s.lvalues:
Expand Down Expand Up @@ -1896,8 +1896,10 @@ def unbox_literal(self, e: Expression) -> Optional[Union[int, float, bool, str]]
return True if e.name == 'True' else False
return None

def analyze_simple_literal_type(self, rvalue: Expression) -> Optional[Type]:
"""Return builtins.int if rvalue is an int literal, etc."""
def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Optional[Type]:
"""Return builtins.int if rvalue is an int literal, etc.

If this is a 'Final' context, we return "Literal[...]" instead."""
if self.options.semantic_analysis_only or self.function_stack:
# Skip this if we're only doing the semantic analysis pass.
# This is mostly to avoid breaking unit tests.
Expand All @@ -1906,16 +1908,31 @@ def analyze_simple_literal_type(self, rvalue: Expression) -> Optional[Type]:
# inside type variables with value restrictions (like
# AnyStr).
return None
if isinstance(rvalue, IntExpr):
return self.named_type_or_none('builtins.int')
if isinstance(rvalue, FloatExpr):
return self.named_type_or_none('builtins.float')

value = None # type: LiteralValue
type_name = None # type: Optional[str]
if isinstance(rvalue, IntExpr):
value, type_name = rvalue.value, 'builtins.int'
if isinstance(rvalue, StrExpr):
return self.named_type_or_none('builtins.str')
value, type_name = rvalue.value, 'builtins.str'
if isinstance(rvalue, BytesExpr):
return self.named_type_or_none('builtins.bytes')
value, type_name = rvalue.value, 'builtins.bytes'
if isinstance(rvalue, UnicodeExpr):
return self.named_type_or_none('builtins.unicode')
value, type_name = rvalue.value, 'builtins.unicode'

if type_name is not None:
typ = self.named_type_or_none(type_name)
if typ and is_final:
return typ.copy_modified(final_value=LiteralType(
value=value,
fallback=typ,
line=typ.line,
column=typ.column,
))
return typ

return None

def analyze_alias(self, rvalue: Expression) -> Tuple[Optional[Type], List[str],
Expand Down
3 changes: 2 additions & 1 deletion mypy/server/astdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ def visit_deleted_type(self, typ: DeletedType) -> SnapshotItem:
def visit_instance(self, typ: Instance) -> SnapshotItem:
return ('Instance',
typ.type.fullname(),
snapshot_types(typ.args))
snapshot_types(typ.args),
None if typ.final_value is None else snapshot_type(typ.final_value))

def visit_type_var(self, typ: TypeVarType) -> SnapshotItem:
return ('TypeVar',
Expand Down
2 changes: 2 additions & 0 deletions mypy/server/astmerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,8 @@ def visit_instance(self, typ: Instance) -> None:
typ.type = self.fixup(typ.type)
for arg in typ.args:
arg.accept(self)
if typ.final_value:
typ.final_value.accept(self)

def visit_any(self, typ: AnyType) -> None:
pass
Expand Down
2 changes: 2 additions & 0 deletions mypy/server/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,8 @@ def visit_instance(self, typ: Instance) -> List[str]:
triggers = [trigger]
for arg in typ.args:
triggers.extend(self.get_type_triggers(arg))
if typ.final_value:
triggers.extend(self.get_type_triggers(typ.final_value))
return triggers

def visit_any(self, typ: AnyType) -> List[str]:
Expand Down
15 changes: 13 additions & 2 deletions mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from abc import abstractmethod
from collections import OrderedDict
from typing import Generic, TypeVar, cast, Any, List, Callable, Iterable
from typing import Generic, TypeVar, cast, Any, List, Callable, Iterable, Optional
from mypy_extensions import trait

T = TypeVar('T')
Expand Down Expand Up @@ -159,7 +159,18 @@ def visit_deleted_type(self, t: DeletedType) -> Type:
return t

def visit_instance(self, t: Instance) -> Type:
return Instance(t.type, self.translate_types(t.args), t.line, t.column)
final_value = None # type: Optional[LiteralType]
if t.final_value is not None:
raw_final_value = t.final_value.accept(self)
assert isinstance(raw_final_value, LiteralType)
final_value = raw_final_value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't you just write:

final_value = t.final_value.accept(self)
assert isinstance(final_value, LiteralType)

thus avoiding the extra variable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because t.final_value.accept(self) has a return type of Type, which means we wouldn't be able to assign it to final_value without raising an error or without casting.

I can replace this entire if statement with a cast if you want, similar to what we're doing in visit_tuple_type or visit_typeddict_type below, but I wanted to add a runtime check mostly for peace of mind.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will my version work if you remove the type comment on final_value?

(This is not important however.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alas, it doesn't: I get a 'Argument "final_value" to "Instance" has incompatible type "Optional[Type]"; expected "Optional[LiteralType]"` error a little later on when we use the variable.

return Instance(
typ=t.type,
args=self.translate_types(t.args),
line=t.line,
column=t.column,
final_value=final_value,
)

def visit_type_var(self, t: TypeVarType) -> Type:
return t
Expand Down
3 changes: 3 additions & 0 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,9 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L
elif isinstance(arg, (NoneTyp, LiteralType)):
# Types that we can just add directly to the literal/potential union of literals.
return [arg]
elif isinstance(arg, Instance) and arg.final_value is not None:
# Types generated from declarations like "var: Final = 4".
return [arg.final_value]
elif isinstance(arg, UnionType):
out = []
for union_arg in arg.items:
Expand Down
Loading