Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/torch_measure/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch_measure.models.beta_twopl import BetaTwoPL
from torch_measure.models.bifactor import Bifactor
from torch_measure.models.bradley_terry import BradleyTerry
from torch_measure.models.doubly_robust import DoublyRobustModel
from torch_measure.models.ggm import GaussianGraphicalModel
from torch_measure.models.ising import IsingModel
from torch_measure.models.llm_judge import LLMJudge
Expand Down Expand Up @@ -50,4 +51,5 @@
"bifactor_rotation",
"NCF",
"LLMJudge",
"DoublyRobustModel",
]
179 changes: 179 additions & 0 deletions src/torch_measure/models/doubly_robust.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright (c) 2026 AIMS Foundations. MIT License.

"""Doubly robust predictor: learns a bias-correction on top of a frozen base model.

The correction is trained with inverse-propensity-weighted (IPW) loss so that
the combined predictor remains consistent under informative missingness (MNAR)
in sparse benchmark matrices.
"""

from __future__ import annotations

import numpy as np
import torch
from sklearn.linear_model import LogisticRegression
from torch import nn

from torch_measure.models._base import IRTModel


class DoublyRobustModel(IRTModel):
"""Residual IRT model trained with propensity-weighted loss.

Wraps a pre-trained base model and learns an additive correction:

P(correct | i, j) = clamp( base(i, j) + correction(i, j) )

where ``correction(i, j) = sigmoid(alpha_i - beta_j) - 0.5``, a centered
residual Rasch layer. During fitting, the loss for each observed cell is
weighted by ``1 / e(i, j)`` where ``e`` is the estimated propensity
(probability of observation), making the estimator consistent even when
missingness depends on unobserved outcomes.

Parameters
----------
base_model : IRTModel
A fitted IRT model whose parameters will be frozen.
clip_propensity : tuple[float, float]
Clamp range for propensity scores to avoid extreme weights.
"""

def __init__(
self,
base_model: IRTModel,
clip_propensity: tuple[float, float] = (0.05, 0.95),
) -> None:
n_subjects = base_model.n_subjects
n_items = base_model.n_items
super().__init__(n_subjects, n_items, device=str(base_model.device))

self._base = base_model
for p in self._base.parameters():
p.requires_grad_(False)

self._clip_propensity = clip_propensity

self.correction_ability = nn.Parameter(torch.zeros(n_subjects, device=self._device))
self.correction_difficulty = nn.Parameter(torch.zeros(n_items, device=self._device))

self._propensity_weights: torch.Tensor | None = None

def predict(self, query: dict[str, torch.Tensor]) -> torch.Tensor:
"""P(correct) = clamp(base + correction)."""
s = query["subject_idx"]
i = query["item_idx"]

base_prob = self._base.predict(query).detach()
correction = torch.sigmoid(self.correction_ability[s] - self.correction_difficulty[i]) - 0.5

return (base_prob + correction).clamp(1e-7, 1 - 1e-7)

def fit(
self,
data: torch.Tensor,
mask: torch.Tensor | None = None,
method: str = "mle",
max_epochs: int = 500,
lr: float = 0.01,
verbose: bool = True,
**kwargs,
) -> dict:
"""Fit the correction layer with IPW-weighted loss.

Before running the optimizer, estimates propensity scores from the
observation pattern via logistic regression, then passes per-observation
weights (1/propensity) into the fitting loop.

Parameters
----------
data : torch.Tensor
Wide-form response matrix (n_subjects, n_items). NaN = unobserved.
mask : torch.Tensor | None
Boolean observation mask. Inferred from NaN if None.
method : str
Fitting backend (default ``"mle"``).
max_epochs : int
Optimization epochs for the correction layer.
lr : float
Learning rate.
verbose : bool
Show progress bar.

Returns
-------
dict
Training history.
"""
if mask is None:
mask = ~torch.isnan(data)

self._estimate_propensity(data, mask)

subject_idx, item_idx, response = self._normalize_fit_inputs(data, mask)

weights = self._get_observation_weights(subject_idx, item_idx)

def ipw_loss(predicted_probs: torch.Tensor, observed: torch.Tensor) -> torch.Tensor:
per_obs_nll = -observed * torch.log(predicted_probs) - (1 - observed) * torch.log(1 - predicted_probs)
return (per_obs_nll * weights).mean()

from torch_measure.fitting.mle import mle_fit

return mle_fit(
self,
subject_idx,
item_idx,
response,
max_epochs=max_epochs,
lr=lr,
verbose=verbose,
loss_fn=ipw_loss,
**kwargs,
)

def _estimate_propensity(self, data: torch.Tensor, mask: torch.Tensor) -> None:
"""Fit a logistic regression on observation indicators."""
n_s, n_i = data.shape
obs = mask.float()

row_rate = obs.mean(dim=1)
col_rate = obs.mean(dim=0)

features = torch.stack(
[
row_rate.repeat_interleave(n_i),
col_rate.repeat(n_s),
],
dim=1,
).numpy()

if hasattr(self._base, "ability") and hasattr(self._base, "difficulty"):
ability = self._base.ability.detach().cpu()
difficulty = self._base.difficulty.detach().cpu()
features = np.hstack(
[
features,
ability.repeat_interleave(n_i).numpy()[:, None],
difficulty.repeat(n_s).numpy()[:, None],
]
)

y = mask.reshape(-1).numpy().astype(np.int32)

if y.all() or not y.any():
self._propensity_weights = torch.ones(n_s, n_i, device=self._device)
return

lr = LogisticRegression(max_iter=1000, solver="lbfgs", random_state=0)
lr.fit(features, y)
prop_flat = lr.predict_proba(features)[:, 1]
propensity = torch.from_numpy(prop_flat).float().reshape(n_s, n_i)
propensity = propensity.clamp(self._clip_propensity[0], self._clip_propensity[1])

self._propensity_weights = (1.0 / propensity).to(self._device)

def _get_observation_weights(self, subject_idx: torch.Tensor, item_idx: torch.Tensor) -> torch.Tensor:
"""Look up per-observation IPW weights."""
if self._propensity_weights is None:
return torch.ones(subject_idx.shape[0], device=self._device)
return self._propensity_weights[subject_idx, item_idx]
160 changes: 160 additions & 0 deletions tests/test_models/test_doubly_robust.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) 2026 AIMS Foundations. MIT License.

import torch

from torch_measure.models import DoublyRobustModel, Rasch
from torch_measure.models._predictor import predict_dense


def _make_sparse_rasch(
n_subjects: int = 40,
n_items: int = 30,
obs_rate: float = 0.6,
seed: int = 0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate a Rasch response matrix with MNAR missingness."""
rng = torch.Generator().manual_seed(seed)
ability = torch.randn(n_subjects, generator=rng)
difficulty = torch.randn(n_items, generator=rng)
logits = ability.unsqueeze(1) - difficulty.unsqueeze(0)
probs = torch.sigmoid(logits)
data_full = torch.bernoulli(probs, generator=rng)

# MNAR: high-ability subjects less likely observed on easy items
obs_logits = -0.4 * ability.unsqueeze(1) + 0.4 * difficulty.unsqueeze(0)
obs_probs = torch.sigmoid(obs_logits).clamp(0.2, 0.95)
obs_mask = torch.bernoulli(obs_probs, generator=rng).bool()
# guarantee at least one obs per subject
for i in range(n_subjects):
if not obs_mask[i].any():
obs_mask[i, 0] = True

data_sparse = data_full.clone()
data_sparse[~obs_mask] = float("nan")
return data_sparse, data_full, ability, difficulty


class TestDoublyRobustModel:
def test_init_freezes_base(self):
base = Rasch(n_subjects=10, n_items=20)
dr = DoublyRobustModel(base)
for p in dr._base.parameters():
assert not p.requires_grad

def test_init_correction_shape(self):
base = Rasch(n_subjects=15, n_items=25)
dr = DoublyRobustModel(base)
assert dr.correction_ability.shape == (15,)
assert dr.correction_difficulty.shape == (25,)
assert dr.n_subjects == 15
assert dr.n_items == 25

def test_predict_shape_and_range(self):
base = Rasch(n_subjects=10, n_items=20)
dr = DoublyRobustModel(base)
probs = predict_dense(dr)
assert probs.shape == (10, 20)
assert (probs > 0).all()
assert (probs < 1).all()

def test_zero_correction_matches_base(self):
base = Rasch(n_subjects=10, n_items=20)
dr = DoublyRobustModel(base)
# correction params initialized to zero → correction = sigmoid(0) - 0.5 = 0
base_probs = predict_dense(base)
dr_probs = predict_dense(dr)
torch.testing.assert_close(dr_probs, base_probs, atol=1e-6, rtol=0)

def test_fit_reduces_loss(self):
data_sparse, _, _, _ = _make_sparse_rasch(30, 20, seed=5)
base = Rasch(n_subjects=30, n_items=20)
base.fit(data_sparse, max_epochs=50, verbose=False)

dr = DoublyRobustModel(base)
history = dr.fit(data_sparse, max_epochs=50, verbose=False)
assert len(history["losses"]) > 1
assert history["losses"][-1] < history["losses"][0]

def test_fit_changes_correction_params(self):
data_sparse, _, _, _ = _make_sparse_rasch(30, 20, seed=7)
base = Rasch(n_subjects=30, n_items=20)
base.fit(data_sparse, max_epochs=50, verbose=False)

dr = DoublyRobustModel(base)
before_ability = dr.correction_ability.detach().clone()
dr.fit(data_sparse, max_epochs=50, verbose=False)
assert not torch.allclose(dr.correction_ability, before_ability)

def test_base_params_unchanged_after_fit(self):
data_sparse, _, _, _ = _make_sparse_rasch(30, 20, seed=9)
base = Rasch(n_subjects=30, n_items=20)
base.fit(data_sparse, max_epochs=50, verbose=False)

ability_before = base.ability.detach().clone()
difficulty_before = base.difficulty.detach().clone()

dr = DoublyRobustModel(base)
dr.fit(data_sparse, max_epochs=50, verbose=False)

torch.testing.assert_close(base.ability, ability_before)
torch.testing.assert_close(base.difficulty, difficulty_before)

def test_improves_prediction_on_sparse_data(self):
"""DR model should predict held-out cells better than base alone."""
torch.manual_seed(42)
data_sparse, data_full, ability, difficulty = _make_sparse_rasch(n_subjects=60, n_items=40, seed=11)

base = Rasch(n_subjects=60, n_items=40)
base.fit(data_sparse, max_epochs=200, verbose=False)

dr = DoublyRobustModel(base)
dr.fit(data_sparse, max_epochs=200, verbose=False)

# Evaluate on all cells
base_preds = predict_dense(base).detach()
dr_preds = predict_dense(dr).detach()

base_mse = ((base_preds - data_full) ** 2).mean().item()
dr_mse = ((dr_preds - data_full) ** 2).mean().item()

# DR should not be substantially worse
assert dr_mse < base_mse + 0.02, f"DR MSE {dr_mse:.4f} much worse than base {base_mse:.4f}"

def test_propensity_clipping(self):
data_sparse, _, _, _ = _make_sparse_rasch(20, 15, seed=13)
base = Rasch(n_subjects=20, n_items=15)
base.fit(data_sparse, max_epochs=30, verbose=False)

dr = DoublyRobustModel(base, clip_propensity=(0.1, 0.9))
dr.fit(data_sparse, max_epochs=30, verbose=False)

# Should not produce NaN/Inf
preds = predict_dense(dr)
assert torch.isfinite(preds).all()

def test_complete_data_correction_near_zero(self):
"""On fully observed data, correction should stay small."""
torch.manual_seed(99)
n_s, n_i = 20, 15
ability = torch.randn(n_s)
difficulty = torch.randn(n_i)
logits = ability.unsqueeze(1) - difficulty.unsqueeze(0)
data = torch.bernoulli(torch.sigmoid(logits))

base = Rasch(n_subjects=n_s, n_items=n_i)
base.fit(data, max_epochs=100, verbose=False)

dr = DoublyRobustModel(base)
dr.fit(data, max_epochs=100, verbose=False)

# Correction params should remain near zero since no missingness bias
assert dr.correction_ability.abs().mean().item() < 0.5
assert dr.correction_difficulty.abs().mean().item() < 0.5

def test_forward_equals_predict(self):
base = Rasch(n_subjects=10, n_items=20)
dr = DoublyRobustModel(base)
from torch_measure.models._predictor import cartesian_query

query = cartesian_query(10, 20)
torch.testing.assert_close(dr(query), dr.predict(query))
Loading
Loading