Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 63 additions & 39 deletions crates/observers/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ where
P: EquationProblem<1, Input = M::Input, Output = M::Output>,
{
fn residual(&self) -> f64 {
match self.result() {
Ok(eval) => eval.residuals[0],
Err(_) => f64::NAN,
match self {
bisection::Event::Evaluated { point, .. } => point.residual,
bisection::Event::ModelFailed { .. } | bisection::Event::ProblemFailed { .. } => {
f64::NAN
}
}
}
}
Expand Down Expand Up @@ -209,6 +211,22 @@ mod tests {
}
}

struct FailingEqProblem;

impl EquationProblem<1> for FailingEqProblem {
type Input = f64;
type Output = f64;
type Error = Failure;

fn input(&self, x: &[f64; 1]) -> Result<f64, Failure> {
Ok(x[0])
}

fn residuals(&self, _: &f64, _: &f64) -> Result<[f64; 1], Failure> {
Err(Failure)
}
}

struct FailingOptProblem;

impl OptimizationProblem<1> for FailingOptProblem {
Expand All @@ -227,46 +245,52 @@ mod tests {

// --- HasResidual for bisection::Event ---

fn test_bracket() -> bisection::Bracket {
bisection::Bracket::new(
(0.0, bisection::Sign::Negative),
(1.0, bisection::Sign::Positive),
)
.unwrap()
}

#[test]
fn bisection_residual_ok() {
// Drive the solver one step to get a real event with a valid residual.
// LinearProblem: residual = output = input = x, so residual ≠ NAN.
let model = Identity;
let problem = LinearProblem;
let mut residual_seen = None;
let _ = bisection::solve(
&model,
&problem,
[-1.0, 1.0],
&bisection::Config::default(),
|event: &bisection::Event<'_, Identity, LinearProblem>| {
if residual_seen.is_none() {
residual_seen = Some(event.residual());
}
None
},
);
let r = residual_seen.expect("at least one event emitted");
assert!(r.is_finite(), "expected finite residual, got {r}");
fn bisection_residual_evaluated() {
let input = 1.0_f64;
let output = 1.0_f64;
let bracket = test_bracket();
let event: bisection::Event<'_, Identity, LinearProblem> = bisection::Event::Evaluated {
point: bisection::Point::new(1.0, 0.5),
input: &input,
output: &output,
bracket: &bracket,
};
assert_relative_eq!(event.residual(), 0.5);
}

#[test]
fn bisection_residual_nan_on_model_failed() {
let error = Failure;
let bracket = test_bracket();
let event: bisection::Event<'_, FailingModel, LinearProblem> =
bisection::Event::ModelFailed {
x: 0.5,
error: &error,
bracket: &bracket,
};
assert!(event.residual().is_nan());
}

#[test]
fn bisection_residual_nan_on_model_error() {
// FailingModel always errors, so every event result is Err → NAN.
let model = FailingModel;
let problem = LinearProblem;
let mut got_nan = false;
let _ = bisection::solve(
&model,
&problem,
[-1.0, 1.0],
&bisection::Config::default(),
|event: &bisection::Event<'_, FailingModel, LinearProblem>| {
got_nan = event.residual().is_nan();
Some(bisection::Action::StopEarly)
},
);
assert!(got_nan);
fn bisection_residual_nan_on_problem_failed() {
let error = Failure;
let bracket = test_bracket();
let event: bisection::Event<'_, Identity, FailingEqProblem> =
bisection::Event::ProblemFailed {
x: 0.5,
error: &error,
bracket: &bracket,
};
assert!(event.residual().is_nan());
}

// --- HasObjective for golden_section::Event ---
Expand Down
7 changes: 5 additions & 2 deletions crates/solvers/src/equation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
//!
//! [`EquationProblem`]: twine_core::EquationProblem

mod best;
mod evaluate;

pub use evaluate::{EvalError, EvaluateResult, Evaluation, evaluate};

pub mod bisection;
pub mod bracket;
pub mod solution;

pub use evaluate::{EvalError, Evaluation, evaluate};
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
use crate::equation::Evaluation;

use super::{Error, Solution, Status};
use super::solution::{Solution, Status};

/// Tracks the best evaluation encountered so far.
///
/// The best evaluation is defined by minimum residual magnitude.
/// The `Option` lets us represent the state before any successful evaluation.
pub(super) struct Best<I, O> {
/// The bracket can shrink without any successful evaluation (via observer
/// recovery), so `None` is a normal operating state — not an error condition.
pub(crate) struct Best<I, O> {
eval: Option<Evaluation<I, O, 1>>,
}

impl<I, O> Best<I, O> {
/// Creates an empty best tracker.
pub(super) fn empty() -> Self {
pub(crate) fn empty() -> Self {
Self { eval: None }
}

/// Updates the best evaluation if the residual magnitude improves.
pub(super) fn update(&mut self, eval: Evaluation<I, O, 1>) {
pub(crate) fn update(&mut self, eval: Evaluation<I, O, 1>) {
if let Some(best) = self.eval.as_ref()
&& eval.residuals[0].abs() >= best.residuals[0].abs()
{
Expand All @@ -27,20 +28,18 @@ impl<I, O> Best<I, O> {
}

/// Returns true if the best residual meets the tolerance.
pub(super) fn is_residual_converged(&self, residual_tol: f64) -> bool {
pub(crate) fn is_residual_converged(&self, residual_tol: f64) -> bool {
self.eval
.as_ref()
.is_some_and(|eval| eval.residuals[0].abs() <= residual_tol)
}

/// Finalizes the solver using the best available evaluation.
/// Builds a solution from the best evaluation.
///
/// # Errors
///
/// Returns `Error::NoSuccessfulEvaluation` if no successful evaluation is stored.
pub(super) fn finish(self, status: Status, iters: usize) -> Result<Solution<I, O>, Error> {
let eval = self.eval.ok_or(Error::NoSuccessfulEvaluation)?;
Ok(Solution {
/// Returns `None` if no successful evaluation has been recorded.
pub(crate) fn into_solution(self, status: Status, iters: usize) -> Option<Solution<I, O>> {
let eval = self.eval?;
Some(Solution {
status,
x: eval.x[0],
residual: eval.residuals[0],
Expand Down Expand Up @@ -74,7 +73,7 @@ mod tests {
best.update(eval(3.0, 1.0));

let solution = best
.finish(Status::StoppedByObserver, 0)
.into_solution(Status::StoppedByObserver, 0)
.expect("best eval");

assert_relative_eq!(solution.x, 3.0);
Expand All @@ -88,15 +87,15 @@ mod tests {
best.update(eval(2.0, 2.0));

let solution = best
.finish(Status::StoppedByObserver, 0)
.into_solution(Status::StoppedByObserver, 0)
.expect("best eval");

assert_relative_eq!(solution.x, 1.0);
assert_relative_eq!(solution.residual, -0.5);
}

#[test]
fn residual_converged_requires_best() {
fn residual_converged_requires_eval() {
let best: Best<(), ()> = Best::empty();
assert!(!best.is_residual_converged(1e-3));
}
Expand All @@ -111,18 +110,17 @@ mod tests {
}

#[test]
fn finish_errors_without_eval() {
fn into_solution_returns_none_without_eval() {
let best: Best<(), ()> = Best::empty();
let err = best.finish(Status::StoppedByObserver, 0);
assert!(matches!(err, Err(Error::NoSuccessfulEvaluation)));
assert!(best.into_solution(Status::StoppedByObserver, 0).is_none());
}

#[test]
fn finish_builds_solution() {
fn into_solution_builds_solution() {
let mut best = Best::empty();
best.update(eval(2.0, -1.25));

let solution = best.finish(Status::Converged, 4).expect("best eval");
let solution = best.into_solution(Status::Converged, 4).expect("best eval");

assert_eq!(solution.status, Status::Converged);
assert_eq!(solution.iters, 4);
Expand Down
Loading