-
Notifications
You must be signed in to change notification settings - Fork 2
add QTE support for covariate adaptive randomization #107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,9 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import numpy as np | ||
| from typing import Tuple, Any | ||
| from typing import Optional, Tuple, Any | ||
| from copy import deepcopy | ||
| from scipy.stats import norm | ||
| from tqdm.auto import tqdm | ||
| from dte_adj.base import DistributionEstimatorBase | ||
| from dte_adj.util import ArrayLike, _convert_to_ndarray | ||
|
|
@@ -153,6 +154,77 @@ def _compute_interval_probability( | |
| conditional_prediction[:, 1:] - conditional_prediction[:, :-1], | ||
| ) | ||
|
|
||
| def predict_qte( | ||
| self, | ||
| target_treatment_arm: int, | ||
| control_treatment_arm: int, | ||
| quantiles: Optional[np.ndarray] = None, | ||
| alpha: float = 0.05, | ||
| n_bootstrap=500, | ||
| display_progress: bool = True, | ||
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | ||
| """ | ||
| Compute Quantile Treatment Effects (QTE) using stratified bootstrap. | ||
|
|
||
| Uses stratified bootstrap (resampling independently within each stratum) to | ||
| correctly estimate variance under covariate adaptive randomization (CAR). | ||
|
|
||
| Args: | ||
| target_treatment_arm (int): The index of the treatment arm of the treatment group. | ||
| control_treatment_arm (int): The index of the treatment arm of the control group. | ||
| quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9]. | ||
| alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. | ||
| n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500. | ||
| display_progress (bool, optional): Whether to display a progress bar. Defaults to True. | ||
|
|
||
| Returns: | ||
| Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: | ||
| - Expected QTEs (np.ndarray): Treatment effect estimates at each quantile | ||
| - Lower bounds (np.ndarray): Lower confidence interval bounds | ||
| - Upper bounds (np.ndarray): Upper confidence interval bounds | ||
| """ | ||
| qte = self._compute_qtes( | ||
| target_treatment_arm, | ||
| control_treatment_arm, | ||
| quantiles, | ||
| self.covariates, | ||
| self.treatment_arms, | ||
| self.outcomes, | ||
| self.strata, | ||
| ) | ||
|
|
||
| # Precompute stratum indices for stratified bootstrap | ||
| unique_strata = np.unique(self.strata) | ||
| strata_indices = {s: np.where(self.strata == s)[0] for s in unique_strata} | ||
|
|
||
| qtes = np.zeros((n_bootstrap, qte.shape[0])) | ||
| bootstrap_iter = range(n_bootstrap) | ||
| if display_progress: | ||
| bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE") | ||
| for b in bootstrap_iter: | ||
| # Stratified bootstrap: resample within each stratum independently | ||
| bootstrap_indexes = np.concatenate([ | ||
| np.random.choice(idx, size=len(idx), replace=True) | ||
| for idx in strata_indices.values() | ||
| ]) | ||
|
Comment on lines
+196
to
+209
|
||
|
|
||
| qtes[b] = self._compute_qtes( | ||
| target_treatment_arm, | ||
| control_treatment_arm, | ||
| quantiles, | ||
| self.covariates[bootstrap_indexes], | ||
| self.treatment_arms[bootstrap_indexes], | ||
| self.outcomes[bootstrap_indexes], | ||
| self.strata[bootstrap_indexes], | ||
| ) | ||
|
Comment on lines
+205
to
+219
|
||
|
|
||
| qte_var = qtes.var(axis=0) | ||
|
|
||
| qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var) | ||
| qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var) | ||
|
|
||
| return qte, qte_lower, qte_upper | ||
|
|
||
|
|
||
| class AdjustedStratifiedDistributionEstimator(DistributionEstimatorBase): | ||
| """A class is for estimating the adjusted distribution function and computing the Distributional parameters for CAR.""" | ||
|
|
@@ -405,6 +477,77 @@ def _compute_interval_probability( | |
|
|
||
| return prediction.mean(axis=0), prediction, superset_prediction | ||
|
|
||
| def predict_qte( | ||
| self, | ||
| target_treatment_arm: int, | ||
| control_treatment_arm: int, | ||
| quantiles: Optional[np.ndarray] = None, | ||
| alpha: float = 0.05, | ||
| n_bootstrap=500, | ||
| display_progress: bool = True, | ||
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | ||
| """ | ||
| Compute Quantile Treatment Effects (QTE) using stratified bootstrap. | ||
|
|
||
| Uses stratified bootstrap (resampling independently within each stratum) to | ||
| correctly estimate variance under covariate adaptive randomization (CAR). | ||
|
|
||
| Args: | ||
| target_treatment_arm (int): The index of the treatment arm of the treatment group. | ||
| control_treatment_arm (int): The index of the treatment arm of the control group. | ||
| quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9]. | ||
| alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. | ||
| n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500. | ||
| display_progress (bool, optional): Whether to display a progress bar. Defaults to True. | ||
|
|
||
| Returns: | ||
| Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: | ||
| - Expected QTEs (np.ndarray): Treatment effect estimates at each quantile | ||
| - Lower bounds (np.ndarray): Lower confidence interval bounds | ||
| - Upper bounds (np.ndarray): Upper confidence interval bounds | ||
| """ | ||
| qte = self._compute_qtes( | ||
| target_treatment_arm, | ||
| control_treatment_arm, | ||
| quantiles, | ||
| self.covariates, | ||
| self.treatment_arms, | ||
| self.outcomes, | ||
| self.strata, | ||
| ) | ||
|
Comment on lines
+484
to
+517
|
||
|
|
||
| # Precompute stratum indices for stratified bootstrap | ||
| unique_strata = np.unique(self.strata) | ||
| strata_indices = {s: np.where(self.strata == s)[0] for s in unique_strata} | ||
|
|
||
| qtes = np.zeros((n_bootstrap, qte.shape[0])) | ||
| bootstrap_iter = range(n_bootstrap) | ||
| if display_progress: | ||
| bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE") | ||
| for b in bootstrap_iter: | ||
| # Stratified bootstrap: resample within each stratum independently | ||
| bootstrap_indexes = np.concatenate([ | ||
| np.random.choice(idx, size=len(idx), replace=True) | ||
| for idx in strata_indices.values() | ||
| ]) | ||
|
|
||
| qtes[b] = self._compute_qtes( | ||
| target_treatment_arm, | ||
| control_treatment_arm, | ||
| quantiles, | ||
| self.covariates[bootstrap_indexes], | ||
| self.treatment_arms[bootstrap_indexes], | ||
| self.outcomes[bootstrap_indexes], | ||
| self.strata[bootstrap_indexes], | ||
| ) | ||
|
|
||
| qte_var = qtes.var(axis=0) | ||
|
|
||
| qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var) | ||
| qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var) | ||
|
|
||
| return qte, qte_lower, qte_upper | ||
|
|
||
| def _compute_model_prediction(self, model, covariates: np.ndarray) -> np.ndarray: | ||
| if hasattr(model, "predict_proba"): | ||
| if self.is_multi_task: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quantilesis documented as optional with a default ([0.1, …, 0.9]) but it’s passed straight into_compute_qtes. If the caller leavesquantiles=None,_compute_qteswill error when accessingquantiles.shape. Please set a default array whenquantiles is None(and ideally validate they’re in (0,1)).