Skip to content

Replace shape-based empty batch handling inside DPDataLoader with structure-aware approach#806

Closed
david-stan wants to merge 13 commits into
meta-pytorch:mainfrom
JetBrains-Research:david-stan/collate-empty-batch
Closed

Replace shape-based empty batch handling inside DPDataLoader with structure-aware approach#806
david-stan wants to merge 13 commits into
meta-pytorch:mainfrom
JetBrains-Research:david-stan/collate-empty-batch

Conversation

@david-stan
Copy link
Copy Markdown
Contributor

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Docs change / refactoring / dependency upgrade

Motivation and Context / Related issue

Replaces unstable shape-based empty batch handling with a stateful approach that learns and replicates the actual output structure from collate_fn. This fixes a critical bug where custom collate functions returning non-list structures (dicts, custom classes) were incompatible with Poisson sampling.

The old implementation inspected dataset[0] to pre-compute shapes, then hardcoded empty batches as lists:

def collate(batch, collate_fn, sample_empty_shapes, dtypes):
    if len(batch) > 0:
        return collate_fn(batch)  # Could return dict, custom class, etc.
    else:
        return [torch.zeros(shape, dtype=dtype) for ...]  # Always list!

Bug -> if collate_fn returns a dict, non-empty batches are dicts but empty batches are lists -> type mismatch crash

Existing, related issue: #534

Solution:

New CollateFnWithEmpty learns the structure from the first non-empty batch:

class CollateFnWithEmpty:
    def __call__(self, batch):
        if len(batch) > 0:
            output = self.wrapped_collator_fn(batch)
            if self.first_batch is None:
                self.first_batch = copy.deepcopy(output)  # Learn structure
        else:
            output = self._make_empty_batch(self.first_batch)  # Replicate structure
        return output

Now empty batches match the structure of non-empty batches, regardless of what collate_fn returns.

If the first non-empty batch is actually the first batch, then it returns an error:

if self.first_batch is None:
    raise ValueError(
        "First sampled batch cannot be empty. Please ensure your dataset "
        "has sufficient samples or increase sample_rate."
    )

Key Changes

  • Removed: shape_safe(), dtype_safe(), hardcoded list return
  • Added: CollateFnWithEmpty class with recursive structure replication
  • Changed: wrap_collate_with_empty() signature: (collate_fn, sample_empty_shapes, dtype) -> (collate_fn, batch_first, rand_on_empty)

It is compatible with existing API.
A small disclosure: for small percentage of users who hacked around empty batches handling, it might cause problems but in majority of cases it should be compatible.

How Has This Been Tested (if it applies)

  • We used this approach to fine-tune Qwen 7B model using trl library for model alignment
  • Tested on Mellum 5B parameter model fine-tuning

Checklist

  • The documentation is up-to-date with the changes I made.
  • [] I have read the CONTRIBUTING document and completed the CLA (see CONTRIBUTING).
  • All tests passed, and additional code has been covered with new tests.

…sistent batch structure

Mark tests incompatible with new empty batch handling as skipped
…rove documentation, and add extensive test coverage
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 26, 2026
@meta-codesync
Copy link
Copy Markdown

meta-codesync Bot commented Jan 26, 2026

@facebook-github-bot has imported this pull request. If you are a Meta employee, you can view this in D91500466. (Because this pull request was imported automatically, there will not be any future comments.)

@coveralls
Copy link
Copy Markdown

Pull Request Test Coverage Report for Build 21371492613

Details

  • 160 of 162 (98.77%) changed or added relevant lines in 4 files are covered.
  • 27 unchanged lines in 4 files lost coverage.
  • Overall coverage increased (+0.02%) to 78.194%

Changes Missing Coverage Covered Lines Changed/Added Lines %
opacus/data_loader.py 33 34 97.06%
opacus/tests/dpdataloader_test.py 123 124 99.19%
Files with Coverage Reduction New Missed Lines %
opacus/optimizers/optimizer.py 1 87.78%
opacus/utils/batch_memory_manager.py 3 83.33%
opacus/tests/privacy_engine_test.py 4 94.53%
opacus/tests/batch_memory_manager_test.py 19 81.55%
Totals Coverage Status
Change from base Build 21792510836: 0.02%
Covered Lines: 5784
Relevant Lines: 7397

💛 - Coveralls

Comment thread opacus/data_loader.py
Comment thread opacus/data_loader.py Outdated
return type(sample)(converted)

# base case
return sample
Copy link
Copy Markdown
Contributor

@iden-kalemaj iden-kalemaj Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@david-stan am I understanding correctly that if the return of the collate_fn does not follow any of the 3 listed instances, you always return the first batch instead of the empty batch? This breaks the DP guarantee because it violates the assumption that each sample is used in training with a certain probability.

Let's raise an error describing what the supported output types are together with a note to either raise an issue or provide a PR if there's a need for a different output type.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, we should raise an error here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment thread opacus/tests/dpdataloader_test.py Outdated

dataset = TensorDataset(x, y)
data_loader = DPDataLoader(dataset, sample_rate=1e-5)
# Use moderate sample rate to get non-empty batches
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@david-stan were you able to check that with this sampling rate there are indeed some empty batches produced?

Copy link
Copy Markdown
Contributor Author

@david-stan david-stan Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the test to deterministically produce first batch non-empty, and lowered sample rate to consistently generate empty batches after that.

return SampleConvNet()


@pytest.mark.skip(("Incompatible with the new empty batch handling"))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's delete this test instead of skipping.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it could be useful to maintain some of the old behavior, per one of my comments.

Comment thread opacus/data_loader.py Outdated
self.first_batch = copy.deepcopy(output)
else:
if self.first_batch is None:
raise ValueError(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@david-stan when first_batch is empty, how about we maintain the old behavior of using lists, so that we still offer some support for more basic collate functions for the case when sampling rate is small and first batch is empty. We can raise a warning here that lists are used. Open to your opinion here as well.

return [
            torch.zeros(shape, dtype=dtype)
            for shape, dtype in zip(sample_empty_shapes, dtypes)
        ]

Copy link
Copy Markdown
Contributor Author

@david-stan david-stan Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having default behavior on random seems like a bigger concern. Also, having an extra parameter for this scenario is also debatable. Generally, interesting idea. What would you suggest?

@iden-kalemaj
Copy link
Copy Markdown
Contributor

Hi @david-stan, thank you for this change and the overall approach looks good to me. Could you please address the comments and also see the failed lint test. Please ping me when ready, so I can re-run the tests.

@iden-kalemaj iden-kalemaj self-assigned this Feb 16, 2026
Comment thread opacus/data_loader.py Outdated
f"CollateFnWithEmpty only supports batches containing torch.Tensor, "
f"dict (Mapping), list, or tuple types. "
f"If you need support for a different output type, please open an issue at "
f"https://github.com/JetBrains-Research/opacus/issues or submit a PR."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets remove the link and just say ... please open an issue on Opacus or submit a PR.

@iden-kalemaj
Copy link
Copy Markdown
Contributor

@david-stan please also see the failing lint test.

@david-stan
Copy link
Copy Markdown
Contributor Author

david-stan commented Feb 23, 2026

@david-stan when first_batch is empty, how about we maintain the old behavior of using lists, so that we still offer some support for more basic collate functions for the case when sampling rate is small and first batch is empty. We can raise a warning here that lists are used. Open to your opinion here as well.

return [
            torch.zeros(shape, dtype=dtype)
            for shape, dtype in zip(sample_empty_shapes, dtypes)
        ]

This one is last it seems, what is your decision on this one? Are we sticking to lists at the end

@iden-kalemaj
Copy link
Copy Markdown
Contributor

This one is last it seems, what is your decision on this one? Are we sticking to lists at the end

Yes for backward compatibility if self.first_batch is None let's return an empty list and raise a Warning that says that 'First batch is empty. We are using a list of zero-valued tensors as a batch. This may causes issues if the model expects a different batch format. To fix, use more data, increase epsilon, or increase sampling rate'.

Also please see failing lint (and test code with black and isort as well to make sure those pass too).

@david-stan
Copy link
Copy Markdown
Contributor Author

This one is last it seems, what is your decision on this one? Are we sticking to lists at the end

Yes for backward compatibility if self.first_batch is None let's return an empty list and raise a Warning that says that 'First batch is empty. We are using a list of zero-valued tensors as a batch. This may causes issues if the model expects a different batch format. To fix, use more data, increase epsilon, or increase sampling rate'.

Also please see failing lint (and test code with black and isort as well to make sure those pass too).

Changed to return empty list. Potential problem is if you explicitly wanted list of zero-valued tensors instead. In that case I will need to reintroduce sample_empty_shapes and dtypes, which would require additional API changes.

@iden-kalemaj
Copy link
Copy Markdown
Contributor

@david-stan apologies, the behavior I intended was to return a list of zero valued tensors using sample_empty_shapes, i.e., reverting to the original behavior.

We can either:

  1. Raise a warning if first batch is empty (i.e., revert your last commit)
  2. If you have the time, implement returning a list of zero valued tensors.

Please let me know which one you would prefer.

Reintroduce sample_empty_shapes and dtypes from dataset[0] so that
when the first Poisson-sampled batch is empty, CollateFnWithEmpty
returns properly shaped zero tensors instead of an empty list.
Add thorough tests with deterministic seeds for the empty first batch
path and the transition to learned batch structure.
@david-stan
Copy link
Copy Markdown
Contributor Author

@david-stan apologies, the behavior I intended was to return a list of zero valued tensors using sample_empty_shapes, i.e., reverting to the original behavior.

We can either:

  1. Raise a warning if first batch is empty (i.e., revert your last commit)
  2. If you have the time, implement returning a list of zero valued tensors.

Please let me know which one you would prefer.

Committed, please review!

@david-stan
Copy link
Copy Markdown
Contributor Author

Just saw the lint error, fixed

@iden-kalemaj
Copy link
Copy Markdown
Contributor

@david-stan please see another lint failure. Just curious if you tried all the linting tests from our contribution guide before submitting and if those passed?

@david-stan
Copy link
Copy Markdown
Contributor Author

Should be fine

iden-kalemaj pushed a commit to iden-kalemaj/opacus that referenced this pull request Mar 26, 2026
…tructure-aware approach (meta-pytorch#806)

Summary:
## Types of changes

- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [x] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Docs change / refactoring / dependency upgrade

## Motivation and Context / Related issue

Replaces unstable shape-based empty batch handling with a stateful approach that learns and replicates the actual output structure from `collate_fn`. This fixes a critical bug where custom collate functions returning non-list structures (dicts, custom classes) were incompatible with Poisson sampling.

The old implementation inspected `dataset[0]` to pre-compute shapes, then hardcoded empty batches as lists:
```python
def collate(batch, collate_fn, sample_empty_shapes, dtypes):
    if len(batch) > 0:
        return collate_fn(batch)  # Could return dict, custom class, etc.
    else:
        return [torch.zeros(shape, dtype=dtype) for ...]  # Always list!
```
Bug -> if `collate_fn` returns a dict, non-empty batches are dicts but empty batches are lists -> type mismatch crash

Existing, related issue: meta-pytorch#534

### Solution:
New `CollateFnWithEmpty` learns the structure from the first non-empty batch:
```python
class CollateFnWithEmpty:
    def __call__(self, batch):
        if len(batch) > 0:
            output = self.wrapped_collator_fn(batch)
            if self.first_batch is None:
                self.first_batch = copy.deepcopy(output)  # Learn structure
        else:
            output = self._make_empty_batch(self.first_batch)  # Replicate structure
        return output

```
Now empty batches match the structure of non-empty batches, regardless of what `collate_fn` returns.

If the first non-empty batch is actually the first batch, then it returns an error:
```python
if self.first_batch is None:
    raise ValueError(
        "First sampled batch cannot be empty. Please ensure your dataset "
        "has sufficient samples or increase sample_rate."
    )
```

### Key Changes

- Removed: `shape_safe()`, `dtype_safe()`, hardcoded list return
- Added: `CollateFnWithEmpty` class with recursive structure replication
- Changed: `wrap_collate_with_empty()` signature: `(collate_fn, sample_empty_shapes, dtype)` -> `(collate_fn, batch_first, rand_on_empty)`

It is compatible with existing API.
A small disclosure: for small percentage of users who hacked around empty batches handling, it might cause problems but in majority of cases it should be compatible.

## How Has This Been Tested (if it applies)

-  We used this approach to fine-tune `Qwen 7B` model using `trl` library for model alignment
-  Tested on `Mellum` 5B parameter model fine-tuning

## Checklist

- [ ] The documentation is up-to-date with the changes I made.
- [] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**).
- [x] All tests passed, and additional code has been covered with new tests.

Pull Request resolved: meta-pytorch#806

Differential Revision: D91500466
iden-kalemaj pushed a commit to iden-kalemaj/opacus that referenced this pull request Mar 26, 2026
…tructure-aware approach (meta-pytorch#806)

Summary:
## Types of changes



- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [x] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Docs change / refactoring / dependency upgrade

## Motivation and Context / Related issue


Replaces unstable shape-based empty batch handling with a stateful approach that learns and replicates the actual output structure from `collate_fn`. This fixes a critical bug where custom collate functions returning non-list structures (dicts, custom classes) were incompatible with Poisson sampling.

The old implementation inspected `dataset[0]` to pre-compute shapes, then hardcoded empty batches as lists:
```python
def collate(batch, collate_fn, sample_empty_shapes, dtypes):
    if len(batch) > 0:
        return collate_fn(batch)  # Could return dict, custom class, etc.
    else:
        return [torch.zeros(shape, dtype=dtype) for ...]  # Always list!
```
Bug -> if `collate_fn` returns a dict, non-empty batches are dicts but empty batches are lists -> type mismatch crash

Existing, related issue: meta-pytorch#534


### Solution:
New `CollateFnWithEmpty` learns the structure from the first non-empty batch:
```python
class CollateFnWithEmpty:
    def __call__(self, batch):
        if len(batch) > 0:
            output = self.wrapped_collator_fn(batch)
            if self.first_batch is None:
                self.first_batch = copy.deepcopy(output)  # Learn structure
        else:
            output = self._make_empty_batch(self.first_batch)  # Replicate structure
        return output

```
Now empty batches match the structure of non-empty batches, regardless of what `collate_fn` returns.

If the first non-empty batch is actually the first batch, then it returns an error:
```python
if self.first_batch is None:
    raise ValueError(
        "First sampled batch cannot be empty. Please ensure your dataset "
        "has sufficient samples or increase sample_rate."
    )
```

### Key Changes

- Removed: `shape_safe()`, `dtype_safe()`, hardcoded list return
- Added: `CollateFnWithEmpty` class with recursive structure replication
- Changed: `wrap_collate_with_empty()` signature: `(collate_fn, sample_empty_shapes, dtype)` -> `(collate_fn, batch_first, rand_on_empty)`

It is compatible with existing API.
A small disclosure: for small percentage of users who hacked around empty batches handling, it might cause problems but in majority of cases it should be compatible.

## How Has This Been Tested (if it applies)



-  We used this approach to fine-tune `Qwen 7B` model using `trl` library for model alignment
-  Tested on `Mellum` 5B parameter model fine-tuning

## Checklist




- [ ] The documentation is up-to-date with the changes I made.
- [] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**).
- [x] All tests passed, and additional code has been covered with new tests.


Test Plan:
Imported from GitHub, without a `Test Plan:` line.
Unit tests

Differential Revision: D98312879

Pulled By: iden-kalemaj
@meta-codesync meta-codesync Bot closed this in 6dc0a27 Mar 26, 2026
@meta-codesync
Copy link
Copy Markdown

meta-codesync Bot commented Mar 26, 2026

@iden-kalemaj merged this pull request in 6dc0a27.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants