From 212bf2212ee56735db9a8457127df26e40cfa0e3 Mon Sep 17 00:00:00 2001 From: chidoziemanagwu Date: Wed, 18 Mar 2026 17:55:36 +0100 Subject: [PATCH 1/3] feat: Add public utility for per-sample gradient validation (#484) --- README.md | 27 +++ .../tests/per_sample_gradients_utils_test.py | 74 +++++++ opacus/utils/__init__.py | 25 +++ opacus/utils/per_sample_gradients_utils.py | 191 +++++++++++++++++- 4 files changed, 316 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 69866ef69..cdfa3c405 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,33 @@ of blogposts and talks: Check out the [FAQ](https://opacus.ai/docs/faq) page for answers to some of the most frequently asked questions about differential privacy and Opacus. +## Verifying Per-Sample Gradients + +If you implement custom layers or grad samplers, you can use +`get_per_sample_gradient_diagnostics` to verify that Opacus computes +per-sample gradients correctly for your model. It compares Opacus's optimized +computation against a reliable (but slow) micro-batch reference and returns a +detailed per-parameter report. + +```python +import torch +from opacus.utils import get_per_sample_gradient_diagnostics + +model = MyCustomModel() +x = torch.randn(8, 16) # sample batch + +report = get_per_sample_gradient_diagnostics(x, model) +if report["passed"]: + print("All per-sample gradients are correct!") +else: + for name, p in report["reductions"]["mean"]["parameters"].items(): + if not p["passed"]: + print(f"FAIL {name}: MSE={p['mse']:.2e}, L1={p['l1_loss']:.2e}") +``` + +The simpler `check_per_sample_gradients_are_correct` function is also available +if you only need a boolean pass/fail result. + ## Contributing See the diff --git a/opacus/tests/per_sample_gradients_utils_test.py b/opacus/tests/per_sample_gradients_utils_test.py index 522a3a209..549d18026 100644 --- a/opacus/tests/per_sample_gradients_utils_test.py +++ b/opacus/tests/per_sample_gradients_utils_test.py @@ -22,6 +22,7 @@ from opacus.utils.per_sample_gradients_utils import ( check_per_sample_gradients_are_correct, get_grad_sample_modes, + get_per_sample_gradient_diagnostics, ) from torch import nn @@ -152,3 +153,76 @@ def test_linear( return self.per_sample_grads_utils_test(x, linear, grad_sample_mode, N == 0) + + +class DiagnosticsUtilsTest(unittest.TestCase): + """Tests for the public get_per_sample_gradient_diagnostics API.""" + + def test_diagnostics_output_structure(self): + """Verify the diagnostics report has the expected dictionary structure.""" + model = nn.Linear(8, 4) + x = torch.randn(2, 8) + + report = get_per_sample_gradient_diagnostics(x, model) + + # Top-level keys + self.assertIn("passed", report) + self.assertIn("num_parameters", report) + self.assertIn("reductions", report) + self.assertIsInstance(report["passed"], bool) + self.assertIsInstance(report["num_parameters"], int) + + # Reduction-level keys + for reduction in ["sum", "mean"]: + self.assertIn(reduction, report["reductions"]) + red_report = report["reductions"][reduction] + self.assertIn("passed", red_report) + self.assertIn("parameters", red_report) + + # Parameter-level keys + for param_name, param_report in red_report["parameters"].items(): + self.assertIn("passed", param_report) + self.assertIn("shape_match", param_report) + self.assertIn("opacus_shape", param_report) + self.assertIn("microbatch_shape", param_report) + self.assertIn("opacus_l2_norm", param_report) + self.assertIn("microbatch_l2_norm", param_report) + self.assertIn("mse", param_report) + self.assertIn("l1_loss", param_report) + + def test_diagnostics_linear_passes(self): + """Verify that a standard nn.Linear model passes diagnostics.""" + model = nn.Linear(10, 5, bias=True) + x = torch.randn(4, 10) + + report = get_per_sample_gradient_diagnostics(x, model) + + self.assertTrue(report["passed"]) + self.assertEqual(report["num_parameters"], 2) # weight + bias + for reduction in ["sum", "mean"]: + red_report = report["reductions"][reduction] + self.assertTrue(red_report["passed"]) + for param_name, param_report in red_report["parameters"].items(): + self.assertTrue(param_report["passed"]) + self.assertTrue(param_report["shape_match"]) + + def test_layer_norm(self): + """Verify diagnostics pass for nn.LayerNorm.""" + W = 8 + model = nn.LayerNorm(W, elementwise_affine=True) + x = torch.randn(4, 6, W) + + report = get_per_sample_gradient_diagnostics(x, model) + self.assertTrue(report["passed"]) + + def test_public_import_path(self): + """Verify the public import from opacus.utils works.""" + from opacus.utils import ( + check_per_sample_gradients_are_correct as check_fn, + ) + from opacus.utils import ( + get_per_sample_gradient_diagnostics as diag_fn, + ) + + self.assertTrue(callable(check_fn)) + self.assertTrue(callable(diag_fn)) diff --git a/opacus/utils/__init__.py b/opacus/utils/__init__.py index e69de29bb..e7da71eb5 100644 --- a/opacus/utils/__init__.py +++ b/opacus/utils/__init__.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .per_sample_gradients_utils import ( + check_per_sample_gradients_are_correct, + get_per_sample_gradient_diagnostics, +) + + +__all__ = [ + "check_per_sample_gradients_are_correct", + "get_per_sample_gradient_diagnostics", +] diff --git a/opacus/utils/per_sample_gradients_utils.py b/opacus/utils/per_sample_gradients_utils.py index 7703b77cf..27a043e88 100644 --- a/opacus/utils/per_sample_gradients_utils.py +++ b/opacus/utils/per_sample_gradients_utils.py @@ -14,7 +14,7 @@ # limitations under the License. import io -from typing import Callable, Dict, Iterable, List, Union +from typing import Any, Callable, Dict, Iterable, List, Union import numpy as np import torch @@ -276,6 +276,195 @@ def check_per_sample_gradients_are_correct( return True +def get_per_sample_gradient_diagnostics( + x: Union[torch.Tensor, PackedSequence], + module: nn.Module, + *, + batch_first: bool = True, + atol: float = 10e-6, + rtol: float = 10e-5, + grad_sample_mode: str = "hooks", +) -> Dict[str, Any]: + """ + Computes detailed diagnostics for per-sample gradient correctness. + + This utility helps users verify that Opacus computes per-sample gradients + correctly for their specific model. It compares the result of the slow but + reliable micro-batch method (computing gradients one sample at a time) with + Opacus's optimized method, and returns a detailed report for each trainable + parameter. + + This is particularly useful when: + - Implementing custom ``grad_samplers`` for new layer types. + - Using ``functorch``-based gradient computation with complex architectures. + - Validating that a model's per-sample gradients are correct before + deploying it for privacy-sensitive training. + + Args: + x: Sample input batch to run through the model. + module: The ``nn.Module`` to check. Should be the raw model + (not wrapped with ``GradSampleModule``). + batch_first: Whether batch size is the first dimension (as opposed + to the second). Defaults to True. + atol: The absolute tolerance parameter for ``torch.allclose``. + Defaults to 10e-6. + rtol: The relative tolerance parameter for ``torch.allclose``. + Defaults to 10e-5. + grad_sample_mode: The Opacus grad sample mode to use. + One of ``"hooks"``, ``"functorch"``, or ``"ew"``. + Defaults to ``"hooks"``. + + Returns: + A dictionary with the following structure:: + + { + "passed": bool, # True if all parameters pass for all reductions + "num_parameters": int, + "reductions": { + "mean": { + "passed": bool, + "parameters": { + "": { + "passed": bool, + "shape_match": bool, + "opacus_shape": tuple, + "microbatch_shape": tuple, + "opacus_l2_norm": float, + "microbatch_l2_norm": float, + "mse": float, # Mean Squared Error + "l1_loss": float, # Mean Absolute Error + }, + ... + } + }, + "sum": { ... } # same structure as "mean" + } + } + + Raises: + RuntimeError: If ``grad_sample_mode="ew"`` and ``batch_first`` is False + or the torch version is incompatible. + RuntimeError: If the input batch ``x`` is empty. + + Example: + >>> import torch + >>> import torch.nn as nn + >>> from opacus.utils.per_sample_gradients_utils import ( + ... get_per_sample_gradient_diagnostics, + ... ) + >>> model = nn.Linear(10, 5) + >>> x = torch.randn(4, 10) + >>> report = get_per_sample_gradient_diagnostics(x, model) + >>> report["passed"] + True + """ + reductions = ["sum", "mean"] + if grad_sample_mode == "ew": + if not batch_first: + raise RuntimeError("Batch should be first dimension.") + if not check_torch_version_for_ew_sample(): + raise RuntimeError(f"Unsupported torch version: {torch.__version__}.") + + all_passed = True + reduction_results: Dict[str, Any] = {} + + for loss_reduction in reductions: + reduction_report = _get_diagnostics_for_reduction( + x, + module, + batch_first=batch_first, + loss_reduction=loss_reduction, + atol=atol, + rtol=rtol, + grad_sample_mode=grad_sample_mode, + ) + if not reduction_report["passed"]: + all_passed = False + reduction_results[loss_reduction] = reduction_report + + return { + "passed": all_passed, + "num_parameters": len( + next(iter(reduction_results.values()))["parameters"] + ), + "reductions": reduction_results, + } + + +def _get_diagnostics_for_reduction( + x: Union[torch.Tensor, PackedSequence], + module: nn.Module, + batch_first: bool = True, + loss_reduction: str = "mean", + atol: float = 10e-6, + rtol: float = 10e-5, + grad_sample_mode: str = "hooks", +) -> Dict[str, Any]: + """ + Internal helper: computes per-parameter diagnostics for a single + loss reduction mode (``"mean"`` or ``"sum"``). + """ + ( + microbatch_grad_samples, + opacus_grad_samples, + ) = compute_grad_samples_microbatch_and_opacus( + x, + module, + batch_first=batch_first, + loss_reduction=loss_reduction, + grad_sample_mode=grad_sample_mode, + ) + + all_passed = True + parameters: Dict[str, Dict[str, Any]] = {} + + for name, opacus_grad_sample in opacus_grad_samples.items(): + microbatch_grad_sample = microbatch_grad_samples[name] + + shape_match = opacus_grad_sample.shape == microbatch_grad_sample.shape + + if shape_match: + values_close = bool( + torch.allclose( + microbatch_grad_sample, opacus_grad_sample, atol, rtol + ) + ) + mse = float( + torch.nn.functional.mse_loss( + opacus_grad_sample, microbatch_grad_sample + ) + ) + l1_loss = float( + torch.nn.functional.l1_loss( + opacus_grad_sample, microbatch_grad_sample + ) + ) + else: + values_close = False + mse = float("inf") + l1_loss = float("inf") + + param_passed = shape_match and values_close + if not param_passed: + all_passed = False + + parameters[name] = { + "passed": param_passed, + "shape_match": shape_match, + "opacus_shape": tuple(opacus_grad_sample.shape), + "microbatch_shape": tuple(microbatch_grad_sample.shape), + "opacus_l2_norm": float(opacus_grad_sample.norm(2)), + "microbatch_l2_norm": float(microbatch_grad_sample.norm(2)), + "mse": mse, + "l1_loss": l1_loss, + } + + return { + "passed": all_passed, + "parameters": parameters, + } + + def compute_microbatch_grad_sample_tensor_or_seq( x: Union[torch.Tensor, PackedSequence], module: nn.Module, From 6476e8f5de19e6a6d600a0242894436220f200b6 Mon Sep 17 00:00:00 2001 From: chidoziemanagwu Date: Wed, 25 Mar 2026 17:47:37 +0100 Subject: [PATCH 2/3] style: apply black formatting --- opacus/utils/per_sample_gradients_utils.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/opacus/utils/per_sample_gradients_utils.py b/opacus/utils/per_sample_gradients_utils.py index 27a043e88..009b6788f 100644 --- a/opacus/utils/per_sample_gradients_utils.py +++ b/opacus/utils/per_sample_gradients_utils.py @@ -384,9 +384,7 @@ def get_per_sample_gradient_diagnostics( return { "passed": all_passed, - "num_parameters": len( - next(iter(reduction_results.values()))["parameters"] - ), + "num_parameters": len(next(iter(reduction_results.values()))["parameters"]), "reductions": reduction_results, } @@ -425,19 +423,13 @@ def _get_diagnostics_for_reduction( if shape_match: values_close = bool( - torch.allclose( - microbatch_grad_sample, opacus_grad_sample, atol, rtol - ) + torch.allclose(microbatch_grad_sample, opacus_grad_sample, atol, rtol) ) mse = float( - torch.nn.functional.mse_loss( - opacus_grad_sample, microbatch_grad_sample - ) + torch.nn.functional.mse_loss(opacus_grad_sample, microbatch_grad_sample) ) l1_loss = float( - torch.nn.functional.l1_loss( - opacus_grad_sample, microbatch_grad_sample - ) + torch.nn.functional.l1_loss(opacus_grad_sample, microbatch_grad_sample) ) else: values_close = False From 69253de8f292300075b6e0315e08bfab73ed4a47 Mon Sep 17 00:00:00 2001 From: chidoziemanagwu Date: Fri, 22 May 2026 21:57:11 +0100 Subject: [PATCH 3/3] refactor: address PR #810 review - hide check_per_sample_gradients_are_correct - Remove check_per_sample_gradients_are_correct from public opacus.utils API - Drop README mention of the boolean helper for simplicity - Add diagnostics test exercising the mismatched-gradient (failing) path --- README.md | 3 -- .../tests/per_sample_gradients_utils_test.py | 32 ++++++++++++++++--- opacus/utils/__init__.py | 6 +--- 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index cdfa3c405..1bf590237 100644 --- a/README.md +++ b/README.md @@ -155,9 +155,6 @@ else: print(f"FAIL {name}: MSE={p['mse']:.2e}, L1={p['l1_loss']:.2e}") ``` -The simpler `check_per_sample_gradients_are_correct` function is also available -if you only need a boolean pass/fail result. - ## Contributing See the diff --git a/opacus/tests/per_sample_gradients_utils_test.py b/opacus/tests/per_sample_gradients_utils_test.py index 549d18026..bbfb26741 100644 --- a/opacus/tests/per_sample_gradients_utils_test.py +++ b/opacus/tests/per_sample_gradients_utils_test.py @@ -15,12 +15,14 @@ import unittest from typing import Callable +from unittest.mock import patch import hypothesis.strategies as st import torch from hypothesis import given, settings from opacus.utils.per_sample_gradients_utils import ( check_per_sample_gradients_are_correct, + compute_opacus_grad_sample, get_grad_sample_modes, get_per_sample_gradient_diagnostics, ) @@ -217,12 +219,34 @@ def test_layer_norm(self): def test_public_import_path(self): """Verify the public import from opacus.utils works.""" - from opacus.utils import ( - check_per_sample_gradients_are_correct as check_fn, - ) from opacus.utils import ( get_per_sample_gradient_diagnostics as diag_fn, ) - self.assertTrue(callable(check_fn)) self.assertTrue(callable(diag_fn)) + + def test_diagnostics_reports_mismatch(self): + """Verify the diagnostics report flags failure when per-sample gradients + do not match the micro-batch reference.""" + model = nn.Linear(10, 5, bias=True) + x = torch.randn(4, 10) + + def perturbed_compute_opacus_grad_sample(*args, **kwargs): + result = compute_opacus_grad_sample(*args, **kwargs) + return {name: g + 1.0 for name, g in result.items()} + + with patch( + "opacus.utils.per_sample_gradients_utils.compute_opacus_grad_sample", + side_effect=perturbed_compute_opacus_grad_sample, + ): + report = get_per_sample_gradient_diagnostics(x, model) + + self.assertFalse(report["passed"]) + for reduction in ["sum", "mean"]: + red_report = report["reductions"][reduction] + self.assertFalse(red_report["passed"]) + for param_report in red_report["parameters"].values(): + self.assertFalse(param_report["passed"]) + self.assertTrue(param_report["shape_match"]) + self.assertGreater(param_report["mse"], 0.0) + self.assertGreater(param_report["l1_loss"], 0.0) diff --git a/opacus/utils/__init__.py b/opacus/utils/__init__.py index e7da71eb5..7ed0954c6 100644 --- a/opacus/utils/__init__.py +++ b/opacus/utils/__init__.py @@ -13,13 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .per_sample_gradients_utils import ( - check_per_sample_gradients_are_correct, - get_per_sample_gradient_diagnostics, -) +from .per_sample_gradients_utils import get_per_sample_gradient_diagnostics __all__ = [ - "check_per_sample_gradients_are_correct", "get_per_sample_gradient_diagnostics", ]