Skip to content

Fix integer overflow in LM_MA_ES#95

Merged
RobertTLange merged 1 commit intoRobertTLange:mainfrom
XingyuQu:main
Aug 18, 2025
Merged

Fix integer overflow in LM_MA_ES#95
RobertTLange merged 1 commit intoRobertTLange:mainfrom
XingyuQu:main

Conversation

@XingyuQu
Copy link
Copy Markdown
Contributor

This commit replaces the integer constant 4 with the float 4.0 when creating the c_c array in the LM_MA_ES algorithm. The original code

    @property
    def _default_params(self) -> Params:
        # Calculate m for LM-MA-ES
        self.m = int(4 + jnp.floor(3 * jnp.log(self.num_dims)))

        # Get parent class parameters
        parent_params = super()._default_params

        # Override or add LM-MA-ES specific parameters
        c_d = 1 / jnp.power(1.5, jnp.arange(self.m)) / self.num_dims
        c_c = self.population_size / jnp.power(4, jnp.arange(self.m)) / self.num_dims

        # Set c_1 and c_mu to 0 as they're not used in LM-MA-ES
        return Params(
            std_init=parent_params.std_init,
            std_min=parent_params.std_min,
            std_max=parent_params.std_max,
            weights=parent_params.weights,
            mu_eff=parent_params.mu_eff,
            c_mean=parent_params.c_mean,
            c_std=parent_params.c_std,
            d_std=parent_params.d_std,
            c_c=c_c,
            c_1=0.0,  # Not used in LM-MA-ES
            c_mu=0.0,  # Not used in LM-MA-ES
            chi_n=parent_params.chi_n,
            c_d=c_d,
        )

will create an int-typed array, leading to an integer overflow when the gene dimension self.num_dims, and thus the memory size self.m is large. An example is as below:

import jax.numpy as jnp

population_size = 5
num_dims = jnp.ceil(jnp.exp(5)).astype(int)
m = int(4 + jnp.floor(3 * jnp.log(num_dims)))

c_c = population_size / jnp.power(4, jnp.arange(m)) / num_dims

print("Dimensions:", num_dims)
print("c_c:", c_c)

Outputs:

Dimensions: 149
c_c: [3.3557046e-02 8.3892616e-03 2.0973154e-03 5.2432885e-04 1.3108221e-04
 3.2770553e-05 8.1926382e-06 2.0481596e-06 5.1203989e-07 1.2800997e-07
 3.2002493e-08 8.0006233e-09 2.0001558e-09 5.0003895e-10 1.2500974e-10
 3.1252435e-11           inf           inf           inf]

A code snippet to show the error when running the algorithm:

import jax
import jax.numpy as jnp
from evosax.algorithms import LM_MA_ES

key = jax.random.PRNGKey(0)
d = jnp.ceil(jnp.exp(5)).astype(int)             
pop_size = 5

mean = jnp.zeros((d,), jnp.float32)
es = LM_MA_ES(population_size=pop_size, solution=mean)
params = es.default_params
key, sub = jax.random.split(key)
state = es.init(sub, mean=mean, params=params)

def es_step(key, state):
    key, ask_key, tell_key = jax.random.split(key, 3)
    pop, state = es.ask(ask_key, state, params)
    loss = jnp.zeros((pop_size,), jnp.float32)
    state, metrics = es.tell(tell_key, pop, loss, state, params)
    has_nan = jnp.any(~jnp.isfinite(pop))
    return key, state, bool(has_nan)

for gen in range(20):
    key, state, has_nan = es_step(key, state)
    print(f'gen {gen:02d} has_nan_in_population = {has_nan}')

Outputs:

gen 00 has_nan_in_population = False
gen 01 has_nan_in_population = False
gen 02 has_nan_in_population = False
gen 03 has_nan_in_population = False
gen 04 has_nan_in_population = False
gen 05 has_nan_in_population = False
gen 06 has_nan_in_population = False
gen 07 has_nan_in_population = False
gen 08 has_nan_in_population = False
gen 09 has_nan_in_population = False
gen 10 has_nan_in_population = False
gen 11 has_nan_in_population = False
gen 12 has_nan_in_population = False
gen 13 has_nan_in_population = False
gen 14 has_nan_in_population = False
gen 15 has_nan_in_population = False
gen 16 has_nan_in_population = False
gen 17 has_nan_in_population = True
gen 18 has_nan_in_population = True
gen 19 has_nan_in_population = True

@RobertTLange RobertTLange merged commit a213f27 into RobertTLange:main Aug 18, 2025
0 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants