diff --git a/lib/firebird/compiler/ir_gen.ex b/lib/firebird/compiler/ir_gen.ex index 24562c5..e1b9ab9 100644 --- a/lib/firebird/compiler/ir_gen.ex +++ b/lib/firebird/compiler/ir_gen.ex @@ -23,6 +23,7 @@ defmodule Firebird.Compiler.IRGen do - Enum on ranges: `Enum.reduce/3`, `Enum.sum/1`, `Enum.product/1`, `Enum.count/1`, `Enum.count/2`, `Enum.min/1`, `Enum.max/1`, `Enum.any?/2`, `Enum.all?/2`, `Enum.find/2`, `Enum.find/3`, `Enum.find_index/2` over integer ranges (compiled to efficient WASM — closed-form where possible, early-termination loops for predicates and search) - Enum.member?/2: `Enum.member?(start..stop, value)`, `Enum.member?(start..stop//step, value)` — O(1) membership test on integer ranges (closed-form bounds check + divisibility, no loop) - Enum.at/2 and Enum.at/3: `Enum.at(start..stop, index)`, `Enum.at(start..stop//step, index, default)` — O(1) element access on integer ranges with bounds check and negative index support + - Enum.reduce_while/3: `Enum.reduce_while(start..stop, init, fn elem, acc -> {:halt, val} | {:cont, val} end)` — reduction with early termination (compiled to tail-recursive loop with break) - Enum.zip_reduce/4: `Enum.zip_reduce(range1, range2, acc, fn e1, e2, acc -> body end)` — pairwise reduction of two ranges (dot products, weighted sums, etc.) - Parity tests: `is_even/1`, `is_odd/1` (also as guards and `Integer.is_even/1`, `Integer.is_odd/1`) - Integer module: `Integer.floor_div/2`, `Integer.mod/2`, `Integer.pow/2`, `Integer.is_even/1`, `Integer.is_odd/1` @@ -1230,6 +1231,49 @@ defmodule Firebird.Compiler.IRGen do compile_enum_find_index(start, stop, step, iter, body) end + # --- Enum.reduce_while/3 on integer ranges --- + # Enum.reduce_while(start..stop, init, fn elem, acc -> ... end) + # Enum.reduce_while(start..stop//step, init, fn elem, acc -> ... end) + # Compiles to a tail-recursive loop with early termination: + # __rw_N__(i, stop, acc) = + # if i > stop do acc (range exhausted → return accumulator) + # else + # evaluate body with {:halt, val} → val (return immediately) + # and {:cont, val} → __rw_N__(i + step, stop, val) (continue) + # + # The reducer function must return {:halt, value} or {:cont, value}. + # {:halt, value} stops iteration and returns value immediately. + # {:cont, value} uses value as the new accumulator and continues. + # + # Common usage patterns: + # Enum.reduce_while(1..100, 0, fn i, acc -> + # if acc + i > 100, do: {:halt, acc}, else: {:cont, acc + i} + # end) + + # Enum.reduce_while/3 with plain range + def expr_to_ir( + {{:., _, [{:__aliases__, _, [:Enum]}, :reduce_while]}, _, + [ + {:.., _, [start, stop]}, + init, + {:fn, _, [{:->, _, [[iter, acc], body]}]} + ]} + ) do + compile_enum_reduce_while(start, stop, 1, init, iter, acc, body) + end + + # Enum.reduce_while/3 with stepped range + def expr_to_ir( + {{:., _, [{:__aliases__, _, [:Enum]}, :reduce_while]}, _, + [ + {:"..//", _, [start, stop, step]}, + init, + {:fn, _, [{:->, _, [[iter, acc], body]}]} + ]} + ) do + compile_enum_reduce_while(start, stop, step, init, iter, acc, body) + end + # --- Enum.zip_reduce/4 on two integer ranges --- # Enum.zip_reduce(start1..stop1, start2..stop2, init, fn e1, e2, acc -> body end) # Enum.zip_reduce(start1..stop1//step1, start2..stop2//step2, init, fn e1, e2, acc -> body end) @@ -2155,6 +2199,302 @@ defmodule Firebird.Compiler.IRGen do end end + # --- Enum.reduce_while/3 compilation --- + # Compiles Enum.reduce_while(start..stop//step, init, fn elem, acc -> body end) + # to a tail-recursive helper with early termination: + # + # __rw_N__(i, stop, acc, ...captured) = + # if i > stop do acc + # else + # // body with {:halt, val} → val, {:cont, val} → __rw_N__(i+step, stop, val, ...captured) + # + # The body is transformed at the AST level: {:halt, val} becomes a direct + # return, {:cont, val} becomes a recursive call to the helper. + # Free variables (outer function parameters referenced in the body) are + # automatically captured and threaded through as extra helper parameters. + defp compile_enum_reduce_while( + start_ast, + stop_ast, + step_ast, + init_ast, + iter_var_ast, + acc_var_ast, + body_ast + ) do + start_ir = expr_to_ir(start_ast) + stop_ir = expr_to_ir(stop_ast) + step_ir = expr_to_ir(step_ast) |> normalize_constant_step() + + iter_var = extract_var_name(iter_var_ast) + acc_var = extract_var_name(acc_var_ast) + init_ir = expr_to_ir(init_ast) + + # Generate a unique helper function name + counter = Process.get(:__irgen_for_counter__, 0) + helper_name = String.to_atom("__rw_#{counter}__") + Process.put(:__irgen_for_counter__, counter + 1) + + # Transform the body AST: + # - {:halt, val} → val (return immediately, no recursion) + # - {:cont, val} → {:__rw_continue__, [], [val]} (marker for recursion) + transformed_body = transform_rw_body(body_ast) + + # Compile the transformed body to IR + body_ir = expr_to_ir(transformed_body) + body_ir = substitute_vars(body_ir, %{iter_var => :p0, acc_var => :p2}) + + # Detect free variables in the body (outer scope references, not iter/acc/let-bound). + # These need to be passed as extra parameters to the helper function. + bound_vars = MapSet.new([:p0, :p1, :p2, :p3]) + body_let_vars = collect_let_bound_vars(body_ir) + bound_vars = MapSet.union(bound_vars, body_let_vars) + + free_vars = + collect_ir_vars(body_ir) |> MapSet.difference(bound_vars) |> MapSet.to_list() |> Enum.sort() + + # Map free variables to parameter positions: p_base, p_base+1, ... + # Base parameter count: 3 (i, stop, acc) or 4 (i, stop, acc, step) + base_arity = + case step_ir do + {:literal, s} when is_integer(s) and s != 0 -> 3 + _ -> 4 + end + + free_var_params = + free_vars + |> Enum.with_index(base_arity) + |> Enum.map(fn {_var, idx} -> String.to_atom("p#{idx}") end) + + free_var_map = Enum.zip(free_vars, free_var_params) |> Map.new() + + # Substitute free variables with their parameter names + body_ir = substitute_vars(body_ir, free_var_map) + + # Determine step handling: constant vs dynamic + {cmp_op, core_params, recurse_core_args_fn} = + case step_ir do + {:literal, s} when is_integer(s) and s > 0 -> + {:gt_s, [:p0, :p1, :p2], + fn val_ir -> + [{:binop, :add, {:var, :p0}, step_ir}, {:var, :p1}, val_ir] + end} + + {:literal, s} when is_integer(s) and s < 0 -> + {:lt_s, [:p0, :p1, :p2], + fn val_ir -> + [{:binop, :add, {:var, :p0}, step_ir}, {:var, :p1}, val_ir] + end} + + _ -> + {:dynamic, [:p0, :p1, :p2, :p3], + fn val_ir -> + [{:binop, :add, {:var, :p0}, {:var, :p3}}, {:var, :p1}, val_ir, {:var, :p3}] + end} + end + + all_params = core_params ++ free_var_params + total_arity = length(all_params) + + # Build recurse_args_fn that includes free var pass-through + free_var_pass = Enum.map(free_var_params, fn p -> {:var, p} end) + + recurse_args_fn = fn val_ir -> + recurse_core_args_fn.(val_ir) ++ free_var_pass + end + + # Replace {:call, :__rw_continue__, [val_ir]} markers in the body IR + # with actual recursive calls to the helper function + body_ir = replace_rw_continue(body_ir, helper_name, recurse_args_fn) + + # Build the helper body with range termination check + helper_body = + case cmp_op do + :dynamic -> + pos_done = + {:binop, :and_, {:binop, :gt_s, {:var, :p3}, {:literal, 0}}, + {:binop, :gt_s, {:var, :p0}, {:var, :p1}}} + + neg_done = + {:binop, :and_, {:binop, :lt_s, {:var, :p3}, {:literal, 0}}, + {:binop, :lt_s, {:var, :p0}, {:var, :p1}}} + + done_cond = {:binop, :or_, pos_done, neg_done} + {:if, done_cond, {:var, :p2}, body_ir} + + _ -> + {:if, {:binop, cmp_op, {:var, :p0}, {:var, :p1}}, {:var, :p2}, body_ir} + end + + helper_func = %IR.Function{ + name: helper_name, + arity: total_arity, + params: all_params, + body: helper_body, + clauses: [], + type: nil + } + + helpers = Process.get(:__irgen_for_helpers__, []) + Process.put(:__irgen_for_helpers__, [helper_func | helpers]) + + # Build the call arguments: core args + free variable values from outer scope + free_var_args = Enum.map(free_vars, fn var -> {:var, var} end) + + core_call_args = + case cmp_op do + :dynamic -> [start_ir, stop_ir, init_ir, step_ir] + _ -> [start_ir, stop_ir, init_ir] + end + + {:call, helper_name, core_call_args ++ free_var_args} + end + + # Transform the body AST of reduce_while: replace {:halt, val} and {:cont, val} + # with appropriate IR-level constructs. + # {:halt, val} → val (just return the value, no recursion) + # {:cont, val} → {:__rw_continue__, [], [val]} (sentinel for recursive call) + defp transform_rw_body({:halt, val}), do: val + + defp transform_rw_body({:cont, val}), do: {:__rw_continue__, [], [val]} + + # Walk through common AST structures to find halt/cont in nested positions + defp transform_rw_body({:if, meta, [cond, [do: then_body, else: else_body]]}) do + {:if, meta, [cond, [do: transform_rw_body(then_body), else: transform_rw_body(else_body)]]} + end + + defp transform_rw_body({:if, meta, [cond, [do: then_body]]}) do + {:if, meta, [cond, [do: transform_rw_body(then_body)]]} + end + + defp transform_rw_body({:case, meta, [subject, [do: clauses]]}) do + transformed_clauses = + Enum.map(clauses, fn {:->, clause_meta, [patterns, body]} -> + {:->, clause_meta, [patterns, transform_rw_body(body)]} + end) + + {:case, meta, [subject, [do: transformed_clauses]]} + end + + defp transform_rw_body({:cond, meta, [[do: clauses]]}) do + transformed_clauses = + Enum.map(clauses, fn {:->, clause_meta, [cond, body]} -> + {:->, clause_meta, [cond, transform_rw_body(body)]} + end) + + {:cond, meta, [[do: transformed_clauses]]} + end + + defp transform_rw_body({:__block__, meta, exprs}) do + # Only the last expression in a block can be halt/cont + {init, [last]} = Enum.split(exprs, -1) + {:__block__, meta, init ++ [transform_rw_body(last)]} + end + + defp transform_rw_body(other), do: other + + # Replace {:call, :__rw_continue__, [val_ir]} markers in compiled IR + # with actual recursive calls to the helper function. + defp replace_rw_continue({:call, :__rw_continue__, [val_ir]}, helper_name, args_fn) do + val_ir = replace_rw_continue(val_ir, helper_name, args_fn) + {:call, helper_name, args_fn.(val_ir)} + end + + defp replace_rw_continue({:if, cond_ir, then_ir, else_ir}, helper_name, args_fn) do + {:if, replace_rw_continue(cond_ir, helper_name, args_fn), + replace_rw_continue(then_ir, helper_name, args_fn), + replace_rw_continue(else_ir, helper_name, args_fn)} + end + + defp replace_rw_continue({:binop, op, left, right}, helper_name, args_fn) do + {:binop, op, replace_rw_continue(left, helper_name, args_fn), + replace_rw_continue(right, helper_name, args_fn)} + end + + defp replace_rw_continue({:unaryop, op, expr}, helper_name, args_fn) do + {:unaryop, op, replace_rw_continue(expr, helper_name, args_fn)} + end + + defp replace_rw_continue({:call, name, args}, helper_name, args_fn) do + {:call, name, Enum.map(args, &replace_rw_continue(&1, helper_name, args_fn))} + end + + defp replace_rw_continue({:case, subject, clauses}, helper_name, args_fn) do + {:case, replace_rw_continue(subject, helper_name, args_fn), + Enum.map(clauses, fn {pat, guard, body} -> + {pat, if(guard, do: replace_rw_continue(guard, helper_name, args_fn), else: nil), + replace_rw_continue(body, helper_name, args_fn)} + end)} + end + + defp replace_rw_continue({:block, exprs}, helper_name, args_fn) do + {:block, Enum.map(exprs, &replace_rw_continue(&1, helper_name, args_fn))} + end + + defp replace_rw_continue({:let, name, value}, helper_name, args_fn) do + {:let, name, replace_rw_continue(value, helper_name, args_fn)} + end + + defp replace_rw_continue(other, _helper_name, _args_fn), do: other + + # Collect all variable names referenced in an IR expression. + defp collect_ir_vars({:var, name}), do: MapSet.new([name]) + + defp collect_ir_vars({:binop, _, left, right}), + do: MapSet.union(collect_ir_vars(left), collect_ir_vars(right)) + + defp collect_ir_vars({:unaryop, _, expr}), do: collect_ir_vars(expr) + + defp collect_ir_vars({:call, _, args}), + do: Enum.reduce(args, MapSet.new(), fn a, acc -> MapSet.union(acc, collect_ir_vars(a)) end) + + defp collect_ir_vars({:if, c, t, e}), + do: MapSet.union(collect_ir_vars(c), MapSet.union(collect_ir_vars(t), collect_ir_vars(e))) + + defp collect_ir_vars({:case, subject, clauses}) do + clause_vars = + Enum.reduce(clauses, MapSet.new(), fn {_pat, guard, body}, acc -> + guard_vars = if guard, do: collect_ir_vars(guard), else: MapSet.new() + MapSet.union(acc, MapSet.union(guard_vars, collect_ir_vars(body))) + end) + + MapSet.union(collect_ir_vars(subject), clause_vars) + end + + defp collect_ir_vars({:block, exprs}), + do: Enum.reduce(exprs, MapSet.new(), fn e, acc -> MapSet.union(acc, collect_ir_vars(e)) end) + + defp collect_ir_vars({:let, _name, value}), do: collect_ir_vars(value) + defp collect_ir_vars(_), do: MapSet.new() + + # Collect all variable names bound by let expressions in an IR expression. + defp collect_let_bound_vars({:let, name, value}), + do: MapSet.put(collect_let_bound_vars(value), name) + + defp collect_let_bound_vars({:block, exprs}), + do: + Enum.reduce(exprs, MapSet.new(), fn e, acc -> + MapSet.union(acc, collect_let_bound_vars(e)) + end) + + defp collect_let_bound_vars({:if, c, t, e}), + do: + MapSet.union( + collect_let_bound_vars(c), + MapSet.union(collect_let_bound_vars(t), collect_let_bound_vars(e)) + ) + + defp collect_let_bound_vars({:case, subject, clauses}) do + clause_vars = + Enum.reduce(clauses, MapSet.new(), fn {_pat, guard, body}, acc -> + guard_vars = if guard, do: collect_let_bound_vars(guard), else: MapSet.new() + MapSet.union(acc, MapSet.union(guard_vars, collect_let_bound_vars(body))) + end) + + MapSet.union(collect_let_bound_vars(subject), clause_vars) + end + + defp collect_let_bound_vars(_), do: MapSet.new() + # --- Enum.zip_reduce/4 compilation --- # Compiles Enum.zip_reduce(range1, range2, acc, fn e1, e2, acc -> body end) # to a tail-recursive helper that iterates both ranges simultaneously: diff --git a/test/compiler/enum_reduce_while_test.exs b/test/compiler/enum_reduce_while_test.exs new file mode 100644 index 0000000..bad0900 --- /dev/null +++ b/test/compiler/enum_reduce_while_test.exs @@ -0,0 +1,258 @@ +defmodule Firebird.Compiler.EnumReduceWhileTest do + use ExUnit.Case, async: true + + alias Firebird.Compiler + + describe "Enum.reduce_while/3 on plain ranges" do + test "basic reduce_while with halt condition" do + # Sum numbers until sum exceeds 10 + source = """ + defmodule ReduceWhileBasic do + @wasm true + def sum_until_exceeds(n) do + Enum.reduce_while(1..n, 0, fn i, acc -> + new_acc = acc + i + if new_acc > 10, do: {:halt, new_acc}, else: {:cont, new_acc} + end) + end + end + """ + + {:ok, result} = Compiler.compile_source(source, optimize: true, tco: true) + assert result.wat != nil + + if Compiler.wat2wasm_available?() do + {:ok, wasm} = Compiler.wat_to_wasm(result.wat) + assert is_binary(wasm) + end + end + + test "reduce_while that never halts (exhausts range)" do + # All elements satisfy {:cont, _}, so we get the final accumulator + source = """ + defmodule ReduceWhileNoHalt do + @wasm true + def sum_all(n) do + Enum.reduce_while(1..n, 0, fn i, acc -> + {:cont, acc + i} + end) + end + end + """ + + {:ok, result} = Compiler.compile_source(source, optimize: true, tco: true) + assert result.wat != nil + end + + test "reduce_while that immediately halts" do + # First element triggers halt + source = """ + defmodule ReduceWhileImmediateHalt do + @wasm true + def first_if_positive(n) do + Enum.reduce_while(1..n, 0, fn i, acc -> + if i > 0, do: {:halt, i}, else: {:cont, acc} + end) + end + end + """ + + {:ok, result} = Compiler.compile_source(source, optimize: true, tco: true) + assert result.wat != nil + end + + test "reduce_while with accumulator transformation" do + # Multiply accumulator until it exceeds threshold + source = """ + defmodule ReduceWhileMul do + @wasm true + def grow_until(n, threshold) do + Enum.reduce_while(1..n, 1, fn i, acc -> + new_acc = acc * i + if new_acc > threshold, do: {:halt, new_acc}, else: {:cont, new_acc} + end) + end + end + """ + + {:ok, result} = Compiler.compile_source(source, optimize: true, tco: true) + assert result.wat != nil + end + end + + describe "Enum.reduce_while/3 on stepped ranges" do + test "reduce_while with positive step" do + # Sum even numbers (step 2) until sum exceeds threshold + source = """ + defmodule ReduceWhileStep do + @wasm true + def sum_evens_until(n, threshold) do + Enum.reduce_while(0..n//2, 0, fn i, acc -> + new_acc = acc + i + if new_acc > threshold, do: {:halt, new_acc}, else: {:cont, new_acc} + end) + end + end + """ + + {:ok, result} = Compiler.compile_source(source, optimize: true, tco: true) + assert result.wat != nil + end + + test "reduce_while with negative step" do + # Count down from n, halting when accumulator exceeds threshold + source = """ + defmodule ReduceWhileNegStep do + @wasm true + def countdown_sum(n, threshold) do + Enum.reduce_while(n..1//-1, 0, fn i, acc -> + new_acc = acc + i + if new_acc > threshold, do: {:halt, new_acc}, else: {:cont, new_acc} + end) + end + end + """ + + {:ok, result} = Compiler.compile_source(source, optimize: true, tco: true) + assert result.wat != nil + end + end + + describe "Enum.reduce_while/3 end-to-end execution" do + @tag :wat2wasm_required + test "sum_until_exceeds returns correct value" do + source = """ + defmodule ReduceWhileE2E do + @wasm true + def sum_until_exceeds(n) do + Enum.reduce_while(1..n, 0, fn i, acc -> + new_acc = acc + i + if new_acc > 10, do: {:halt, new_acc}, else: {:cont, new_acc} + end) + end + + @wasm true + def sum_all_within(n) do + Enum.reduce_while(1..n, 0, fn i, acc -> + {:cont, acc + i} + end) + end + + @wasm true + def find_first_above(n, threshold) do + Enum.reduce_while(1..n, 0, fn i, _acc -> + if i > threshold, do: {:halt, i}, else: {:cont, 0} + end) + end + end + """ + + {:ok, result} = Compiler.compile_source(source, optimize: true, tco: true) + {:ok, wasm} = Compiler.wat_to_wasm(result.wat) + + {:ok, instance} = Wasmex.start_link(%{bytes: wasm}) + + # sum_until_exceeds(100): 1+2+3+4+5 = 15 > 10, so returns 15 + {:ok, [value]} = Wasmex.call_function(instance, "sum_until_exceeds", [100]) + assert value == 15 + + # sum_until_exceeds(4): 1+2+3+4 = 10, not > 10. 1+2+3+4+5 is beyond range. + # Range is 1..4, sum = 10, never exceeds, returns final acc = 10 + {:ok, [value]} = Wasmex.call_function(instance, "sum_until_exceeds", [4]) + assert value == 10 + + # sum_all_within(5): just sums 1..5 = 15 (never halts) + {:ok, [value]} = Wasmex.call_function(instance, "sum_all_within", [5]) + assert value == 15 + + # find_first_above(10, 3): first i > 3 is 4 + {:ok, [value]} = Wasmex.call_function(instance, "find_first_above", [10, 3]) + assert value == 4 + end + + @tag :wat2wasm_required + test "reduce_while with stepped range execution" do + source = """ + defmodule ReduceWhileStepE2E do + @wasm true + def sum_evens_until(n, threshold) do + Enum.reduce_while(0..n//2, 0, fn i, acc -> + new_acc = acc + i + if new_acc > threshold, do: {:halt, new_acc}, else: {:cont, new_acc} + end) + end + end + """ + + {:ok, result} = Compiler.compile_source(source, optimize: true, tco: true) + {:ok, wasm} = Compiler.wat_to_wasm(result.wat) + + {:ok, instance} = Wasmex.start_link(%{bytes: wasm}) + + # 0 + 2 = 2, 2 + 4 = 6, 6 + 6 = 12 > 10, halts with 12 + {:ok, [value]} = Wasmex.call_function(instance, "sum_evens_until", [100, 10]) + assert value == 12 + + # With large threshold, sum all: 0+2+4+6+8+10 = 30 + {:ok, [value]} = Wasmex.call_function(instance, "sum_evens_until", [10, 100]) + assert value == 30 + end + end + + describe "Enum.reduce_while/3 IR generation" do + test "generates helper function with __rw_ prefix" do + source = """ + defmodule ReduceWhileIR do + @wasm true + def test_rw(n) do + Enum.reduce_while(1..n, 0, fn i, acc -> + if acc > 10, do: {:halt, acc}, else: {:cont, acc + i} + end) + end + end + """ + + {:ok, result} = Compiler.compile_source(source, wat_only: true) + # The WAT should contain our helper function + assert result.wat =~ "__rw_" + end + + test "compiles with optimization passes" do + source = """ + defmodule ReduceWhileOpt do + @wasm true + def optimized_rw(n) do + Enum.reduce_while(1..n, 0, fn i, acc -> + new_acc = acc + i + if new_acc > 100, do: {:halt, new_acc}, else: {:cont, new_acc} + end) + end + end + """ + + {:ok, result} = Compiler.compile_source(source, optimize: true, tco: true, wat_only: true) + assert result.wat != nil + # Should have TCO applied (loop/br_table structure) + assert result.wat =~ "loop" or result.wat =~ "call" + end + end + + describe "Enum.reduce_while/3 with block body" do + test "block body with let bindings" do + source = """ + defmodule ReduceWhileBlock do + @wasm true + def factorial_until(n, limit) do + Enum.reduce_while(1..n, 1, fn i, acc -> + result = acc * i + if result > limit, do: {:halt, acc}, else: {:cont, result} + end) + end + end + """ + + {:ok, result} = Compiler.compile_source(source, optimize: true, tco: true) + assert result.wat != nil + end + end +end