diff --git a/opacus/validators/module_validator.py b/opacus/validators/module_validator.py index b3d5fb58a..c08bcfc7d 100644 --- a/opacus/validators/module_validator.py +++ b/opacus/validators/module_validator.py @@ -59,11 +59,17 @@ def validate( IllegalModuleConfigurationError("Model needs to be in training mode") ) # 2. perform module specific validations for trainable modules. - # TODO: use module name here - it's useful part of error message - for _, sub_module in trainable_modules(module): + for module_name, sub_module in trainable_modules(module): if type(sub_module) in ModuleValidator.VALIDATORS: sub_module_validator = ModuleValidator.VALIDATORS[type(sub_module)] - errors.extend(sub_module_validator(sub_module)) + sub_errors = sub_module_validator(sub_module) + for err in sub_errors: + # Prepend module name to error message for easier debugging + err.args = ( + f"{module_name} ({type(sub_module).__name__}): {err.args[0]}", + *err.args[1:], + ) + errors.extend(sub_errors) # raise/return as needed if strict and len(errors) > 0: raise UnsupportedModuleError(errors)