Skip to content

Commit

Permalink
[mypyc] Optimize __(a)enter__/__(a)exit__ paths for native case (#14530)
Browse files Browse the repository at this point in the history
Closes mypyc/mypyc#904

Directly calls enter and exit handlers in the case that the context
manager is implemented natively.

Unfortunately the implementation becomes a bit more complicated because
there are two different places where we call exit in different ways, and
they both need to support the native and non-native cases.
  • Loading branch information
jhance committed Jan 30, 2023
1 parent cf2e404 commit b2cf9d1
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 16 deletions.
54 changes: 38 additions & 16 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
Integer,
LoadAddress,
LoadErrorValue,
MethodCall,
RaiseStandardError,
Register,
Return,
Expand All @@ -61,6 +62,7 @@
RInstance,
exc_rtuple,
is_tagged,
none_rprimitive,
object_pointer_rprimitive,
object_rprimitive,
)
Expand Down Expand Up @@ -657,14 +659,45 @@ def transform_with(
al = "a" if is_async else ""

mgr_v = builder.accept(expr)
typ = builder.call_c(type_op, [mgr_v], line)
exit_ = builder.maybe_spill(builder.py_get_attr(typ, f"__{al}exit__", line))
value = builder.py_call(builder.py_get_attr(typ, f"__{al}enter__", line), [mgr_v], line)
is_native = isinstance(mgr_v.type, RInstance)
if is_native:
value = builder.add(MethodCall(mgr_v, f"__{al}enter__", args=[], line=line))
exit_ = None
else:
typ = builder.call_c(type_op, [mgr_v], line)
exit_ = builder.maybe_spill(builder.py_get_attr(typ, f"__{al}exit__", line))
value = builder.py_call(builder.py_get_attr(typ, f"__{al}enter__", line), [mgr_v], line)

mgr = builder.maybe_spill(mgr_v)
exc = builder.maybe_spill_assignable(builder.true())
if is_async:
value = emit_await(builder, value, line)

def maybe_natively_call_exit(exc_info: bool) -> Value:
if exc_info:
args = get_sys_exc_info(builder)
else:
none = builder.none_object()
args = [none, none, none]

if is_native:
assert isinstance(mgr_v.type, RInstance)
exit_val = builder.gen_method_call(
builder.read(mgr),
f"__{al}exit__",
arg_values=args,
line=line,
result_type=none_rprimitive,
)
else:
assert exit_ is not None
exit_val = builder.py_call(builder.read(exit_), [builder.read(mgr)] + args, line)

if is_async:
return emit_await(builder, exit_val, line)
else:
return exit_val

def try_body() -> None:
if target:
builder.assign(builder.get_assignment_target(target), value, line)
Expand All @@ -673,13 +706,7 @@ def try_body() -> None:
def except_body() -> None:
builder.assign(exc, builder.false(), line)
out_block, reraise_block = BasicBlock(), BasicBlock()
exit_val = builder.py_call(
builder.read(exit_), [builder.read(mgr)] + get_sys_exc_info(builder), line
)
if is_async:
exit_val = emit_await(builder, exit_val, line)

builder.add_bool_branch(exit_val, out_block, reraise_block)
builder.add_bool_branch(maybe_natively_call_exit(exc_info=True), out_block, reraise_block)
builder.activate_block(reraise_block)
builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO)
builder.add(Unreachable())
Expand All @@ -689,13 +716,8 @@ def finally_body() -> None:
out_block, exit_block = BasicBlock(), BasicBlock()
builder.add(Branch(builder.read(exc), exit_block, out_block, Branch.BOOL))
builder.activate_block(exit_block)
none = builder.none_object()
exit_val = builder.py_call(
builder.read(exit_), [builder.read(mgr), none, none, none], line
)
if is_async:
emit_await(builder, exit_val, line)

maybe_natively_call_exit(exc_info=False)
builder.goto_and_activate(out_block)

transform_try_finally_stmt(
Expand Down
105 changes: 105 additions & 0 deletions mypyc/test-data/irbuild-try.test
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,108 @@ L19:
L20:
return 1

[case testWithNativeSimple]
class DummyContext:
def __enter__(self) -> None:
pass
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
pass

def foo(x: DummyContext) -> None:
with x:
print('hello')
[out]
def DummyContext.__enter__(self):
self :: __main__.DummyContext
L0:
return 1
def DummyContext.__exit__(self, exc_type, exc_val, exc_tb):
self :: __main__.DummyContext
exc_type, exc_val, exc_tb :: object
L0:
return 1
def foo(x):
x :: __main__.DummyContext
r0 :: None
r1 :: bool
r2 :: str
r3 :: object
r4 :: str
r5, r6 :: object
r7, r8 :: tuple[object, object, object]
r9, r10, r11 :: object
r12 :: None
r13 :: object
r14 :: int32
r15 :: bit
r16 :: bool
r17 :: bit
r18, r19, r20 :: tuple[object, object, object]
r21 :: object
r22 :: None
r23 :: bit
L0:
r0 = x.__enter__()
r1 = 1
L1:
L2:
r2 = 'hello'
r3 = builtins :: module
r4 = 'print'
r5 = CPyObject_GetAttr(r3, r4)
r6 = PyObject_CallFunctionObjArgs(r5, r2, 0)
goto L8
L3: (handler for L2)
r7 = CPy_CatchError()
r1 = 0
r8 = CPy_GetExcInfo()
r9 = r8[0]
r10 = r8[1]
r11 = r8[2]
r12 = x.__exit__(r9, r10, r11)
r13 = box(None, r12)
r14 = PyObject_IsTrue(r13)
r15 = r14 >= 0 :: signed
r16 = truncate r14: int32 to builtins.bool
if r16 goto L5 else goto L4 :: bool
L4:
CPy_Reraise()
unreachable
L5:
L6:
CPy_RestoreExcInfo(r7)
goto L8
L7: (handler for L3, L4, L5)
CPy_RestoreExcInfo(r7)
r17 = CPy_KeepPropagating()
unreachable
L8:
L9:
L10:
r18 = <error> :: tuple[object, object, object]
r19 = r18
goto L12
L11: (handler for L1, L6, L7, L8)
r20 = CPy_CatchError()
r19 = r20
L12:
if r1 goto L13 else goto L14 :: bool
L13:
r21 = load_address _Py_NoneStruct
r22 = x.__exit__(r21, r21, r21)
L14:
if is_error(r19) goto L16 else goto L15
L15:
CPy_Reraise()
unreachable
L16:
goto L20
L17: (handler for L12, L13, L14, L15)
if is_error(r19) goto L19 else goto L18
L18:
CPy_RestoreExcInfo(r19)
L19:
r23 = CPy_KeepPropagating()
unreachable
L20:
return 1
17 changes: 17 additions & 0 deletions mypyc/test-data/run-generators.test
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,20 @@ def list_comp() -> List[int]:
[file driver.py]
from native import list_comp
assert list_comp() == [5]

[case testWithNative]
class DummyContext:
def __init__(self) -> None:
self.x = 0

def __enter__(self) -> None:
self.x += 1

def __exit__(self, exc_type, exc_value, exc_tb) -> None:
self.x -= 1

def test_basic() -> None:
context = DummyContext()
with context:
assert context.x == 1
assert context.x == 0
30 changes: 30 additions & 0 deletions mypyc/test-data/run-misc.test
Original file line number Diff line number Diff line change
Expand Up @@ -1116,3 +1116,33 @@ i = b"foo"

def test_redefinition() -> None:
assert i == b"foo"

[case testWithNative]
class DummyContext:
def __init__(self):
self.c = 0
def __enter__(self) -> None:
self.c += 1
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.c -= 1

def test_dummy_context() -> None:
c = DummyContext()
with c:
assert c.c == 1
assert c.c == 0

[case testWithNativeVarArgs]
class DummyContext:
def __init__(self):
self.c = 0
def __enter__(self) -> None:
self.c += 1
def __exit__(self, *args: object) -> None:
self.c -= 1

def test_dummy_context() -> None:
c = DummyContext()
with c:
assert c.c == 1
assert c.c == 0

0 comments on commit b2cf9d1

Please sign in to comment.