diff --git a/act/front_end/spec_creator_base.py b/act/front_end/spec_creator_base.py index 98df08aa4..bd0ff2872 100644 --- a/act/front_end/spec_creator_base.py +++ b/act/front_end/spec_creator_base.py @@ -327,7 +327,9 @@ def _validate_output_spec_shape( (is_valid, error_message) tuple """ if spec.kind in [OutKind.MARGIN_ROBUST, OutKind.TOP1_ROBUST]: - if not (0 <= spec.y_true < num_classes): + y_true_valid_class = (0 <= spec.y_true).logical_and(spec.y_true < num_classes) + + if not y_true_valid_class.all(): return False, f"Class label {spec.y_true} out of range [0, {num_classes})" elif spec.kind == OutKind.LINEAR_LE: diff --git a/act/front_end/torchvision_loader/create_specs.py b/act/front_end/torchvision_loader/create_specs.py index b55ce2af8..eb60c8283 100644 --- a/act/front_end/torchvision_loader/create_specs.py +++ b/act/front_end/torchvision_loader/create_specs.py @@ -464,7 +464,7 @@ def _validate_and_filter_specs( for input_spec, output_spec in spec_pairs: try: - is_valid = self.validate_spec_pair_with_model( + is_valid, errors = self.validate_spec_pair_with_model( input_spec, output_spec, pytorch_model, @@ -475,8 +475,8 @@ def _validate_and_filter_specs( valid_pairs.append((input_spec, output_spec)) else: logger.debug( - f"Spec pair validation failed: " - f"{input_spec.kind}, {output_spec.kind}" + f"Spec pair validation failed: {input_spec.kind}, {output_spec.kind}, with errors:\n" + f"{"\n".join(errors)}" ) except Exception as e: diff --git a/act/front_end/vnnlib_loader/create_specs.py b/act/front_end/vnnlib_loader/create_specs.py index 21dae0eac..f1036ba05 100644 --- a/act/front_end/vnnlib_loader/create_specs.py +++ b/act/front_end/vnnlib_loader/create_specs.py @@ -281,7 +281,7 @@ def _validate_and_filter_specs( for input_spec, output_spec in spec_pairs: try: - is_valid = self.validate_spec_pair_with_model( + is_valid, errors = self.validate_spec_pair_with_model( input_spec, output_spec, pytorch_model, @@ -292,8 +292,8 @@ def _validate_and_filter_specs( valid_pairs.append((input_spec, output_spec)) else: logger.debug( - f"Spec pair validation failed: " - f"{input_spec.kind}, {output_spec.kind}" + f"Spec pair validation failed: {input_spec.kind}, {output_spec.kind}, with errors:\n" + f"{"\n".join(errors)}" ) except Exception as e: