Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cfrx/algorithms/cfr/cfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from cfrx.envs import Env
from cfrx.policy import TabularPolicy
from cfrx.tree import Tree
from cfrx.tree.traverse_old import instantiate_tree_from_root, traverse_tree_cfr
from cfrx.tree.tree_old import Root
from cfrx.tree.traverse import instantiate_tree_from_root, traverse_tree_cfr
from cfrx.tree.tree import Root
from cfrx.utils import regret_matching


Expand Down
7 changes: 3 additions & 4 deletions cfrx/tree/traverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def select_new_node_and_play(
action = tree.action_from_parent[child_index]

parent_state = jax.tree_map(lambda x: x[parent_index], tree.states)
print(parent_state.legal_action_mask.shape)
new_state = env.step(parent_state, action)

return new_state, parent_index, child_index, action
Expand Down Expand Up @@ -318,8 +317,8 @@ def loop_fn(val: Tuple) -> Tuple:
use_behavior_policy=jnp.bool_(False),
)

chance_strategy = env.get_chance_probs(parent_state)[action]
# jax.debug.breakpoint()
chance_probs = env.get_chance_probs(parent_state)
chance_strategy = chance_probs[action]

action_prob = jnp.where(
parent_state.chance_node, chance_strategy, strategy[action]
Expand Down Expand Up @@ -348,7 +347,7 @@ def loop_fn(val: Tuple) -> Tuple:
new_state.rewards
),
children_prior_logits=tree.children_prior_logits.at[parent_index].set(
jnp.where(parent_state.chance_node, chance_strategy, strategy)
jnp.where(parent_state.chance_node, chance_probs, strategy)
),
parents=tree.parents.at[child_index].set(parent_index),
action_from_parent=tree.action_from_parent.at[child_index].set(action),
Expand Down