From 4396f9a5272e6143003bed6fbdfa8f093fdfefed Mon Sep 17 00:00:00 2001 From: Sam Tucker <49291748+svtuck@users.noreply.github.com> Date: Sun, 22 Feb 2026 21:22:57 -0600 Subject: [PATCH] 1. Update outdated imports. 2. Fix a bug in tabular cfr that caused it to not converge to the correct cfr for leduc --- cfrx/algorithms/cfr/cfr.py | 4 ++-- cfrx/tree/traverse.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/cfrx/algorithms/cfr/cfr.py b/cfrx/algorithms/cfr/cfr.py index 3dd56ed..8f4e774 100644 --- a/cfrx/algorithms/cfr/cfr.py +++ b/cfrx/algorithms/cfr/cfr.py @@ -10,8 +10,8 @@ from cfrx.envs import Env from cfrx.policy import TabularPolicy from cfrx.tree import Tree -from cfrx.tree.traverse_old import instantiate_tree_from_root, traverse_tree_cfr -from cfrx.tree.tree_old import Root +from cfrx.tree.traverse import instantiate_tree_from_root, traverse_tree_cfr +from cfrx.tree.tree import Root from cfrx.utils import regret_matching diff --git a/cfrx/tree/traverse.py b/cfrx/tree/traverse.py index 19a9676..76a63dd 100644 --- a/cfrx/tree/traverse.py +++ b/cfrx/tree/traverse.py @@ -167,7 +167,6 @@ def select_new_node_and_play( action = tree.action_from_parent[child_index] parent_state = jax.tree_map(lambda x: x[parent_index], tree.states) - print(parent_state.legal_action_mask.shape) new_state = env.step(parent_state, action) return new_state, parent_index, child_index, action @@ -318,8 +317,8 @@ def loop_fn(val: Tuple) -> Tuple: use_behavior_policy=jnp.bool_(False), ) - chance_strategy = env.get_chance_probs(parent_state)[action] - # jax.debug.breakpoint() + chance_probs = env.get_chance_probs(parent_state) + chance_strategy = chance_probs[action] action_prob = jnp.where( parent_state.chance_node, chance_strategy, strategy[action] @@ -348,7 +347,7 @@ def loop_fn(val: Tuple) -> Tuple: new_state.rewards ), children_prior_logits=tree.children_prior_logits.at[parent_index].set( - jnp.where(parent_state.chance_node, chance_strategy, strategy) + jnp.where(parent_state.chance_node, chance_probs, strategy) ), parents=tree.parents.at[child_index].set(parent_index), action_from_parent=tree.action_from_parent.at[child_index].set(action),