diff --git a/CHANGES.md b/CHANGES.md index 3490684715..8e7e25a169 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,7 @@ ## Features/Changes * Compiler: exit-loop-early in more cases (#2077) +* Compiler/wasm: omit code pointer from closures when not used (#2059) # 6.1.1 (2025-07-07) - Lille diff --git a/compiler/lib-wasm/call_graph_analysis.ml b/compiler/lib-wasm/call_graph_analysis.ml new file mode 100644 index 0000000000..f897a046e1 --- /dev/null +++ b/compiler/lib-wasm/call_graph_analysis.ml @@ -0,0 +1,63 @@ +open! Stdlib +open Code + +let debug = Debug.find "call-graph" + +let times = Debug.find "times" + +let block_deps ~info ~non_escaping ~ambiguous ~blocks pc = + let block = Addr.Map.find pc blocks in + List.iter block.body ~f:(fun i -> + match i with + | Let (_, Apply { f; _ }) -> ( + try + match Var.Tbl.get info.Global_flow.info_approximation f with + | Top -> () + | Values { known; others } -> + if others || Var.Set.cardinal known > 1 + then Var.Set.iter (fun x -> Var.Hashtbl.replace ambiguous x ()) known; + if debug () + then + Format.eprintf + "CALL others:%b known:%d@." + others + (Var.Set.cardinal known) + with Invalid_argument _ -> ()) + | Let (x, Closure _) -> ( + match Var.Tbl.get info.Global_flow.info_approximation x with + | Top -> () + | Values { known; others } -> + if Var.Set.cardinal known = 1 && (not others) && Var.Set.mem x known + then ( + let may_escape = Var.ISet.mem info.Global_flow.info_may_escape x in + if debug () then Format.eprintf "CLOSURE may-escape:%b@." may_escape; + if not may_escape then Var.Hashtbl.replace non_escaping x ())) + | Let (_, (Prim _ | Block _ | Constant _ | Field _ | Special _)) + | Event _ | Assign _ | Set_field _ | Offset_ref _ | Array_set _ -> ()) + +type t = { unambiguous_non_escaping : unit Var.Hashtbl.t } + +let direct_calls_only info f = + Config.Flag.optcall () && Var.Hashtbl.mem info.unambiguous_non_escaping f + +let f p info = + let t = Timer.make () in + let non_escaping = Var.Hashtbl.create 128 in + let ambiguous = Var.Hashtbl.create 128 in + fold_closures + p + (fun _ _ (pc, _) _ () -> + traverse + { fold = Code.fold_children } + (fun pc () -> block_deps ~info ~non_escaping ~ambiguous ~blocks:p.blocks pc) + pc + p.blocks + ()) + (); + if debug () + then Format.eprintf "SUMMARY non-escaping:%d" (Var.Hashtbl.length non_escaping); + Var.Hashtbl.iter (fun x () -> Var.Hashtbl.remove non_escaping x) ambiguous; + if debug () + then Format.eprintf " unambiguous-non-escaping:%d@." (Var.Hashtbl.length non_escaping); + if times () then Format.eprintf " call graph analysis: %a@." Timer.print t; + { unambiguous_non_escaping = non_escaping } diff --git a/compiler/lib-wasm/call_graph_analysis.mli b/compiler/lib-wasm/call_graph_analysis.mli new file mode 100644 index 0000000000..3188253a2a --- /dev/null +++ b/compiler/lib-wasm/call_graph_analysis.mli @@ -0,0 +1,5 @@ +type t + +val direct_calls_only : t -> Code.Var.t -> bool + +val f : Code.program -> Global_flow.info -> t diff --git a/compiler/lib-wasm/gc_target.ml b/compiler/lib-wasm/gc_target.ml index 36ca054e4c..306846aae5 100644 --- a/compiler/lib-wasm/gc_target.ml +++ b/compiler/lib-wasm/gc_target.ml @@ -22,8 +22,6 @@ open Code_generation type expression = Wasm_ast.expression Code_generation.t -let include_closure_arity = false - module Type = struct let value = W.Ref { nullable = false; typ = Eq } @@ -215,13 +213,7 @@ module Type = struct let closure_common_fields ~cps = let* fun_ty = function_type ~cps 1 in return - (let function_pointer = - [ { W.mut = false; typ = W.Value (Ref { nullable = false; typ = Type fun_ty }) } - ] - in - if include_closure_arity - then { W.mut = false; typ = W.Value I32 } :: function_pointer - else function_pointer) + [ { W.mut = false; typ = W.Value (Ref { nullable = false; typ = Type fun_ty }) } ] let closure_type_1 ~cps = register_type @@ -289,36 +281,41 @@ module Type = struct }) env_type - let env_type ~cps ~arity ~env_type_id ~env_type = + let env_type ~cps ~arity ~no_code_pointer ~env_type_id ~env_type = register_type (if cps then Printf.sprintf "cps_env_%d_%d" arity env_type_id else Printf.sprintf "env_%d_%d" arity env_type_id) (fun () -> - let* cl_typ = closure_type ~usage:`Alloc ~cps arity in - let* common = closure_common_fields ~cps in - let* fun_ty' = function_type ~cps arity in - return - { supertype = Some cl_typ - ; final = true - ; typ = - W.Struct - ((if arity = 1 - then common - else if arity = 0 - then - [ { mut = false - ; typ = Value (Ref { nullable = false; typ = Type fun_ty' }) - } - ] - else - common - @ [ { mut = false + if no_code_pointer + then + return + { supertype = None; final = true; typ = W.Struct (make_env_type env_type) } + else + let* cl_typ = closure_type ~usage:`Alloc ~cps arity in + let* common = closure_common_fields ~cps in + let* fun_ty' = function_type ~cps arity in + return + { supertype = Some cl_typ + ; final = true + ; typ = + W.Struct + ((if arity = 1 + then common + else if arity = 0 + then + [ { mut = false ; typ = Value (Ref { nullable = false; typ = Type fun_ty' }) } - ]) - @ make_env_type env_type) - }) + ] + else + common + @ [ { mut = false + ; typ = Value (Ref { nullable = false; typ = Type fun_ty' }) + } + ]) + @ make_env_type env_type) + }) let rec_env_type ~function_count ~env_type_id ~env_type = register_type (Printf.sprintf "rec_env_%d_%d" function_count env_type_id) (fun () -> @@ -336,34 +333,48 @@ module Type = struct @ make_env_type env_type) }) - let rec_closure_type ~cps ~arity ~function_count ~env_type_id ~env_type = + let rec_closure_type ~cps ~arity ~no_code_pointer ~function_count ~env_type_id ~env_type + = register_type (if cps then Printf.sprintf "cps_closure_rec_%d_%d_%d" arity function_count env_type_id else Printf.sprintf "closure_rec_%d_%d_%d" arity function_count env_type_id) (fun () -> - let* cl_typ = closure_type ~usage:`Alloc ~cps arity in - let* common = closure_common_fields ~cps in - let* fun_ty' = function_type ~cps arity in let* env_ty = rec_env_type ~function_count ~env_type_id ~env_type in - return - { supertype = Some cl_typ - ; final = true - ; typ = - W.Struct - ((if arity = 1 - then common - else - common - @ [ { mut = false - ; typ = Value (Ref { nullable = false; typ = Type fun_ty' }) - } - ]) - @ [ { W.mut = false + if no_code_pointer + then + return + { supertype = None + ; final = true + ; typ = + W.Struct + [ { W.mut = false ; typ = W.Value (Ref { nullable = false; typ = Type env_ty }) } - ]) - }) + ] + } + else + let* cl_typ = closure_type ~usage:`Alloc ~cps arity in + let* common = closure_common_fields ~cps in + let* fun_ty' = function_type ~cps arity in + return + { supertype = Some cl_typ + ; final = true + ; typ = + W.Struct + ((if arity = 1 + then common + else + common + @ [ { mut = false + ; typ = Value (Ref { nullable = false; typ = Type fun_ty' }) + } + ]) + @ [ { W.mut = false + ; typ = W.Value (Ref { nullable = false; typ = Type env_ty }) + } + ]) + }) let rec curry_type ~cps arity m = register_type @@ -806,17 +817,22 @@ module Memory = struct let set_field e idx e' = wasm_array_set e (Arith.const (Int32.of_int (idx + 1))) e' - let env_start arity = - if arity = 0 - then 1 - else (if include_closure_arity then 1 else 0) + if arity = 1 then 1 else 2 + let env_start ~no_code_pointer arity = + if no_code_pointer + then 0 + else + match arity with + | 0 | 1 -> 1 + | _ -> 2 let load_function_pointer ~cps ~arity ?(skip_cast = false) closure = let arity = if cps then arity - 1 else arity in let* ty = Type.closure_type ~usage:`Access ~cps arity in let* fun_ty = Type.function_type ~cps arity in let casted_closure = if skip_cast then closure else wasm_cast ty closure in - let* e = wasm_struct_get ty casted_closure (env_start arity - 1) in + let* e = + wasm_struct_get ty casted_closure (env_start ~no_code_pointer:false arity - 1) + in return (fun_ty, e) let load_real_closure ~cps ~arity closure = @@ -824,7 +840,12 @@ module Memory = struct let* ty = Type.dummy_closure_type ~cps ~arity in let* cl_typ = Type.closure_type ~usage:`Access ~cps arity in let* e = - wasm_cast cl_typ (wasm_struct_get ty (wasm_cast ty closure) (env_start arity)) + wasm_cast + cl_typ + (wasm_struct_get + ty + (wasm_cast ty closure) + (env_start ~no_code_pointer:false arity)) in return (cl_typ, e) @@ -1061,7 +1082,7 @@ module Closure = struct | [ (g, _) ] -> Code.Var.equal f g | _ :: r -> is_last_fun r f - let translate ~context ~closures ~cps f = + let translate ~context ~closures ~cps ~no_code_pointer f = let info = Code.Var.Map.find f closures in let free_variables = get_free_variables ~context info in assert ( @@ -1070,29 +1091,29 @@ module Closure = struct ~f:(fun x -> Code.Var.Set.mem x context.globalized_variables) free_variables)); let _, arity = List.find ~f:(fun (f', _) -> Code.Var.equal f f') info.functions in - let arity = if cps then arity - 1 else arity in + let arity = if no_code_pointer then 0 else if cps then arity - 1 else arity in let* curry_fun = if arity > 1 then need_curry_fun ~cps ~arity else return f in if List.is_empty free_variables then - let* typ = Type.closure_type ~usage:`Alloc ~cps arity in - let name = Code.Var.fork f in - let* () = - register_global - name - { mut = false; typ = Type.value } - (W.StructNew - ( typ - , if arity = 0 - then [ W.RefFunc f ] - else - let code_pointers = - if arity = 1 then [ W.RefFunc f ] else [ RefFunc curry_fun; RefFunc f ] - in - if include_closure_arity - then Const (I32 (Int32.of_int arity)) :: code_pointers - else code_pointers )) - in - return (W.GlobalGet name) + if no_code_pointer + then Value.unit + else + let* typ = Type.closure_type ~usage:`Alloc ~cps arity in + let name = Code.Var.fork f in + let* () = + register_global + name + { mut = false; typ = Type.value } + (W.StructNew + ( typ + , if no_code_pointer + then [] + else + match arity with + | 0 | 1 -> [ W.RefFunc f ] + | _ -> [ RefFunc curry_fun; RefFunc f ] )) + in + return (W.GlobalGet name) else let* env_type = expression_list variable_type free_variables in let env_type_id = @@ -1106,22 +1127,17 @@ module Closure = struct match info.Closure_conversion.functions with | [] -> assert false | [ _ ] -> - let* typ = Type.env_type ~cps ~arity ~env_type_id ~env_type in + let* typ = Type.env_type ~cps ~arity ~no_code_pointer ~env_type_id ~env_type in let* l = expression_list load free_variables in return (W.StructNew ( typ - , (if arity = 0 - then [ W.RefFunc f ] + , (if no_code_pointer + then [] else - let code_pointers = - if arity = 1 - then [ W.RefFunc f ] - else [ RefFunc curry_fun; RefFunc f ] - in - if include_closure_arity - then W.Const (I32 (Int32.of_int arity)) :: code_pointers - else code_pointers) + match arity with + | 0 | 1 -> [ W.RefFunc f ] + | _ -> [ RefFunc curry_fun; RefFunc f ]) @ l )) | (g, _) :: _ as functions -> let function_count = List.length functions in @@ -1147,21 +1163,25 @@ module Closure = struct load env in let* typ = - Type.rec_closure_type ~cps ~arity ~function_count ~env_type_id ~env_type + Type.rec_closure_type + ~cps + ~arity + ~no_code_pointer + ~function_count + ~env_type_id + ~env_type in let res = let* env = env in return (W.StructNew ( typ - , (let code_pointers = - if arity = 1 - then [ W.RefFunc f ] - else [ RefFunc curry_fun; RefFunc f ] - in - if include_closure_arity - then W.Const (I32 (Int32.of_int arity)) :: code_pointers - else code_pointers) + , (if no_code_pointer + then [] + else + match arity with + | 0 | 1 -> [ W.RefFunc f ] + | _ -> [ RefFunc curry_fun; RefFunc f ]) @ [ env ] )) in if is_last_fun functions f @@ -1182,11 +1202,10 @@ module Closure = struct (load f) else res - let bind_environment ~context ~closures ~cps f = + let bind_environment ~context ~closures ~cps ~no_code_pointer f = let info = Code.Var.Map.find f closures in let free_variables = get_free_variables ~context info in - let free_variable_count = List.length free_variables in - if free_variable_count = 0 + if List.is_empty free_variables then (* The closures are all constants and the environment is empty. *) let* _ = add_var (Code.Var.fresh ()) in @@ -1194,11 +1213,13 @@ module Closure = struct else let env_type_id = Option.value ~default:(-1) info.id in let _, arity = List.find ~f:(fun (f', _) -> Code.Var.equal f f') info.functions in - let arity = if cps then arity - 1 else arity in - let offset = Memory.env_start arity in + let arity = if no_code_pointer then 0 else if cps then arity - 1 else arity in + let offset = Memory.env_start ~no_code_pointer arity in match info.Closure_conversion.functions with | [ _ ] -> - let* typ = Type.env_type ~cps ~arity ~env_type_id ~env_type:[] in + let* typ = + Type.env_type ~cps ~arity ~no_code_pointer ~env_type_id ~env_type:[] + in let* _ = add_var f in let env = Code.Var.fresh_n "env" in let* () = @@ -1218,7 +1239,13 @@ module Closure = struct | functions -> let function_count = List.length functions in let* typ = - Type.rec_closure_type ~cps ~arity ~function_count ~env_type_id ~env_type:[] + Type.rec_closure_type + ~cps + ~arity + ~no_code_pointer + ~function_count + ~env_type_id + ~env_type:[] in let* _ = add_var f in let env = Code.Var.fresh_n "env" in @@ -1247,13 +1274,7 @@ module Closure = struct in let* closure = Memory.wasm_cast cl_ty (load closure) in let* arg = load arg in - let closure_contents = [ W.RefFunc f; closure; arg ] in - return - (W.StructNew - ( ty - , if include_closure_arity - then Const (I32 1l) :: closure_contents - else closure_contents )) + return (W.StructNew (ty, [ W.RefFunc f; closure; arg ])) let curry_load ~cps ~arity m closure = let m = m + 1 in @@ -1264,7 +1285,7 @@ module Closure = struct else Type.curry_type ~cps arity (m + 1) in let cast e = if m = 2 then Memory.wasm_cast ty e else e in - let offset = Memory.env_start 1 in + let offset = Memory.env_start ~no_code_pointer:false 1 in return ( Memory.wasm_struct_get ty (cast (load closure)) (offset + 1) , Memory.wasm_struct_get ty (cast (load closure)) offset @@ -1283,12 +1304,7 @@ module Closure = struct then [ W.RefFunc dummy_fun; RefNull (Type cl_typ) ] else [ RefFunc curry_fun; RefFunc dummy_fun; RefNull (Type cl_typ) ] in - return - (W.StructNew - ( ty - , if include_closure_arity - then Const (I32 1l) :: closure_contents - else closure_contents )) + return (W.StructNew (ty, closure_contents)) end module Math = struct diff --git a/compiler/lib-wasm/generate.ml b/compiler/lib-wasm/generate.ml index 6bbe9830c6..02a84ef131 100644 --- a/compiler/lib-wasm/generate.ml +++ b/compiler/lib-wasm/generate.ml @@ -37,6 +37,7 @@ module Generate (Target : Target_sig.S) = struct ; in_cps : Effects.in_cps ; deadcode_sentinal : Var.t ; global_flow_info : Global_flow.info + ; fun_info : Call_graph_analysis.t ; types : Typing.typ Var.Tbl.t ; blocks : block Addr.Map.t ; closures : Closure_conversion.closure Var.Map.t @@ -786,51 +787,39 @@ module Generate (Target : Target_sig.S) = struct let rec translate_expr ctx context x e = match e with | Apply { f; args; exact; _ } -> + let* closure = load f in + let* args = expression_list (fun x -> load_and_box ctx x) args in if exact || List.length args = if Var.Set.mem x ctx.in_cps then 2 else 1 then - let rec loop acc l = - match l with - | [] -> ( - let arity = List.length args in - let funct = Var.fresh () in - let* closure = tee funct (load f) in - let* ty, funct = - Memory.load_function_pointer - ~cps:(Var.Set.mem x ctx.in_cps) - ~arity - (load funct) - in - let* b = is_closure f in - if b - then return (W.Call (f, List.rev (closure :: acc))) - else - match funct with - | W.RefFunc g -> - (* Functions with constant closures ignore their - environment. In case of partial application, we - still need the closure. *) - let* cl = if exact then Value.unit else return closure in - return (W.Call (g, List.rev (cl :: acc))) - | _ -> ( - match - if exact - then Global_flow.get_unique_closure ctx.global_flow_info f - else None - with - | Some g -> return (W.Call (g, List.rev (closure :: acc))) - | None -> return (W.Call_ref (ty, funct, List.rev (closure :: acc))) - )) - | x :: r -> - let* x = load_and_box ctx x in - loop (x :: acc) r - in - loop [] args + match + if exact then Global_flow.get_unique_closure ctx.global_flow_info f else None + with + | Some g -> + let* cl = + (* Functions with constant closures ignore their environment. *) + match closure with + | GlobalGet global -> + let* init = get_global global in + if Option.is_some init then Value.unit else return closure + | _ -> return closure + in + return (W.Call (g, args @ [ cl ])) + | None -> ( + let funct = Var.fresh () in + let* closure = tee funct (return closure) in + let* ty, funct = + Memory.load_function_pointer + ~cps:(Var.Set.mem x ctx.in_cps) + ~arity:(List.length args) + (load funct) + in + match funct with + | W.RefFunc g -> return (W.Call (g, args @ [ closure ])) + | _ -> return (W.Call_ref (ty, funct, args @ [ closure ]))) else let* apply = need_apply_fun ~cps:(Var.Set.mem x ctx.in_cps) ~arity:(List.length args) in - let* args = expression_list (fun x -> load_and_box ctx x) args in - let* closure = load f in return (W.Call (apply, args @ [ closure ])) | Block (tag, a, _, _) -> Memory.allocate @@ -848,6 +837,7 @@ module Generate (Target : Target_sig.S) = struct ~context:ctx.global_context ~closures:ctx.closures ~cps:(Var.Set.mem x ctx.in_cps) + ~no_code_pointer:(Call_graph_analysis.direct_calls_only ctx.fun_info x) x | Constant c -> Constant.translate c | Special (Alias_prim _) -> assert false @@ -1272,6 +1262,7 @@ module Generate (Target : Target_sig.S) = struct ~context:ctx.global_context ~closures:ctx.closures ~cps:(Var.Set.mem f ctx.in_cps) + ~no_code_pointer:(Call_graph_analysis.direct_calls_only ctx.fun_info f) f | None -> return () in @@ -1400,6 +1391,7 @@ module Generate (Target : Target_sig.S) = struct *) ~deadcode_sentinal ~global_flow_info + ~fun_info ~types = global_context.unit_name <- unit_name; let p, closures = Closure_conversion.f p in @@ -1411,6 +1403,7 @@ module Generate (Target : Target_sig.S) = struct ; in_cps ; deadcode_sentinal ; global_flow_info + ; fun_info ; types ; blocks = p.blocks ; closures @@ -1520,9 +1513,10 @@ let start () = make_context ~value_type:Gc_target.Type.value let f ~context ~unit_name p ~live_vars ~in_cps ~deadcode_sentinal ~global_flow_data = let state, info = global_flow_data in - let p = Structure.norm p in + let fun_info = Call_graph_analysis.f p info in let types = Typing.f ~state ~info ~deadcode_sentinal p in let t = Timer.make () in + let p = Structure.norm p in let p = fix_switch_branches p in let res = G.f @@ -1532,6 +1526,7 @@ let f ~context ~unit_name p ~live_vars ~in_cps ~deadcode_sentinal ~global_flow_d ~in_cps ~deadcode_sentinal ~global_flow_info:info + ~fun_info ~types p in diff --git a/compiler/lib-wasm/target_sig.ml b/compiler/lib-wasm/target_sig.ml index 053e3be066..f3ee8be13c 100644 --- a/compiler/lib-wasm/target_sig.ml +++ b/compiler/lib-wasm/target_sig.ml @@ -174,6 +174,7 @@ module type S = sig context:Code_generation.context -> closures:Closure_conversion.closure Code.Var.Map.t -> cps:bool + -> no_code_pointer:bool -> Code.Var.t -> expression @@ -181,6 +182,7 @@ module type S = sig context:Code_generation.context -> closures:Closure_conversion.closure Code.Var.Map.t -> cps:bool + -> no_code_pointer:bool -> Code.Var.t -> unit Code_generation.t diff --git a/compiler/lib-wasm/typing.ml b/compiler/lib-wasm/typing.ml index 3e4781fcc9..2e1be67c99 100644 --- a/compiler/lib-wasm/typing.ml +++ b/compiler/lib-wasm/typing.ml @@ -4,6 +4,8 @@ open Global_flow let debug = Debug.find "typing" +let times = Debug.find "times" + module Integer = struct type kind = | Ref @@ -420,10 +422,12 @@ let solver st = Solver.f () g (propagate st) let f ~state ~info ~deadcode_sentinal p = + let t = Timer.make () in update_deps state p; let function_parameters = mark_function_parameters p in let typ = solver { state; info; function_parameters } in Var.Tbl.set typ deadcode_sentinal (Int Normalized); + if times () then Format.eprintf " type analysis: %a@." Timer.print t; if debug () then ( Var.ISet.iter diff --git a/runtime/wasm/effect.wat b/runtime/wasm/effect.wat index 05bc0ad9c2..40aa1dc7c1 100644 --- a/runtime/wasm/effect.wat +++ b/runtime/wasm/effect.wat @@ -45,7 +45,7 @@ (type $block (array (mut (ref eq)))) (type $bytes (array (mut i8))) (type $function_1 (func (param (ref eq) (ref eq)) (result (ref eq)))) - (type $closure (sub (struct (;(field i32);) (field (ref $function_1))))) + (type $closure (sub (struct (field (ref $function_1))))) (type $function_3 (func (param (ref eq) (ref eq) (ref eq) (ref eq)) (result (ref eq)))) (type $closure_3 diff --git a/runtime/wasm/obj.wat b/runtime/wasm/obj.wat index 5e6a96ea5f..4fc39ee904 100644 --- a/runtime/wasm/obj.wat +++ b/runtime/wasm/obj.wat @@ -34,11 +34,9 @@ (type $float (struct (field f64))) (type $float_array (array (mut f64))) (type $function_1 (func (param (ref eq) (ref eq)) (result (ref eq)))) - (type $closure (sub (struct (;(field i32);) (field (ref $function_1))))) - (type $closure_last_arg - (sub $closure (struct (;(field i32);) (field (ref $function_1))))) - (type $function_2 - (func (param (ref eq) (ref eq) (ref eq)) (result (ref eq)))) + (type $closure (sub (struct (field (ref $function_1))))) + (type $closure_last_arg (sub $closure (struct (field (ref $function_1))))) + (type $function_2 (func (param (ref eq) (ref eq) (ref eq)) (result (ref eq)))) (type $cps_closure (sub (struct (field (ref $function_2))))) (type $cps_closure_last_arg (sub $cps_closure (struct (field (ref $function_2)))))