Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
strict = True
ignore_missing_imports = True
disallow_untyped_calls = False
disable_error_code = no-any-return
disable_error_code = no-any-return
explicit_package_bases = True
136 changes: 116 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,29 +1,125 @@
# Python Project Template
# CIAO: Contextual Importance Assessment via Obfuscation

This project template serves as a robust foundation for Python projects, promoting best practices and streamlining development workflows. It comes pre-configured with essential tools and features to enhance the development experience.
An implementation of explainable AI techniques for image classification. CIAO identifies influential image regions by systematically segmenting images, obfuscating segments, and using search algorithms to find important regions (hyperpixels).

## Tools Included
## Overview

- [uv](https://docs.astral.sh/uv/) for efficient dependency management.
- [Ruff](https://docs.astral.sh/ruff) for comprehensive linting and code formatting.
- [Pytest](https://docs.pytest.org) for running tests and ensuring code reliability.
- [GitLab CI/CD](https://docs.gitlab.com/ee/ci) for continuous integration.
- [Pydocstyle](https://www.pydocstyle.org) for validating docstring styles, also following the [Google style](https://google.github.io/styleguide/pyguide.html#s3.8-comments-and-docstrings).
- [Mypy](https://mypy-lang.org) for static type checking.
CIAO explains what regions of an image contribute to a neural network's classification decisions. The method:

1. Segments the image into small regions
2. Obfuscates each segment and measures impact on model predictions
3. Uses search algorithms to group adjacent important segments into hyperpixels
4. Generates explanations showing which regions influenced the prediction

## Usage
## Quick Start

Key commands for effective project management:
### Installation

- `uv sync` - Installs all project dependencies.
- `uv add <package>` - Adds a new dependency to the project.
- `uv run ruff check` - Runs linting.
- `uv run ruff format` - Runs formatting
- `uv run mypy .` - Runs mypy.
- `uv run pytest tests` - Executes tests located in the tests directory.
- `uv run <command>` - Runs the specified command within the virtual environment.
```bash
# Clone the repository
git clone https://github.com/RationAI/ciao.git
cd ciao

## CI/CD
# Install dependencies using uv
uv sync
```

The project uses our [GitLab CI/CD templates](https://gitlab.ics.muni.cz/rationai/digital-pathology/templates/ci-templates) to automate the linting and testing processes. The pipeline is triggered on every merge request and push to the default branch.
### Basic Usage

Explain a single image with default settings:

```bash
uv run ciao
```

Customize the explanation using Hydra configuration overrides:

```bash
uv run ciao data.image_path=./my_image.jpg explanation.method=mcts explanation.segment_size=8
```

Alternatively, run as a module:

```bash
uv run python -m ciao
```

### Development Commands

- `uv sync` - Install all dependencies
- `uv add <package>` - Add a new dependency
- `uv run ruff check` - Run linting
- `uv run ruff format` - Format code
- `uv run mypy .` - Run type checking
- `uv run ciao` - Run CIAO with default configuration
- `uv run pytest tests` - Execute tests

## Method Details

### How CIAO Works

1. **Segmentation**: The input image is divided into small regions (segments) using hexagonal or square grids
2. **Score Calculation**: Each segment is obfuscated (replaced) and the model is queried to measure how much that segment affects the prediction. This gives an importance score to each segment
3. **Hyperpixel Search**: A search algorithm finds groups of adjacent segments with high importance scores, creating "hyperpixels" that represent influential image regions
4. **Explanation**: The top hyperpixels are visualized to show which regions most influenced the model's prediction

### Search Algorithms

- **MCTS (Monte Carlo Tree Search)**: Tree-based search with UCB exploration
- **MC-RAVE**: MCTS with Rapid Action Value Estimation
- **MCGS (Monte Carlo Graph Search)**: Graph-based variant allowing revisiting of states
- **MCGS-RAVE**: MCGS with RAVE enhancements
- **Lookahead**: Greedy search with lookahead using efficient bitset operations
- **Potential**: Potential field-guided sequential search

### Segmentation Methods

- **Hexagonal Grid**: Divides image into hexagonal cells for better spatial coverage
- **Square Grid**: Simple square grid segmentation

### Replacement Methods

- **Mean Color**: Replace masked regions with the image's mean color (normalized)
- **Blur**: Gaussian blur applied to masked regions
- **Interlacing**: Interlaced pattern replacement
- **Solid Color**: Replace with a specified solid color (RGB)

## Proposed project Structure

```
ciao/
├── ciao/ # Main package
│ ├── algorithm/ # Search algorithms and data structures
│ │ ├── mcts.py # Monte Carlo Tree Search
│ │ ├── mcgs.py # Monte Carlo Graph Search
│ │ ├── lookahead_bitset.py # Greedy lookahead with bitsets
│ │ ├── potential.py # Potential-based search
│ │ ├── bitmask_graph.py # Bitset operations for hyperpixels
│ │ ├── nodes.py # Node classes for tree/graph search
│ │ └── search_helpers.py # Shared MCTS/MCGS helper functions
│ ├── data/ # Data loading and preprocessing
│ │ ├── loader.py # Image loaders
│ │ ├── preprocessing.py # Image preprocessing utilities
│ │ └── segmentation.py # Segmentation utilities (hex/square grids)
│ ├── evaluation/ # Scoring and evaluation
│ │ ├── surrogate.py # Surrogate dataset creation and segment scoring
│ │ └── hyperpixel.py # Hyperpixel evaluation and selection
│ ├── explainer/ # Core explainer implementation
│ │ └── ciao_explainer.py # Main CIAO explainer class
│ ├── model/ # Model inference and predictions
│ │ └── predictor.py # ModelPredictor class for inference
│ ├── visualization/ # Visualization tools
│ │ ├── visualization.py # Interactive visualizations
│ │ └── visualize_tree.py # Tree/graph visualization utilities
│ └── __main__.py # CLI entry point
├── configs/ # Hydra configuration files
│ ├── ciao.yaml # Main entry point
│ ├── base.yaml # Base configuration
│ ├── data/ # Data configurations
│ │ └── default.yaml
│ ├── explanation/ # Explanation method configs
│ │ └── ciao_default.yaml # Default CIAO parameters
│ ├── hydra/ # Hydra settings
│ └── logger/ # Logger configurations
└── pyproject.toml # Project metadata and dependencies
```
20 changes: 20 additions & 0 deletions ciao/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Data loading utilities for CIAO."""

from ciao.data.loader import get_image_loader
from ciao.data.preprocessing import load_and_preprocess_image
from ciao.data.replacement import (
calculate_image_mean_color,
get_replacement_image,
plot_image_mean_color,
)
from ciao.data.segmentation import create_segmentation


__all__ = [
"calculate_image_mean_color",
"create_segmentation",
"get_image_loader",
"get_replacement_image",
"load_and_preprocess_image",
"plot_image_mean_color",
]
51 changes: 51 additions & 0 deletions ciao/data/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Simple image path loading utilities."""

from collections.abc import Iterator
from pathlib import Path

from omegaconf import DictConfig


# Supported image formats
IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".webp")


def get_image_loader(config: DictConfig) -> Iterator[Path]:
"""Create image loader based on configuration.

Args:
config: Hydra config object

Returns:
Iterator of Path objects

Raises:
ValueError: If neither image_path nor batch_path is specified
FileNotFoundError: If single image_path does not exist
"""
if config.data.get("image_path"):
# Single image mode - validate file exists
image_path = Path(config.data.image_path)
if not image_path.is_file():
raise FileNotFoundError(
f"image_path must be a valid file, got: {image_path}. "
"Check for typos or incorrect path configuration."
)
yield image_path

elif config.data.get("batch_path"):
# Directory mode - find all images with supported extensions
directory = Path(config.data.batch_path)
if not directory.is_dir():
raise ValueError(
f"batch_path must be a valid directory, got: {directory}. "
"Check for typos or incorrect path configuration."
)

# Single rglob pass with suffix filtering
for path in directory.rglob("*"):
if path.suffix.lower() in IMAGE_EXTENSIONS:
yield path

else:
raise ValueError("Must specify either image_path or batch_path in config")
41 changes: 41 additions & 0 deletions ciao/data/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from pathlib import Path
from typing import cast

import torch
import torchvision.transforms as transforms
from PIL import Image


# ImageNet preprocessing transforms
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)


def load_and_preprocess_image(
image_path: str | Path, device: torch.device | None = None
) -> torch.Tensor:
"""Load and preprocess an image for the model.

Args:
image_path: Path to image file
device: Device to place tensor on (defaults to cuda if available, else cpu)

Returns:
Preprocessed image tensor [3, 224, 224] on specified device
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Use context manager to prevent file descriptor leaks
with Image.open(image_path) as img:
image = img.convert("RGB")
tensor = cast(torch.Tensor, preprocess(image)) # (3, 224, 224)
input_tensor = tensor.to(device)

return input_tensor
119 changes: 119 additions & 0 deletions ciao/data/replacement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Image replacement strategies for masking operations."""

import torch
import torchvision.transforms.functional as TF
from matplotlib import pyplot as plt


# ImageNet normalization constants
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406])
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225])


def calculate_image_mean_color(input_tensor: torch.Tensor) -> torch.Tensor:
"""Calculate image mean color using ImageNet normalization constants.

Args:
input_tensor: Input tensor [3, H, W] or [1, 3, H, W] (ImageNet normalized)

Returns:
Mean color tensor [3, 1, 1] (ImageNet normalized)
"""
device = input_tensor.device

# Add batch dimension if needed
if input_tensor.dim() == 3:
input_tensor = input_tensor.unsqueeze(0)

# Move normalization constants to same device
imagenet_mean = IMAGENET_MEAN.view(1, 3, 1, 1).to(device)
imagenet_std = IMAGENET_STD.view(1, 3, 1, 1).to(device)

# Unnormalize, calculate mean, then re-normalize
unnormalized = (input_tensor * imagenet_std) + imagenet_mean
mean_color = unnormalized.mean(dim=(2, 3), keepdim=True)
normalized_mean = (mean_color - imagenet_mean) / imagenet_std

return normalized_mean.squeeze(0) # Remove batch dimension


def get_replacement_image(
input_tensor: torch.Tensor,
replacement: str = "mean_color",
color: tuple[int, int, int] = (0, 0, 0),
) -> torch.Tensor:
"""Generate replacement image for masking operations.

Args:
input_tensor: Input tensor [3, H, W] (ImageNet normalized)
replacement: Strategy - "mean_color", "interlacing", "blur", or "solid_color"
color: For solid_color mode, RGB tuple (0-255). Defaults to black (0, 0, 0)

Returns:
replacement_image: torch tensor [3, H, W] on same device
"""
device = input_tensor.device

# Extract spatial dimensions from input tensor
_, height, width = input_tensor.shape

if replacement == "mean_color":
# Fill entire image with mean color
mean_color = calculate_image_mean_color(input_tensor) # [3, 1, 1]
replacement_image = mean_color.expand(-1, height, width) # [3, H, W]

elif replacement == "interlacing":
# Create interlaced pattern: even columns flipped vertically, then even rows flipped horizontally
replacement_image = input_tensor.clone()
even_row_indices = torch.arange(0, height, 2) # Even row indices
even_col_indices = torch.arange(0, width, 2) # Even column indices

# Step 1: Flip even columns vertically (upside down)
replacement_image[:, :, even_col_indices] = torch.flip(
replacement_image[:, :, even_col_indices], dims=[1]
)

# Step 2: Flip even rows horizontally (left-right)
replacement_image[:, even_row_indices, :] = torch.flip(
replacement_image[:, even_row_indices, :], dims=[2]
)

elif replacement == "blur":
# Apply Gaussian blur using torchvision functional API
input_batch = input_tensor.unsqueeze(0) # [1, 3, H, W]
replacement_image = TF.gaussian_blur(
input_batch, kernel_size=[7, 7], sigma=[1.5, 1.5]
).squeeze(0) # [3, H, W]

elif replacement == "solid_color":
# Fill with specified solid color (expects RGB values in 0-255 range)
# Convert color to torch tensor
color_tensor = torch.tensor(color, dtype=torch.float32, device=device)

# Convert from 0-255 range to 0-1 range
color_tensor = color_tensor / 255.0

# Apply ImageNet normalization
mean = IMAGENET_MEAN.view(3, 1, 1).to(device)
std = IMAGENET_STD.view(3, 1, 1).to(device)
normalized_color = (color_tensor.view(3, 1, 1) - mean) / std
replacement_image = normalized_color.expand(-1, height, width) # [3, H, W]

else:
raise ValueError(f"Unknown replacement strategy: {replacement}")

return replacement_image


def plot_image_mean_color(input_tensor: torch.Tensor) -> None:
"""Display the mean color of the image.

Args:
input_tensor: Input tensor [3, H, W] (ImageNet normalized)

Note:
The visualization shows the normalized tensor (ImageNet normalization).
"""
normalized_mean = calculate_image_mean_color(input_tensor).unsqueeze(0)
plt.imshow(normalized_mean[0].permute(1, 2, 0))
plt.show()
Loading