diff --git a/CHANGES.md b/CHANGES.md index e7481e2973..e6a570a4ec 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -25,6 +25,11 @@ whole-program builds emit reversed `Filename.concat` operands and silently broke `Filename.temp_file` * Runtime/wasm: fix string conversion from JS to OCaml (#2230) +* Compiler: avoid JS stack overflow on deep mutually recursive + direct-style calls under `--effects=double-translation`; the trampoline + pass now applies the same `caml_stack_check_depth` / trampoline pattern + used for CPS calls to the direct half of cps_needed mutually recursive + closures # 6.3.2 (2026-02-15) - Lille diff --git a/compiler/lib/driver.ml b/compiler/lib/driver.ml index 5573acae27..d8df255dd9 100644 --- a/compiler/lib/driver.ml +++ b/compiler/lib/driver.ml @@ -718,9 +718,8 @@ let optimize ~shapes ~profile ~keep_flow_data p = +> effects_and_exact_calls ~keep_flow_data ~deadcode_sentinel ~shapes profile +> map_fst5 (match Config.target (), Config.effects () with - | `JavaScript, `Disabled -> Generate_closure.f - | `JavaScript, (`Cps | `Double_translation) | `Wasm, (`Disabled | `Jspi | `Cps) - -> Fun.id + | `JavaScript, (`Disabled | `Double_translation) -> Generate_closure.f + | `JavaScript, `Cps | `Wasm, (`Disabled | `Jspi | `Cps) -> Fun.id | `JavaScript, `Jspi | `Wasm, `Double_translation -> assert false) +> map_fst5 deadcode' in diff --git a/compiler/lib/generate_closure.ml b/compiler/lib/generate_closure.ml index 659a5a6479..3581ab7bfc 100644 --- a/compiler/lib/generate_closure.ml +++ b/compiler/lib/generate_closure.ml @@ -22,6 +22,13 @@ open Code let debug_tc = Debug.find "gen_tc" +type cps_pair = + { direct_c : Code.Var.t + ; cps_c : Code.Var.t + ; cps_args : Code.Var.t list + ; cps_cont : Code.cont + } + type closure_info = { f_name : Code.Var.t ; args : Code.Var.t list @@ -29,6 +36,14 @@ type closure_info = ; tc : Code.Addr.Set.t Code.Var.Map.t ; pos : int ; cloc : Parse_info.t option + ; (* Under --effects=double-translation, the [Closure] instruction that + binds [f_name]'s direct version is followed by a sibling [Closure] for + the CPS version and a [caml_cps_closure] primitive pairing them. In + that case, [f_name] is the public paired closure (the [x] of + [Let x = caml_cps_closure(direct_c, cps_c)]), and this field records + the names and body of the CPS half so trampolines can rewrite the + triple. [None] in every other mode and for non-cps_needed closures. *) + cps_pair : cps_pair option } module SCC = Strongly_connected_components.Make (Var) @@ -66,10 +81,27 @@ let rec collect_apply pc blocks visited tc = let rec collect_closures blocks l pos = match l with + | Let (direct_c, Closure (args, ((pc, _) as cont), cloc)) + :: Let (cps_c, Closure (cps_args, cps_cont, _)) + :: Let (x, Prim (Extern "caml_cps_closure", [ Pv d; Pv c ])) + :: rem + when Var.equal d direct_c && Var.equal c cps_c -> + let _, tc = collect_apply pc blocks Addr.Set.empty Var.Map.empty in + let l, rem = collect_closures blocks rem (succ pos) in + { f_name = x + ; args + ; cont + ; tc + ; pos + ; cloc + ; cps_pair = Some { direct_c; cps_c; cps_args; cps_cont } + } + :: l + , rem | Let (f_name, Closure (args, ((pc, _) as cont), cloc)) :: rem -> let _, tc = collect_apply pc blocks Addr.Set.empty Var.Map.empty in let l, rem = collect_closures blocks rem (succ pos) in - { f_name; args; cont; tc; pos; cloc } :: l, rem + { f_name; args; cont; tc; pos; cloc; cps_pair = None } :: l, rem | rem -> [], rem let group_closures closures_map = @@ -96,6 +128,19 @@ type w = ; code : Code.instr ; wrapper : Code.instr } + | Paired_one of + { name : Code.Var.t + ; direct_code : Code.instr + ; cps_code : Code.instr + ; pair_code : Code.instr + } + | Paired_wrapper of + { name : Code.Var.t + ; inner_direct_code : Code.instr + ; wrapper_code : Code.instr + ; cps_code : Code.instr + ; pair_code : Code.instr + } module Trampoline = struct let direct_call_block ~counter ~x ~f ~args = @@ -160,124 +205,341 @@ module Trampoline = struct let wrapper_closure pc args cloc = Closure (args, (pc, []), cloc) - let f free_pc blocks closures_map component = - match component with - | SCC.No_loop id -> - let ci = Var.Map.find id closures_map in - let instr = Let (ci.f_name, Closure (ci.args, ci.cont, ci.cloc)) in - free_pc, blocks, [ One { name = ci.f_name; code = instr } ] - | SCC.Has_loop all -> - if debug_tc () - then ( - Format.eprintf "Detect cycles of size (%d).\n%!" (List.length all); - Format.eprintf - "%a\n%!" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.pp_print_string fmt ", ") - Var.print) - all); - let tailcall_max_depth = Config.Param.tailcall_max_depth () in - let all = + let has_loop free_pc blocks closures_map all = + if debug_tc () + then ( + Format.eprintf "Detect cycles of size (%d).\n%!" (List.length all); + Format.eprintf + "%a\n%!" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.pp_print_string fmt ", ") + Var.print) + all); + let tailcall_max_depth = Config.Param.tailcall_max_depth () in + let all = + List.map all ~f:(fun id -> + ( (if tailcall_max_depth = 0 then None else Some (Code.Var.fresh_n "counter")) + , Var.Map.find id closures_map )) + in + let blocks, free_pc, closures = + List.fold_left + all + ~init:(blocks, free_pc, []) + ~f:(fun (blocks, free_pc, closures) (counter, ci) -> + if debug_tc () + then Format.eprintf "Rewriting for %a\n%!" Var.print ci.f_name; + let new_f = Code.Var.fork ci.f_name in + let new_args = List.map ci.args ~f:Code.Var.fork in + let wrapper_pc = free_pc in + let free_pc = free_pc + 1 in + let new_counter = Option.map counter ~f:Code.Var.fork in + let start_loc = + let block = Addr.Map.find (fst ci.cont) blocks in + match block.body with + | Event loc :: _ -> loc + | _ -> Parse_info.zero + in + let wrapper_block = + wrapper_block new_f ~args:new_args ~counter:new_counter start_loc + in + let blocks = Addr.Map.add wrapper_pc wrapper_block blocks in + let instr_wrapper = + Let (ci.f_name, wrapper_closure wrapper_pc new_args ci.cloc) + in + let instr_real = + match counter with + | None -> Let (new_f, Closure (ci.args, ci.cont, ci.cloc)) + | Some counter -> + Let (new_f, Closure (counter :: ci.args, ci.cont, ci.cloc)) + in + let counter_and_pc = + List.fold_left all ~init:[] ~f:(fun acc (counter, ci2) -> + try + let pcs = Addr.Set.elements (Var.Map.find ci.f_name ci2.tc) in + List.map pcs ~f:(fun x -> counter, x) @ acc + with Not_found -> acc) + in + let blocks, free_pc = + List.fold_left + counter_and_pc + ~init:(blocks, free_pc) + ~f:(fun (blocks, free_pc) (counter, pc) -> + if debug_tc () then Format.eprintf "Rewriting tc in %d\n%!" pc; + let block = Addr.Map.find pc blocks in + let direct_call_pc = free_pc in + let bounce_call_pc = free_pc + 1 in + let free_pc = free_pc + 2 in + match List.rev block.body with + | Let (x, Apply { f; args; exact = true }) :: rem_rev -> + assert (Var.equal f ci.f_name); + let blocks = + Addr.Map.add + direct_call_pc + (direct_call_block ~counter ~x ~f:new_f ~args) + blocks + in + let blocks = + Addr.Map.add + bounce_call_pc + (bounce_call_block ~x ~f:new_f ~args) + blocks + in + let block = + match counter with + | None -> + let branch = Branch (bounce_call_pc, []) in + { block with body = List.rev rem_rev; branch } + | Some counter -> + let direct = Code.Var.fresh () in + let branch = + Cond (direct, (direct_call_pc, []), (bounce_call_pc, [])) + in + let last = + Let + ( direct + , Prim + ( Lt + , [ Pv counter + ; Pc + (Int (Targetint.of_int_exn tailcall_max_depth)) + ] ) ) + in + { block with body = List.rev (last :: rem_rev); branch } + in + let blocks = Addr.Map.remove pc blocks in + Addr.Map.add pc block blocks, free_pc + | _ -> assert false) + in + ( blocks + , free_pc + , Wrapper { name = ci.f_name; code = instr_real; wrapper = instr_wrapper } + :: closures )) + in + free_pc, blocks, closures +end + +(* Trampoline variant for --effects=double-translation. The SCC consists of + [caml_cps_closure]-paired closures. We apply the same depth-guarded + trampoline strategy that [--effects=cps] uses for ordinary CPS calls + (cf. effects.ml emit of [caml_stack_check_depth ? f(args) : + caml_trampoline_return(f, args, 0)]), only here the call we are guarding + is a plain direct-style call between mutually recursive functions. + + For each member of the SCC: + - The original direct closure is renamed [new_direct_c], and a small + wrapper that drives a [caml_direct_trampoline] loop takes the direct + slot of [caml_cps_closure]. External direct callers go through the + wrapper; sibling tail calls within the SCC bypass it. + - Every recursive tail call in the inner-direct body is split into a + direct branch ([Apply new_direct_c_i]) and a bounce branch that returns + [caml_trampoline_return(new_direct_c_i, args, 1)]. The bounce object + bubbles up to the wrapper's trampoline loop, which reapplies the + callee with a fresh stack budget. [caml_stack_check_depth] gates + between the two, exactly like the CPS-side check. *) +module Trampoline_dt = struct + let direct_call_block ~x ~f ~args = + let return = Code.Var.fork x in + { params = [] + ; body = [ Let (return, Apply { f; args; exact = true }) ] + ; branch = Return return + } + + let bounce_call_block ~x ~f ~args = + let return = Code.Var.fork x in + let new_args = Code.Var.fresh () in + { params = [] + ; body = + [ Let + (new_args, Prim (Extern "%js_array", List.map args ~f:(fun x -> Pv x))) + ; Let + ( return + , Prim + ( Extern "caml_trampoline_return" + , [ Pv f; Pv new_args; Pc (Int Targetint.one) ] ) ) + ] + ; branch = Return return + } + + let wrapper_block inner ~args loc = + let args_arr = Code.Var.fresh () in + let result = Code.Var.fresh () in + { params = [] + ; body = + [ Event loc + ; Let + (args_arr, Prim (Extern "%js_array", List.map args ~f:(fun x -> Pv x))) + ; Let + ( result + , Prim + (Extern "caml_direct_trampoline", [ Pv inner; Pv args_arr ]) ) + ] + ; branch = Return result + } + + let wrapper_closure pc args cloc = Closure (args, (pc, []), cloc) + + let has_loop free_pc blocks closures_map all = + if debug_tc () + then ( + Format.eprintf + "Detect cycles (paired, double-translation) of size (%d).\n%!" + (List.length all); + Format.eprintf + "%a\n%!" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.pp_print_string fmt ", ") + Var.print) + all); + let all = List.map all ~f:(fun id -> Var.Map.find id closures_map) in + let blocks, free_pc, closures = + List.fold_left + all + ~init:(blocks, free_pc, []) + ~f:(fun (blocks, free_pc, closures) ci -> + let { direct_c; cps_c; cps_args; cps_cont } = + match ci.cps_pair with + | Some p -> p + | None -> assert false + in + if debug_tc () + then Format.eprintf "Rewriting (paired) for %a\n%!" Var.print ci.f_name; + let new_direct_c = Code.Var.fork direct_c in + let new_args = List.map ci.args ~f:Code.Var.fork in + let wrapper_pc = free_pc in + let free_pc = free_pc + 1 in + let start_loc = + let block = Addr.Map.find (fst ci.cont) blocks in + match block.body with + | Event loc :: _ -> loc + | _ -> Parse_info.zero + in + let wrapper_b = wrapper_block new_direct_c ~args:new_args start_loc in + let blocks = Addr.Map.add wrapper_pc wrapper_b blocks in + let wrapper_c = Code.Var.fresh_n "wrapper" in + let wrapper_code = + Let (wrapper_c, wrapper_closure wrapper_pc new_args ci.cloc) + in + let inner_direct_code = + Let (new_direct_c, Closure (ci.args, ci.cont, ci.cloc)) + in + let cps_code = Let (cps_c, Closure (cps_args, cps_cont, None)) in + let pair_code = + Let + ( ci.f_name + , Prim (Extern "caml_cps_closure", [ Pv wrapper_c; Pv cps_c ]) ) + in + let scc_callees = + List.fold_left all ~init:[] ~f:(fun acc ci2 -> + try + let pcs = Addr.Set.elements (Var.Map.find ci.f_name ci2.tc) in + pcs @ acc + with Not_found -> acc) + in + let blocks, free_pc = + List.fold_left + scc_callees + ~init:(blocks, free_pc) + ~f:(fun (blocks, free_pc) pc -> + if debug_tc () + then Format.eprintf "Rewriting tc (paired) in %d\n%!" pc; + let block = Addr.Map.find pc blocks in + let direct_call_pc = free_pc in + let bounce_call_pc = free_pc + 1 in + let free_pc = free_pc + 2 in + match List.rev block.body with + | Let (x, Apply { f; args; exact = true }) :: rem_rev -> + assert (Var.equal f ci.f_name); + let blocks = + Addr.Map.add + direct_call_pc + (direct_call_block ~x ~f:new_direct_c ~args) + blocks + in + let blocks = + Addr.Map.add + bounce_call_pc + (bounce_call_block ~x ~f:new_direct_c ~args) + blocks + in + let direct = Code.Var.fresh () in + let branch = + Cond (direct, (direct_call_pc, []), (bounce_call_pc, [])) + in + let last = + Let (direct, Prim (Extern "caml_stack_check_depth", [])) + in + let block = + { block with body = List.rev (last :: rem_rev); branch } + in + let blocks = Addr.Map.remove pc blocks in + Addr.Map.add pc block blocks, free_pc + | _ -> assert false) + in + ( blocks + , free_pc + , Paired_wrapper + { name = ci.f_name + ; inner_direct_code + ; wrapper_code + ; cps_code + ; pair_code + } + :: closures )) + in + free_pc, blocks, closures +end + +let dispatch_component free_pc blocks closures_map component = + match component with + | SCC.No_loop id -> + let ci = Var.Map.find id closures_map in + (match ci.cps_pair with + | None -> + let instr = Let (ci.f_name, Closure (ci.args, ci.cont, ci.cloc)) in + free_pc, blocks, [ One { name = ci.f_name; code = instr } ] + | Some { direct_c; cps_c; cps_args; cps_cont } -> + let direct_code = Let (direct_c, Closure (ci.args, ci.cont, ci.cloc)) in + let cps_code = Let (cps_c, Closure (cps_args, cps_cont, None)) in + let pair_code = + Let + ( ci.f_name + , Prim (Extern "caml_cps_closure", [ Pv direct_c; Pv cps_c ]) ) + in + ( free_pc + , blocks + , [ Paired_one { name = ci.f_name; direct_code; cps_code; pair_code } ] + )) + | SCC.Has_loop all -> + let all_paired = + List.for_all all ~f:(fun id -> + Option.is_some (Var.Map.find id closures_map).cps_pair) + in + let any_paired = + List.exists all ~f:(fun id -> + Option.is_some (Var.Map.find id closures_map).cps_pair) + in + assert (Bool.equal any_paired all_paired); + if all_paired + then Trampoline_dt.has_loop free_pc blocks closures_map all + else if not (Poly.equal (Config.effects ()) `Disabled) + then + (* Under --effects=cps/double-translation/jspi, unpaired SCCs are + rare and the classic [Trampoline] transformation is not safe to + run on them (its bounce mechanism doesn't compose with the CPS + call-gen). Emit the closures back unchanged; deep recursion in + these will still risk overflow, but they should generally not + occur because partial_cps_analysis promotes mutually recursive + functions to cps_needed. *) + let closures = List.map all ~f:(fun id -> - ( (if tailcall_max_depth = 0 then None else Some (Code.Var.fresh_n "counter")) - , Var.Map.find id closures_map )) - in - let blocks, free_pc, closures = - List.fold_left - all - ~init:(blocks, free_pc, []) - ~f:(fun (blocks, free_pc, closures) (counter, ci) -> - if debug_tc () - then Format.eprintf "Rewriting for %a\n%!" Var.print ci.f_name; - let new_f = Code.Var.fork ci.f_name in - let new_args = List.map ci.args ~f:Code.Var.fork in - let wrapper_pc = free_pc in - let free_pc = free_pc + 1 in - let new_counter = Option.map counter ~f:Code.Var.fork in - let start_loc = - let block = Addr.Map.find (fst ci.cont) blocks in - match block.body with - | Event loc :: _ -> loc - | _ -> Parse_info.zero - in - let wrapper_block = - wrapper_block new_f ~args:new_args ~counter:new_counter start_loc - in - let blocks = Addr.Map.add wrapper_pc wrapper_block blocks in - let instr_wrapper = - Let (ci.f_name, wrapper_closure wrapper_pc new_args ci.cloc) - in - let instr_real = - match counter with - | None -> Let (new_f, Closure (ci.args, ci.cont, ci.cloc)) - | Some counter -> - Let (new_f, Closure (counter :: ci.args, ci.cont, ci.cloc)) - in - let counter_and_pc = - List.fold_left all ~init:[] ~f:(fun acc (counter, ci2) -> - try - let pcs = Addr.Set.elements (Var.Map.find ci.f_name ci2.tc) in - List.map pcs ~f:(fun x -> counter, x) @ acc - with Not_found -> acc) - in - let blocks, free_pc = - List.fold_left - counter_and_pc - ~init:(blocks, free_pc) - ~f:(fun (blocks, free_pc) (counter, pc) -> - if debug_tc () then Format.eprintf "Rewriting tc in %d\n%!" pc; - let block = Addr.Map.find pc blocks in - let direct_call_pc = free_pc in - let bounce_call_pc = free_pc + 1 in - let free_pc = free_pc + 2 in - match List.rev block.body with - | Let (x, Apply { f; args; exact = true }) :: rem_rev -> - assert (Var.equal f ci.f_name); - let blocks = - Addr.Map.add - direct_call_pc - (direct_call_block ~counter ~x ~f:new_f ~args) - blocks - in - let blocks = - Addr.Map.add - bounce_call_pc - (bounce_call_block ~x ~f:new_f ~args) - blocks - in - let block = - match counter with - | None -> - let branch = Branch (bounce_call_pc, []) in - { block with body = List.rev rem_rev; branch } - | Some counter -> - let direct = Code.Var.fresh () in - let branch = - Cond (direct, (direct_call_pc, []), (bounce_call_pc, [])) - in - let last = - Let - ( direct - , Prim - ( Lt - , [ Pv counter - ; Pc - (Int (Targetint.of_int_exn tailcall_max_depth)) - ] ) ) - in - { block with body = List.rev (last :: rem_rev); branch } - in - let blocks = Addr.Map.remove pc blocks in - Addr.Map.add pc block blocks, free_pc - | _ -> assert false) - in - ( blocks - , free_pc - , Wrapper { name = ci.f_name; code = instr_real; wrapper = instr_wrapper } - :: closures )) + let ci = Var.Map.find id closures_map in + One + { name = ci.f_name + ; code = Let (ci.f_name, Closure (ci.args, ci.cont, ci.cloc)) + }) in free_pc, blocks, closures -end + else Trampoline.has_loop free_pc blocks closures_map all let rec rewrite_closures free_pc blocks body : int * _ * _ list = match body with @@ -294,7 +556,7 @@ let rec rewrite_closures free_pc blocks body : int * _ * _ list = ~init:(free_pc, blocks, []) ~f:(fun (free_pc, blocks, acc) component -> let free_pc, blocks, closures = - Trampoline.f free_pc blocks closures_map component + dispatch_component free_pc blocks closures_map component in let intrs = closures :: acc in free_pc, blocks, intrs) @@ -304,12 +566,19 @@ let rec rewrite_closures free_pc blocks body : int * _ * _ list = let pos = function | One { name; _ } -> pos_of_var name | Wrapper { name; _ } -> pos_of_var name + | Paired_one { name; _ } -> pos_of_var name + | Paired_wrapper { name; _ } -> pos_of_var name in List.flatten closures |> List.sort ~cmp:(fun a b -> compare (pos a) (pos b)) |> List.concat_map ~f:(function | One { code; _ } -> [ code ] - | Wrapper { code; wrapper; _ } -> [ code; wrapper ]) + | Wrapper { code; wrapper; _ } -> [ code; wrapper ] + | Paired_one { direct_code; cps_code; pair_code; _ } -> + [ direct_code; cps_code; pair_code ] + | Paired_wrapper + { inner_direct_code; wrapper_code; cps_code; pair_code; _ } -> + [ inner_direct_code; wrapper_code; cps_code; pair_code ]) in let free_pc, blocks, rem = rewrite_closures free_pc blocks rem in free_pc, blocks, closures @ rem @@ -337,8 +606,8 @@ let f p : Code.program = let f p = assert ( match Config.effects () with - | `Disabled | `Jspi -> true - | `Cps | `Double_translation -> false); + | `Disabled | `Jspi | `Double_translation -> true + | `Cps -> false); let open Config.Param in match tailcall_optim () with | TcNone -> p diff --git a/compiler/tests-jsoo/lib-effects/mutual_recursion_deep.ml b/compiler/tests-jsoo/lib-effects/mutual_recursion_deep.ml new file mode 100644 index 0000000000..f8c9978b22 --- /dev/null +++ b/compiler/tests-jsoo/lib-effects/mutual_recursion_deep.ml @@ -0,0 +1,24 @@ +(* Regression test for stack overflow on deep mutually recursive direct-style + functions under --effects=double-translation. + + Under --effects=cps every recursive call is CPS and bounces via + caml_stack_check_depth, so deep recursion is safe. Native and bytecode also + handle mutual tail recursion. Historically, under + --effects=double-translation the direct version of each function called + its siblings directly without any depth guard, so a sufficiently deep + mutual recursion blew the JS stack. The fix wraps cps_needed mutually + recursive direct closures with a depth-guarded trampoline that bounces + into the CPS partner when the JS stack budget is exhausted; this test + guards the fix. *) + +let rec ping n acc = if n = 0 then acc else pong (n - 1) (acc + 1) +and pong n acc = if n = 0 then acc else ping (n - 1) (acc + 1) + +let run n = + match ping n 0 with + | v -> Printf.sprintf "ok: %d" v + | exception Stack_overflow -> "stack overflow" + +let%expect_test "deep mutual recursion under double-translation" = + print_endline (run 1_000_000); + [%expect {| ok: 1000000 |}] diff --git a/runtime/js/jslib.js b/runtime/js/jslib.js index 2f95c36779..7b650af586 100644 --- a/runtime/js/jslib.js +++ b/runtime/js/jslib.js @@ -76,6 +76,33 @@ function caml_stack_check_depth() { return --caml_stack_depth > 0; } +//Provides: caml_direct_trampoline +//If: effects +//If: doubletranslate +//Requires: caml_stack_depth +// Entry trampoline for the direct version of a cps_needed mutually +// recursive function under --effects=double-translation. Sets up a stack +// budget, runs the inner direct body, then loops on caml_trampoline_return +// bounce objects until a plain value comes back. Mirrors the CPS-side +// trampoline (see caml_callback / caml_resume): each iteration starts with +// a fresh stack budget and dispatches the bounce target directly (the +// joo_tramp is always the inner direct closure of an SCC member, never a +// paired wrapper, so plain apply is correct). +function caml_direct_trampoline(f, args) { + var saved = caml_stack_depth; + try { + caml_stack_depth = 40; + var res = f.apply(null, args); + while (res?.joo_tramp) { + caml_stack_depth = 40; + res = res.joo_tramp.apply(null, res.joo_args); + } + return res; + } finally { + caml_stack_depth = saved; + } +} + //Provides: caml_callback //If: !effects //Requires:caml_call_gen