Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ciao/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Model prediction utilities for CIAO."""

from ciao.model.predictor import ModelPredictor


__all__ = ["ModelPredictor"]
44 changes: 44 additions & 0 deletions ciao/model/predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch


class ModelPredictor:
"""Handles model predictions and class information."""

def __init__(self, model: torch.nn.Module, class_names: list[str]) -> None:
self.model = model

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's a good practice to set the model to evaluation mode within the predictor's __init__. This ensures that layers like Dropout or BatchNorm behave correctly for inference, making the class more robust and preventing potential inconsistencies in predictions. The caller of this class should not have to remember to call .eval() on the model beforehand.

Suggested change
self.model = model
self.model = model.eval()

self.class_names = class_names
self.device = next(model.parameters()).device
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

next(model.parameters()).device will raise StopIteration for parameterless modules (e.g., wrappers or functional models). Consider handling the empty-parameter case explicitly (fallback device, or derive device from buffers), or avoid storing self.device if it isn’t reliable/used.

Suggested change
self.device = next(model.parameters()).device
params = list(model.parameters())
if params:
self.device = params[0].device
else:
buffers = list(model.buffers())
if buffers:
self.device = buffers[0].device
else:
self.device = torch.device("cpu")

Copilot uses AI. Check for mistakes.

def get_predictions(self, input_batch: torch.Tensor) -> torch.Tensor:
"""Get model predictions (returns probabilities)."""
with torch.no_grad():
outputs = self.model(input_batch)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
Comment on lines +12 to +16
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For inference wrappers like this, it’s easy to get nondeterministic outputs if the model is left in training mode (dropout/BatchNorm). Consider switching the model to eval() inside the predictor (or clearly documenting that callers must set eval mode) before running predictions.

Copilot uses AI. Check for mistakes.
return probabilities
Comment on lines +12 to +17
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_predictions() forwards input_batch as-is. If the model is on GPU and the caller passes a CPU tensor (or vice versa), this will raise a device mismatch error. Consider moving input_batch to self.device (or validating and raising a clear error) before calling the model.

Copilot uses AI. Check for mistakes.

def predict_image(
self, input_batch: torch.Tensor, top_k: int = 5
) -> list[tuple[int, str, float]]:
"""Get top-k predictions for an image."""
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predict_image() assumes a batch with at least one element and always uses probabilities[0], silently ignoring the rest of the batch. Either enforce/validate batch_size == 1 (with a clear error) or change the API to accept a single image tensor (and add an unsqueeze(0) path for 3D inputs) to match the method name.

Suggested change
"""Get top-k predictions for an image."""
"""Get top-k predictions for a single image.
Expects `input_batch` to contain exactly one image (batch size 1).
"""
# Validate that the input is a non-empty batch with exactly one element.
if input_batch.dim() < 1:
raise ValueError(
f"predict_image expects a batched tensor with batch size 1; got tensor with shape {tuple(input_batch.shape)}"
)
if input_batch.size(0) != 1:
raise ValueError(
f"predict_image expects a batch size of 1, but got batch size {input_batch.size(0)}."
)

Copilot uses AI. Check for mistakes.
probabilities = self.get_predictions(input_batch)
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.topk(probabilities[0], top_k) will throw if top_k is greater than the number of classes. Consider clamping top_k to probabilities.shape[1] or raising a ValueError with a helpful message when top_k is out of range.

Suggested change
probabilities = self.get_predictions(input_batch)
probabilities = self.get_predictions(input_batch)
num_classes = probabilities.shape[1]
if top_k <= 0:
raise ValueError(f"top_k must be a positive integer, got {top_k}.")
if top_k > num_classes:
raise ValueError(
f"top_k ({top_k}) cannot be greater than the number of classes ({num_classes})."
)

Copilot uses AI. Check for mistakes.
top_probs, top_indices = torch.topk(probabilities[0], top_k)

results = []
for i in range(top_k):
class_idx = int(top_indices[i].item())
prob = float(top_probs[i].item())
class_name = (
self.class_names[class_idx]
if class_idx < len(self.class_names)
else f"class_{class_idx}"
)
results.append((class_idx, class_name, prob))
Comment on lines +26 to +35
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In predict_image(), calling .item() inside the loop can cause repeated device synchronization when tensors are on GPU. Consider converting top_indices/top_probs to CPU (or Python lists) once before the loop to avoid per-element sync overhead.

Copilot uses AI. Check for mistakes.
return results
Comment on lines +19 to +36

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The method predict_image is misleading and has a bug. The parameter name input_batch implies it can handle multiple images, but the implementation probabilities[0] only processes the first image in the batch, silently discarding the rest. This will lead to incorrect results when a batch with more than one image is provided.

To fix this, the method should be updated to correctly process the entire batch. I suggest renaming it to predict_batch to make its behavior clear, and modifying the implementation to handle all images in the input tensor. The return type should also be updated to a list of lists, where each inner list contains the predictions for one image.

    def predict_batch(
        self, input_batch: torch.Tensor, top_k: int = 5
    ) -> list[list[tuple[int, str, float]]]:
        """Get top-k predictions for a batch of images."""
        probabilities = self.get_predictions(input_batch)
        top_probs, top_indices = torch.topk(probabilities, top_k, dim=1)

        batch_results = []
        for i in range(top_probs.shape[0]):
            image_results = []
            for j in range(top_k):
                class_idx = top_indices[i, j].item()
                prob = top_probs[i, j].item()
                class_name = (
                    self.class_names[class_idx]
                    if class_idx < len(self.class_names)
                    else f"class_{class_idx}"
                )
                image_results.append((class_idx, class_name, prob))
            batch_results.append(image_results)
        return batch_results


def get_class_logit_batch(
self, input_batch: torch.Tensor, target_class_idx: int
) -> torch.Tensor:
"""Get logits for a batch of images - optimized for batched inference (directly from model outputs)."""
with torch.no_grad():
outputs = self.model(input_batch)
return outputs[:, target_class_idx]
Comment on lines +38 to +44
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_class_logit_batch() has the same potential device mismatch as get_predictions() (input may not be on the model’s device). Consider moving/validating input_batch to self.device here as well, so callers don’t have to remember per-method requirements.

Copilot uses AI. Check for mistakes.