Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 69 additions & 64 deletions evosax/problems/rl/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,63 +115,78 @@ def eval(
) -> tuple[Fitness, State, Metrics]:
"""Evaluate a population of policies."""
keys = jax.random.split(key, self.num_rollouts)
fitness, env_states = self._eval(keys, solutions, state)
fitness, all_stats = self._eval(keys, solutions, state)

# Update running statistics
if self.use_normalize_obs:
state = self.update_stats(env_states.obs, state)
# Scan over rollouts, applying update_stats sequentially
def _update_with_rollout(rollout_idx, carry_state):
rollout_stats = jax.tree.map(lambda a: a[0, rollout_idx], all_stats)
return self.update_stats(rollout_stats, carry_state)

state = jax.lax.fori_loop(0, self.num_rollouts, _update_with_rollout, state)

return (
jnp.mean(fitness, axis=-1),
state.replace(counter=state.counter + 1),
{"env_states": env_states},
{},
)

def _rollout(
self, key: jax.Array, policy_params: PyTree, state: State
) -> tuple[jax.Array, PyTree]:
def _rollout(self, key: jax.Array, policy_params: PyTree, state: State):
"""Perform a single rollout in the environment."""
key_reset, key_scan = jax.random.split(key)

# Reset environment
env_state = self.env.reset(key_reset)

def _step(carry, key):
env_state, cum_reward, valid = carry
def _cond(carry):
_, _, done, t, _ = carry
return ~done & (t < self.episode_length)

def _step(carry):
env_state, cum_reward, _, t, stats = carry
t = t + 1

key_action = jax.random.fold_in(key_scan, t)

# Normalize observations
obs = self.normalize_obs(env_state.obs, state)
obs = env_state.obs
if self.use_normalize_obs:
obs = self.normalize_obs(obs, state)

# Sample action from policy
action = self.policy.apply(policy_params, obs, key)
action = self.policy.apply(policy_params, obs, key_action)

# Step environment
env_state = self.env.step(env_state, action)

# Update cumulative reward and valid mask
cum_reward = cum_reward + env_state.reward * valid
valid = valid * (1 - env_state.done)
carry = (
env_state,
cum_reward,
valid,
)
return carry, env_state

# Rollout
keys = jax.random.split(key_scan, self.episode_length)
carry, env_states = jax.lax.scan(
_step,
(
env_state,
jnp.array(0.0),
jnp.array(1.0),
),
xs=keys,
)
# Update stats
if self.use_normalize_obs:
mean, var_sum = stats

def _update_leaf(leaf_obs, leaf_mean, leaf_var_sum):
diff = leaf_obs - leaf_mean
new_mean = leaf_mean + diff / t
new_var_sum = leaf_var_sum + diff * (leaf_obs - new_mean)
return new_mean, new_var_sum

mean, var_sum = jax.tree.map(_update_leaf, env_state.obs, mean, var_sum)
stats = (mean, var_sum)

cum_reward = cum_reward + env_state.reward
return (env_state, cum_reward, env_state.done.astype(bool), t, stats)

# Return the sum of rewards accumulated by agent in episode rollout and states
return carry[1], env_states
# Initialize per-rollout stats
ph = jax.tree.map(lambda x: jnp.zeros_like(x), env_state.obs)
stats = (ph, ph) if self.use_normalize_obs else None

# While loop rollout
carry = (env_state, 0.0, False, 0, stats)
carry = jax.lax.while_loop(_cond, _step, carry)

# Return the sum of rewards accumulated by agent in episode rollout and stats
_, cum_reward, _, t, (mean, var_sum) = carry
return cum_reward, (mean, var_sum, t)

def normalize_obs(self, obs: PyTree, state: State) -> PyTree:
"""Normalize observations using running statistics."""
Expand All @@ -182,47 +197,37 @@ def normalize_obs(self, obs: PyTree, state: State) -> PyTree:
state.obs_std,
)

def update_stats(self, obs: PyTree, state: State) -> State:
"""Update running statistics for observations using Welford's online algorithm.
def update_stats(self, all_stats: tuple, state: State) -> State:
"""Update running statistics using parallel reduction.

This method implements a numerically stable algorithm for computing
running mean and variance statistics across episodes [2].
This method combines per-rollout statistics using parallel reduction
formulas to update global running statistics.

Args:
obs: PyTree containing observations with shape
(population_size, num_rollouts, episode_length, ...)
all_stats: Tuple of (mean_arrays, var_sum_arrays, count_arrays),
each with shape (population_size, num_rollouts, ...)
state: Current state containing running statistics

Returns:
Updated state with new observation statistics

"""
# Batch dimensions are (population_size, num_rollouts, episode_length)
batch_size = obs.shape[0] * obs.shape[1] * obs.shape[2]
new_obs_counter = state.obs_counter + batch_size

# Function to update statistics for each leaf in the PyTree
def _update_leaf_stats(leaf_obs, leaf_mean, leaf_var_sum):
# Compute the new mean
diff_to_old_mean = leaf_obs - leaf_mean
new_obs_mean = (
leaf_mean + jnp.sum(diff_to_old_mean, axis=(0, 1, 2)) / new_obs_counter
)

# Compute new variance
diff_to_new_mean = leaf_obs - new_obs_mean
new_obs_var_sum = leaf_var_sum + jnp.sum(
diff_to_old_mean * diff_to_new_mean, axis=(0, 1, 2)
)

return new_obs_mean, new_obs_var_sum

# Apply the update function to each leaf in the observation PyTree
obs_mean, obs_var_sum = jax.tree.map(
lambda obs, mean, var: _update_leaf_stats(obs, mean, var),
obs,
mean, var_sum, count = all_stats

# Combine with global stats using parallel reduction
new_obs_counter = state.obs_counter + count
obs_mean = jax.tree.map(
lambda m, gm: (state.obs_counter * gm + count * m) / new_obs_counter,
mean,
state.obs_mean,
state.obs_var_sum,
)

def _combine_stats(v, gv, m, gm):
factor = state.obs_counter * count / new_obs_counter
return gv + v + factor * (gm - m) * (gm - m)

obs_var_sum = jax.tree.map(
_combine_stats, var_sum, state.obs_var_sum, mean, state.obs_mean
)

obs_var_sum = jnp.maximum(obs_var_sum, 0)
Expand Down
125 changes: 64 additions & 61 deletions evosax/problems/rl/gymnax.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,31 +126,43 @@ def eval(
) -> tuple[Fitness, State, Metrics]:
"""Evaluate a population of policies."""
keys = jax.random.split(key, self.num_rollouts)
fitness, env_states = self._eval(keys, solutions, state)
fitness, all_stats = self._eval(keys, solutions, state)

# Update running statistics
if self.use_normalize_obs:
state = self.update_stats(env_states[0], state)
# Scan over rollouts, applying update_stats sequentially
def _update_with_rollout(rollout_idx, carry_state):
rollout_stats = jax.tree.map(lambda a: a[0, rollout_idx], all_stats)
return self.update_stats(rollout_stats, carry_state)

state = jax.lax.fori_loop(0, self.num_rollouts, _update_with_rollout, state)

return (
jnp.mean(fitness, axis=-1),
state.replace(counter=state.counter + 1),
{"env_states": env_states},
{},
)

def _rollout(self, key: jax.Array, policy_params: PyTree, state: State):
"""Perform a single rollout in the environment."""
key_reset, key_scan = jax.random.split(key)

# Reset environment
obs, env_state = self.env.reset(key_reset, self.env_params)

def _step(carry, key):
obs, env_state, cum_reward, valid = carry
def _cond(carry):
_, _, _, done, t, _ = carry
return ~done & (t < self.episode_length)

def _step(carry):
obs, env_state, cum_reward, _, t, stats = carry
t = t + 1

key_action, key_step = jax.random.split(key)
key_action, key_step = jax.random.split(jax.random.fold_in(key_scan, t))

# Normalize observations
obs = self.normalize_obs(obs, state) if self.use_normalize_obs else obs
if self.use_normalize_obs:
obs = self.normalize_obs(obs, state)

# Sample action from policy
action = self.policy.apply(policy_params, obs, key_action)
Expand All @@ -160,32 +172,33 @@ def _step(carry, key):
key_step, env_state, action, self.env_params
)

# Update cumulative reward and valid mask
cum_reward = cum_reward + reward * valid
valid = valid * (1 - done)
carry = (
obs,
env_state,
cum_reward,
valid,
)
return carry, (obs, env_state)

# Rollout
keys = jax.random.split(key_scan, self.episode_length)
carry, env_states = jax.lax.scan(
_step,
(
obs,
env_state,
jnp.array(0.0),
jnp.array(1.0),
),
xs=keys,
)
# Update stats
if self.use_normalize_obs:
mean, var_sum = stats

def _update_leaf(leaf_obs, leaf_mean, leaf_var_sum):
diff = leaf_obs - leaf_mean
new_mean = leaf_mean + diff / t
new_var_sum = leaf_var_sum + diff * (leaf_obs - new_mean)
return new_mean, new_var_sum

mean, var_sum = jax.tree.map(_update_leaf, obs, mean, var_sum)
stats = (mean, var_sum)

cum_reward = cum_reward + reward
return (obs, env_state, cum_reward, done.astype(bool), t, stats)

# Initialize per-rollout stats
ph = jax.tree.map(lambda x: jnp.zeros_like(x), obs)
stats = (ph, ph) if self.use_normalize_obs else None

# Return the sum of rewards accumulated by agent in episode rollout and states
return carry[2], env_states
# While loop rollout
carry = (obs, env_state, 0.0, False, 0, stats)
carry = jax.lax.while_loop(_cond, _step, carry)

# Return the sum of rewards accumulated by agent in episode rollout and stats
_, _, cum_reward, _, t, (mean, var_sum) = carry
return cum_reward, (mean, var_sum, t)

def normalize_obs(self, obs: PyTree, state: State) -> PyTree:
"""Normalize observations using running statistics."""
Expand All @@ -196,47 +209,37 @@ def normalize_obs(self, obs: PyTree, state: State) -> PyTree:
state.obs_std,
)

def update_stats(self, obs: PyTree, state: State) -> State:
"""Update running statistics for observations using Welford's online algorithm.
def update_stats(self, all_stats: tuple, state: State) -> State:
"""Update running statistics using parallel reduction.

This method implements a numerically stable algorithm for computing
running mean and variance statistics across episodes.
This method combines per-rollout statistics using parallel reduction
formulas to update global running statistics.

Args:
obs: PyTree containing observations with shape
(population_size, num_rollouts, episode_length, ...)
all_stats: Tuple of (mean_arrays, var_sum_arrays, count_arrays),
each with shape (population_size, num_rollouts, ...)
state: Current state containing running statistics

Returns:
Updated state with new observation statistics

"""
# Batch dimensions are (population_size, num_rollouts, episode_length)
batch_size = obs.shape[0] * obs.shape[1] * obs.shape[2]
new_obs_counter = state.obs_counter + batch_size

# Function to update statistics for each leaf in the PyTree
def _update_leaf_stats(leaf_obs, leaf_mean, leaf_var_sum):
# Compute the new mean
diff_to_old_mean = leaf_obs - leaf_mean
new_obs_mean = (
leaf_mean + jnp.sum(diff_to_old_mean, axis=(0, 1, 2)) / new_obs_counter
)
mean, var_sum, count = all_stats

# Compute new variance
diff_to_new_mean = leaf_obs - new_obs_mean
new_obs_var_sum = leaf_var_sum + jnp.sum(
diff_to_old_mean * diff_to_new_mean, axis=(0, 1, 2)
)
# Combine with global stats using parallel reduction
new_obs_counter = state.obs_counter + count
obs_mean = jax.tree.map(
lambda m, gm: (state.obs_counter * gm + count * m) / new_obs_counter,
mean,
state.obs_mean,
)

return new_obs_mean, new_obs_var_sum
def _combine_stats(v, gv, m, gm):
factor = state.obs_counter * count / new_obs_counter
return gv + v + factor * (gm - m) * (gm - m)

# Apply the update function to each leaf in the observation PyTree
obs_mean, obs_var_sum = jax.tree.map(
lambda obs, mean, var: _update_leaf_stats(obs, mean, var),
obs,
state.obs_mean,
state.obs_var_sum,
obs_var_sum = jax.tree.map(
_combine_stats, var_sum, state.obs_var_sum, mean, state.obs_mean
)

obs_var_sum = jnp.maximum(obs_var_sum, 0)
Expand Down
Loading
Loading