Fix SquareCBExploration.act for batched values (broadcast crash + per-row probability sum)#133
Conversation
SquareCBExploration.act reshapes values to (batch_size, action_count) and then computed the empirical gaps as `max_val - values`, where max_val has shape (batch_size,). A no-op `max_val.repeat(1, action_space.n)` (whose result was discarded) was clearly meant to broadcast max_val across actions but never did. As a result `(batch_size,) - (batch_size, action_count)` raises a RuntimeError whenever batch_size != action_count, so act cannot run on batched input at all. A second issue: the greedy action's residual probability used `torch.sum(prob_policy)` over the whole (batch_size, action_count) tensor instead of the current row, so per-row distributions did not sum to 1 and the greedy action was mis-weighted for batches with more than one row. Both happened to be masked when batch_size == 1 (or == action_count), which is why single-state inference worked and the bug went unnoticed. Fix the gap with `max_val.unsqueeze(1) - values` and sum the current row only. Adds a regression test exercising batch_size != action_count.
There was a problem hiding this comment.
Pull request overview
This PR fixes SquareCBExploration.act to correctly handle batched values tensors by repairing the empirical gap broadcasting and ensuring per-row probability normalization, and adds unit tests to prevent regressions in batched and single-state paths.
Changes:
- Fix empirical gap computation for batched inputs via
max_val.unsqueeze(1)broadcasting. - Fix greedy-action residual probability computation to sum only the current batch row.
- Add
with_pytorchunit tests covering batched non-square shapes, per-row distribution validity, and single-state behavior.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
pearl/policy_learners/exploration_modules/contextual_bandits/squarecb_exploration.py |
Fixes batched gap broadcasting and corrects per-row probability normalization (also impacts derived exploration modules). |
test/unit/with_pytorch/test_squarecb_exploration.py |
Adds tests to catch the original batched crash and validate per-row probability distributions and single-state behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Reconstruct the distribution act() builds (no randomness involved). | ||
| max_val, max_indices = torch.max(values, dim=1) | ||
| empirical_gaps = max_val.unsqueeze(1) - values | ||
| prob = torch.div(1.0, action_space.n + gamma * empirical_gaps) | ||
| for b in range(values.size(0)): |
There was a problem hiding this comment.
Good catch — fixed. The test now builds the expected distribution by calling the module's own get_unnormalize_prob(...) per row instead of duplicating the formula, which both removes the unused exploration local (flake8 F841) and keeps the test aligned with the implementation.
| # Construct probability distribution over actions and sample from it | ||
| selected_actions = torch.zeros((values.size(dim=0),), dtype=torch.int) | ||
| prob_policy = self.get_unnormalize_prob(empirical_gaps, max_val, action_space.n) | ||
| for batch_ind in range(values.size(dim=0)): | ||
| # Get sum of all the probabilities besides the maximum |
There was a problem hiding this comment.
Agreed, and thanks — this is exactly right. FastCBExploration inherits this act and overrides get_unnormalize_prob with if max_val <= self.reward_lb:, which raises RuntimeError: Boolean value of Tensor with more than one value is ambiguous on a batched max_val (confirmed by calling it directly). I refactored act to build the policy inside the per-row loop, passing the scalar row maximum max_val[batch_ind] and the row gaps empirical_gaps[batch_ind, :]. Results are identical for SquareCBExploration, and FastCBExploration now supports batched input too. Added a FastCBExploration batched regression test alongside the SquareCB ones.
Address review feedback: compute the unnormalized policy inside the per-row loop, passing the scalar row maximum to get_unnormalize_prob. This keeps the shared act() compatible with FastCBExploration (which inherits act and overrides get_unnormalize_prob with `if max_val <= self.reward_lb:`, raising an ambiguous-truth-value error on a batched max_val tensor). Results are identical for SquareCBExploration. Also use the module's own get_unnormalize_prob in the per-row distribution test instead of duplicating the formula (removes an unused local / flake8 F841), and add a FastCBExploration batched regression test.
Summary
SquareCBExploration.actcannot run on batchedvalues.After reshaping
valuesto(batch_size, action_count), it computes the empirical gaps as:The
max_val.repeat(...)was evidently meant to broadcastmax_valacross actions, but its result is never assigned, somax_valstays 1-D.(batch_size,) - (batch_size, action_count)then only broadcasts whenbatch_size == action_count(or== 1); otherwise it raises:There is a second, subtler bug just below:
torch.sum(prob_policy)sums the entire(batch_size, action_count)tensor rather than the current row, so for multi-row batches the greedy action's residual probability is computed from other rows' probabilities — the per-row distributions don't sum to 1 and the greedy action is mis-weighted.Both issues are masked when
batch_size == 1, which is why single-state inference worked and they went unnoticed.Fix
Validation
main, the newtest_act_batched_states_does_not_crash(batch_size=2, action_count=3) fails with theRuntimeErrorabove at line 83.Adds
test/unit/with_pytorch/test_squarecb_exploration.pycovering batched input (batch_size != action_count), per-row distribution validity, and the single-state path.