Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 144 additions & 1 deletion dte_adj/stratified.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import numpy as np
from typing import Tuple, Any
from typing import Optional, Tuple, Any
from copy import deepcopy
from scipy.stats import norm
from tqdm.auto import tqdm
from dte_adj.base import DistributionEstimatorBase
from dte_adj.util import ArrayLike, _convert_to_ndarray
Expand Down Expand Up @@ -153,6 +154,77 @@ def _compute_interval_probability(
conditional_prediction[:, 1:] - conditional_prediction[:, :-1],
)

def predict_qte(
self,
target_treatment_arm: int,
control_treatment_arm: int,
quantiles: Optional[np.ndarray] = None,
alpha: float = 0.05,
n_bootstrap=500,
display_progress: bool = True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute Quantile Treatment Effects (QTE) using stratified bootstrap.

Uses stratified bootstrap (resampling independently within each stratum) to
correctly estimate variance under covariate adaptive randomization (CAR).

Args:
target_treatment_arm (int): The index of the treatment arm of the treatment group.
control_treatment_arm (int): The index of the treatment arm of the control group.
quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9].
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
display_progress (bool, optional): Whether to display a progress bar. Defaults to True.

Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
- Expected QTEs (np.ndarray): Treatment effect estimates at each quantile
- Lower bounds (np.ndarray): Lower confidence interval bounds
- Upper bounds (np.ndarray): Upper confidence interval bounds
"""
qte = self._compute_qtes(
target_treatment_arm,
control_treatment_arm,
quantiles,
self.covariates,
self.treatment_arms,
self.outcomes,
self.strata,
)
Comment on lines +161 to +194
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quantiles is documented as optional with a default ([0.1, …, 0.9]) but it’s passed straight into _compute_qtes. If the caller leaves quantiles=None, _compute_qtes will error when accessing quantiles.shape. Please set a default array when quantiles is None (and ideally validate they’re in (0,1)).

Copilot uses AI. Check for mistakes.

# Precompute stratum indices for stratified bootstrap
unique_strata = np.unique(self.strata)
strata_indices = {s: np.where(self.strata == s)[0] for s in unique_strata}

qtes = np.zeros((n_bootstrap, qte.shape[0]))
bootstrap_iter = range(n_bootstrap)
if display_progress:
bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE")
for b in bootstrap_iter:
# Stratified bootstrap: resample within each stratum independently
bootstrap_indexes = np.concatenate([
np.random.choice(idx, size=len(idx), replace=True)
for idx in strata_indices.values()
])
Comment on lines +196 to +209
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stratified-bootstrap implementation here is duplicated verbatim in both stratified estimator classes. Consider extracting it into a shared private helper (or into DistributionEstimatorBase) to avoid divergence/bugs when one implementation is updated and the other isn’t.

Copilot uses AI. Check for mistakes.

qtes[b] = self._compute_qtes(
target_treatment_arm,
control_treatment_arm,
quantiles,
self.covariates[bootstrap_indexes],
self.treatment_arms[bootstrap_indexes],
self.outcomes[bootstrap_indexes],
self.strata[bootstrap_indexes],
)
Comment on lines +205 to +219
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR changes QTE variance estimation under CAR by using stratified bootstrap, but the existing unit tests only assert shapes and basic ordering. Please add a test that would fail if bootstrapping were not stratified (e.g., assert each bootstrap replicate preserves per-stratum sample counts, or compare variance vs an unstratified bootstrap on an imbalanced-strata synthetic dataset).

Copilot uses AI. Check for mistakes.

qte_var = qtes.var(axis=0)

qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var)
qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var)

return qte, qte_lower, qte_upper


class AdjustedStratifiedDistributionEstimator(DistributionEstimatorBase):
"""A class is for estimating the adjusted distribution function and computing the Distributional parameters for CAR."""
Expand Down Expand Up @@ -405,6 +477,77 @@ def _compute_interval_probability(

return prediction.mean(axis=0), prediction, superset_prediction

def predict_qte(
self,
target_treatment_arm: int,
control_treatment_arm: int,
quantiles: Optional[np.ndarray] = None,
alpha: float = 0.05,
n_bootstrap=500,
display_progress: bool = True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute Quantile Treatment Effects (QTE) using stratified bootstrap.

Uses stratified bootstrap (resampling independently within each stratum) to
correctly estimate variance under covariate adaptive randomization (CAR).

Args:
target_treatment_arm (int): The index of the treatment arm of the treatment group.
control_treatment_arm (int): The index of the treatment arm of the control group.
quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9].
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
display_progress (bool, optional): Whether to display a progress bar. Defaults to True.

Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
- Expected QTEs (np.ndarray): Treatment effect estimates at each quantile
- Lower bounds (np.ndarray): Lower confidence interval bounds
- Upper bounds (np.ndarray): Upper confidence interval bounds
"""
qte = self._compute_qtes(
target_treatment_arm,
control_treatment_arm,
quantiles,
self.covariates,
self.treatment_arms,
self.outcomes,
self.strata,
)
Comment on lines +484 to +517
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above: quantiles is optional in the signature/docs but is passed directly into _compute_qtes, which expects an ndarray and will break on None. Please initialize the default quantile grid when quantiles is None (and validate range/order).

Copilot uses AI. Check for mistakes.
Comment on lines +509 to +517
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AdjustedStratifiedDistributionEstimator._compute_cumulative_distribution draws a fresh random fold assignment on each call (folds = np.random.randint(...)). Because predict_qte calls _compute_qtes many times inside the bootstrap, the resulting CI will include extra Monte Carlo noise from re-randomizing folds, not just resampling variability. Consider fixing folds once (e.g., store them at fit time or accept a random_state and reuse a RNG/seed) so bootstrap variance reflects sampling uncertainty only.

Copilot uses AI. Check for mistakes.

# Precompute stratum indices for stratified bootstrap
unique_strata = np.unique(self.strata)
strata_indices = {s: np.where(self.strata == s)[0] for s in unique_strata}

qtes = np.zeros((n_bootstrap, qte.shape[0]))
bootstrap_iter = range(n_bootstrap)
if display_progress:
bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE")
for b in bootstrap_iter:
# Stratified bootstrap: resample within each stratum independently
bootstrap_indexes = np.concatenate([
np.random.choice(idx, size=len(idx), replace=True)
for idx in strata_indices.values()
])

qtes[b] = self._compute_qtes(
target_treatment_arm,
control_treatment_arm,
quantiles,
self.covariates[bootstrap_indexes],
self.treatment_arms[bootstrap_indexes],
self.outcomes[bootstrap_indexes],
self.strata[bootstrap_indexes],
)

qte_var = qtes.var(axis=0)

qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var)
qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var)

return qte, qte_lower, qte_upper

def _compute_model_prediction(self, model, covariates: np.ndarray) -> np.ndarray:
if hasattr(model, "predict_proba"):
if self.is_multi_task:
Expand Down
Loading