Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,30 @@ of blogposts and talks:
Check out the [FAQ](https://opacus.ai/docs/faq) page for answers to some of the
most frequently asked questions about differential privacy and Opacus.

## Verifying Per-Sample Gradients

If you implement custom layers or grad samplers, you can use
`get_per_sample_gradient_diagnostics` to verify that Opacus computes
per-sample gradients correctly for your model. It compares Opacus's optimized
computation against a reliable (but slow) micro-batch reference and returns a
detailed per-parameter report.

```python
import torch
from opacus.utils import get_per_sample_gradient_diagnostics

model = MyCustomModel()
x = torch.randn(8, 16) # sample batch

report = get_per_sample_gradient_diagnostics(x, model)
if report["passed"]:
print("All per-sample gradients are correct!")
else:
for name, p in report["reductions"]["mean"]["parameters"].items():
if not p["passed"]:
print(f"FAIL {name}: MSE={p['mse']:.2e}, L1={p['l1_loss']:.2e}")
```

## Contributing

See the
Expand Down
98 changes: 98 additions & 0 deletions opacus/tests/per_sample_gradients_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@

import unittest
from typing import Callable
from unittest.mock import patch

import hypothesis.strategies as st
import torch
from hypothesis import given, settings
from opacus.utils.per_sample_gradients_utils import (
check_per_sample_gradients_are_correct,
compute_opacus_grad_sample,
get_grad_sample_modes,
get_per_sample_gradient_diagnostics,
)
from torch import nn

Expand Down Expand Up @@ -152,3 +155,98 @@ def test_linear(
return

self.per_sample_grads_utils_test(x, linear, grad_sample_mode, N == 0)


class DiagnosticsUtilsTest(unittest.TestCase):
"""Tests for the public get_per_sample_gradient_diagnostics API."""

def test_diagnostics_output_structure(self):
"""Verify the diagnostics report has the expected dictionary structure."""
model = nn.Linear(8, 4)
x = torch.randn(2, 8)

report = get_per_sample_gradient_diagnostics(x, model)

# Top-level keys
self.assertIn("passed", report)
self.assertIn("num_parameters", report)
self.assertIn("reductions", report)
self.assertIsInstance(report["passed"], bool)
self.assertIsInstance(report["num_parameters"], int)

# Reduction-level keys
for reduction in ["sum", "mean"]:
self.assertIn(reduction, report["reductions"])
red_report = report["reductions"][reduction]
self.assertIn("passed", red_report)
self.assertIn("parameters", red_report)

# Parameter-level keys
for param_name, param_report in red_report["parameters"].items():
self.assertIn("passed", param_report)
self.assertIn("shape_match", param_report)
self.assertIn("opacus_shape", param_report)
self.assertIn("microbatch_shape", param_report)
self.assertIn("opacus_l2_norm", param_report)
self.assertIn("microbatch_l2_norm", param_report)
self.assertIn("mse", param_report)
self.assertIn("l1_loss", param_report)

def test_diagnostics_linear_passes(self):
"""Verify that a standard nn.Linear model passes diagnostics."""
model = nn.Linear(10, 5, bias=True)
x = torch.randn(4, 10)

report = get_per_sample_gradient_diagnostics(x, model)

self.assertTrue(report["passed"])
self.assertEqual(report["num_parameters"], 2) # weight + bias
for reduction in ["sum", "mean"]:
red_report = report["reductions"][reduction]
self.assertTrue(red_report["passed"])
for param_name, param_report in red_report["parameters"].items():
self.assertTrue(param_report["passed"])
self.assertTrue(param_report["shape_match"])

def test_layer_norm(self):
"""Verify diagnostics pass for nn.LayerNorm."""
W = 8
model = nn.LayerNorm(W, elementwise_affine=True)
x = torch.randn(4, 6, W)

report = get_per_sample_gradient_diagnostics(x, model)
self.assertTrue(report["passed"])

def test_public_import_path(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance we could add a test diagnosing mismatched gradients (i.e., assertFalse rather than assertTrue)?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure we can

"""Verify the public import from opacus.utils works."""
from opacus.utils import (
get_per_sample_gradient_diagnostics as diag_fn,
)

self.assertTrue(callable(diag_fn))

def test_diagnostics_reports_mismatch(self):
"""Verify the diagnostics report flags failure when per-sample gradients
do not match the micro-batch reference."""
model = nn.Linear(10, 5, bias=True)
x = torch.randn(4, 10)

def perturbed_compute_opacus_grad_sample(*args, **kwargs):
result = compute_opacus_grad_sample(*args, **kwargs)
return {name: g + 1.0 for name, g in result.items()}

with patch(
"opacus.utils.per_sample_gradients_utils.compute_opacus_grad_sample",
side_effect=perturbed_compute_opacus_grad_sample,
):
report = get_per_sample_gradient_diagnostics(x, model)

self.assertFalse(report["passed"])
for reduction in ["sum", "mean"]:
red_report = report["reductions"][reduction]
self.assertFalse(red_report["passed"])
for param_report in red_report["parameters"].values():
self.assertFalse(param_report["passed"])
self.assertTrue(param_report["shape_match"])
self.assertGreater(param_report["mse"], 0.0)
self.assertGreater(param_report["l1_loss"], 0.0)
21 changes: 21 additions & 0 deletions opacus/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .per_sample_gradients_utils import get_per_sample_gradient_diagnostics


__all__ = [
"get_per_sample_gradient_diagnostics",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why we need to expose check_per_sample_gradients_are_correct? I thought get_per_sample_gradient_diagnostics should have contained all the information needed.

]
183 changes: 182 additions & 1 deletion opacus/utils/per_sample_gradients_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import io
from typing import Callable, Dict, Iterable, List, Union
from typing import Any, Callable, Dict, Iterable, List, Union

import numpy as np
import torch
Expand Down Expand Up @@ -276,6 +276,187 @@ def check_per_sample_gradients_are_correct(
return True


def get_per_sample_gradient_diagnostics(
x: Union[torch.Tensor, PackedSequence],
module: nn.Module,
*,
batch_first: bool = True,
atol: float = 10e-6,
rtol: float = 10e-5,
grad_sample_mode: str = "hooks",
) -> Dict[str, Any]:
"""
Computes detailed diagnostics for per-sample gradient correctness.

This utility helps users verify that Opacus computes per-sample gradients
correctly for their specific model. It compares the result of the slow but
reliable micro-batch method (computing gradients one sample at a time) with
Opacus's optimized method, and returns a detailed report for each trainable
parameter.

This is particularly useful when:
- Implementing custom ``grad_samplers`` for new layer types.
- Using ``functorch``-based gradient computation with complex architectures.
- Validating that a model's per-sample gradients are correct before
deploying it for privacy-sensitive training.

Args:
x: Sample input batch to run through the model.
module: The ``nn.Module`` to check. Should be the raw model
(not wrapped with ``GradSampleModule``).
batch_first: Whether batch size is the first dimension (as opposed
to the second). Defaults to True.
atol: The absolute tolerance parameter for ``torch.allclose``.
Defaults to 10e-6.
rtol: The relative tolerance parameter for ``torch.allclose``.
Defaults to 10e-5.
grad_sample_mode: The Opacus grad sample mode to use.
One of ``"hooks"``, ``"functorch"``, or ``"ew"``.
Defaults to ``"hooks"``.

Returns:
A dictionary with the following structure::

{
"passed": bool, # True if all parameters pass for all reductions
"num_parameters": int,
"reductions": {
"mean": {
"passed": bool,
"parameters": {
"<param_name>": {
"passed": bool,
"shape_match": bool,
"opacus_shape": tuple,
"microbatch_shape": tuple,
"opacus_l2_norm": float,
"microbatch_l2_norm": float,
"mse": float, # Mean Squared Error
"l1_loss": float, # Mean Absolute Error
},
...
}
},
"sum": { ... } # same structure as "mean"
}
}

Raises:
RuntimeError: If ``grad_sample_mode="ew"`` and ``batch_first`` is False
or the torch version is incompatible.
RuntimeError: If the input batch ``x`` is empty.

Example:
>>> import torch
>>> import torch.nn as nn
>>> from opacus.utils.per_sample_gradients_utils import (
... get_per_sample_gradient_diagnostics,
... )
>>> model = nn.Linear(10, 5)
>>> x = torch.randn(4, 10)
>>> report = get_per_sample_gradient_diagnostics(x, model)
>>> report["passed"]
True
"""
reductions = ["sum", "mean"]
if grad_sample_mode == "ew":
if not batch_first:
raise RuntimeError("Batch should be first dimension.")
if not check_torch_version_for_ew_sample():
raise RuntimeError(f"Unsupported torch version: {torch.__version__}.")

all_passed = True
reduction_results: Dict[str, Any] = {}

for loss_reduction in reductions:
reduction_report = _get_diagnostics_for_reduction(
x,
module,
batch_first=batch_first,
loss_reduction=loss_reduction,
atol=atol,
rtol=rtol,
grad_sample_mode=grad_sample_mode,
)
if not reduction_report["passed"]:
all_passed = False
reduction_results[loss_reduction] = reduction_report

return {
"passed": all_passed,
"num_parameters": len(next(iter(reduction_results.values()))["parameters"]),
"reductions": reduction_results,
}


def _get_diagnostics_for_reduction(
x: Union[torch.Tensor, PackedSequence],
module: nn.Module,
batch_first: bool = True,
loss_reduction: str = "mean",
atol: float = 10e-6,
rtol: float = 10e-5,
grad_sample_mode: str = "hooks",
) -> Dict[str, Any]:
"""
Internal helper: computes per-parameter diagnostics for a single
loss reduction mode (``"mean"`` or ``"sum"``).
"""
(
microbatch_grad_samples,
opacus_grad_samples,
) = compute_grad_samples_microbatch_and_opacus(
x,
module,
batch_first=batch_first,
loss_reduction=loss_reduction,
grad_sample_mode=grad_sample_mode,
)

all_passed = True
parameters: Dict[str, Dict[str, Any]] = {}

for name, opacus_grad_sample in opacus_grad_samples.items():
microbatch_grad_sample = microbatch_grad_samples[name]

shape_match = opacus_grad_sample.shape == microbatch_grad_sample.shape

if shape_match:
values_close = bool(
torch.allclose(microbatch_grad_sample, opacus_grad_sample, atol, rtol)
)
mse = float(
torch.nn.functional.mse_loss(opacus_grad_sample, microbatch_grad_sample)
)
l1_loss = float(
torch.nn.functional.l1_loss(opacus_grad_sample, microbatch_grad_sample)
)
else:
values_close = False
mse = float("inf")
l1_loss = float("inf")

param_passed = shape_match and values_close
if not param_passed:
all_passed = False

parameters[name] = {
"passed": param_passed,
"shape_match": shape_match,
"opacus_shape": tuple(opacus_grad_sample.shape),
"microbatch_shape": tuple(microbatch_grad_sample.shape),
"opacus_l2_norm": float(opacus_grad_sample.norm(2)),
"microbatch_l2_norm": float(microbatch_grad_sample.norm(2)),
"mse": mse,
"l1_loss": l1_loss,
}

return {
"passed": all_passed,
"parameters": parameters,
}


def compute_microbatch_grad_sample_tensor_or_seq(
x: Union[torch.Tensor, PackedSequence],
module: nn.Module,
Expand Down