Skip to content

feat: Add public utility for per-sample gradient validation (#484)#810

Open
chidoziemanagwu wants to merge 3 commits into
meta-pytorch:mainfrom
chidoziemanagwu:feature/issue-484-per-sample-gradient-diagnostics
Open

feat: Add public utility for per-sample gradient validation (#484)#810
chidoziemanagwu wants to merge 3 commits into
meta-pytorch:mainfrom
chidoziemanagwu:feature/issue-484-per-sample-gradient-diagnostics

Conversation

@chidoziemanagwu
Copy link
Copy Markdown

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Docs change / refactoring / dependency upgrade

Motivation and Context / Related issue

Fixes #484: Utility to test correctness of per sample gradients

Context:
Currently, users building custom privacy models or extending Opacus with custom grad_samplers have no public-facing mechanism to verify that Opacus computes their per-sample gradients mathematically correctly. The internal validation method check_per_sample_gradients_are_correct swallows all diagnostic data and only returns a boolean flag, making it impossible for engineers to debug why their Layer gradient failed.

The Approach:
I extracted the core comparison logic between Opacus's reference micro-batch computation and the optimized hook computation into a newly exported public API: get_per_sample_gradient_diagnostics.

  • Rich Diagnostics: This function returns strict mathematical assertions (L1 Loss, Mean Squared Error, L2 Norms, and Shape Mismatches) for every individually named trainable parameter processed by Opacus across both mean and sum loss reduction paradigms.
    # Example output dictionary per parameter
    'weight': {
        'passed': True,
        'shape_match': True,
        'opacus_shape': (4, 5, 10),
        'microbatch_shape': (4, 5, 10),
        'opacus_l2_norm': 2.45,
        'microbatch_l2_norm': 2.45,
        'mse': 1.2e-15,
        'l1_loss': 3.4e-8
    }
  • Backward Compatibility: I deliberately left the original internal assertions intact to ensure zero testing breakages for upstream Opacus developers. The new APIs are properly documented and exported via opacus/utils/__init__.py.

Community Impact (Scalability & Architecture):
By exposing this tool, external researchers and security engineers can independently verify their custom mechanisms against the Opacus framework. This lowers the technical barrier to entry for developing novel Differential Privacy architectures, directly shifting the validation and debugging burden away from project maintainers.

How Has This Been Tested (if it applies)

I expanded the opacus.tests.per_sample_gradients_utils_test.py suite. Alongside Conv1d and Linear tests, I introduced coverage for LayerNorm arrays, validated the exact structured dictionary returned by the new diagnostic tool, and tested the public import routing. I also updated the README.md to reflect these public enhancements.

Usage Example:

import torch
import torch.nn as nn
from opacus.utils import get_per_sample_gradient_diagnostics

model = nn.Linear(10, 5)
x = torch.randn(4, 10)

report = get_per_sample_gradient_diagnostics(x, model)
if report["passed"]:
    print("All per-sample gradients are mathematically correct.")
else:
    for name, p in report["reductions"]["mean"]["parameters"].items():
         if not p["passed"]:
              print(f"FAILED LAYER: {name} (MSE: {p['mse']:.2e}, Shape Match: {p['shape_match']})")

Test Execution & Visual Proof (Terminal Log):

$ python -m unittest opacus.tests.per_sample_gradients_utils_test
UserWarning: Full backward hook is firing when gradients are computed with respect to module outputs since no inputs require gradients.
  loss.backward()
......
----------------------------------------------------------------------
Ran 6 tests in 11.077s

OK

Checklist

  • The documentation is up-to-date with the changes I made.
  • I have read the CONTRIBUTING document and completed the CLA (see CONTRIBUTING).
  • All tests passed, and additional code has been covered with new tests.

@meta-cla
Copy link
Copy Markdown

meta-cla Bot commented Mar 18, 2026

Hi @chidoziemanagwu!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@meta-cla
Copy link
Copy Markdown

meta-cla Bot commented Mar 18, 2026

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 18, 2026
@chidoziemanagwu
Copy link
Copy Markdown
Author

hi @iden-kalemaj @alexandresablayrolles the CLA has been signed and all import checks are passing. This PR adds a public diagnostic utility for per-sample gradient verification (Issue #484). Would you be able to approve the workflows and review when you have a moment? Thank you.

@meta-codesync
Copy link
Copy Markdown

meta-codesync Bot commented Mar 25, 2026

@facebook-github-bot has imported this pull request. If you are a Meta employee, you can view this in D98158224. (Because this pull request was imported automatically, there will not be any future comments.)

@HuanyuZhang HuanyuZhang self-assigned this Mar 26, 2026
@chidoziemanagwu
Copy link
Copy Markdown
Author

Hi @HuanyuZhang, I noticed this PR was imported into Meta's internal system (D98158224) and assigned a few weeks ago.

Just checking in to see if there is any feedback from the internal review or if there are any additional changes needed on my end to move this toward a merge. This utility will be a great help for anyone building custom layers in Opacus!

Thanks for your time.

@HuanyuZhang
Copy link
Copy Markdown
Contributor

Thx @chidoziemanagwu for the reminder. Just left some comments.

Comment thread opacus/utils/__init__.py

__all__ = [
"check_per_sample_gradients_are_correct",
"get_per_sample_gradient_diagnostics",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why we need to expose check_per_sample_gradients_are_correct? I thought get_per_sample_gradient_diagnostics should have contained all the information needed.

Comment thread README.md Outdated
```

The simpler `check_per_sample_gradients_are_correct` function is also available
if you only need a boolean pass/fail result.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us hide check_per_sample_gradients_are_correct for simplicity.

report = get_per_sample_gradient_diagnostics(x, model)
self.assertTrue(report["passed"])

def test_public_import_path(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance we could add a test diagnosing mismatched gradients (i.e., assertFalse rather than assertTrue)?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure we can

chidoziemanagwu added a commit to chidoziemanagwu/opacus that referenced this pull request May 22, 2026
…gradients_are_correct

- Remove check_per_sample_gradients_are_correct from public opacus.utils API
- Drop README mention of the boolean helper for simplicity
- Add diagnostics test exercising the mismatched-gradient (failing) path

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…gradients_are_correct

- Remove check_per_sample_gradients_are_correct from public opacus.utils API
- Drop README mention of the boolean helper for simplicity
- Add diagnostics test exercising the mismatched-gradient (failing) path
@chidoziemanagwu chidoziemanagwu force-pushed the feature/issue-484-per-sample-gradient-diagnostics branch from fcaf563 to 69253de Compare May 22, 2026 21:06
@chidoziemanagwu
Copy link
Copy Markdown
Author

Hello @HuanyuZhang I made an update pls review :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Utility to test the correctness of per sample gradients

2 participants