From d8fcfe5bb8141734adfcfba8693579a697f736e2 Mon Sep 17 00:00:00 2001 From: "Vangalla, Rohith" Date: Fri, 22 May 2026 10:54:17 -0500 Subject: [PATCH] Include module name and type in validation error messages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, validation errors from ModuleValidator.validate() did not indicate which specific sub-module caused the failure. For models with many layers, this made debugging DP compatibility issues difficult — users had to manually inspect each layer to find the offending module. This change prepends the module path and type to each validation error message. For example, instead of: 'BatchNorm cannot support DP' users now see: 'encoder.layer.3.norm (BatchNorm2d): BatchNorm cannot support DP' This makes it immediately clear which layer needs to be replaced, especially in large models with hundreds of sub-modules. Addresses the TODO: 'use module name here - it's useful part of error message' --- opacus/validators/module_validator.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/opacus/validators/module_validator.py b/opacus/validators/module_validator.py index b3d5fb58a..c08bcfc7d 100644 --- a/opacus/validators/module_validator.py +++ b/opacus/validators/module_validator.py @@ -59,11 +59,17 @@ def validate( IllegalModuleConfigurationError("Model needs to be in training mode") ) # 2. perform module specific validations for trainable modules. - # TODO: use module name here - it's useful part of error message - for _, sub_module in trainable_modules(module): + for module_name, sub_module in trainable_modules(module): if type(sub_module) in ModuleValidator.VALIDATORS: sub_module_validator = ModuleValidator.VALIDATORS[type(sub_module)] - errors.extend(sub_module_validator(sub_module)) + sub_errors = sub_module_validator(sub_module) + for err in sub_errors: + # Prepend module name to error message for easier debugging + err.args = ( + f"{module_name} ({type(sub_module).__name__}): {err.args[0]}", + *err.args[1:], + ) + errors.extend(sub_errors) # raise/return as needed if strict and len(errors) > 0: raise UnsupportedModuleError(errors)