diff --git a/crates/observers/src/traits.rs b/crates/observers/src/traits.rs index 4f83d3d..96309ad 100644 --- a/crates/observers/src/traits.rs +++ b/crates/observers/src/traits.rs @@ -78,9 +78,11 @@ where P: EquationProblem<1, Input = M::Input, Output = M::Output>, { fn residual(&self) -> f64 { - match self.result() { - Ok(eval) => eval.residuals[0], - Err(_) => f64::NAN, + match self { + bisection::Event::Evaluated { point, .. } => point.residual, + bisection::Event::ModelFailed { .. } | bisection::Event::ProblemFailed { .. } => { + f64::NAN + } } } } @@ -209,6 +211,22 @@ mod tests { } } + struct FailingEqProblem; + + impl EquationProblem<1> for FailingEqProblem { + type Input = f64; + type Output = f64; + type Error = Failure; + + fn input(&self, x: &[f64; 1]) -> Result { + Ok(x[0]) + } + + fn residuals(&self, _: &f64, _: &f64) -> Result<[f64; 1], Failure> { + Err(Failure) + } + } + struct FailingOptProblem; impl OptimizationProblem<1> for FailingOptProblem { @@ -227,46 +245,52 @@ mod tests { // --- HasResidual for bisection::Event --- + fn test_bracket() -> bisection::Bracket { + bisection::Bracket::new( + (0.0, bisection::Sign::Negative), + (1.0, bisection::Sign::Positive), + ) + .unwrap() + } + #[test] - fn bisection_residual_ok() { - // Drive the solver one step to get a real event with a valid residual. - // LinearProblem: residual = output = input = x, so residual ≠ NAN. - let model = Identity; - let problem = LinearProblem; - let mut residual_seen = None; - let _ = bisection::solve( - &model, - &problem, - [-1.0, 1.0], - &bisection::Config::default(), - |event: &bisection::Event<'_, Identity, LinearProblem>| { - if residual_seen.is_none() { - residual_seen = Some(event.residual()); - } - None - }, - ); - let r = residual_seen.expect("at least one event emitted"); - assert!(r.is_finite(), "expected finite residual, got {r}"); + fn bisection_residual_evaluated() { + let input = 1.0_f64; + let output = 1.0_f64; + let bracket = test_bracket(); + let event: bisection::Event<'_, Identity, LinearProblem> = bisection::Event::Evaluated { + point: bisection::Point::new(1.0, 0.5), + input: &input, + output: &output, + bracket: &bracket, + }; + assert_relative_eq!(event.residual(), 0.5); + } + + #[test] + fn bisection_residual_nan_on_model_failed() { + let error = Failure; + let bracket = test_bracket(); + let event: bisection::Event<'_, FailingModel, LinearProblem> = + bisection::Event::ModelFailed { + x: 0.5, + error: &error, + bracket: &bracket, + }; + assert!(event.residual().is_nan()); } #[test] - fn bisection_residual_nan_on_model_error() { - // FailingModel always errors, so every event result is Err → NAN. - let model = FailingModel; - let problem = LinearProblem; - let mut got_nan = false; - let _ = bisection::solve( - &model, - &problem, - [-1.0, 1.0], - &bisection::Config::default(), - |event: &bisection::Event<'_, FailingModel, LinearProblem>| { - got_nan = event.residual().is_nan(); - Some(bisection::Action::StopEarly) - }, - ); - assert!(got_nan); + fn bisection_residual_nan_on_problem_failed() { + let error = Failure; + let bracket = test_bracket(); + let event: bisection::Event<'_, Identity, FailingEqProblem> = + bisection::Event::ProblemFailed { + x: 0.5, + error: &error, + bracket: &bracket, + }; + assert!(event.residual().is_nan()); } // --- HasObjective for golden_section::Event --- diff --git a/crates/solvers/src/equation.rs b/crates/solvers/src/equation.rs index 5fda786..20e2fda 100644 --- a/crates/solvers/src/equation.rs +++ b/crates/solvers/src/equation.rs @@ -10,8 +10,11 @@ //! //! [`EquationProblem`]: twine_core::EquationProblem +mod best; mod evaluate; -pub use evaluate::{EvalError, EvaluateResult, Evaluation, evaluate}; - pub mod bisection; +pub mod bracket; +pub mod solution; + +pub use evaluate::{EvalError, Evaluation, evaluate}; diff --git a/crates/solvers/src/equation/bisection/best.rs b/crates/solvers/src/equation/best.rs similarity index 70% rename from crates/solvers/src/equation/bisection/best.rs rename to crates/solvers/src/equation/best.rs index f15a773..de47326 100644 --- a/crates/solvers/src/equation/bisection/best.rs +++ b/crates/solvers/src/equation/best.rs @@ -1,23 +1,24 @@ use crate::equation::Evaluation; -use super::{Error, Solution, Status}; +use super::solution::{Solution, Status}; /// Tracks the best evaluation encountered so far. /// /// The best evaluation is defined by minimum residual magnitude. -/// The `Option` lets us represent the state before any successful evaluation. -pub(super) struct Best { +/// The bracket can shrink without any successful evaluation (via observer +/// recovery), so `None` is a normal operating state — not an error condition. +pub(crate) struct Best { eval: Option>, } impl Best { /// Creates an empty best tracker. - pub(super) fn empty() -> Self { + pub(crate) fn empty() -> Self { Self { eval: None } } /// Updates the best evaluation if the residual magnitude improves. - pub(super) fn update(&mut self, eval: Evaluation) { + pub(crate) fn update(&mut self, eval: Evaluation) { if let Some(best) = self.eval.as_ref() && eval.residuals[0].abs() >= best.residuals[0].abs() { @@ -27,20 +28,18 @@ impl Best { } /// Returns true if the best residual meets the tolerance. - pub(super) fn is_residual_converged(&self, residual_tol: f64) -> bool { + pub(crate) fn is_residual_converged(&self, residual_tol: f64) -> bool { self.eval .as_ref() .is_some_and(|eval| eval.residuals[0].abs() <= residual_tol) } - /// Finalizes the solver using the best available evaluation. + /// Builds a solution from the best evaluation. /// - /// # Errors - /// - /// Returns `Error::NoSuccessfulEvaluation` if no successful evaluation is stored. - pub(super) fn finish(self, status: Status, iters: usize) -> Result, Error> { - let eval = self.eval.ok_or(Error::NoSuccessfulEvaluation)?; - Ok(Solution { + /// Returns `None` if no successful evaluation has been recorded. + pub(crate) fn into_solution(self, status: Status, iters: usize) -> Option> { + let eval = self.eval?; + Some(Solution { status, x: eval.x[0], residual: eval.residuals[0], @@ -74,7 +73,7 @@ mod tests { best.update(eval(3.0, 1.0)); let solution = best - .finish(Status::StoppedByObserver, 0) + .into_solution(Status::StoppedByObserver, 0) .expect("best eval"); assert_relative_eq!(solution.x, 3.0); @@ -88,7 +87,7 @@ mod tests { best.update(eval(2.0, 2.0)); let solution = best - .finish(Status::StoppedByObserver, 0) + .into_solution(Status::StoppedByObserver, 0) .expect("best eval"); assert_relative_eq!(solution.x, 1.0); @@ -96,7 +95,7 @@ mod tests { } #[test] - fn residual_converged_requires_best() { + fn residual_converged_requires_eval() { let best: Best<(), ()> = Best::empty(); assert!(!best.is_residual_converged(1e-3)); } @@ -111,18 +110,17 @@ mod tests { } #[test] - fn finish_errors_without_eval() { + fn into_solution_returns_none_without_eval() { let best: Best<(), ()> = Best::empty(); - let err = best.finish(Status::StoppedByObserver, 0); - assert!(matches!(err, Err(Error::NoSuccessfulEvaluation))); + assert!(best.into_solution(Status::StoppedByObserver, 0).is_none()); } #[test] - fn finish_builds_solution() { + fn into_solution_builds_solution() { let mut best = Best::empty(); best.update(eval(2.0, -1.25)); - let solution = best.finish(Status::Converged, 4).expect("best eval"); + let solution = best.into_solution(Status::Converged, 4).expect("best eval"); assert_eq!(solution.status, Status::Converged); assert_eq!(solution.iters, 4); diff --git a/crates/solvers/src/equation/bisection.rs b/crates/solvers/src/equation/bisection.rs index 7f108b2..e2fbecd 100644 --- a/crates/solvers/src/equation/bisection.rs +++ b/crates/solvers/src/equation/bisection.rs @@ -1,64 +1,56 @@ mod action; -mod best; -mod bracket; mod config; mod decision; mod error; mod eval_context; mod event; -mod solution; +mod point; pub use action::Action; -pub use bracket::{Bracket, BracketError, Sign}; pub use config::{Config, ConfigError}; pub use error::Error; pub use event::Event; -pub use solution::{Solution, Status}; +pub use point::Point; + +pub use crate::equation::{ + bracket::{Bracket, BracketError, Sign}, + solution::{Solution, Status}, +}; use twine_core::{EquationProblem, Model, Observer}; -use best::Best; -use bracket::Bounds; +use crate::equation::{best::Best, bracket::Bounds, evaluate}; + use decision::Decision; use eval_context::EvalContext; /// Finds a root of the equation using the bisection method. /// -/// # Algorithm -/// -/// 1. Evaluate the left and right endpoints. -/// 2. Validate that the endpoints bracket a root using residual signs. -/// 3. Iterate: evaluate the midpoint, shrink the bracket, and update the best evaluation. +/// Evaluates both endpoints to establish a bracket, then iterates by +/// evaluating midpoints and shrinking the bracket. /// -/// Convergence is reported when either: -/// - The best residual magnitude is within `config.residual_tol` (absolute only), or -/// - The bracket width satisfies `x_abs_tol + x_rel_tol * |mid|`. +/// Endpoint evaluation failures are hard errors — if either endpoint fails, +/// the solver returns immediately with [`Error::Model`] or [`Error::Problem`]. +/// For control over endpoint evaluation (e.g., domain-specific error +/// recovery or noise filtering), evaluate endpoints yourself and use +/// [`solve_from_bracket`]. /// /// # Observer /// -/// The observer receives an [`Event`] for each evaluation and may: -/// - Return `Action::StopEarly` to stop and return the best evaluation so far. -/// - Return `Action::AssumeResidualSign(Sign)` to recover from evaluation -/// failures by providing a residual sign for bracket updates. -/// When this action is used on a successful evaluation, that evaluation is -/// not considered for the best solution. -/// -/// # Notes -/// -/// The returned [`Solution`] always reflects the best successful evaluation -/// seen so far (by residual magnitude). -/// Iteration counts correspond to the number of midpoint evaluations performed. +/// The observer receives an [`Event`] for each **midpoint** evaluation. +/// Endpoint evaluations are not observed. +/// See [`solve_from_bracket`] for details on observer actions. /// /// # Errors /// /// Returns an error if the bracket is invalid, the config is invalid, -/// or the model or problem returns an unrecovered error during evaluation. +/// or an endpoint or midpoint evaluation fails without observer recovery. pub fn solve( model: &M, problem: &P, bracket: [f64; 2], config: &Config, - mut observer: Obs, + observer: Obs, ) -> Result, Error> where M: Model, @@ -74,44 +66,128 @@ where let [left, right] = bounds.as_array(); let mut best = Best::empty(); - let mut ctx = EvalContext::new(model, problem, &mut observer); - // Resolve left endpoint. - let (left_eval, left_decision) = ctx.left_endpoint(left); - if let Some(eval) = left_eval { - best.update(eval); - } - let left_sign = match left_decision { - Decision::Continue(sign) => sign, - Decision::StopEarly => return best.finish(Status::StoppedByObserver, 0), - Decision::Error(error) => return Err(error), + // Evaluate left endpoint. + let left_sign = match evaluate(model, problem, [left]) { + Ok(eval) => { + let sign = Sign::of(eval.residuals[0]); + best.update(eval); + sign + } + Err(error) => return Err(error.into()), }; - // Resolve right endpoint. - let (right_eval, right_decision) = ctx.right_endpoint(right); - if let Some(eval) = right_eval { - best.update(eval); - } - let right_sign = match right_decision { - Decision::Continue(sign) => sign, - Decision::StopEarly => return best.finish(Status::StoppedByObserver, 0), - Decision::Error(error) => return Err(error), + // Evaluate right endpoint. + let right_sign = match evaluate(model, problem, [right]) { + Ok(eval) => { + let sign = Sign::of(eval.residuals[0]); + best.update(eval); + sign + } + Err(error) => return Err(error.into()), }; - // Validate bracket signs now that both endpoints are known. - let mut bracket = Bracket::new(bounds, left_sign, right_sign)?; + // Validate bracket signs. + let bracket = Bracket::from_bounds(bounds, left_sign, right_sign)?; + + solve_from_bracket_inner(model, problem, bracket, config, best, observer) +} +/// Runs bisection without observation. +/// +/// # Errors +/// +/// Returns an error if the bracket is invalid, the config is invalid, +/// or the model or problem returns an error during evaluation. +pub fn solve_unobserved( + model: &M, + problem: &P, + bracket: [f64; 2], + config: &Config, +) -> Result, Error> +where + M: Model, + M::Input: Clone, + M::Output: Clone, + P: EquationProblem<1, Input = M::Input, Output = M::Output>, +{ + solve(model, problem, bracket, config, ()) +} + +/// Finds a root using bisection with a pre-validated bracket. +/// +/// This skips endpoint evaluation — the caller is responsible for evaluating +/// the endpoints and constructing a valid [`Bracket`] with known residual +/// signs. +/// This is useful when endpoint evaluation requires domain-specific handling +/// (e.g., error recovery, noise filtering) that the solver's observer protocol +/// doesn't cover. +/// +/// # Observer +/// +/// The observer receives an [`Event`] for each midpoint evaluation and may: +/// - Return [`Action::StopEarly`] to stop and return the best evaluation so far. +/// - Return [`Action::AssumeResidualSign`] to recover from evaluation failures +/// by providing a residual sign for bracket updates. +/// When this action is used on a successful evaluation, that evaluation is +/// not considered for the best solution. +/// +/// # Notes +/// +/// The returned [`Solution`] reflects the best successful midpoint evaluation +/// seen during the solve. +/// Endpoint evaluations are not tracked — if no midpoint succeeds, this +/// returns [`Error::NoSuccessfulEvaluation`]. +/// +/// # Errors +/// +/// Returns an error if the config is invalid or the model or problem returns +/// an unrecovered error during evaluation. +pub fn solve_from_bracket( + model: &M, + problem: &P, + bracket: Bracket, + config: &Config, + observer: Obs, +) -> Result, Error> +where + M: Model, + M::Input: Clone, + M::Output: Clone, + P: EquationProblem<1, Input = M::Input, Output = M::Output>, + Obs: for<'a> Observer, Action>, +{ + config.validate()?; + solve_from_bracket_inner(model, problem, bracket, config, Best::empty(), observer) +} + +/// Core midpoint loop shared by `solve` and `solve_from_bracket`. +fn solve_from_bracket_inner( + model: &M, + problem: &P, + mut bracket: Bracket, + config: &Config, + mut best: Best, + mut observer: Obs, +) -> Result, Error> +where + M: Model, + M::Input: Clone, + M::Output: Clone, + P: EquationProblem<1, Input = M::Input, Output = M::Output>, + Obs: for<'a> Observer, Action>, +{ if best.is_residual_converged(config.residual_tol) { - return best.finish(Status::Converged, 0); + return finish(best, Status::Converged, 0); } - // Iterate by shrinking the bracket with midpoint evaluations. + let mut ctx = EvalContext::new(model, problem, &mut observer); + for iter in 1..=config.max_iters { if bracket.is_x_converged(config.x_abs_tol, config.x_rel_tol) { - return best.finish(Status::Converged, iter - 1); + return finish(best, Status::Converged, iter - 1); } - // Evaluate the midpoint and update the bracket. let mid = bracket.midpoint(); let (mid_eval, mid_decision) = ctx.midpoint(mid, &bracket); if let Some(eval) = mid_eval { @@ -120,38 +196,23 @@ where match mid_decision { Decision::Continue(sign) => bracket.shrink(mid, sign), Decision::StopEarly => { - return best.finish(Status::StoppedByObserver, iter); + return finish(best, Status::StoppedByObserver, iter); } Decision::Error(error) => return Err(error), } if best.is_residual_converged(config.residual_tol) { - return best.finish(Status::Converged, iter); + return finish(best, Status::Converged, iter); } } - best.finish(Status::MaxIters, config.max_iters) + finish(best, Status::MaxIters, config.max_iters) } -/// Runs bisection without observation. -/// -/// # Errors -/// -/// Returns an error if the bracket is invalid, the config is invalid, -/// or the model or problem returns an error during evaluation. -pub fn solve_unobserved( - model: &M, - problem: &P, - bracket: [f64; 2], - config: &Config, -) -> Result, Error> -where - M: Model, - M::Input: Clone, - M::Output: Clone, - P: EquationProblem<1, Input = M::Input, Output = M::Output>, -{ - solve(model, problem, bracket, config, ()) +/// Converts a best tracker into a solution or a "no successful evaluation" error. +fn finish(best: Best, status: Status, iters: usize) -> Result, Error> { + best.into_solution(status, iters) + .ok_or(Error::NoSuccessfulEvaluation) } #[cfg(test)] @@ -233,6 +294,8 @@ mod tests { } } + // --- solve tests (full lifecycle) --- + #[test] fn finds_square_root() { let model = SquareModel; @@ -260,27 +323,14 @@ mod tests { } #[test] - fn observer_can_stop_iteration() { - let model = SquareModel; + fn solve_errors_on_endpoint_failure() { + // Model fails everywhere — endpoints can't be evaluated. + let model = ThresholdModel { threshold: -1.0 }; let problem = TargetOutputProblem { target: 9.0 }; - let mut midpoint_count = 0usize; - let observer = |event: &Event<'_, _, _>| { - if matches!(event, Event::Midpoint { .. }) { - midpoint_count += 1; - if midpoint_count >= 3 { - return Some(Action::StopEarly); - } - } - None - }; + let result = solve_unobserved(&model, &problem, [0.0, 10.0], &Config::default()); - let solution = solve(&model, &problem, [0.0, 10.0], &Config::default(), observer) - .expect("should stop cleanly"); - - assert_eq!(solution.status, Status::StoppedByObserver); - assert_eq!(solution.iters, 3); - assert_eq!(midpoint_count, 3); + assert!(matches!(result, Err(Error::Model(_)))); } #[test] @@ -298,59 +348,93 @@ mod tests { assert_eq!(solution.status, Status::MaxIters); assert_eq!(solution.iters, 0); // x=2 gives residual |4-9|=5, x=10 gives |100-9|=91 - // So best endpoint should be x=2 assert_relative_eq!(solution.x, 2.0); } #[test] - fn observer_can_recover_from_eval_failure() { - // Model fails above x=7, root is at x=3 (for target=9) - let model = ThresholdModel { threshold: 7.0 }; + fn converges_on_small_bracket_width() { + let model = SquareModel; + let problem = TargetOutputProblem { target: 9.0 }; + + let config = Config { + max_iters: 10, + x_abs_tol: 1.0, + ..Config::default() + }; + + let solution = solve_unobserved(&model, &problem, [2.9, 3.1], &config) + .expect("should converge on x tolerance"); + + assert_eq!(solution.status, Status::Converged); + assert_eq!(solution.iters, 0); + } + + // --- solve_from_bracket tests (midpoint loop only) --- + + #[test] + fn from_bracket_finds_root() { + let model = SquareModel; let problem = TargetOutputProblem { target: 9.0 }; - // Initial bracket [0, 10] would fail at right endpoint (x=10 > threshold=7) - // Observer tells solver to use a positive residual for failed points - // (points above threshold would have large positive residuals: x^2 - 9 > 0) - let observer = |event: &Event<'_, _, _>| { - let is_err = event.result().is_err(); - if is_err { - // Failed points are above threshold, so residual would be positive - Some(Action::assume_positive()) + // x=0: residual = 0-9 = -9 (negative) + // x=10: residual = 100-9 = 91 (positive) + let bracket = + Bracket::new((0.0, Sign::Negative), (10.0, Sign::Positive)).expect("valid bracket"); + + let solution = solve_from_bracket(&model, &problem, bracket, &Config::default(), ()) + .expect("should solve"); + + assert_eq!(solution.status, Status::Converged); + assert_relative_eq!(solution.x, 3.0, epsilon = 1e-10); + } + + #[test] + fn observer_can_stop_iteration() { + let model = SquareModel; + let problem = TargetOutputProblem { target: 9.0 }; + + let bracket = + Bracket::new((0.0, Sign::Negative), (10.0, Sign::Positive)).expect("valid bracket"); + + let mut eval_count = 0usize; + let observer = |_event: &Event<'_, _, _>| { + eval_count += 1; + if eval_count >= 3 { + Some(Action::StopEarly) } else { None } }; - let solution = solve(&model, &problem, [0.0, 10.0], &Config::default(), observer) - .expect("should recover and solve"); + let solution = solve_from_bracket(&model, &problem, bracket, &Config::default(), observer) + .expect("should stop cleanly"); - assert_eq!(solution.status, Status::Converged); - assert_relative_eq!(solution.x, 3.0, epsilon = 1e-10); + assert_eq!(solution.status, Status::StoppedByObserver); + assert_eq!(solution.iters, 3); + assert_eq!(eval_count, 3); } #[test] fn midpoint_failure_assumes_sign() { // Model fails above x=3.5, root is at x=3 (for target=9) - // Initial bracket [0, 3.5] is valid, midpoint=1.75 is valid - // But as bisection homes in from the left, midpoints > 3.5 will fail let model = ThresholdModel { threshold: 3.5 }; let problem = TargetOutputProblem { target: 9.0 }; + // x=0: residual = -9 (negative), x=3.5: residual = 3.25 (positive) + let bracket = + Bracket::new((0.0, Sign::Negative), (3.5, Sign::Positive)).expect("valid bracket"); + let mut recovery_count = 0usize; let observer = |event: &Event<'_, _, _>| { - let is_err = event.result().is_err(); - if is_err { + if matches!(event, Event::ModelFailed { .. }) { recovery_count += 1; - // Failed points are above threshold, so residual would be positive Some(Action::assume_positive()) } else { None } }; - // Bracket: left residual at x=0 is 0-9=-9, right residual at x=3.5 is 12.25-9=3.25 - // Different signs, so valid bracket - let solution = solve(&model, &problem, [0.0, 3.5], &Config::default(), observer) + let solution = solve_from_bracket(&model, &problem, bracket, &Config::default(), observer) .expect("should recover and solve"); assert_eq!(solution.status, Status::Converged); @@ -362,21 +446,22 @@ mod tests { let model = SquareModel; let problem = TargetOutputProblem { target: 9.0 }; - let observer = |event: &Event<'_, _, _>| match event { - Event::Left { .. } => Some(Action::assume_negative()), - Event::Right { .. } | Event::Midpoint { .. } => None, - }; + // x=2: residual = -5 (negative), x=10: residual = 91 (positive) + let bracket = + Bracket::new((2.0, Sign::Negative), (10.0, Sign::Positive)).expect("valid bracket"); + + // Assume positive on every midpoint — always shrink from the right. + // The actual eval is discarded, so best stays empty. + let observer = |_event: &Event<'_, _, _>| Some(Action::assume_positive()); let config = Config { - max_iters: 0, + max_iters: 3, ..Config::default() }; - let solution = solve(&model, &problem, [2.0, 10.0], &config, observer) - .expect("should return best endpoint"); + let result = solve_from_bracket(&model, &problem, bracket, &config, observer); - assert_eq!(solution.status, Status::MaxIters); - assert_relative_eq!(solution.x, 10.0); + assert!(matches!(result, Err(Error::NoSuccessfulEvaluation))); } #[test] @@ -384,31 +469,14 @@ mod tests { let model = ThresholdModel { threshold: -1.0 }; let problem = TargetOutputProblem { target: 9.0 }; - let observer = |event: &Event<'_, _, _>| match event { - Event::Left { .. } => Some(Action::assume_negative()), - Event::Right { .. } | Event::Midpoint { .. } => Some(Action::assume_positive()), - }; - - let result = solve(&model, &problem, [0.0, 10.0], &Config::default(), observer); - - assert!(matches!(result, Err(Error::NoSuccessfulEvaluation))); - } - - #[test] - fn converges_on_small_bracket_width() { - let model = SquareModel; - let problem = TargetOutputProblem { target: 9.0 }; + let bracket = + Bracket::new((0.0, Sign::Negative), (10.0, Sign::Positive)).expect("valid bracket"); - let config = Config { - max_iters: 10, - x_abs_tol: 1.0, - ..Config::default() - }; + // Model fails everywhere, observer assumes signs to keep going. + let observer = |_event: &Event<'_, _, _>| Some(Action::assume_positive()); - let solution = solve_unobserved(&model, &problem, [2.9, 3.1], &config) - .expect("should converge on x tolerance"); + let result = solve_from_bracket(&model, &problem, bracket, &Config::default(), observer); - assert_eq!(solution.status, Status::Converged); - assert_eq!(solution.iters, 0); + assert!(matches!(result, Err(Error::NoSuccessfulEvaluation))); } } diff --git a/crates/solvers/src/equation/bisection/action.rs b/crates/solvers/src/equation/bisection/action.rs index a9a3d0e..f17161d 100644 --- a/crates/solvers/src/equation/bisection/action.rs +++ b/crates/solvers/src/equation/bisection/action.rs @@ -1,4 +1,4 @@ -use super::bracket::Sign; +use crate::equation::bracket::Sign; /// Control actions supported by the bisection solver. #[derive(Debug, Clone, Copy, PartialEq)] diff --git a/crates/solvers/src/equation/bisection/decision.rs b/crates/solvers/src/equation/bisection/decision.rs index 3a9273b..69579f0 100644 --- a/crates/solvers/src/equation/bisection/decision.rs +++ b/crates/solvers/src/equation/bisection/decision.rs @@ -1,4 +1,6 @@ -use super::{Action, Error, Sign}; +use crate::equation::bracket::Sign; + +use super::{Action, Error}; /// Control flow outcomes for a single evaluation. #[derive(Debug)] diff --git a/crates/solvers/src/equation/bisection/error.rs b/crates/solvers/src/equation/bisection/error.rs index 93dff98..61efa32 100644 --- a/crates/solvers/src/equation/bisection/error.rs +++ b/crates/solvers/src/equation/bisection/error.rs @@ -2,9 +2,9 @@ use std::error::Error as StdError; use thiserror::Error; -use crate::equation::EvalError; +use crate::equation::{EvalError, bracket::BracketError}; -use super::{bracket::BracketError, config::ConfigError}; +use super::config::ConfigError; /// Errors that can occur during bisection solving. #[derive(Debug, Error)] diff --git a/crates/solvers/src/equation/bisection/eval_context.rs b/crates/solvers/src/equation/bisection/eval_context.rs index 44bda61..9eacc48 100644 --- a/crates/solvers/src/equation/bisection/eval_context.rs +++ b/crates/solvers/src/equation/bisection/eval_context.rs @@ -1,8 +1,8 @@ use twine_core::{EquationProblem, Model, Observer}; -use crate::equation::{Evaluation, evaluate}; +use crate::equation::{EvalError, Evaluation, bracket::Bracket, evaluate}; -use super::{Action, Bracket, Decision, Event}; +use super::{Action, Decision, Event, Point}; type EvalOutcome = (Option>, Decision); @@ -38,68 +38,64 @@ where } } - /// Evaluates the left endpoint and returns the observer decision. - pub(crate) fn left_endpoint(&mut self, x: f64) -> EvalOutcome { - let result = evaluate(self.model, self.problem, [x]); - let action = self.observer.observe(&Event::Left { x, result: &result }); - - let (residual, mut eval) = match result { - Ok(eval) => (Ok(eval.residuals[0]), Some(eval)), - Err(error) => (Err(error.into()), None), - }; - - let decision = Decision::new(action, residual); - - if matches!(action, Some(Action::AssumeResidualSign(_))) { - eval = None; - } - - (eval, decision) - } - - /// Evaluates the right endpoint and returns the observer decision. - pub(crate) fn right_endpoint(&mut self, x: f64) -> EvalOutcome { - let result = evaluate(self.model, self.problem, [x]); - let action = self.observer.observe(&Event::Right { x, result: &result }); - - let (residual, mut eval) = match result { - Ok(eval) => (Ok(eval.residuals[0]), Some(eval)), - Err(error) => (Err(error.into()), None), - }; - - let decision = Decision::new(action, residual); - - if matches!(action, Some(Action::AssumeResidualSign(_))) { - eval = None; - } - - (eval, decision) - } - - /// Evaluates the midpoint and returns the observer decision. + /// Evaluates the midpoint, emits an event, and returns the outcome. pub(crate) fn midpoint( &mut self, x: f64, bracket: &Bracket, ) -> EvalOutcome { - let result = evaluate(self.model, self.problem, [x]); - let action = self.observer.observe(&Event::Midpoint { - x, - bracket, - result: &result, - }); - - let (residual, mut eval) = match result { - Ok(eval) => (Ok(eval.residuals[0]), Some(eval)), - Err(error) => (Err(error.into()), None), - }; - - let decision = Decision::new(action, residual); - - if matches!(action, Some(Action::AssumeResidualSign(_))) { - eval = None; + match evaluate(self.model, self.problem, [x]) { + Ok(eval) => { + let point = Point::from(&eval); + let event = Event::Evaluated { + point, + input: &eval.snapshot.input, + output: &eval.snapshot.output, + bracket, + }; + let action = self.observer.observe(&event); + let decision = Decision::new(action, Ok(point.residual)); + + let kept_eval = if matches!(action, Some(Action::AssumeResidualSign(_))) { + None + } else { + Some(eval) + }; + + (kept_eval, decision) + } + Err(error) => { + let action = Self::observe_failure(x, bracket, &error, self.observer); + let decision = Decision::new(action, Err(error.into())); + (None, decision) + } } + } - (eval, decision) + /// Emits a failure event and returns the observer's action. + fn observe_failure( + x: f64, + bracket: &Bracket, + error: &EvalError, + observer: &mut Obs, + ) -> Option { + match error { + EvalError::Model(e) => { + let event = Event::ModelFailed { + x, + error: e, + bracket, + }; + observer.observe(&event) + } + EvalError::Problem(e) => { + let event = Event::ProblemFailed { + x, + error: e, + bracket, + }; + observer.observe(&event) + } + } } } diff --git a/crates/solvers/src/equation/bisection/event.rs b/crates/solvers/src/equation/bisection/event.rs index d68dc92..5533627 100644 --- a/crates/solvers/src/equation/bisection/event.rs +++ b/crates/solvers/src/equation/bisection/event.rs @@ -1,59 +1,81 @@ use twine_core::{EquationProblem, Model}; -use crate::equation::EvaluateResult; +use crate::equation::bracket::Bracket; -use super::Bracket; +use super::Point; -/// Event emitted by the bisection solver for each evaluation. +/// Events emitted by the bisection solver during the midpoint loop. +/// +/// Each event provides the evaluation outcome and a reference to the current +/// bracket. +/// Observers can pattern-match on the outcome (success vs. failure) to steer +/// the search. pub enum Event<'a, M, P> where M: Model, P: EquationProblem<1, Input = M::Input, Output = M::Output>, { - /// Left bracket endpoint evaluation. - Left { - /// The x value that was evaluated. - x: f64, - /// The result of the evaluation. - result: &'a EvaluateResult, + /// Successful evaluation. + Evaluated { + /// The evaluated point (x and residual). + point: Point, + + /// The model input at this point. + input: &'a M::Input, + + /// The model output at this point. + output: &'a M::Output, + + /// The current search bracket. + bracket: &'a Bracket, }, - /// Right bracket endpoint evaluation. - Right { - /// The x value that was evaluated. + + /// Model evaluation failed. + ModelFailed { + /// The x value where evaluation failed. x: f64, - /// The result of the evaluation. - result: &'a EvaluateResult, + + /// The model error. + error: &'a M::Error, + + /// The current search bracket. + bracket: &'a Bracket, }, - /// Midpoint evaluation with a validated bracket. - Midpoint { - /// The x value that was evaluated. + + /// Problem method failed (input construction or residual computation). + ProblemFailed { + /// The x value where evaluation failed. x: f64, - /// Current search bracket. + + /// The problem error. + error: &'a P::Error, + + /// The current search bracket. bracket: &'a Bracket, - /// The result of the evaluation. - result: &'a EvaluateResult, }, } -impl<'a, M, P> Event<'a, M, P> +impl Event<'_, M, P> where M: Model, P: EquationProblem<1, Input = M::Input, Output = M::Output>, { - /// Returns the evaluated x value. + /// Returns the x value that was evaluated (or attempted). #[must_use] pub fn x(&self) -> f64 { match self { - Event::Left { x, .. } | Event::Right { x, .. } | Event::Midpoint { x, .. } => *x, + Self::Evaluated { point, .. } => point.x, + Self::ModelFailed { x, .. } | Self::ProblemFailed { x, .. } => *x, } } - /// Returns the evaluation result. - pub fn result(&self) -> &'a EvaluateResult { + /// Returns the current search bracket. + #[must_use] + pub fn bracket(&self) -> &Bracket { match self { - Event::Left { result, .. } - | Event::Right { result, .. } - | Event::Midpoint { result, .. } => result, + Self::Evaluated { bracket, .. } + | Self::ModelFailed { bracket, .. } + | Self::ProblemFailed { bracket, .. } => bracket, } } } diff --git a/crates/solvers/src/equation/bisection/point.rs b/crates/solvers/src/equation/bisection/point.rs new file mode 100644 index 0000000..ac7e5e8 --- /dev/null +++ b/crates/solvers/src/equation/bisection/point.rs @@ -0,0 +1,25 @@ +use crate::equation::Evaluation; + +/// A point with its evaluated residual value. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct Point { + /// The x value. + pub x: f64, + + /// The residual at x. + pub residual: f64, +} + +impl Point { + /// Creates a new point. + #[must_use] + pub fn new(x: f64, residual: f64) -> Self { + Self { x, residual } + } +} + +impl From<&Evaluation> for Point { + fn from(eval: &Evaluation) -> Self { + Self::new(eval.x[0], eval.residuals[0]) + } +} diff --git a/crates/solvers/src/equation/bisection/bracket.rs b/crates/solvers/src/equation/bracket.rs similarity index 61% rename from crates/solvers/src/equation/bisection/bracket.rs rename to crates/solvers/src/equation/bracket.rs index b490e46..f69a74d 100644 --- a/crates/solvers/src/equation/bisection/bracket.rs +++ b/crates/solvers/src/equation/bracket.rs @@ -9,6 +9,9 @@ pub enum BracketError { /// Endpoints are equal, giving zero width. #[error("zero width")] ZeroWidth, + /// Left endpoint is greater than right endpoint. + #[error("left > right")] + Inverted, /// Residual signs do not bracket a root. #[error("no sign change")] NoSignChange, @@ -24,12 +27,51 @@ pub struct Bracket { } impl Bracket { - /// Creates a validated bracket with known residual signs. + /// Creates a validated bracket from left and right endpoint–sign pairs. + /// + /// Left must be strictly less than right. /// /// # Errors /// - /// Returns `BracketError::NoSignChange` if the signs do not bracket a root. - pub(super) fn new( + /// Returns [`BracketError::NonFinite`] if either endpoint is non-finite, + /// [`BracketError::ZeroWidth`] if the endpoints are equal, + /// [`BracketError::Inverted`] if left > right, or + /// [`BracketError::NoSignChange`] if the signs are the same. + pub fn new(left: (f64, Sign), right: (f64, Sign)) -> Result { + let (left, left_sign) = left; + let (right, right_sign) = right; + + if !left.is_finite() || !right.is_finite() { + return Err(BracketError::NonFinite); + } + + // Exact equality is intentional — zero-width brackets are invalid. + #[allow(clippy::float_cmp)] + if left == right { + return Err(BracketError::ZeroWidth); + } + + if left > right { + return Err(BracketError::Inverted); + } + + if left_sign == right_sign { + return Err(BracketError::NoSignChange); + } + + Ok(Self { + left, + right, + left_sign, + right_sign, + }) + } + + /// Creates a bracket from pre-validated, pre-ordered bounds and signs. + /// + /// This skips endpoint validation (finiteness, ordering) since `Bounds` + /// already enforces those invariants. Only validates sign opposition. + pub(crate) fn from_bounds( bounds: Bounds, left_sign: Sign, right_sign: Sign, @@ -72,7 +114,7 @@ impl Bracket { } /// Shrinks the bracket using a new endpoint and its residual sign. - pub(super) fn shrink(&mut self, x: f64, sign: Sign) { + pub(crate) fn shrink(&mut self, x: f64, sign: Sign) { if self.left_sign == sign { self.left = x; self.left_sign = sign; @@ -104,9 +146,9 @@ impl Sign { } } -/// Ordered finite bounds for a bisection bracket. +/// Ordered finite bounds for a bracket. #[derive(Debug, Clone, Copy, PartialEq)] -pub(super) struct Bounds { +pub(crate) struct Bounds { left: f64, right: f64, } @@ -117,13 +159,14 @@ impl Bounds { /// # Errors /// /// Returns `BracketError` if endpoints are non-finite or zero width. - pub(super) fn new(bracket: [f64; 2]) -> Result { + pub(crate) fn new(bracket: [f64; 2]) -> Result { let [left, right] = bracket; if !left.is_finite() || !right.is_finite() { return Err(BracketError::NonFinite); } + // Exact equality is intentional — zero-width brackets are invalid. #[allow(clippy::float_cmp)] if left == right { return Err(BracketError::ZeroWidth); @@ -140,7 +183,7 @@ impl Bounds { } /// Returns the bounds as an array. - pub(super) fn as_array(&self) -> [f64; 2] { + pub(crate) fn as_array(&self) -> [f64; 2] { [self.left, self.right] } } @@ -179,20 +222,41 @@ mod tests { } #[test] - fn new_bracket_rejects_no_sign_change() { - let bounds = Bounds::new([0.0, 1.0]).expect("valid bounds"); - let err = Bracket::new(bounds, Sign::Positive, Sign::Positive); - assert!(matches!(err, Err(BracketError::NoSignChange))); + fn new_rejects_non_finite() { + assert!(matches!( + Bracket::new((f64::NAN, Sign::Negative), (1.0, Sign::Positive)), + Err(BracketError::NonFinite) + )); + } + + #[test] + fn new_rejects_zero_width() { + assert!(matches!( + Bracket::new((2.0, Sign::Negative), (2.0, Sign::Positive)), + Err(BracketError::ZeroWidth) + )); + } + + #[test] + fn new_rejects_inverted() { + assert!(matches!( + Bracket::new((10.0, Sign::Negative), (0.0, Sign::Positive)), + Err(BracketError::Inverted) + )); + } + + #[test] + fn new_rejects_no_sign_change() { + assert!(matches!( + Bracket::new((0.0, Sign::Positive), (1.0, Sign::Positive)), + Err(BracketError::NoSignChange) + )); } #[test] fn shrink_shifts_bounds() { - let mut bracket = Bracket::new( - Bounds::new([0.0, 2.0]).expect("valid bounds"), - Sign::Negative, - Sign::Positive, - ) - .expect("valid bracket"); + let mut bracket = + Bracket::new((0.0, Sign::Negative), (2.0, Sign::Positive)).expect("valid bracket"); bracket.shrink(1.0, Sign::Negative); let [left, right] = bracket.as_array(); diff --git a/crates/solvers/src/equation/evaluate.rs b/crates/solvers/src/equation/evaluate.rs index 1587cd7..2a201fc 100644 --- a/crates/solvers/src/equation/evaluate.rs +++ b/crates/solvers/src/equation/evaluate.rs @@ -21,8 +21,8 @@ pub enum EvalError { Problem(#[source] PE), } -/// Type alias for the result of [`evaluate`]. -pub type EvaluateResult = Result< +/// Result type for [`evaluate`], reducing signature complexity. +type EvaluateResult = Result< Evaluation<::Input, ::Output, N>, EvalError<::Error,

>::Error>, >; diff --git a/crates/solvers/src/equation/bisection/solution.rs b/crates/solvers/src/equation/solution.rs similarity index 95% rename from crates/solvers/src/equation/bisection/solution.rs rename to crates/solvers/src/equation/solution.rs index 183f335..85c2f62 100644 --- a/crates/solvers/src/equation/bisection/solution.rs +++ b/crates/solvers/src/equation/solution.rs @@ -11,7 +11,7 @@ pub enum Status { StoppedByObserver, } -/// The result of a bisection solve. +/// The result of an equation solve. #[derive(Debug, Clone)] pub struct Solution { /// Final solver status. diff --git a/crates/solvers/src/optimization.rs b/crates/solvers/src/optimization.rs index 227e8a0..c72087d 100644 --- a/crates/solvers/src/optimization.rs +++ b/crates/solvers/src/optimization.rs @@ -13,6 +13,6 @@ mod evaluate; -pub use evaluate::{EvalError, EvaluateResult, Evaluation, evaluate}; +pub use evaluate::{EvalError, Evaluation, evaluate}; pub mod golden_section; diff --git a/crates/solvers/src/optimization/evaluate.rs b/crates/solvers/src/optimization/evaluate.rs index 0cd2f80..3592894 100644 --- a/crates/solvers/src/optimization/evaluate.rs +++ b/crates/solvers/src/optimization/evaluate.rs @@ -24,8 +24,8 @@ pub enum EvalError { Problem(#[source] PE), } -/// Type alias for the result of [`evaluate`]. -pub type EvaluateResult = Result< +/// Result type for [`evaluate`], reducing signature complexity. +type EvaluateResult = Result< Evaluation<::Input, ::Output, N>, EvalError<::Error,

>::Error>, >;