From f3fc42b9a257ed2e04db67545a6aedd5ec09f701 Mon Sep 17 00:00:00 2001 From: Jordan Gunn Date: Thu, 23 Apr 2026 15:35:32 +0100 Subject: [PATCH] Add DE mutation dithering --- .../differential_evolution.py | 16 +++- .../test_algorithms/test_population_based.py | 81 +++++++++++++++++++ 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/evosax/algorithms/population_based/differential_evolution.py b/evosax/algorithms/population_based/differential_evolution.py index 327f63c..d42b73b 100755 --- a/evosax/algorithms/population_based/differential_evolution.py +++ b/evosax/algorithms/population_based/differential_evolution.py @@ -31,6 +31,8 @@ class Params(BaseParams): elitism: bool # If elitism, base vector is best member else random crossover_rate: float # [0, 1] differential_weight: float # [0, 2] + differential_weight_min: float # [0, 2] + differential_weight_max: float # [0, 2] class DifferentialEvolution(PopulationBasedAlgorithm): @@ -56,6 +58,8 @@ def _default_params(self) -> Params: elitism=True, crossover_rate=0.9, differential_weight=0.8, + differential_weight_min=0.8, + differential_weight_max=0.8, ) def _init(self, key: jax.Array, params: Params) -> State: @@ -77,6 +81,16 @@ def _ask( keys = jax.random.split(key, self.population_size) member_ids = jnp.arange(self.population_size) best_index = jnp.argmin(state.fitness) + differential_weight = jnp.where( + params.differential_weight_min < params.differential_weight_max, + jax.random.uniform( + jax.random.fold_in(key, state.generation_counter), + (), + minval=params.differential_weight_min, + maxval=params.differential_weight_max, + ), + params.differential_weight, + ) def _ask_member(key, member_id): x = state.population[member_id] @@ -107,7 +121,7 @@ def _ask_member(key, member_id): subkey, state.population, (2,), replace=False, p=p ) - a = jnp.where(mask, a + params.differential_weight * (b - c), x) + a = jnp.where(mask, a + differential_weight * (b - c), x) return a diff --git a/tests/test_algorithms/test_population_based.py b/tests/test_algorithms/test_population_based.py index f4ec1a6..4affca5 100644 --- a/tests/test_algorithms/test_population_based.py +++ b/tests/test_algorithms/test_population_based.py @@ -2,6 +2,7 @@ import jax import jax.numpy as jnp +from evosax.algorithms import DifferentialEvolution from evosax.algorithms.population_based import population_based_algorithms @@ -149,3 +150,83 @@ def test_base_api(population_based_algorithm_name, key, num_dims, population_siz assert "best_solution" in metrics assert "best_fitness_in_generation" in metrics assert "best_solution_in_generation" in metrics + + +def test_differential_evolution_default_dithering_inactive(key): + """Test default dithering params preserve fixed-weight behavior.""" + solution = jnp.zeros((2,)) + algo = DifferentialEvolution(population_size=6, solution=solution) + params = algo.default_params.replace(crossover_rate=1.0) + + population = jnp.arange(12, dtype=float).reshape(6, 2) + fitness = jnp.arange(6, dtype=float) + + key, key_init, key_ask = jax.random.split(key, 3) + state = algo.init(key_init, population, fitness, params) + + fixed_params = params.replace( + differential_weight_min=0.0, + differential_weight_max=0.0, + ) + population_default, _ = algo.ask(key_ask, state, params) + population_fixed, _ = algo.ask(key_ask, state, fixed_params) + + assert jnp.allclose(population_default, population_fixed) + + +def test_differential_evolution_dithering_changes_population(key): + """Test active dithering range affects generated candidates.""" + solution = jnp.zeros((2,)) + algo = DifferentialEvolution(population_size=6, solution=solution) + params = algo.default_params.replace( + crossover_rate=1.0, + differential_weight=0.0, + differential_weight_min=0.0, + differential_weight_max=0.0, + ) + dither_params = params.replace( + differential_weight_min=1.0, + differential_weight_max=1.000001, + ) + + population = jnp.arange(12, dtype=float).reshape(6, 2) + fitness = jnp.arange(6, dtype=float) + + key, key_init, key_ask = jax.random.split(key, 3) + state = algo.init(key_init, population, fitness, params) + + population_fixed, _ = algo.ask(key_ask, state, params) + population_dithered, _ = algo.ask(key_ask, state, dither_params) + + assert population_dithered.shape == population.shape + assert jnp.all(jnp.isfinite(population_dithered)) + assert not jnp.allclose(population_dithered, population_fixed) + + +def test_differential_evolution_dithering_scan(key): + """Test dithered DifferentialEvolution inside a scan loop.""" + solution = jnp.zeros((2,)) + algo = DifferentialEvolution(population_size=6, solution=solution) + params = algo.default_params.replace( + differential_weight_min=0.5, + differential_weight_max=1.0, + ) + + population = jnp.arange(12, dtype=float).reshape(6, 2) / 10 + fitness = jnp.sum(jnp.square(population), axis=-1) + + key, key_init = jax.random.split(key) + state = algo.init(key_init, population, fitness, params) + + def step(carry, _): + key, state = carry + key, key_ask, key_tell = jax.random.split(key, 3) + population, state = algo.ask(key_ask, state, params) + fitness = jnp.sum(jnp.square(population), axis=-1) + state, metrics = algo.tell(key_tell, population, fitness, state, params) + return (key, state), metrics["best_fitness"] + + _, fitness_log = jax.lax.scan(step, (key, state), jnp.zeros(4)) + + assert fitness_log.shape == (4,) + assert jnp.all(jnp.isfinite(fitness_log))