Skip to content

Fix SquareCBExploration.act for batched values (broadcast crash + per-row probability sum)#133

Open
Osamaali313 wants to merge 2 commits into
facebookresearch:mainfrom
Osamaali313:fix/squarecb-batched-act
Open

Fix SquareCBExploration.act for batched values (broadcast crash + per-row probability sum)#133
Osamaali313 wants to merge 2 commits into
facebookresearch:mainfrom
Osamaali313:fix/squarecb-batched-act

Conversation

@Osamaali313

Copy link
Copy Markdown

Summary

SquareCBExploration.act cannot run on batched values.

After reshaping values to (batch_size, action_count), it computes the empirical gaps as:

max_val, max_indices = torch.max(values, dim=1)   # (batch_size,)
max_val.repeat(1, action_space.n)                  # result DISCARDED
empirical_gaps = max_val - values                  # (batch_size,) - (batch_size, n)

The max_val.repeat(...) was evidently meant to broadcast max_val across actions, but its result is never assigned, so max_val stays 1-D. (batch_size,) - (batch_size, action_count) then only broadcasts when batch_size == action_count (or == 1); otherwise it raises:

RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

There is a second, subtler bug just below:

prob_policy[batch_ind, max_indices[batch_ind]] = 0.0
complementary_sum = torch.sum(prob_policy)         # sums the WHOLE tensor
prob_policy[batch_ind, max_indices[batch_ind]] = 1.0 - complementary_sum

torch.sum(prob_policy) sums the entire (batch_size, action_count) tensor rather than the current row, so for multi-row batches the greedy action's residual probability is computed from other rows' probabilities — the per-row distributions don't sum to 1 and the greedy action is mis-weighted.

Both issues are masked when batch_size == 1, which is why single-state inference worked and they went unnoticed.

Fix

-        max_val, max_indices = torch.max(values, dim=1)
-        max_val.repeat(1, action_space.n)
-        empirical_gaps = max_val - values
+        max_val, max_indices = torch.max(values, dim=1)
+        empirical_gaps = max_val.unsqueeze(1) - values
-            complementary_sum = torch.sum(prob_policy)
+            complementary_sum = torch.sum(prob_policy[batch_ind, :])

Validation

  • RED: on main, the new test_act_batched_states_does_not_crash (batch_size=2, action_count=3) fails with the RuntimeError above at line 83.
  • GREEN: with the fix, all three new tests pass. Per-row distributions sum to 1.0 and the greedy action is the most probable in every row; single-state behavior is unchanged.
$ python -m pytest test/unit/with_pytorch/test_squarecb_exploration.py -q
...                                                                      [100%]
3 passed

Adds test/unit/with_pytorch/test_squarecb_exploration.py covering batched input (batch_size != action_count), per-row distribution validity, and the single-state path.

SquareCBExploration.act reshapes values to (batch_size, action_count) and
then computed the empirical gaps as `max_val - values`, where max_val has
shape (batch_size,). A no-op `max_val.repeat(1, action_space.n)` (whose
result was discarded) was clearly meant to broadcast max_val across actions
but never did. As a result `(batch_size,) - (batch_size, action_count)`
raises a RuntimeError whenever batch_size != action_count, so act cannot run
on batched input at all.

A second issue: the greedy action's residual probability used
`torch.sum(prob_policy)` over the whole (batch_size, action_count) tensor
instead of the current row, so per-row distributions did not sum to 1 and the
greedy action was mis-weighted for batches with more than one row.

Both happened to be masked when batch_size == 1 (or == action_count), which
is why single-state inference worked and the bug went unnoticed.

Fix the gap with `max_val.unsqueeze(1) - values` and sum the current row only.
Adds a regression test exercising batch_size != action_count.
Copilot AI review requested due to automatic review settings June 20, 2026 20:33
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 20, 2026

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes SquareCBExploration.act to correctly handle batched values tensors by repairing the empirical gap broadcasting and ensuring per-row probability normalization, and adds unit tests to prevent regressions in batched and single-state paths.

Changes:

  • Fix empirical gap computation for batched inputs via max_val.unsqueeze(1) broadcasting.
  • Fix greedy-action residual probability computation to sum only the current batch row.
  • Add with_pytorch unit tests covering batched non-square shapes, per-row distribution validity, and single-state behavior.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
pearl/policy_learners/exploration_modules/contextual_bandits/squarecb_exploration.py Fixes batched gap broadcasting and corrects per-row probability normalization (also impacts derived exploration modules).
test/unit/with_pytorch/test_squarecb_exploration.py Adds tests to catch the original batched crash and validate per-row probability distributions and single-state behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +52 to +56
# Reconstruct the distribution act() builds (no randomness involved).
max_val, max_indices = torch.max(values, dim=1)
empirical_gaps = max_val.unsqueeze(1) - values
prob = torch.div(1.0, action_space.n + gamma * empirical_gaps)
for b in range(values.size(0)):

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — fixed. The test now builds the expected distribution by calling the module's own get_unnormalize_prob(...) per row instead of duplicating the formula, which both removes the unused exploration local (flake8 F841) and keeps the test aligned with the implementation.

Comment on lines 84 to 88
# Construct probability distribution over actions and sample from it
selected_actions = torch.zeros((values.size(dim=0),), dtype=torch.int)
prob_policy = self.get_unnormalize_prob(empirical_gaps, max_val, action_space.n)
for batch_ind in range(values.size(dim=0)):
# Get sum of all the probabilities besides the maximum

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, and thanks — this is exactly right. FastCBExploration inherits this act and overrides get_unnormalize_prob with if max_val <= self.reward_lb:, which raises RuntimeError: Boolean value of Tensor with more than one value is ambiguous on a batched max_val (confirmed by calling it directly). I refactored act to build the policy inside the per-row loop, passing the scalar row maximum max_val[batch_ind] and the row gaps empirical_gaps[batch_ind, :]. Results are identical for SquareCBExploration, and FastCBExploration now supports batched input too. Added a FastCBExploration batched regression test alongside the SquareCB ones.

Address review feedback: compute the unnormalized policy inside the per-row
loop, passing the scalar row maximum to get_unnormalize_prob. This keeps the
shared act() compatible with FastCBExploration (which inherits act and
overrides get_unnormalize_prob with `if max_val <= self.reward_lb:`, raising
an ambiguous-truth-value error on a batched max_val tensor). Results are
identical for SquareCBExploration.

Also use the module's own get_unnormalize_prob in the per-row distribution
test instead of duplicating the formula (removes an unused local / flake8
F841), and add a FastCBExploration batched regression test.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants