Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion act/front_end/spec_creator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,9 @@ def _validate_output_spec_shape(
(is_valid, error_message) tuple
"""
if spec.kind in [OutKind.MARGIN_ROBUST, OutKind.TOP1_ROBUST]:
if not (0 <= spec.y_true < num_classes):
y_true_valid_class = (0 <= spec.y_true).logical_and(spec.y_true < num_classes)

if not y_true_valid_class.all():
return False, f"Class label {spec.y_true} out of range [0, {num_classes})"

elif spec.kind == OutKind.LINEAR_LE:
Expand Down
6 changes: 3 additions & 3 deletions act/front_end/torchvision_loader/create_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def _validate_and_filter_specs(

for input_spec, output_spec in spec_pairs:
try:
is_valid = self.validate_spec_pair_with_model(
is_valid, errors = self.validate_spec_pair_with_model(
input_spec,
output_spec,
pytorch_model,
Expand All @@ -475,8 +475,8 @@ def _validate_and_filter_specs(
valid_pairs.append((input_spec, output_spec))
else:
logger.debug(
f"Spec pair validation failed: "
f"{input_spec.kind}, {output_spec.kind}"
f"Spec pair validation failed: {input_spec.kind}, {output_spec.kind}, with errors:\n"
f"{"\n".join(errors)}"
)

except Exception as e:
Expand Down
6 changes: 3 additions & 3 deletions act/front_end/vnnlib_loader/create_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def _validate_and_filter_specs(

for input_spec, output_spec in spec_pairs:
try:
is_valid = self.validate_spec_pair_with_model(
is_valid, errors = self.validate_spec_pair_with_model(
input_spec,
output_spec,
pytorch_model,
Expand All @@ -292,8 +292,8 @@ def _validate_and_filter_specs(
valid_pairs.append((input_spec, output_spec))
else:
logger.debug(
f"Spec pair validation failed: "
f"{input_spec.kind}, {output_spec.kind}"
f"Spec pair validation failed: {input_spec.kind}, {output_spec.kind}, with errors:\n"
f"{"\n".join(errors)}"
)

except Exception as e:
Expand Down
Loading