Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions DENOISING_DIFFUSION/src/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .metrics import mse, psnr, ssim, compare_denoisers

__all__ = ["mse", "psnr", "ssim", "compare_denoisers"]
105 changes: 105 additions & 0 deletions DENOISING_DIFFUSION/src/utils/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Image quality metrics and baseline denoiser comparison utility."""

from __future__ import annotations

from typing import Dict

import numpy as np
from skimage.metrics import structural_similarity


def mse(pred: np.ndarray, target: np.ndarray) -> float:
"""Mean Squared Error between pred and target."""
return float(np.mean((pred.astype(np.float64) - target.astype(np.float64)) ** 2))


def psnr(pred: np.ndarray, target: np.ndarray, data_range: float = 1.0) -> float:
"""
Peak Signal-to-Noise Ratio in dB.

Args:
pred: Predicted image, values in [0, data_range]
target: Ground-truth image, values in [0, data_range]
data_range: Value range of the images (default 1.0)

Returns:
PSNR in dB, or inf if pred == target exactly.
"""
err = mse(pred, target)
if err == 0.0:
return float("inf")
return float(10.0 * np.log10((data_range ** 2) / err))


def ssim(pred: np.ndarray, target: np.ndarray, data_range: float = 1.0) -> float:
"""
Structural Similarity Index (SSIM).

Args:
pred: Predicted image
target: Ground-truth image
data_range: Value range of the images (default 1.0)

Returns:
SSIM score in [-1, 1].
"""
p = pred.astype(np.float64)
t = target.astype(np.float64)

# skimage expects (H, W) or (H, W, C)
if p.ndim == 4: # (B, C, H, W) -> average over batch
scores = [
structural_similarity(
p[i].transpose(1, 2, 0),
t[i].transpose(1, 2, 0),
data_range=data_range,
channel_axis=-1,
)
for i in range(p.shape[0])
]
return float(np.mean(scores))

if p.ndim == 3: # (C, H, W) -> (H, W, C)
p = p.transpose(1, 2, 0)
t = t.transpose(1, 2, 0)
return float(structural_similarity(p, t, data_range=data_range, channel_axis=-1))

# (H, W)
return float(structural_similarity(p, t, data_range=data_range))


def compare_denoisers(
noisy: np.ndarray,
target: np.ndarray,
outputs: Dict[str, np.ndarray],
data_range: float = 1.0,
) -> Dict[str, Dict[str, float]]:
"""
Compare multiple denoised outputs against a clean target.

Args:
noisy: Noisy input image (used as the baseline entry "noisy")
target: Clean ground-truth image
outputs: Mapping of denoiser name -> denoised image
data_range: Value range of the images (default 1.0)

Returns:
Dict mapping each name (plus "noisy" baseline) to
{"psnr": float, "ssim": float, "mse": float}.

Example:
>>> results = compare_denoisers(noisy, clean, {
... "gaussian": gaussian_filtered,
... "ddpm": ddpm_output,
... })
>>> print(results["ddpm"]["psnr"])
"""
all_outputs = {"noisy": noisy, **outputs}
return {
name: {
"psnr": psnr(img, target, data_range=data_range),
"ssim": ssim(img, target, data_range=data_range),
"mse": mse(img, target),
}
for name, img in all_outputs.items()
}
188 changes: 188 additions & 0 deletions DENOISING_DIFFUSION/tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Tests for metrics: mse, psnr, ssim, compare_denoisers."""

import math

import numpy as np
import pytest

from src.utils.metrics import mse, psnr, ssim, compare_denoisers


# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------

RNG = np.random.default_rng(42)

def clean_hw():
return RNG.random((64, 64), dtype=np.float32)

def noisy_hw(img, sigma=0.1):
return np.clip(img + RNG.normal(0, sigma, img.shape).astype(np.float32), 0, 1)


# ---------------------------------------------------------------------------
# MSE
# ---------------------------------------------------------------------------

def test_mse_identical_images_is_zero():
img = clean_hw()
assert mse(img, img) == 0.0


def test_mse_positive_for_different_images():
img = clean_hw()
assert mse(noisy_hw(img), img) > 0.0


def test_mse_symmetric():
a = clean_hw()
b = noisy_hw(a)
assert math.isclose(mse(a, b), mse(b, a), rel_tol=1e-6)


def test_mse_returns_float():
img = clean_hw()
assert isinstance(mse(img, img), float)


def test_mse_known_value():
a = np.zeros((4, 4), dtype=np.float32)
b = np.ones((4, 4), dtype=np.float32)
assert math.isclose(mse(a, b), 1.0)


# ---------------------------------------------------------------------------
# PSNR
# ---------------------------------------------------------------------------

def test_psnr_identical_images_is_inf():
img = clean_hw()
assert math.isinf(psnr(img, img))


def test_psnr_noisy_is_finite_positive():
img = clean_hw()
val = psnr(noisy_hw(img), img)
assert math.isfinite(val) and val > 0


def test_psnr_higher_for_less_noise():
img = clean_hw()
low_noise = psnr(noisy_hw(img, sigma=0.01), img)
high_noise = psnr(noisy_hw(img, sigma=0.2), img)
assert low_noise > high_noise


def test_psnr_returns_float():
img = clean_hw()
assert isinstance(psnr(noisy_hw(img), img), float)


def test_psnr_data_range_affects_result():
img = clean_hw()
noisy = noisy_hw(img)
assert psnr(noisy, img, data_range=1.0) != psnr(noisy, img, data_range=255.0)


def test_psnr_chw_input():
img = RNG.random((1, 32, 32), dtype=np.float32)
noisy = np.clip(img + 0.05, 0, 1).astype(np.float32)
val = psnr(noisy, img)
assert math.isfinite(val)


# ---------------------------------------------------------------------------
# SSIM
# ---------------------------------------------------------------------------

def test_ssim_identical_images_is_one():
img = clean_hw()
assert math.isclose(ssim(img, img), 1.0, abs_tol=1e-5)


def test_ssim_noisy_less_than_one():
img = clean_hw()
assert ssim(noisy_hw(img), img) < 1.0


def test_ssim_in_valid_range():
img = clean_hw()
val = ssim(noisy_hw(img), img)
assert -1.0 <= val <= 1.0


def test_ssim_returns_float():
img = clean_hw()
assert isinstance(ssim(img, img), float)


def test_ssim_higher_for_less_noise():
img = clean_hw()
low = ssim(noisy_hw(img, sigma=0.01), img)
high = ssim(noisy_hw(img, sigma=0.3), img)
assert low > high


def test_ssim_chw_input():
img = RNG.random((1, 32, 32), dtype=np.float32)
val = ssim(img, img)
assert math.isclose(val, 1.0, abs_tol=1e-5)


def test_ssim_bchw_input():
img = RNG.random((2, 1, 32, 32), dtype=np.float32)
val = ssim(img, img)
assert math.isclose(val, 1.0, abs_tol=1e-5)


# ---------------------------------------------------------------------------
# compare_denoisers
# ---------------------------------------------------------------------------

def test_compare_denoisers_returns_dict():
img = clean_hw()
noisy = noisy_hw(img)
result = compare_denoisers(noisy, img, {"method_a": noisy_hw(img, 0.05)})
assert isinstance(result, dict)


def test_compare_denoisers_includes_noisy_baseline():
img = clean_hw()
noisy = noisy_hw(img)
result = compare_denoisers(noisy, img, {})
assert "noisy" in result


def test_compare_denoisers_includes_all_keys():
img = clean_hw()
noisy = noisy_hw(img)
result = compare_denoisers(noisy, img, {"gaussian": noisy_hw(img, 0.05), "ddpm": noisy_hw(img, 0.02)})
assert "gaussian" in result and "ddpm" in result


def test_compare_denoisers_metric_keys():
img = clean_hw()
noisy = noisy_hw(img)
result = compare_denoisers(noisy, img, {"a": noisy_hw(img)})
for entry in result.values():
assert set(entry.keys()) == {"psnr", "ssim", "mse"}


def test_compare_denoisers_perfect_output():
img = clean_hw()
noisy = noisy_hw(img)
result = compare_denoisers(noisy, img, {"perfect": img})
assert math.isinf(result["perfect"]["psnr"])
assert math.isclose(result["perfect"]["ssim"], 1.0, abs_tol=1e-5)
assert result["perfect"]["mse"] == 0.0


def test_compare_denoisers_better_method_higher_psnr():
img = clean_hw()
noisy = noisy_hw(img, sigma=0.2)
result = compare_denoisers(noisy, img, {
"weak": noisy_hw(img, sigma=0.15),
"strong": noisy_hw(img, sigma=0.01),
})
assert result["strong"]["psnr"] > result["weak"]["psnr"]