diff --git a/compiler/lib-wasm/curry.ml b/compiler/lib-wasm/curry.ml index b6d5ab0cab..255a41391a 100644 --- a/compiler/lib-wasm/curry.ml +++ b/compiler/lib-wasm/curry.ml @@ -256,20 +256,25 @@ module Make (Target : Target_sig.S) = struct match l with | [] -> let* y = y in - instr (Push y) + instr (Return (Some y)) | x :: rem -> let* x = load x in - build_applies (call ~cps:false ~arity:1 y [ x ]) rem + let* c = call ~cps:false ~arity:1 y [ x ] in + build_applies (return (W.Br_on_null (0, c))) rem in build_applies (load f) l) in + let body = + let* () = block { params = []; result = [] } body in + instr (Return (Some (RefNull Any))) + in let param_names = l @ [ f ] in let locals, body = function_body ~context ~param_names ~body in W.Function { name ; exported_name = None ; typ = None - ; signature = Type.primitive_type (arity + 1) + ; signature = Type.func_type arity ; param_names ; locals ; body diff --git a/compiler/lib-wasm/gc_target.ml b/compiler/lib-wasm/gc_target.ml index 36ca054e4c..e7e0420e98 100644 --- a/compiler/lib-wasm/gc_target.ml +++ b/compiler/lib-wasm/gc_target.ml @@ -27,6 +27,8 @@ let include_closure_arity = false module Type = struct let value = W.Ref { nullable = false; typ = Eq } + let value_or_exn = W.Ref { nullable = true; typ = Eq } + let block_type = register_type "block" (fun () -> return @@ -205,7 +207,8 @@ module Type = struct let primitive_type n = { W.params = List.init ~len:n ~f:(fun _ -> value); result = [ value ] } - let func_type n = primitive_type (n + 1) + let func_type n = + { W.params = List.init ~len:(n + 1) ~f:(fun _ -> value); result = [ value_or_exn ] } let function_type ~cps n = let n = if cps then n + 1 else n in diff --git a/compiler/lib-wasm/generate.ml b/compiler/lib-wasm/generate.ml index 7d82219aa1..9bd6a7195d 100644 --- a/compiler/lib-wasm/generate.ml +++ b/compiler/lib-wasm/generate.ml @@ -782,6 +782,8 @@ module Generate (Target : Target_sig.S) = struct in Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal ~load l) + let exception_handler_pc = -3 + let rec translate_expr ctx context x e = match e with | Apply { f; args; exact } @@ -799,8 +801,9 @@ module Generate (Target : Target_sig.S) = struct (load funct) in let* b = is_closure f in + let label = label_index context exception_handler_pc in if b - then return (W.Call (f, List.rev (closure :: acc))) + then return (W.Br_on_null (label, W.Call (f, List.rev (closure :: acc)))) else match funct with | W.RefFunc g -> @@ -808,8 +811,11 @@ module Generate (Target : Target_sig.S) = struct environment. In case of partial application, we still need the closure. *) let* cl = if exact then Value.unit else return closure in - return (W.Call (g, List.rev (cl :: acc))) - | _ -> return (W.Call_ref (ty, funct, List.rev (closure :: acc)))) + return (W.Br_on_null (label, W.Call (g, List.rev (cl :: acc)))) + | _ -> + return + (W.Br_on_null + (label, W.Call_ref (ty, funct, List.rev (closure :: acc))))) | x :: r -> let* x = load_and_box ctx x in loop (x :: acc) r @@ -821,7 +827,9 @@ module Generate (Target : Target_sig.S) = struct in let* args = expression_list (fun x -> load_and_box ctx x) args in let* closure = load f in - return (W.Call (apply, args @ [ closure ])) + return + (W.Br_on_null + (label_index context exception_handler_pc, W.Call (apply, args @ [ closure ]))) | Block (tag, a, _, _) -> Memory.allocate ~deadcode_sentinal:ctx.deadcode_sentinal @@ -1075,32 +1083,55 @@ module Generate (Target : Target_sig.S) = struct { params = []; result = [] } (body ~result_typ:[] ~fall_through:(`Block pc) ~context:(`Block pc :: context)) in - if List.is_empty result_typ + if true && List.is_empty result_typ then handler else let* () = handler in - instr (W.Return (Some (RefI31 (Const (I32 0l))))) + let* u = Value.unit in + instr (W.Return (Some u)) else body ~result_typ ~fall_through ~context - let wrap_with_handlers p pc ~result_typ ~fall_through ~context body = + let wrap_with_handlers ~location p pc ~result_typ ~fall_through ~context body = let need_zero_divide_handler, need_bound_error_handler = needed_handlers p pc in wrap_with_handler - need_bound_error_handler - bound_error_pc - (let* f = - register_import ~name:"caml_bound_error" (Fun { params = []; result = [] }) - in - instr (CallInstr (f, []))) + true + exception_handler_pc + (match location with + | `Toplevel -> + let* exn = + register_import + ~import_module:"env" + ~name:"caml_exception" + (Global { mut = true; typ = Type.value }) + in + let* tag = register_import ~name:exception_name (Tag Type.value) in + instr (Throw (tag, GlobalGet exn)) + | `Exception_handler -> + let* exn = + register_import + ~import_module:"env" + ~name:"caml_exception" + (Global { mut = true; typ = Type.value }) + in + instr (Br (2, Some (GlobalGet exn))) + | `Function -> instr (Return (Some (RefNull Any)))) (wrap_with_handler - need_zero_divide_handler - zero_divide_pc + need_bound_error_handler + bound_error_pc (let* f = - register_import - ~name:"caml_raise_zero_divide" - (Fun { params = []; result = [] }) + register_import ~name:"caml_bound_error" (Fun { params = []; result = [] }) in instr (CallInstr (f, []))) - body) + (wrap_with_handler + need_zero_divide_handler + zero_divide_pc + (let* f = + register_import + ~name:"caml_raise_zero_divide" + (Fun { params = []; result = [] }) + in + instr (CallInstr (f, []))) + body)) ~result_typ ~fall_through ~context @@ -1208,19 +1239,34 @@ module Generate (Target : Target_sig.S) = struct instr (Br_table (e, List.map ~f:dest l, dest a.(len - 1))) | Raise (x, _) -> ( let* e = load x in - let* tag = register_import ~name:exception_name (Tag Type.value) in match fall_through with | `Catch -> instr (Push e) | `Block _ | `Return | `Skip -> ( match catch_index context with | Some i -> instr (Br (i, Some e)) - | None -> instr (Throw (tag, e)))) + | None -> + if Option.is_some name_opt + then + let* exn = + register_import + ~import_module:"env" + ~name:"caml_exception" + (Global { mut = true; typ = Type.value }) + in + let* () = instr (GlobalSet (exn, e)) in + instr (Return (Some (RefNull Any))) + else + let* tag = + register_import ~name:exception_name (Tag Type.value) + in + instr (Throw (tag, e)))) | Pushtrap (cont, x, cont') -> handle_exceptions ~result_typ ~fall_through ~context:(extend_context fall_through context) (wrap_with_handlers + ~location:`Exception_handler p (fst cont) (fun ~result_typ ~fall_through ~context -> @@ -1291,6 +1337,10 @@ module Generate (Target : Target_sig.S) = struct let* () = build_initial_env in let* () = wrap_with_handlers + ~location: + (match name_opt with + | None -> `Toplevel + | Some _ -> `Function) p pc ~result_typ:[ Type.value ] @@ -1342,7 +1392,9 @@ module Generate (Target : Target_sig.S) = struct in let* () = instr (Drop (Call (f, []))) in cont) - ~init:(instr (Push (RefI31 (Const (I32 0l))))) + ~init: + (let* u = Value.unit in + instr (Push u)) to_link) in context.other_fields <- diff --git a/compiler/lib-wasm/tail_call.ml b/compiler/lib-wasm/tail_call.ml index b52142d72d..a39a0b23d1 100644 --- a/compiler/lib-wasm/tail_call.ml +++ b/compiler/lib-wasm/tail_call.ml @@ -30,6 +30,10 @@ let rewrite_tail_call ~y i = Some (Wasm_ast.Return_call (symb, l)) | LocalSet (x, Call_ref (ty, e, l)) when Code.Var.equal x y -> Some (Return_call_ref (ty, e, l)) + | LocalSet (x, Br_on_null (_, Call (symb, l))) when Code.Var.equal x y -> + Some (Wasm_ast.Return_call (symb, l)) + | LocalSet (x, Br_on_null (_, Call_ref (ty, e, l))) when Code.Var.equal x y -> + Some (Return_call_ref (ty, e, l)) | _ -> None let rec instruction ~tail i = @@ -42,6 +46,11 @@ let rec instruction ~tail i = | Push (Call (symb, l)) when tail -> Return_call (symb, l) | Push (Call_ref (ty, e, l)) when tail -> Return_call_ref (ty, e, l) | Push (Call_ref _) -> i + | Return (Some (Br_on_null (_, Call (symb, l)))) -> Return_call (symb, l) + | Return (Some (Br_on_null (_, Call_ref (ty, e, l)))) -> Return_call_ref (ty, e, l) + | Push (Br_on_null (_, Call (symb, l))) when tail -> Return_call (symb, l) + | Push (Br_on_null (_, Call_ref (ty, e, l))) when tail -> Return_call_ref (ty, e, l) + | Push (Br_on_null (_, Call_ref _)) -> i | Drop (BlockExpr (typ, l)) -> Drop (BlockExpr (typ, instructions ~tail:false l)) | Drop _ | LocalSet _ diff --git a/runtime/wasm/effect.wat b/runtime/wasm/effect.wat index 05bc0ad9c2..7755c65896 100644 --- a/runtime/wasm/effect.wat +++ b/runtime/wasm/effect.wat @@ -41,13 +41,14 @@ (import "obj" "caml_callback_1" (func $caml_callback_1 (param (ref eq)) (param (ref eq)) (result (ref eq)))) + (import "stdlib" "caml_exception" (global $caml_exception (mut (ref eq)))) (type $block (array (mut (ref eq)))) (type $bytes (array (mut i8))) - (type $function_1 (func (param (ref eq) (ref eq)) (result (ref eq)))) + (type $function_1 (func (param (ref eq) (ref eq)) (result eqref))) (type $closure (sub (struct (;(field i32);) (field (ref $function_1))))) (type $function_3 - (func (param (ref eq) (ref eq) (ref eq) (ref eq)) (result (ref eq)))) + (func (param (ref eq) (ref eq) (ref eq) (ref eq)) (result eqref))) (type $closure_3 (sub $closure (struct (field (ref $function_1)) (field (ref $function_3))))) @@ -66,7 +67,7 @@ (@string $effect_unhandled "Effect.Unhandled") (func $raise_unhandled - (param $eff (ref eq)) (param (ref eq)) (result (ref eq)) + (param $eff (ref eq)) (param (ref eq)) (result eqref) (block $null (call $caml_raise_with_arg (br_on_null $null @@ -140,9 +141,15 @@ (func $apply_pair (param $p (ref $pair)) (result (ref eq)) (local $f (ref eq)) - (return_call_ref $function_1 (struct.get $pair 1 (local.get $p)) - (local.tee $f (struct.get $pair 0 (local.get $p))) - (struct.get $closure 0 (ref.cast (ref $closure) (local.get $f))))) + (local $res eqref) + (local.set $res + (call_ref $function_1 (struct.get $pair 1 (local.get $p)) + (local.tee $f (struct.get $pair 0 (local.get $p))) + (struct.get $closure 0 (ref.cast (ref $closure) (local.get $f))))) + (if (ref.is_null (local.get $res)) + (then + (throw $ocaml_exception (global.get $caml_exception)))) + (return (ref.as_non_null (local.get $res)))) ;; Low-level primitives @@ -298,7 +305,7 @@ (field $cont (ref eq))))) (func $call_effect_handler - (param $tail (ref eq)) (param $venv (ref eq)) (result (ref eq)) + (param $tail (ref eq)) (param $venv (ref eq)) (result eqref) (local $env (ref $call_handler_env)) (local $handler (ref $closure_3)) (local.set $env (ref.cast (ref $call_handler_env) (local.get $venv))) @@ -339,8 +346,10 @@ (if (i32.or (i32.eqz (global.get $effect_allowed)) (ref.is_null (struct.get $fiber $next (global.get $stack)))) (then - (return_call $raise_unhandled - (local.get $eff) (ref.i31 (i32.const 0))))) + (return + (ref.as_non_null + (call $raise_unhandled + (local.get $eff) (ref.i31 (i32.const 0))))))) (return_call $capture_continuation (ref.func $do_perform) (local.get $eff))) @@ -413,7 +422,8 @@ (do (try (result (ref eq)) (do - (call $apply_pair (ref.cast (ref $pair) (local.get $p)))) + (call $apply_pair + (ref.cast (ref $pair) (local.get $p)))) (catch $javascript_exception (throw $ocaml_exception (call $caml_wrap_exception (pop externref)))))) @@ -438,10 +448,10 @@ (@if (= effects "cps") (@then (type $function_2 - (func (param (ref eq) (ref eq) (ref eq)) (result (ref eq)))) + (func (param (ref eq) (ref eq) (ref eq)) (result eqref))) (type $function_4 (func (param (ref eq) (ref eq) (ref eq) (ref eq) (ref eq)) - (result (ref eq)))) + (result eqref))) (type $cps_closure (sub (struct (field (ref $function_2))))) (type $cps_closure_0 (sub (struct (field (ref $function_1))))) (type $cps_closure_3 @@ -485,7 +495,7 @@ (ref.i31 (i32.const 0))) (func $raise_exception - (param $exn (ref eq)) (param (ref eq)) (result (ref eq)) + (param $exn (ref eq)) (param (ref eq)) (result eqref) (throw $ocaml_exception (local.get $exn))) (global $raise_exception (ref eq) @@ -506,13 +516,13 @@ (param $exn (ref eq)) (param (ref eq)) (result (ref eq)) (local.get $exn)) - (func $identity (param (ref eq)) (param (ref eq)) (result (ref eq)) + (func $identity (param (ref eq)) (param (ref eq)) (result eqref) (local.get 0)) (global $identity (ref $closure) (struct.new $closure (ref.func $identity))) (func $trampoline_iterator - (param $f (ref eq)) (param $venv (ref eq)) (result (ref eq)) + (param $f (ref eq)) (param $venv (ref eq)) (result eqref) (local $env (ref $iterator)) (local $i i32) (local $args (ref $block)) (local.set $env (ref.cast (ref $iterator) (local.get $venv))) @@ -532,7 +542,7 @@ (ref.cast (ref $cps_closure) (local.get $f))))) (func $apply_iterator - (param $f (ref eq)) (param $venv (ref eq)) (result (ref eq)) + (param $f (ref eq)) (param $venv (ref eq)) (result eqref) (local $env (ref $iterator)) (local $i i32) (local $args (ref $block)) (local.set $env (ref.cast (ref $iterator) (local.get $venv))) @@ -562,13 +572,13 @@ (ref.cast (ref $block) (local.get $args)))) (func $dummy_cps_fun - (param (ref eq)) (param (ref eq)) (param (ref eq)) (result (ref eq)) + (param (ref eq)) (param (ref eq)) (param (ref eq)) (result eqref) (unreachable)) (func $caml_trampoline (export "caml_cps_trampoline") (param $f (ref eq)) (param $vargs (ref eq)) (result (ref eq)) (local $args (ref $block)) - (local $i i32) (local $res (ref eq)) + (local $i i32) (local $res eqref) (local $exn (ref eq)) (local $top (ref $exn_stack)) (local $saved_fiber_stack (ref $cps_fiber)) (local.set $saved_fiber_stack (global.get $cps_fiber_stack)) @@ -585,7 +595,7 @@ (try (result (ref eq)) (do (local.set $res - (if (result (ref eq)) + (if (result eqref) (i32.eq (array.len (local.get $args)) (i32.const 1)) (then (call_ref $function_1 (global.get $identity) @@ -607,8 +617,13 @@ (local.get $f) (struct.get $cps_closure 0 (ref.cast (ref $cps_closure) (local.get $f))))))) - (global.set $cps_fiber_stack (local.get $saved_fiber_stack)) - (return (local.get $res))) + (if (result (ref eq)) (ref.is_null (local.get $res)) + (then + (global.get $caml_exception)) + (else + (global.set $cps_fiber_stack + (local.get $saved_fiber_stack)) + (return (ref.as_non_null (local.get $res)))))) (catch $ocaml_exception (pop (ref eq))) (catch $javascript_exception @@ -632,7 +647,11 @@ (struct.get $closure 0 (ref.cast (ref $closure) (local.get $f))))) (global.set $cps_fiber_stack (local.get $saved_fiber_stack)) - (return (local.get $res))) + (if (ref.is_null (local.get $res)) + (then + (local.set $exn (global.get $caml_exception)) + (br $loop))) + (return (ref.as_non_null (local.get $res)))) (catch $ocaml_exception (local.set $exn (pop (ref eq))) (br $loop)) @@ -685,15 +704,17 @@ (ref.i31 (i32.const 0))) (func (export "caml_perform_effect") - (param $eff (ref eq)) (param $k0 (ref eq)) (result (ref eq)) + (param $eff (ref eq)) (param $k0 (ref eq)) (result eqref) (local $handler (ref eq)) (local $k1 (ref eq)) (local $cont (ref $block)) (local $last_fiber (ref $cps_fiber)) (if (ref.is_null (struct.get $cps_fiber $next (global.get $cps_fiber_stack))) (then - (return_call $raise_unhandled - (local.get $eff) (ref.i31 (i32.const 0))))) + (return + (ref.as_non_null + (call $raise_unhandled + (local.get $eff) (ref.i31 (i32.const 0))))))) (local.set $handler (struct.get $cps_fiber $effect (global.get $cps_fiber_stack))) (local.set $last_fiber (global.get $cps_fiber_stack)) @@ -710,7 +731,7 @@ (func (export "caml_reperform_effect") (param $eff (ref eq)) (param $vcont (ref eq)) (param $vtail (ref eq)) - (param $k0 (ref eq)) (result (ref eq)) + (param $k0 (ref eq)) (result eqref) (local $handler (ref eq)) (local $k1 (ref eq)) (local $cont (ref $block)) (local $tail (ref $cps_fiber)) (local $last_fiber (ref $cps_fiber)) @@ -722,8 +743,10 @@ (call $caml_continuation_use_noexc (local.get $vcont)) (local.get $vtail) (local.get $k0))) - (return_call $raise_unhandled - (local.get $eff) (ref.i31 (i32.const 0))))) + (return + (ref.as_non_null + (call $raise_unhandled + (local.get $eff) (ref.i31 (i32.const 0))))))) (local.set $cont (ref.cast (ref $block) (local.get $vcont))) (local.set $tail (ref.cast (ref $cps_fiber) (local.get $vtail))) (local.set $handler @@ -741,7 +764,7 @@ (ref.cast (ref $cps_closure_3) (local.get $handler))))) (func $cps_call_handler - (param $handler (ref eq)) (param $x (ref eq)) (result (ref eq)) + (param $handler (ref eq)) (param $x (ref eq)) (result eqref) (return_call_ref $function_2 (local.get $x) (call $caml_pop_fiber) @@ -749,7 +772,7 @@ (struct.get $cps_closure 0 (ref.cast (ref $cps_closure) (local.get $handler))))) - (func $value_handler (param $x (ref eq)) (param (ref eq)) (result (ref eq)) + (func $value_handler (param $x (ref eq)) (param (ref eq)) (result eqref) (return_call $cps_call_handler (struct.get $cps_fiber $value (global.get $cps_fiber_stack)) (local.get $x))) @@ -757,7 +780,7 @@ (global $value_handler (ref $closure) (struct.new $closure (ref.func $value_handler))) - (func $exn_handler (param $x (ref eq)) (param (ref eq)) (result (ref eq)) + (func $exn_handler (param $x (ref eq)) (param (ref eq)) (result eqref) (return_call $cps_call_handler (struct.get $cps_fiber $exn (global.get $cps_fiber_stack)) (local.get $x))) diff --git a/runtime/wasm/obj.wat b/runtime/wasm/obj.wat index 5e06a4a5ed..49038f0571 100644 --- a/runtime/wasm/obj.wat +++ b/runtime/wasm/obj.wat @@ -28,17 +28,19 @@ (import "effect" "caml_cps_trampoline" (func $caml_cps_trampoline (param (ref eq) (ref eq)) (result (ref eq)))) )) + (import "stdlib" "caml_exception" (global $caml_exception (mut (ref eq)))) + (import "fail" "ocaml_exception" (tag $ocaml_exception (param (ref eq)))) (type $block (array (mut (ref eq)))) (type $bytes (array (mut i8))) (type $float (struct (field f64))) (type $float_array (array (mut f64))) - (type $function_1 (func (param (ref eq) (ref eq)) (result (ref eq)))) + (type $function_1 (func (param (ref eq) (ref eq)) (result eqref))) (type $closure (sub (struct (;(field i32);) (field (ref $function_1))))) (type $closure_last_arg (sub $closure (struct (;(field i32);) (field (ref $function_1))))) (type $function_2 - (func (param (ref eq) (ref eq) (ref eq)) (result (ref eq)))) + (func (param (ref eq) (ref eq) (ref eq)) (result eqref))) (type $cps_closure (sub (struct (field (ref $function_2))))) (type $cps_closure_last_arg (sub $cps_closure (struct (field (ref $function_2))))) @@ -59,7 +61,7 @@ (field (mut (ref null $closure_2)))))) (type $function_3 - (func (param (ref eq) (ref eq) (ref eq) (ref eq)) (result (ref eq)))) + (func (param (ref eq) (ref eq) (ref eq) (ref eq)) (result eqref))) (type $closure_3 (sub $closure @@ -72,7 +74,7 @@ (type $function_4 (func (param (ref eq) (ref eq) (ref eq) (ref eq) (ref eq)) - (result (ref eq)))) + (result eqref))) (type $closure_4 (sub $closure @@ -494,19 +496,30 @@ (@else (func $caml_callback_1 (export "caml_callback_1") (param $f (ref eq)) (param $x (ref eq)) (result (ref eq)) - (return_call_ref $function_1 (local.get $x) - (local.get $f) - (struct.get $closure 0 (ref.cast (ref $closure) (local.get $f))))) + (local $res eqref) + (local.set $res + (call_ref $function_1 (local.get $x) + (local.get $f) + (struct.get $closure 0 + (ref.cast (ref $closure) (local.get $f))))) + (if (ref.is_null (local.get $res)) + (then (throw $ocaml_exception (global.get $caml_exception)))) + (return (ref.as_non_null (local.get $res)))) (func (export "caml_callback_2") (param $f (ref eq)) (param $x (ref eq)) (param $y (ref eq)) (result (ref eq)) + (local $res eqref) (drop (block $not_direct (result (ref eq)) - (return_call_ref $function_2 (local.get $x) (local.get $y) - (local.get $f) - (struct.get $closure_2 1 - (br_on_cast_fail $not_direct (ref eq) (ref $closure_2) - (local.get $f)))))) + (local.set $res + (call_ref $function_2 (local.get $x) (local.get $y) + (local.get $f) + (struct.get $closure_2 1 + (br_on_cast_fail $not_direct (ref eq) (ref $closure_2) + (local.get $f))))) + (if (ref.is_null (local.get $res)) + (then (throw $ocaml_exception (global.get $caml_exception)))) + (return (ref.as_non_null (local.get $res))))) (return_call $caml_callback_1 (call $caml_callback_1 (local.get $f) (local.get $x)) (local.get $y))) diff --git a/runtime/wasm/stdlib.wat b/runtime/wasm/stdlib.wat index 62ff000f26..db7082dec4 100644 --- a/runtime/wasm/stdlib.wat +++ b/runtime/wasm/stdlib.wat @@ -232,4 +232,6 @@ (call $caml_format_exception (local.get $exn)) (@string "\n"))))))) (call $exit (i32.const 2))))) + + (global (export "caml_exception") (mut (ref eq)) (ref.i31 (i32.const 0))) )