diff --git a/compiler/lib-wasm/generate.ml b/compiler/lib-wasm/generate.ml index 9a049ab9aa..c707b1589a 100644 --- a/compiler/lib-wasm/generate.ml +++ b/compiler/lib-wasm/generate.ml @@ -2012,9 +2012,28 @@ module Generate (Target : Target_sig.S) = struct | (`Block _ | `Catch | `Skip) as b -> b :: context | `Return -> `Skip :: context - let needed_handlers (p : program) pc = + (* Walk the dominator subtree of [pc] (the structural region of the + loop body, try body, or function body that the wrap covers), + skipping any nested try body since it carries its own wrap. *) + let needed_handlers (p : program) ~dom pc = + let fold : 'c. _ -> _ -> (Addr.t -> 'c -> 'c) -> 'c -> 'c = + fun _blocks pc' f accu -> + let block = Addr.Map.find pc' p.blocks in + let try_body = + match block.branch with + | Pushtrap ((pc'', _), _, _) -> Some pc'' + | _ -> None + in + Addr.Set.fold + (fun child acc -> + match try_body with + | Some pc'' when pc'' = child -> acc + | _ -> f child acc) + (Structure.get_edges dom pc') + accu + in Code.traverse - { fold = fold_children_skip_try_body } + { fold } (fun pc n -> let block = Addr.Map.find pc p.blocks in List.fold_left @@ -2084,8 +2103,8 @@ module Generate (Target : Target_sig.S) = struct instr W.Unreachable else body ~result_typ ~fall_through ~context - let wrap_with_handlers p pc ~result_typ ~fall_through ~context body = - let need_zero_divide_handler, need_bound_error_handler = needed_handlers p pc in + let wrap_with_handlers p ~dom pc ~result_typ ~fall_through ~context body = + let need_zero_divide_handler, need_bound_error_handler = needed_handlers p ~dom pc in wrap_with_handler need_bound_error_handler bound_error_pc @@ -2132,7 +2151,7 @@ module Generate (Target : Target_sig.S) = struct | Cond (_, (pc1, _), (pc2, _)) when pc' = pc1 && pc' = pc2 -> true | _ -> Structure.is_merge_node g pc' in - let code ~context = + let code ~result_typ ~fall_through ~context = let block = Addr.Map.find pc ctx.blocks in let* () = translate_instrs ctx context block.body in translate_node_within @@ -2149,8 +2168,39 @@ module Generate (Target : Target_sig.S) = struct in if Structure.is_loop_header g pc then - loop { params = []; result = result_typ } (code ~context:(`Block pc :: context)) - else code ~context + let outermost_toplevel_loop = + Option.is_none name_opt + && not + (List.exists + ~f:(function + | `Block pc' when pc' >= 0 -> Structure.is_loop_header g pc' + | _ -> false) + context) + in + loop + { params = []; result = result_typ } + (if outermost_toplevel_loop + then + (* The outermost loops of the toplevel function are later + hoisted into helper functions by [Hoist_loops], which + requires them to be self-contained (no [Br] escaping + the loop body). Bounds and zero-divide checks normally + branch to handler blocks installed at the top of the + function; we install handlers around the loop body + itself so those branches stay inside the loop. Nested + loops do not need their own wrap: only the outermost + one is extracted, and its body is copied verbatim into + the helper along with the surrounding handler blocks. *) + wrap_with_handlers + p + ~dom + pc + ~result_typ + ~fall_through + ~context:(`Block pc :: context) + code + else code ~result_typ ~fall_through ~context:(`Block pc :: context)) + else code ~result_typ ~fall_through ~context and translate_node_within ~result_typ ~fall_through ~pc ~l ~context = match l with | pc' :: rem -> @@ -2229,6 +2279,7 @@ module Generate (Target : Target_sig.S) = struct ~context:(extend_context fall_through context) (wrap_with_handlers p + ~dom (fst cont) (fun ~result_typ ~fall_through ~context -> translate_branch result_typ fall_through pc cont context)) @@ -2300,6 +2351,7 @@ module Generate (Target : Target_sig.S) = struct let* () = wrap_with_handlers p + ~dom pc ~result_typ:[ Option.value ~default:Type.value (unboxed_type return_type) ] ~fall_through:`Return @@ -2446,6 +2498,7 @@ module Generate (Target : Target_sig.S) = struct functions in global_context.init_code <- []; + let functions = Hoist_loops.f ~toplevel:toplevel_name functions in global_context.other_fields <- List.rev_append functions global_context.other_fields; let js_code = StringMap.bindings global_context.fragments in global_context.fragments <- StringMap.empty; diff --git a/compiler/lib-wasm/hoist_loops.ml b/compiler/lib-wasm/hoist_loops.ml new file mode 100644 index 0000000000..cd081862f9 --- /dev/null +++ b/compiler/lib-wasm/hoist_loops.ml @@ -0,0 +1,739 @@ +(* Wasm_of_ocaml compiler + * http://www.ocsigen.org/js_of_ocaml/ + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, with linking exception; + * either version 2.1 of the License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. + *) + +(* Extract loops from the toplevel function into separate helper + functions. Wasm engines tier up (compile to optimised native code) + functions containing loops. Since the toplevel function is large and + called only once, we want to avoid this: only the small helpers + should be tiered up. + + A loop is extractable when all its branches stay within the loop + body (no escaping [Br], [Return], or [Rethrow]). Whether it is + actually hoisted is then decided by a backward liveness analysis with + both normal and exceptional continuations: if a local written in the + loop may be observed after an exceptional exit, the loop is left in + place because the helper call has no write-back path on exceptions. + + Variables used in the loop are split into parameters and locals. + A variable becomes a parameter when its current value is live at the + loop head: either the body may read it before rewriting it on some + reachable path, or the loop may exit to a caller read without first + rewriting it. The remaining variables become locals of the helper. + This includes non-nullable refs whose set/get discipline is already + known to be valid because [Initialize_locals] has run on the original + function. + + Modified variables that are live after the loop are returned via a + struct (or directly when there are zero or one), whether they are + parameters or helper locals. *) + +open! Stdlib +module W = Wasm_ast + +let times = Debug.find "times" + +let stats = Debug.find "stats" + +(* Check that all branches in a loop body target labels within the loop. + [depth] counts the number of enclosing control flow constructs + including the loop itself, so it starts at 1 when called on the loop + body. A [Br n] escapes the loop when [n >= depth]. *) + +let rec is_contained_expr ~depth (e : W.expression) = + match e with + | Const _ | GlobalGet _ | Pop _ | RefFunc _ | RefNull _ -> true + | LocalGet _ -> true + | UnOp (_, e') + | I32WrapI64 e' + | I64ExtendI32 (_, e') + | F32DemoteF64 e' + | F64PromoteF32 e' + | RefI31 e' + | I31Get (_, e') + | ArrayLen e' + | StructGet (_, _, _, e') + | RefCast (_, e') + | RefTest (_, e') + | ExternConvertAny e' + | AnyConvertExtern e' -> is_contained_expr ~depth e' + | LocalTee (_, e') -> is_contained_expr ~depth e' + | BinOp (_, e1, e2) + | ArrayNew (_, e1, e2) + | ArrayNewData (_, _, e1, e2) + | ArrayGet (_, _, e1, e2) + | RefEq (e1, e2) -> is_contained_expr ~depth e1 && is_contained_expr ~depth e2 + | Br_on_cast (n, _, _, e') | Br_on_cast_fail (n, _, _, e') -> + n < depth && is_contained_expr ~depth e' + | Br_on_null (n, e') -> n < depth && is_contained_expr ~depth e' + | Call (_, l) | ArrayNewFixed (_, l) | StructNew (_, l) -> + List.for_all ~f:(is_contained_expr ~depth) l + | Call_ref (_, e', l) -> + is_contained_expr ~depth e' && List.for_all ~f:(is_contained_expr ~depth) l + | BlockExpr (_, body) -> is_contained_instrs ~depth:(depth + 1) body + | Seq (instrs, e') -> is_contained_instrs ~depth instrs && is_contained_expr ~depth e' + | IfExpr (_, cond, e1, e2) -> + is_contained_expr ~depth cond + && is_contained_expr ~depth:(depth + 1) e1 + && is_contained_expr ~depth:(depth + 1) e2 + | Try (_, body, catches) -> + is_contained_instrs ~depth:(depth + 1) body + && List.for_all ~f:(fun (_, l, _) -> l < depth) catches + +and is_contained_instr ~depth (i : W.instruction) = + match i with + | Drop e | GlobalSet (_, e) | Push e | Throw (_, e) -> is_contained_expr ~depth e + | LocalSet (_, e) -> is_contained_expr ~depth e + | Br (n, e_opt) -> ( + n < depth + && + match e_opt with + | None -> true + | Some e -> is_contained_expr ~depth e) + | Br_if (n, e) -> n < depth && is_contained_expr ~depth e + | Br_table (e, targets, default) -> + List.for_all ~f:(fun n -> n < depth) targets + && default < depth + && is_contained_expr ~depth e + | Return _ | Return_call _ | Return_call_ref _ -> false + | Loop (_, body) | Block (_, body) -> is_contained_instrs ~depth:(depth + 1) body + | If (_, e, l1, l2) -> + is_contained_expr ~depth e + && is_contained_instrs ~depth:(depth + 1) l1 + && is_contained_instrs ~depth:(depth + 1) l2 + | CallInstr (_, l) -> List.for_all ~f:(is_contained_expr ~depth) l + | Rethrow n -> n < depth + | Nop | Unreachable | Event _ -> true + | ArraySet (_, e1, e2, e3) -> + is_contained_expr ~depth e1 + && is_contained_expr ~depth e2 + && is_contained_expr ~depth e3 + | StructSet (_, _, e1, e2) -> is_contained_expr ~depth e1 && is_contained_expr ~depth e2 + +and is_contained_instrs ~depth l = List.for_all ~f:(is_contained_instr ~depth) l + +let is_extractable_loop_body body = is_contained_instrs ~depth:1 body + +(* Collect local variables referenced in an instruction list. + [reads]: variables appearing in [LocalGet]. + [writes]: variables appearing in [LocalSet] or [LocalTee]. *) + +type var_sets = + { reads : Code.Var.Set.t + ; writes : Code.Var.Set.t + } + +let rec collect_expr acc (e : W.expression) = + match e with + | Const _ | GlobalGet _ | Pop _ | RefFunc _ | RefNull _ -> acc + | LocalGet v -> { acc with reads = Code.Var.Set.add v acc.reads } + | LocalTee (v, e') -> + collect_expr { acc with writes = Code.Var.Set.add v acc.writes } e' + | UnOp (_, e') + | I32WrapI64 e' + | I64ExtendI32 (_, e') + | F32DemoteF64 e' + | F64PromoteF32 e' + | RefI31 e' + | I31Get (_, e') + | ArrayLen e' + | StructGet (_, _, _, e') + | RefCast (_, e') + | RefTest (_, e') + | ExternConvertAny e' + | AnyConvertExtern e' -> collect_expr acc e' + | BinOp (_, e1, e2) + | ArrayNew (_, e1, e2) + | ArrayNewData (_, _, e1, e2) + | ArrayGet (_, _, e1, e2) + | RefEq (e1, e2) -> collect_expr (collect_expr acc e1) e2 + | Br_on_cast (_, _, _, e') | Br_on_cast_fail (_, _, _, e') | Br_on_null (_, e') -> + collect_expr acc e' + | Call (_, l) | ArrayNewFixed (_, l) | StructNew (_, l) -> collect_exprs acc l + | Call_ref (_, e', l) -> collect_expr (collect_exprs acc l) e' + | BlockExpr (_, body) -> collect_instrs acc body + | Seq (instrs, e') -> collect_expr (collect_instrs acc instrs) e' + | IfExpr (_, cond, e1, e2) -> collect_expr (collect_expr (collect_expr acc cond) e1) e2 + | Try (_, body, _) -> collect_instrs acc body + +and collect_exprs acc l = List.fold_left ~f:collect_expr ~init:acc l + +and collect_instr acc (i : W.instruction) = + match i with + | Drop e | GlobalSet (_, e) | Push e | Throw (_, e) -> collect_expr acc e + | LocalSet (v, e) -> collect_expr { acc with writes = Code.Var.Set.add v acc.writes } e + | Br (_, Some e) | Br_if (_, e) | Br_table (e, _, _) -> collect_expr acc e + | Br (_, None) | Return None | Nop | Unreachable | Event _ | Rethrow _ -> acc + | Return (Some e) -> collect_expr acc e + | Loop (_, body) | Block (_, body) -> collect_instrs acc body + | If (_, e, l1, l2) -> collect_instrs (collect_instrs (collect_expr acc e) l1) l2 + | CallInstr (_, l) | Return_call (_, l) -> collect_exprs acc l + | Return_call_ref (_, e', l) -> collect_expr (collect_exprs acc l) e' + | ArraySet (_, e1, e2, e3) -> collect_expr (collect_expr (collect_expr acc e1) e2) e3 + | StructSet (_, _, e1, e2) -> collect_expr (collect_expr acc e1) e2 + +and collect_instrs acc l = List.fold_left ~f:collect_instr ~init:acc l + +let empty_var_sets = { reads = Code.Var.Set.empty; writes = Code.Var.Set.empty } + +let empty_vars = Code.Var.Set.empty + +(* The liveness analysis below only needs to track variables that are + written in some [Loop] body — those are the only ones whose value + the helper would write back. Other variables are either never + written (read-only in the loop, always need the pre-loop value) or + only written outside loops (never appear in any [returned_vars] or + [live_in] decision). Restricting the dataflow to this set keeps the + live sets small and speeds up the unions inside the fixpoint. + + We use a hashtable since the only operations needed are insertion + (during this pre-pass) and O(1) membership lookup at every LocalGet + in the liveness analysis. *) +let writes_in_loops body = + let acc = Code.Var.Hashtbl.create 16 in + let add_writes body = + let { writes; _ } = collect_instrs empty_var_sets body in + Code.Var.Set.iter (fun v -> Code.Var.Hashtbl.replace acc v ()) writes + in + let rec expr (e : W.expression) = + match e with + | Const _ | GlobalGet _ | Pop _ | RefFunc _ | RefNull _ | LocalGet _ -> () + | LocalTee (_, e') + | UnOp (_, e') + | I32WrapI64 e' + | I64ExtendI32 (_, e') + | F32DemoteF64 e' + | F64PromoteF32 e' + | RefI31 e' + | I31Get (_, e') + | ArrayLen e' + | StructGet (_, _, _, e') + | RefCast (_, e') + | RefTest (_, e') + | Br_on_cast (_, _, _, e') + | Br_on_cast_fail (_, _, _, e') + | Br_on_null (_, e') + | ExternConvertAny e' + | AnyConvertExtern e' -> expr e' + | BinOp (_, e1, e2) + | ArrayNew (_, e1, e2) + | ArrayNewData (_, _, e1, e2) + | ArrayGet (_, _, e1, e2) + | RefEq (e1, e2) -> + expr e1; + expr e2 + | Call (_, l) | ArrayNewFixed (_, l) | StructNew (_, l) -> List.iter ~f:expr l + | Call_ref (_, e', l) -> + List.iter ~f:expr l; + expr e' + | BlockExpr (_, body) | Try (_, body, _) -> instrs body + | Seq (l, e') -> + instrs l; + expr e' + | IfExpr (_, cond, e1, e2) -> + expr cond; + expr e1; + expr e2 + and instr (i : W.instruction) = + match i with + | Loop (_, body) -> + (* [collect_instrs] descends into nested structures, so this + captures writes from any nested loops as well. *) + add_writes body + | Block (_, body) -> instrs body + | If (_, e, l1, l2) -> + expr e; + instrs l1; + instrs l2 + | Drop e | GlobalSet (_, e) | Push e | Throw (_, e) | LocalSet (_, e) -> expr e + | Br (_, Some e) | Br_if (_, e) | Br_table (e, _, _) | Return (Some e) -> expr e + | Br (_, None) | Return None | Nop | Unreachable | Event _ | Rethrow _ -> () + | CallInstr (_, l) | Return_call (_, l) -> List.iter ~f:expr l + | Return_call_ref (_, e', l) -> + List.iter ~f:expr l; + expr e' + | ArraySet (_, e1, e2, e3) -> + expr e1; + expr e2; + expr e3 + | StructSet (_, _, e1, e2) -> + expr e1; + expr e2 + and instrs l = List.iter ~f:instr l in + instrs body; + acc + +let label_reads labels depth = + let rec find labels depth = + match labels, depth with + | live :: _, 0 -> live + | _ :: tl, n -> find tl (n - 1) + | [], _ -> assert false + in + find labels depth + +let catches_live_out labels ~exn_live_out catches = + List.fold_left + catches + ~init:exn_live_out + ~f:(fun acc (_, label, _) -> Code.Var.Set.union acc (label_reads labels label)) + +let rec live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars (e : W.expression) = + match e with + | Const _ | GlobalGet _ | Pop _ | RefFunc _ | RefNull _ -> live_out + | LocalGet v -> + if Code.Var.Hashtbl.mem tracked_vars v + then Code.Var.Set.add v live_out + else live_out + | LocalTee (v, e') -> + live_before_expr + ~labels + ~live_out:(Code.Var.Set.remove v live_out) + ~exn_live_out + ~tracked_vars + e' + | UnOp (_, e') + | I32WrapI64 e' + | I64ExtendI32 (_, e') + | F32DemoteF64 e' + | F64PromoteF32 e' + | RefI31 e' + | I31Get (_, e') + | ArrayLen e' + | StructGet (_, _, _, e') + | RefCast (_, e') + | RefTest (_, e') + | ExternConvertAny e' + | AnyConvertExtern e' -> + live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e' + | BinOp (_, e1, e2) + | ArrayNew (_, e1, e2) + | ArrayNewData (_, _, e1, e2) + | ArrayGet (_, _, e1, e2) + | RefEq (e1, e2) -> + let live_out = live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e2 in + live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e1 + | Br_on_cast (n, _, _, e') | Br_on_cast_fail (n, _, _, e') | Br_on_null (n, e') -> + let live_out = Code.Var.Set.union live_out (label_reads labels n) in + live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e' + | Call (_, l) -> + let live_out = + if Code.Var.Set.is_empty exn_live_out + then live_out + else Code.Var.Set.union live_out exn_live_out + in + live_before_exprs ~labels ~live_out ~exn_live_out ~tracked_vars l + | ArrayNewFixed (_, l) | StructNew (_, l) -> + live_before_exprs ~labels ~live_out ~exn_live_out ~tracked_vars l + | Call_ref (_, e', l) -> + let live_out = + if Code.Var.Set.is_empty exn_live_out + then live_out + else Code.Var.Set.union live_out exn_live_out + in + live_before_exprs ~labels ~live_out ~exn_live_out ~tracked_vars (l @ [ e' ]) + | BlockExpr (_, body) -> + let _, live_in = + live_before_instrs + ~labels:(live_out :: labels) + ~live_out + ~exn_live_out + ~tracked_vars + body + in + live_in + | Seq (instrs, e') -> + let live_out = live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e' in + let _, live_in = + live_before_instrs ~labels ~live_out ~exn_live_out ~tracked_vars instrs + in + live_in + | IfExpr (_, cond, e1, e2) -> + let branch_labels = live_out :: labels in + let live1 = + live_before_expr ~labels:branch_labels ~live_out ~exn_live_out ~tracked_vars e1 + in + let live2 = + live_before_expr ~labels:branch_labels ~live_out ~exn_live_out ~tracked_vars e2 + in + let live_out = Code.Var.Set.union live1 live2 in + live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars cond + | Try (_, body, catches) -> + let exn_live_out = catches_live_out labels ~exn_live_out catches in + let _, live_in = + live_before_instrs + ~labels:(live_out :: labels) + ~live_out + ~exn_live_out + ~tracked_vars + body + in + live_in + +and live_before_exprs ~labels ~live_out ~exn_live_out ~tracked_vars l = + List.fold_right + l + ~init:live_out + ~f:(fun e live_out -> live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e) + +and loop_live_in ~labels ~live_out ~exn_live_out ~tracked_vars body = + let rec fix live_head = + let _, live_head' = + live_before_instrs + ~labels:(live_head :: labels) + ~live_out + ~exn_live_out + ~tracked_vars + body + in + if Code.Var.Set.equal live_head live_head' then live_head else fix live_head' + in + fix empty_vars + +and live_before_loop_body ~labels ~live_out ~exn_live_out ~tracked_vars body = + let live_head = loop_live_in ~labels ~live_out ~exn_live_out ~tracked_vars body in + live_before_instrs + ~labels:(live_head :: labels) + ~live_out + ~exn_live_out + ~tracked_vars + body + +and live_before_instr + ~labels + ~rest_loops + ~live_out + ~exn_live_out + ~tracked_vars + (i : W.instruction) = + match i with + | Drop e | Push e -> + rest_loops, live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e + | GlobalSet (_, e) -> + rest_loops, live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e + | Throw (_, e) -> + ( rest_loops + , live_before_expr ~labels ~live_out:exn_live_out ~exn_live_out ~tracked_vars e ) + | LocalSet (v, e) -> + let live_out = Code.Var.Set.remove v live_out in + rest_loops, live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e + | Br (n, None) -> rest_loops, label_reads labels n + | Br (n, Some e) -> + let live_out = label_reads labels n in + rest_loops, live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e + | Br_if (n, e) -> + let live_out = Code.Var.Set.union live_out (label_reads labels n) in + rest_loops, live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e + | Br_table (e, targets, default) -> + let live_out = + List.fold_left + ~init:(label_reads labels default) + ~f:(fun acc n -> Code.Var.Set.union acc (label_reads labels n)) + targets + in + rest_loops, live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e + | Return None -> rest_loops, empty_vars + | Return (Some e) -> + ( rest_loops + , live_before_expr ~labels ~live_out:empty_vars ~exn_live_out ~tracked_vars e ) + | Loop (ty, body) -> + let extractable = + List.is_empty ty.result + && is_extractable_loop_body body + && + let { writes; _ } = collect_instrs empty_var_sets body in + Code.Var.Set.is_empty (Code.Var.Set.inter writes exn_live_out) + in + if extractable + then + let live_in = + loop_live_in ~labels ~live_out ~exn_live_out ~tracked_vars body + in + Some (live_out, live_in) :: rest_loops, live_in + else + let body_loops, live_in = + live_before_loop_body ~labels ~live_out ~exn_live_out ~tracked_vars body + in + (None :: body_loops) @ rest_loops, live_in + | Block (_, body) -> + let body_loops, live_in = + live_before_instrs + ~labels:(live_out :: labels) + ~live_out + ~exn_live_out + ~tracked_vars + body + in + body_loops @ rest_loops, live_in + | If (_, e, l1, l2) -> + let branch_labels = live_out :: labels in + let loops1, live1 = + live_before_instrs + ~labels:branch_labels + ~live_out + ~exn_live_out + ~tracked_vars + l1 + in + let loops2, live2 = + live_before_instrs + ~labels:branch_labels + ~live_out + ~exn_live_out + ~tracked_vars + l2 + in + let live_out = Code.Var.Set.union live1 live2 in + let live_in = live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e in + loops1 @ loops2 @ rest_loops, live_in + | CallInstr (_, l) -> + let live_out = + if Code.Var.Set.is_empty exn_live_out + then live_out + else Code.Var.Set.union live_out exn_live_out + in + rest_loops, live_before_exprs ~labels ~live_out ~exn_live_out ~tracked_vars l + | Nop | Event _ -> rest_loops, live_out + | ArraySet (_, e1, e2, e3) -> + let live_out = live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e3 in + let live_out = live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e2 in + rest_loops, live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e1 + | StructSet (_, _, e1, e2) -> + let live_out = live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e2 in + rest_loops, live_before_expr ~labels ~live_out ~exn_live_out ~tracked_vars e1 + | Return_call (_, l) -> + ( rest_loops + , live_before_exprs ~labels ~live_out:empty_vars ~exn_live_out ~tracked_vars l ) + | Return_call_ref (_, e', l) -> + ( rest_loops + , live_before_exprs + ~labels + ~live_out:empty_vars + ~exn_live_out + ~tracked_vars + (l @ [ e' ]) ) + | Rethrow _ -> rest_loops, exn_live_out + | Unreachable -> rest_loops, empty_vars + +and live_before_instrs ~labels ~live_out ~exn_live_out ~tracked_vars l = + List.fold_right + l + ~init:([], live_out) + ~f:(fun i (rest_loops, live_out) -> + live_before_instr ~labels ~rest_loops ~live_out ~exn_live_out ~tracked_vars i) + +(* Backward dataflow over the function body, producing one entry per + [Loop] encountered in source order: [Some (live_out, live_in)] for + loops that will be hoisted — [live_out] is the set of variables read + after the loop on any normal path through the rest of the function, + and [live_in] is the fixpoint set of variables whose pre-loop value + the body may need; [None] for loops that are left in place (either + not contained, or contained but writing a variable that is read on + an exceptional exit, which the helper has no way to write back). + The forward pass in [transform_instrs] consumes the list in the same + order. *) +let loops_after_reads ~tracked_vars body = + let loops, _ = + live_before_instrs + ~labels:[] + ~live_out:empty_vars + ~exn_live_out:empty_vars + ~tracked_vars + body + in + loops + +(* Transformation context *) + +type ctx = + { var_types : W.value_type Code.Var.Hashtbl.t + ; tracked_vars : unit Code.Var.Hashtbl.t + ; mutable new_fields : W.module_field list + ; mutable extra_locals : (Code.Var.t * W.value_type) list + } + +let lookup_types ctx vars = + Code.Var.Set.fold + (fun v acc -> + match Code.Var.Hashtbl.find_opt ctx.var_types v with + | Some t -> (v, t) :: acc + | None -> acc) + vars + [] + +let extract_loop ctx ~is_initialized ~after_reads ~live_in body = + let { reads; writes } = collect_instrs empty_var_sets body in + let all_vars = Code.Var.Set.union reads writes in + (* Variables in [all_vars] that are not tracked are read in the body + and written nowhere in any loop — they trivially need their + pre-loop value, so treat them as live-in. *) + let param_vars = + Code.Var.Set.filter + (fun v -> + is_initialized v + && (Code.Var.Set.mem v live_in + || not (Code.Var.Hashtbl.mem ctx.tracked_vars v))) + all_vars + in + let local_vars = Code.Var.Set.diff all_vars param_vars in + let param_with_types = lookup_types ctx param_vars in + let local_with_types = lookup_types ctx local_vars in + let returned_vars = Code.Var.Set.inter writes after_reads in + let modified_with_types = lookup_types ctx returned_vars in + let helper_name = Code.Var.fresh_n "loop_helper" in + let args = List.map ~f:(fun (v, _) -> W.LocalGet v) param_with_types in + let param_types = List.map ~f:snd param_with_types in + let param_names = List.map ~f:fst param_with_types in + let loop_instr = W.Loop ({ W.params = []; result = [] }, body) in + let make_helper ~signature ~extra_body = + W.Function + { name = helper_name + ; exported_name = None + ; typ = None + ; signature + ; param_names + ; locals = local_with_types + ; body = loop_instr :: extra_body + } + in + match modified_with_types with + | [] -> + let signature = { W.params = param_types; result = [] } in + ctx.new_fields <- make_helper ~signature ~extra_body:[] :: ctx.new_fields; + [ W.CallInstr (helper_name, args) ] + | [ (v, vt) ] -> + let signature = { W.params = param_types; result = [ vt ] } in + ctx.new_fields <- + make_helper ~signature ~extra_body:[ Push (LocalGet v) ] :: ctx.new_fields; + [ W.LocalSet (v, Call (helper_name, args)) ] + | _ -> + let ret_type_name = Code.Var.fresh_n "loop_ret" in + let fields = + List.map ~f:(fun (_, t) -> { W.mut = false; typ = W.Value t }) modified_with_types + in + ctx.new_fields <- + W.Type + [ { name = ret_type_name; typ = Struct fields; supertype = None; final = true } + ] + :: ctx.new_fields; + let ret_ref_type = W.Ref { nullable = false; typ = Type ret_type_name } in + let signature = { W.params = param_types; result = [ ret_ref_type ] } in + let struct_new = + W.StructNew + (ret_type_name, List.map ~f:(fun (v, _) -> W.LocalGet v) modified_with_types) + in + ctx.new_fields <- + make_helper ~signature ~extra_body:[ Push struct_new ] :: ctx.new_fields; + let tmp = Code.Var.fresh_n "loop_ret" in + ctx.extra_locals <- (tmp, ret_ref_type) :: ctx.extra_locals; + W.LocalSet (tmp, Call (helper_name, args)) + :: List.mapi + ~f:(fun i (v, _) -> + W.LocalSet (v, StructGet (None, ret_type_name, i, LocalGet tmp))) + modified_with_types + +let fork_il_ctx = Initialize_locals.fork_context + +let rec transform_instrs ctx il_ctx pending_loops instrs = + List.concat_map ~f:(transform_instr ctx il_ctx pending_loops) instrs + +and transform_instr ctx il_ctx pending_loops (i : W.instruction) = + match i with + | Loop (ty, body) -> ( + match !pending_loops with + | Some (after_reads, live_in) :: tl -> + pending_loops := tl; + let result = + extract_loop + ctx + ~is_initialized:(Initialize_locals.is_initialized il_ctx) + ~after_reads + ~live_in + body + in + Initialize_locals.scan_instruction il_ctx i; + result + | None :: tl -> + pending_loops := tl; + let inner = fork_il_ctx il_ctx in + let body' = transform_instrs ctx inner pending_loops body in + Initialize_locals.scan_instruction il_ctx i; + [ W.Loop (ty, body') ] + | [] -> assert false) + | Block (ty, body) -> + let inner = fork_il_ctx il_ctx in + let body' = transform_instrs ctx inner pending_loops body in + Initialize_locals.scan_instruction il_ctx i; + [ W.Block (ty, body') ] + | If (ty, e, l1, l2) -> + let inner1 = fork_il_ctx il_ctx in + let inner2 = fork_il_ctx il_ctx in + let l1' = transform_instrs ctx inner1 pending_loops l1 in + let l2' = transform_instrs ctx inner2 pending_loops l2 in + Initialize_locals.scan_instruction il_ctx i; + [ W.If (ty, e, l1', l2') ] + | _ -> + Initialize_locals.scan_instruction il_ctx i; + [ i ] + +let f ~toplevel fields = + let t = Timer.make () in + let hoisted = ref 0 in + let left_in_place = ref 0 in + let result = + List.concat_map + ~f:(fun field -> + match field with + | W.Function ({ name; _ } as func) when Code.Var.equal name toplevel -> + let var_types = Code.Var.Hashtbl.create 16 in + List.iter2 + ~f:(fun v t -> Code.Var.Hashtbl.add var_types v t) + func.param_names + func.signature.params; + List.iter ~f:(fun (v, t) -> Code.Var.Hashtbl.add var_types v t) func.locals; + let tracked_vars = writes_in_loops func.body in + let ctx = { var_types; tracked_vars; new_fields = []; extra_locals = [] } in + let il_ctx = Initialize_locals.create_context () in + List.iter ~f:(Initialize_locals.mark_initialized il_ctx) func.param_names; + List.iter + ~f:(fun (var, typ) -> + match (typ : W.value_type) with + | I32 | I64 | F32 | F64 | Ref { nullable = true; _ } -> + Initialize_locals.mark_initialized il_ctx var + | Ref { nullable = false; _ } -> ()) + func.locals; + let loops = loops_after_reads ~tracked_vars func.body in + List.iter loops ~f:(function + | Some _ -> incr hoisted + | None -> incr left_in_place); + let pending_loops = ref loops in + let body = transform_instrs ctx il_ctx pending_loops func.body in + let func' = + W.Function { func with body; locals = func.locals @ ctx.extra_locals } + in + List.rev ctx.new_fields @ [ func' ] + | _ -> [ field ]) + fields + in + if times () then Format.eprintf " loop hoisting: %a@." Timer.print t; + if stats () + then + Format.eprintf + "Stats - loop hoisting: hoisted %d, left in place %d@." + !hoisted + !left_in_place; + result diff --git a/compiler/lib-wasm/hoist_loops.mli b/compiler/lib-wasm/hoist_loops.mli new file mode 100644 index 0000000000..1d8ef97ad9 --- /dev/null +++ b/compiler/lib-wasm/hoist_loops.mli @@ -0,0 +1,25 @@ +(* Wasm_of_ocaml compiler + * http://www.ocsigen.org/js_of_ocaml/ + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, with linking exception; + * either version 2.1 of the License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. + *) + +(** Extract loops from a function into separate helper functions. + + This avoids Wasm engines unnecessarily tiering up (compiling to optimised + native code) a large function that is called only once but contains a loop. + Only the small helper functions containing the loops will be tiered up. *) + +val f : toplevel:Code.Var.t -> Wasm_ast.module_field list -> Wasm_ast.module_field list diff --git a/compiler/lib-wasm/initialize_locals.ml b/compiler/lib-wasm/initialize_locals.ml index bb9733286a..4f37c34386 100644 --- a/compiler/lib-wasm/initialize_locals.ml +++ b/compiler/lib-wasm/initialize_locals.ml @@ -23,8 +23,13 @@ type ctx = ; uninitialized : Code.Var.Set.t ref } +let create_context () = + { initialized = Code.Var.Set.empty; uninitialized = ref Code.Var.Set.empty } + let mark_initialized ctx i = ctx.initialized <- Code.Var.Set.add i ctx.initialized +let is_initialized ctx i = Code.Var.Set.mem i ctx.initialized + let fork_context { initialized; uninitialized } = { initialized; uninitialized } let check_initialized ctx i = @@ -217,9 +222,7 @@ let has_default (ty : Wasm_ast.heap_type) = | Func | Extern | Array | Struct | None_ | Type _ -> false let f ~param_names ~locals instrs = - let ctx = - { initialized = Code.Var.Set.empty; uninitialized = ref Code.Var.Set.empty } - in + let ctx = create_context () in List.iter ~f:(fun x -> mark_initialized ctx x) param_names; List.iter ~f:(fun (var, typ) -> diff --git a/compiler/lib-wasm/initialize_locals.mli b/compiler/lib-wasm/initialize_locals.mli index c356aa396b..ac9a898f51 100644 --- a/compiler/lib-wasm/initialize_locals.mli +++ b/compiler/lib-wasm/initialize_locals.mli @@ -16,6 +16,18 @@ * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *) +type ctx + +val create_context : unit -> ctx + +val mark_initialized : ctx -> Code.Var.t -> unit + +val is_initialized : ctx -> Code.Var.t -> bool + +val fork_context : ctx -> ctx + +val scan_instruction : ctx -> Wasm_ast.instruction -> unit + val f : param_names:Wasm_ast.var list -> locals:(Wasm_ast.var * Wasm_ast.value_type) list