-
Notifications
You must be signed in to change notification settings - Fork 0
feat: add ModelPredictor class for model predictions #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: feat/data-pipeline
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| """Model prediction utilities for CIAO.""" | ||
|
|
||
| from ciao.model.predictor import ModelPredictor | ||
|
|
||
|
|
||
| __all__ = ["ModelPredictor"] |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,44 @@ | ||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| class ModelPredictor: | ||||||||||||||||||||||||||||||
| """Handles model predictions and class information.""" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def __init__(self, model: torch.nn.Module, class_names: list[str]) -> None: | ||||||||||||||||||||||||||||||
| self.model = model | ||||||||||||||||||||||||||||||
| self.class_names = class_names | ||||||||||||||||||||||||||||||
| self.device = next(model.parameters()).device | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
| self.device = next(model.parameters()).device | |
| params = list(model.parameters()) | |
| if params: | |
| self.device = params[0].device | |
| else: | |
| buffers = list(model.buffers()) | |
| if buffers: | |
| self.device = buffers[0].device | |
| else: | |
| self.device = torch.device("cpu") |
Copilot
AI
Mar 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For inference wrappers like this, it’s easy to get nondeterministic outputs if the model is left in training mode (dropout/BatchNorm). Consider switching the model to eval() inside the predictor (or clearly documenting that callers must set eval mode) before running predictions.
Copilot
AI
Mar 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_predictions() forwards input_batch as-is. If the model is on GPU and the caller passes a CPU tensor (or vice versa), this will raise a device mismatch error. Consider moving input_batch to self.device (or validating and raising a clear error) before calling the model.
Copilot
AI
Mar 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
predict_image() assumes a batch with at least one element and always uses probabilities[0], silently ignoring the rest of the batch. Either enforce/validate batch_size == 1 (with a clear error) or change the API to accept a single image tensor (and add an unsqueeze(0) path for 3D inputs) to match the method name.
| """Get top-k predictions for an image.""" | |
| """Get top-k predictions for a single image. | |
| Expects `input_batch` to contain exactly one image (batch size 1). | |
| """ | |
| # Validate that the input is a non-empty batch with exactly one element. | |
| if input_batch.dim() < 1: | |
| raise ValueError( | |
| f"predict_image expects a batched tensor with batch size 1; got tensor with shape {tuple(input_batch.shape)}" | |
| ) | |
| if input_batch.size(0) != 1: | |
| raise ValueError( | |
| f"predict_image expects a batch size of 1, but got batch size {input_batch.size(0)}." | |
| ) |
Copilot
AI
Mar 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.topk(probabilities[0], top_k) will throw if top_k is greater than the number of classes. Consider clamping top_k to probabilities.shape[1] or raising a ValueError with a helpful message when top_k is out of range.
| probabilities = self.get_predictions(input_batch) | |
| probabilities = self.get_predictions(input_batch) | |
| num_classes = probabilities.shape[1] | |
| if top_k <= 0: | |
| raise ValueError(f"top_k must be a positive integer, got {top_k}.") | |
| if top_k > num_classes: | |
| raise ValueError( | |
| f"top_k ({top_k}) cannot be greater than the number of classes ({num_classes})." | |
| ) |
Copilot
AI
Mar 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In predict_image(), calling .item() inside the loop can cause repeated device synchronization when tensors are on GPU. Consider converting top_indices/top_probs to CPU (or Python lists) once before the loop to avoid per-element sync overhead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method predict_image is misleading and has a bug. The parameter name input_batch implies it can handle multiple images, but the implementation probabilities[0] only processes the first image in the batch, silently discarding the rest. This will lead to incorrect results when a batch with more than one image is provided.
To fix this, the method should be updated to correctly process the entire batch. I suggest renaming it to predict_batch to make its behavior clear, and modifying the implementation to handle all images in the input tensor. The return type should also be updated to a list of lists, where each inner list contains the predictions for one image.
def predict_batch(
self, input_batch: torch.Tensor, top_k: int = 5
) -> list[list[tuple[int, str, float]]]:
"""Get top-k predictions for a batch of images."""
probabilities = self.get_predictions(input_batch)
top_probs, top_indices = torch.topk(probabilities, top_k, dim=1)
batch_results = []
for i in range(top_probs.shape[0]):
image_results = []
for j in range(top_k):
class_idx = top_indices[i, j].item()
prob = top_probs[i, j].item()
class_name = (
self.class_names[class_idx]
if class_idx < len(self.class_names)
else f"class_{class_idx}"
)
image_results.append((class_idx, class_name, prob))
batch_results.append(image_results)
return batch_results
Copilot
AI
Mar 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_class_logit_batch() has the same potential device mismatch as get_predictions() (input may not be on the model’s device). Consider moving/validating input_batch to self.device here as well, so callers don’t have to remember per-method requirements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a good practice to set the model to evaluation mode within the predictor's
__init__. This ensures that layers likeDropoutorBatchNormbehave correctly for inference, making the class more robust and preventing potential inconsistencies in predictions. The caller of this class should not have to remember to call.eval()on the model beforehand.