Skip to content

Commit

Permalink
Fix narrowing on match with function subject
Browse files Browse the repository at this point in the history
Fixes #12998

mypy can't narrow match statements with functions subjects because the
callexpr node is not a literal node. This adds a 'dummy' literal node
that the match statement visitor can use to do the type narrowing.

The python grammar describes the the match subject as a named expression
so this uses that nameexpr node as it's literal.
  • Loading branch information
edpaget committed Nov 16, 2023
1 parent e4c43cb commit 5806bc7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
11 changes: 8 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5043,8 +5043,13 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None:
return None

def visit_match_stmt(self, s: MatchStmt) -> None:
# Create a dummy subject expression to handle cases where a match
# statement's subject is not a literal value which prevent us from correctly
# narrowing types and checking exhaustivity
named_subject = NameExpr("match") if isinstance(s.subject, CallExpr) else s.subject
with self.binder.frame_context(can_skip=False, fall_through=0):
subject_type = get_proper_type(self.expr_checker.accept(s.subject))
self.store_type(named_subject, subject_type)

if isinstance(subject_type, DeletedType):
self.msg.deleted_as_rvalue(subject_type, s)
Expand All @@ -5061,7 +5066,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
# The second pass narrows down the types and type checks bodies.
for p, g, b in zip(s.patterns, s.guards, s.bodies):
current_subject_type = self.expr_checker.narrow_type_from_binder(
s.subject, subject_type
named_subject, subject_type
)
pattern_type = self.pattern_checker.accept(p, current_subject_type)
with self.binder.frame_context(can_skip=True, fall_through=2):
Expand All @@ -5072,7 +5077,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
else_map: TypeMap = {}
else:
pattern_map, else_map = conditional_types_to_typemaps(
s.subject, pattern_type.type, pattern_type.rest_type
named_subject, pattern_type.type, pattern_type.rest_type
)
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
self.push_type_map(pattern_map)
Expand Down Expand Up @@ -5100,7 +5105,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
and expr.fullname == case_target.fullname
):
continue
type_map[s.subject] = type_map[expr]
type_map[named_subject] = type_map[expr]

self.push_type_map(guard_map)
self.accept(b)
Expand Down
12 changes: 12 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,18 @@ match m:

reveal_type(a) # N: Revealed type is "builtins.str"

[case testMatchCapturePatternFromFunctionReturningUnion]
def func(arg: bool) -> str | int:
if arg:
return 1
return "a"

match func(True):
case str(a):
reveal_type(a) # N: Revealed type is "builtins.str"
case a:
reveal_type(a) # N: Revealed type is "builtins.int"

-- Guards --

[case testMatchSimplePatternGuard]
Expand Down

0 comments on commit 5806bc7

Please sign in to comment.