diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 6e465893607d..b9754ba1a147 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -50,6 +50,7 @@ Integer, LoadAddress, LoadErrorValue, + MethodCall, RaiseStandardError, Register, Return, @@ -61,6 +62,7 @@ RInstance, exc_rtuple, is_tagged, + none_rprimitive, object_pointer_rprimitive, object_rprimitive, ) @@ -657,14 +659,45 @@ def transform_with( al = "a" if is_async else "" mgr_v = builder.accept(expr) - typ = builder.call_c(type_op, [mgr_v], line) - exit_ = builder.maybe_spill(builder.py_get_attr(typ, f"__{al}exit__", line)) - value = builder.py_call(builder.py_get_attr(typ, f"__{al}enter__", line), [mgr_v], line) + is_native = isinstance(mgr_v.type, RInstance) + if is_native: + value = builder.add(MethodCall(mgr_v, f"__{al}enter__", args=[], line=line)) + exit_ = None + else: + typ = builder.call_c(type_op, [mgr_v], line) + exit_ = builder.maybe_spill(builder.py_get_attr(typ, f"__{al}exit__", line)) + value = builder.py_call(builder.py_get_attr(typ, f"__{al}enter__", line), [mgr_v], line) + mgr = builder.maybe_spill(mgr_v) exc = builder.maybe_spill_assignable(builder.true()) if is_async: value = emit_await(builder, value, line) + def maybe_natively_call_exit(exc_info: bool) -> Value: + if exc_info: + args = get_sys_exc_info(builder) + else: + none = builder.none_object() + args = [none, none, none] + + if is_native: + assert isinstance(mgr_v.type, RInstance) + exit_val = builder.gen_method_call( + builder.read(mgr), + f"__{al}exit__", + arg_values=args, + line=line, + result_type=none_rprimitive, + ) + else: + assert exit_ is not None + exit_val = builder.py_call(builder.read(exit_), [builder.read(mgr)] + args, line) + + if is_async: + return emit_await(builder, exit_val, line) + else: + return exit_val + def try_body() -> None: if target: builder.assign(builder.get_assignment_target(target), value, line) @@ -673,13 +706,7 @@ def try_body() -> None: def except_body() -> None: builder.assign(exc, builder.false(), line) out_block, reraise_block = BasicBlock(), BasicBlock() - exit_val = builder.py_call( - builder.read(exit_), [builder.read(mgr)] + get_sys_exc_info(builder), line - ) - if is_async: - exit_val = emit_await(builder, exit_val, line) - - builder.add_bool_branch(exit_val, out_block, reraise_block) + builder.add_bool_branch(maybe_natively_call_exit(exc_info=True), out_block, reraise_block) builder.activate_block(reraise_block) builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO) builder.add(Unreachable()) @@ -689,13 +716,8 @@ def finally_body() -> None: out_block, exit_block = BasicBlock(), BasicBlock() builder.add(Branch(builder.read(exc), exit_block, out_block, Branch.BOOL)) builder.activate_block(exit_block) - none = builder.none_object() - exit_val = builder.py_call( - builder.read(exit_), [builder.read(mgr), none, none, none], line - ) - if is_async: - emit_await(builder, exit_val, line) + maybe_natively_call_exit(exc_info=False) builder.goto_and_activate(out_block) transform_try_finally_stmt( diff --git a/mypyc/test-data/irbuild-try.test b/mypyc/test-data/irbuild-try.test index d1119c5deefd..faf3fa1dbd2f 100644 --- a/mypyc/test-data/irbuild-try.test +++ b/mypyc/test-data/irbuild-try.test @@ -416,3 +416,108 @@ L19: L20: return 1 +[case testWithNativeSimple] +class DummyContext: + def __enter__(self) -> None: + pass + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + pass + +def foo(x: DummyContext) -> None: + with x: + print('hello') +[out] +def DummyContext.__enter__(self): + self :: __main__.DummyContext +L0: + return 1 +def DummyContext.__exit__(self, exc_type, exc_val, exc_tb): + self :: __main__.DummyContext + exc_type, exc_val, exc_tb :: object +L0: + return 1 +def foo(x): + x :: __main__.DummyContext + r0 :: None + r1 :: bool + r2 :: str + r3 :: object + r4 :: str + r5, r6 :: object + r7, r8 :: tuple[object, object, object] + r9, r10, r11 :: object + r12 :: None + r13 :: object + r14 :: int32 + r15 :: bit + r16 :: bool + r17 :: bit + r18, r19, r20 :: tuple[object, object, object] + r21 :: object + r22 :: None + r23 :: bit +L0: + r0 = x.__enter__() + r1 = 1 +L1: +L2: + r2 = 'hello' + r3 = builtins :: module + r4 = 'print' + r5 = CPyObject_GetAttr(r3, r4) + r6 = PyObject_CallFunctionObjArgs(r5, r2, 0) + goto L8 +L3: (handler for L2) + r7 = CPy_CatchError() + r1 = 0 + r8 = CPy_GetExcInfo() + r9 = r8[0] + r10 = r8[1] + r11 = r8[2] + r12 = x.__exit__(r9, r10, r11) + r13 = box(None, r12) + r14 = PyObject_IsTrue(r13) + r15 = r14 >= 0 :: signed + r16 = truncate r14: int32 to builtins.bool + if r16 goto L5 else goto L4 :: bool +L4: + CPy_Reraise() + unreachable +L5: +L6: + CPy_RestoreExcInfo(r7) + goto L8 +L7: (handler for L3, L4, L5) + CPy_RestoreExcInfo(r7) + r17 = CPy_KeepPropagating() + unreachable +L8: +L9: +L10: + r18 = :: tuple[object, object, object] + r19 = r18 + goto L12 +L11: (handler for L1, L6, L7, L8) + r20 = CPy_CatchError() + r19 = r20 +L12: + if r1 goto L13 else goto L14 :: bool +L13: + r21 = load_address _Py_NoneStruct + r22 = x.__exit__(r21, r21, r21) +L14: + if is_error(r19) goto L16 else goto L15 +L15: + CPy_Reraise() + unreachable +L16: + goto L20 +L17: (handler for L12, L13, L14, L15) + if is_error(r19) goto L19 else goto L18 +L18: + CPy_RestoreExcInfo(r19) +L19: + r23 = CPy_KeepPropagating() + unreachable +L20: + return 1 diff --git a/mypyc/test-data/run-generators.test b/mypyc/test-data/run-generators.test index 0f2cbe152fc0..bcf9da1846ae 100644 --- a/mypyc/test-data/run-generators.test +++ b/mypyc/test-data/run-generators.test @@ -662,3 +662,20 @@ def list_comp() -> List[int]: [file driver.py] from native import list_comp assert list_comp() == [5] + +[case testWithNative] +class DummyContext: + def __init__(self) -> None: + self.x = 0 + + def __enter__(self) -> None: + self.x += 1 + + def __exit__(self, exc_type, exc_value, exc_tb) -> None: + self.x -= 1 + +def test_basic() -> None: + context = DummyContext() + with context: + assert context.x == 1 + assert context.x == 0 diff --git a/mypyc/test-data/run-misc.test b/mypyc/test-data/run-misc.test index 001e0aa41b25..267a3441808f 100644 --- a/mypyc/test-data/run-misc.test +++ b/mypyc/test-data/run-misc.test @@ -1116,3 +1116,33 @@ i = b"foo" def test_redefinition() -> None: assert i == b"foo" + +[case testWithNative] +class DummyContext: + def __init__(self): + self.c = 0 + def __enter__(self) -> None: + self.c += 1 + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.c -= 1 + +def test_dummy_context() -> None: + c = DummyContext() + with c: + assert c.c == 1 + assert c.c == 0 + +[case testWithNativeVarArgs] +class DummyContext: + def __init__(self): + self.c = 0 + def __enter__(self) -> None: + self.c += 1 + def __exit__(self, *args: object) -> None: + self.c -= 1 + +def test_dummy_context() -> None: + c = DummyContext() + with c: + assert c.c == 1 + assert c.c == 0