Skip to content

feat: add ModelPredictor class for model predictions#6

Open
dhalmazna wants to merge 1 commit intofeat/data-pipelinefrom
feat/model-predictor
Open

feat: add ModelPredictor class for model predictions#6
dhalmazna wants to merge 1 commit intofeat/data-pipelinefrom
feat/model-predictor

Conversation

@dhalmazna
Copy link
Collaborator

Context:
This PR introduces the model/ module.

What was changed:

  • model/predictor.py: Added the ModelPredictor class. It serves as a clean wrapper around the PyTorch model, providing standardized methods for inference, probability extraction, and top-k predictions.

Related Task:
XAI-29

@dhalmazna dhalmazna self-assigned this Mar 5, 2026
Copilot AI review requested due to automatic review settings March 5, 2026 10:23
@coderabbitai
Copy link

coderabbitai bot commented Mar 5, 2026

Important

Review skipped

Auto reviews are disabled on base/target branches other than the default branch.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d3da11f9-5943-47b5-9c3d-07324bcbd18e

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feat/model-predictor

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request establishes a new model/ module within the ciao package, aiming to centralize and standardize model prediction functionalities. The core of this change is the introduction of the ModelPredictor class, which encapsulates PyTorch model inference, providing a consistent interface for obtaining predictions, probabilities, and top-k results. This enhances code organization and simplifies interaction with trained models.

Highlights

  • New model/ module introduced: The pull request introduces a new model/ module within the ciao package, designed to house model-related utilities and prediction logic.
  • ModelPredictor class added: A new ModelPredictor class has been added to ciao/model/predictor.py. This class provides a standardized wrapper for PyTorch models, offering methods for inference, probability extraction, and top-k predictions.
Changelog
  • ciao/model/init.py
    • Created the __init__.py file to define the ciao.model package.
    • Exported the ModelPredictor class, allowing it to be imported directly from ciao.model.
  • ciao/model/predictor.py
    • Created the predictor.py file to house the ModelPredictor class.
    • Implemented the ModelPredictor class with an initializer that takes a PyTorch model and class names.
    • Added the get_predictions method to return class probabilities from model outputs.
    • Included the predict_image method to retrieve top-k predictions (class index, name, and probability) for a given input.
    • Provided the get_class_logit_batch method to extract logits for a specific target class from a batch of inputs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a new ciao.model module that provides a ModelPredictor wrapper around a PyTorch model to standardize inference, probability extraction, and top‑k predictions.

Changes:

  • Added ModelPredictor with helper methods for predictions, top‑k class selection, and extracting per-class logits.
  • Added ciao.model package init exporting ModelPredictor.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.

File Description
ciao/model/predictor.py Implements the new ModelPredictor inference wrapper.
ciao/model/__init__.py Exposes ModelPredictor as the public API of ciao.model.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +38 to +44
def get_class_logit_batch(
self, input_batch: torch.Tensor, target_class_idx: int
) -> torch.Tensor:
"""Get logits for a batch of images - optimized for batched inference (directly from model outputs)."""
with torch.no_grad():
outputs = self.model(input_batch)
return outputs[:, target_class_idx]
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_class_logit_batch() has the same potential device mismatch as get_predictions() (input may not be on the model’s device). Consider moving/validating input_batch to self.device here as well, so callers don’t have to remember per-method requirements.

Copilot uses AI. Check for mistakes.
Comment on lines +12 to +16
def get_predictions(self, input_batch: torch.Tensor) -> torch.Tensor:
"""Get model predictions (returns probabilities)."""
with torch.no_grad():
outputs = self.model(input_batch)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For inference wrappers like this, it’s easy to get nondeterministic outputs if the model is left in training mode (dropout/BatchNorm). Consider switching the model to eval() inside the predictor (or clearly documenting that callers must set eval mode) before running predictions.

Copilot uses AI. Check for mistakes.
Comment on lines +26 to +35
results = []
for i in range(top_k):
class_idx = int(top_indices[i].item())
prob = float(top_probs[i].item())
class_name = (
self.class_names[class_idx]
if class_idx < len(self.class_names)
else f"class_{class_idx}"
)
results.append((class_idx, class_name, prob))
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In predict_image(), calling .item() inside the loop can cause repeated device synchronization when tensors are on GPU. Consider converting top_indices/top_probs to CPU (or Python lists) once before the loop to avoid per-element sync overhead.

Copilot uses AI. Check for mistakes.
def __init__(self, model: torch.nn.Module, class_names: list[str]) -> None:
self.model = model
self.class_names = class_names
self.device = next(model.parameters()).device
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

next(model.parameters()).device will raise StopIteration for parameterless modules (e.g., wrappers or functional models). Consider handling the empty-parameter case explicitly (fallback device, or derive device from buffers), or avoid storing self.device if it isn’t reliable/used.

Suggested change
self.device = next(model.parameters()).device
params = list(model.parameters())
if params:
self.device = params[0].device
else:
buffers = list(model.buffers())
if buffers:
self.device = buffers[0].device
else:
self.device = torch.device("cpu")

Copilot uses AI. Check for mistakes.
Comment on lines +12 to +17
def get_predictions(self, input_batch: torch.Tensor) -> torch.Tensor:
"""Get model predictions (returns probabilities)."""
with torch.no_grad():
outputs = self.model(input_batch)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
return probabilities
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_predictions() forwards input_batch as-is. If the model is on GPU and the caller passes a CPU tensor (or vice versa), this will raise a device mismatch error. Consider moving input_batch to self.device (or validating and raising a clear error) before calling the model.

Copilot uses AI. Check for mistakes.
def predict_image(
self, input_batch: torch.Tensor, top_k: int = 5
) -> list[tuple[int, str, float]]:
"""Get top-k predictions for an image."""
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predict_image() assumes a batch with at least one element and always uses probabilities[0], silently ignoring the rest of the batch. Either enforce/validate batch_size == 1 (with a clear error) or change the API to accept a single image tensor (and add an unsqueeze(0) path for 3D inputs) to match the method name.

Suggested change
"""Get top-k predictions for an image."""
"""Get top-k predictions for a single image.
Expects `input_batch` to contain exactly one image (batch size 1).
"""
# Validate that the input is a non-empty batch with exactly one element.
if input_batch.dim() < 1:
raise ValueError(
f"predict_image expects a batched tensor with batch size 1; got tensor with shape {tuple(input_batch.shape)}"
)
if input_batch.size(0) != 1:
raise ValueError(
f"predict_image expects a batch size of 1, but got batch size {input_batch.size(0)}."
)

Copilot uses AI. Check for mistakes.
self, input_batch: torch.Tensor, top_k: int = 5
) -> list[tuple[int, str, float]]:
"""Get top-k predictions for an image."""
probabilities = self.get_predictions(input_batch)
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.topk(probabilities[0], top_k) will throw if top_k is greater than the number of classes. Consider clamping top_k to probabilities.shape[1] or raising a ValueError with a helpful message when top_k is out of range.

Suggested change
probabilities = self.get_predictions(input_batch)
probabilities = self.get_predictions(input_batch)
num_classes = probabilities.shape[1]
if top_k <= 0:
raise ValueError(f"top_k must be a positive integer, got {top_k}.")
if top_k > num_classes:
raise ValueError(
f"top_k ({top_k}) cannot be greater than the number of classes ({num_classes})."
)

Copilot uses AI. Check for mistakes.
Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a ModelPredictor class, which is a good addition for wrapping model inference logic. My review focuses on improving its robustness and correctness. I've identified a high-severity bug in the predict_image method where it incorrectly handles batch inputs, and I've provided a fix to process batches correctly. Additionally, I've suggested a medium-severity improvement to ensure the model is always in evaluation mode during prediction, which is a crucial best practice.

Comment on lines +19 to +36
def predict_image(
self, input_batch: torch.Tensor, top_k: int = 5
) -> list[tuple[int, str, float]]:
"""Get top-k predictions for an image."""
probabilities = self.get_predictions(input_batch)
top_probs, top_indices = torch.topk(probabilities[0], top_k)

results = []
for i in range(top_k):
class_idx = int(top_indices[i].item())
prob = float(top_probs[i].item())
class_name = (
self.class_names[class_idx]
if class_idx < len(self.class_names)
else f"class_{class_idx}"
)
results.append((class_idx, class_name, prob))
return results

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The method predict_image is misleading and has a bug. The parameter name input_batch implies it can handle multiple images, but the implementation probabilities[0] only processes the first image in the batch, silently discarding the rest. This will lead to incorrect results when a batch with more than one image is provided.

To fix this, the method should be updated to correctly process the entire batch. I suggest renaming it to predict_batch to make its behavior clear, and modifying the implementation to handle all images in the input tensor. The return type should also be updated to a list of lists, where each inner list contains the predictions for one image.

    def predict_batch(
        self, input_batch: torch.Tensor, top_k: int = 5
    ) -> list[list[tuple[int, str, float]]]:
        """Get top-k predictions for a batch of images."""
        probabilities = self.get_predictions(input_batch)
        top_probs, top_indices = torch.topk(probabilities, top_k, dim=1)

        batch_results = []
        for i in range(top_probs.shape[0]):
            image_results = []
            for j in range(top_k):
                class_idx = top_indices[i, j].item()
                prob = top_probs[i, j].item()
                class_name = (
                    self.class_names[class_idx]
                    if class_idx < len(self.class_names)
                    else f"class_{class_idx}"
                )
                image_results.append((class_idx, class_name, prob))
            batch_results.append(image_results)
        return batch_results

"""Handles model predictions and class information."""

def __init__(self, model: torch.nn.Module, class_names: list[str]) -> None:
self.model = model

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's a good practice to set the model to evaluation mode within the predictor's __init__. This ensures that layers like Dropout or BatchNorm behave correctly for inference, making the class more robust and preventing potential inconsistencies in predictions. The caller of this class should not have to remember to call .eval() on the model beforehand.

Suggested change
self.model = model
self.model = model.eval()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants