feat: add ModelPredictor class for model predictions#6
feat: add ModelPredictor class for model predictions#6dhalmazna wants to merge 1 commit intofeat/data-pipelinefrom
Conversation
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request establishes a new Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Pull request overview
This PR introduces a new ciao.model module that provides a ModelPredictor wrapper around a PyTorch model to standardize inference, probability extraction, and top‑k predictions.
Changes:
- Added
ModelPredictorwith helper methods for predictions, top‑k class selection, and extracting per-class logits. - Added
ciao.modelpackage init exportingModelPredictor.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
ciao/model/predictor.py |
Implements the new ModelPredictor inference wrapper. |
ciao/model/__init__.py |
Exposes ModelPredictor as the public API of ciao.model. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| def get_class_logit_batch( | ||
| self, input_batch: torch.Tensor, target_class_idx: int | ||
| ) -> torch.Tensor: | ||
| """Get logits for a batch of images - optimized for batched inference (directly from model outputs).""" | ||
| with torch.no_grad(): | ||
| outputs = self.model(input_batch) | ||
| return outputs[:, target_class_idx] |
There was a problem hiding this comment.
get_class_logit_batch() has the same potential device mismatch as get_predictions() (input may not be on the model’s device). Consider moving/validating input_batch to self.device here as well, so callers don’t have to remember per-method requirements.
| def get_predictions(self, input_batch: torch.Tensor) -> torch.Tensor: | ||
| """Get model predictions (returns probabilities).""" | ||
| with torch.no_grad(): | ||
| outputs = self.model(input_batch) | ||
| probabilities = torch.nn.functional.softmax(outputs, dim=1) |
There was a problem hiding this comment.
For inference wrappers like this, it’s easy to get nondeterministic outputs if the model is left in training mode (dropout/BatchNorm). Consider switching the model to eval() inside the predictor (or clearly documenting that callers must set eval mode) before running predictions.
| results = [] | ||
| for i in range(top_k): | ||
| class_idx = int(top_indices[i].item()) | ||
| prob = float(top_probs[i].item()) | ||
| class_name = ( | ||
| self.class_names[class_idx] | ||
| if class_idx < len(self.class_names) | ||
| else f"class_{class_idx}" | ||
| ) | ||
| results.append((class_idx, class_name, prob)) |
There was a problem hiding this comment.
In predict_image(), calling .item() inside the loop can cause repeated device synchronization when tensors are on GPU. Consider converting top_indices/top_probs to CPU (or Python lists) once before the loop to avoid per-element sync overhead.
| def __init__(self, model: torch.nn.Module, class_names: list[str]) -> None: | ||
| self.model = model | ||
| self.class_names = class_names | ||
| self.device = next(model.parameters()).device |
There was a problem hiding this comment.
next(model.parameters()).device will raise StopIteration for parameterless modules (e.g., wrappers or functional models). Consider handling the empty-parameter case explicitly (fallback device, or derive device from buffers), or avoid storing self.device if it isn’t reliable/used.
| self.device = next(model.parameters()).device | |
| params = list(model.parameters()) | |
| if params: | |
| self.device = params[0].device | |
| else: | |
| buffers = list(model.buffers()) | |
| if buffers: | |
| self.device = buffers[0].device | |
| else: | |
| self.device = torch.device("cpu") |
| def get_predictions(self, input_batch: torch.Tensor) -> torch.Tensor: | ||
| """Get model predictions (returns probabilities).""" | ||
| with torch.no_grad(): | ||
| outputs = self.model(input_batch) | ||
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | ||
| return probabilities |
There was a problem hiding this comment.
get_predictions() forwards input_batch as-is. If the model is on GPU and the caller passes a CPU tensor (or vice versa), this will raise a device mismatch error. Consider moving input_batch to self.device (or validating and raising a clear error) before calling the model.
| def predict_image( | ||
| self, input_batch: torch.Tensor, top_k: int = 5 | ||
| ) -> list[tuple[int, str, float]]: | ||
| """Get top-k predictions for an image.""" |
There was a problem hiding this comment.
predict_image() assumes a batch with at least one element and always uses probabilities[0], silently ignoring the rest of the batch. Either enforce/validate batch_size == 1 (with a clear error) or change the API to accept a single image tensor (and add an unsqueeze(0) path for 3D inputs) to match the method name.
| """Get top-k predictions for an image.""" | |
| """Get top-k predictions for a single image. | |
| Expects `input_batch` to contain exactly one image (batch size 1). | |
| """ | |
| # Validate that the input is a non-empty batch with exactly one element. | |
| if input_batch.dim() < 1: | |
| raise ValueError( | |
| f"predict_image expects a batched tensor with batch size 1; got tensor with shape {tuple(input_batch.shape)}" | |
| ) | |
| if input_batch.size(0) != 1: | |
| raise ValueError( | |
| f"predict_image expects a batch size of 1, but got batch size {input_batch.size(0)}." | |
| ) |
| self, input_batch: torch.Tensor, top_k: int = 5 | ||
| ) -> list[tuple[int, str, float]]: | ||
| """Get top-k predictions for an image.""" | ||
| probabilities = self.get_predictions(input_batch) |
There was a problem hiding this comment.
torch.topk(probabilities[0], top_k) will throw if top_k is greater than the number of classes. Consider clamping top_k to probabilities.shape[1] or raising a ValueError with a helpful message when top_k is out of range.
| probabilities = self.get_predictions(input_batch) | |
| probabilities = self.get_predictions(input_batch) | |
| num_classes = probabilities.shape[1] | |
| if top_k <= 0: | |
| raise ValueError(f"top_k must be a positive integer, got {top_k}.") | |
| if top_k > num_classes: | |
| raise ValueError( | |
| f"top_k ({top_k}) cannot be greater than the number of classes ({num_classes})." | |
| ) |
There was a problem hiding this comment.
Code Review
This pull request introduces a ModelPredictor class, which is a good addition for wrapping model inference logic. My review focuses on improving its robustness and correctness. I've identified a high-severity bug in the predict_image method where it incorrectly handles batch inputs, and I've provided a fix to process batches correctly. Additionally, I've suggested a medium-severity improvement to ensure the model is always in evaluation mode during prediction, which is a crucial best practice.
| def predict_image( | ||
| self, input_batch: torch.Tensor, top_k: int = 5 | ||
| ) -> list[tuple[int, str, float]]: | ||
| """Get top-k predictions for an image.""" | ||
| probabilities = self.get_predictions(input_batch) | ||
| top_probs, top_indices = torch.topk(probabilities[0], top_k) | ||
|
|
||
| results = [] | ||
| for i in range(top_k): | ||
| class_idx = int(top_indices[i].item()) | ||
| prob = float(top_probs[i].item()) | ||
| class_name = ( | ||
| self.class_names[class_idx] | ||
| if class_idx < len(self.class_names) | ||
| else f"class_{class_idx}" | ||
| ) | ||
| results.append((class_idx, class_name, prob)) | ||
| return results |
There was a problem hiding this comment.
The method predict_image is misleading and has a bug. The parameter name input_batch implies it can handle multiple images, but the implementation probabilities[0] only processes the first image in the batch, silently discarding the rest. This will lead to incorrect results when a batch with more than one image is provided.
To fix this, the method should be updated to correctly process the entire batch. I suggest renaming it to predict_batch to make its behavior clear, and modifying the implementation to handle all images in the input tensor. The return type should also be updated to a list of lists, where each inner list contains the predictions for one image.
def predict_batch(
self, input_batch: torch.Tensor, top_k: int = 5
) -> list[list[tuple[int, str, float]]]:
"""Get top-k predictions for a batch of images."""
probabilities = self.get_predictions(input_batch)
top_probs, top_indices = torch.topk(probabilities, top_k, dim=1)
batch_results = []
for i in range(top_probs.shape[0]):
image_results = []
for j in range(top_k):
class_idx = top_indices[i, j].item()
prob = top_probs[i, j].item()
class_name = (
self.class_names[class_idx]
if class_idx < len(self.class_names)
else f"class_{class_idx}"
)
image_results.append((class_idx, class_name, prob))
batch_results.append(image_results)
return batch_results| """Handles model predictions and class information.""" | ||
|
|
||
| def __init__(self, model: torch.nn.Module, class_names: list[str]) -> None: | ||
| self.model = model |
There was a problem hiding this comment.
It's a good practice to set the model to evaluation mode within the predictor's __init__. This ensures that layers like Dropout or BatchNorm behave correctly for inference, making the class more robust and preventing potential inconsistencies in predictions. The caller of this class should not have to remember to call .eval() on the model beforehand.
| self.model = model | |
| self.model = model.eval() |
Context:
This PR introduces the
model/module.What was changed:
model/predictor.py: Added theModelPredictorclass. It serves as a clean wrapper around the PyTorch model, providing standardized methods for inference, probability extraction, and top-k predictions.Related Task:
XAI-29