From f7104f1e816f023d2102a07cf7a885190ff7dfc0 Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Wed, 1 Apr 2026 13:18:24 -0400 Subject: [PATCH 01/20] v2.3.1.904: fix time-varying weights, add fix_weights param, reduce namespace pollution, fix dreamerr/get() crash - Fix faster_mode mismatch with time-varying sampling weights on balanced panel - Add fix_weights parameter: NULL (default), "varying", "base_period", "first_period" - Replace import(stats), import(utils), import(BMisc) with selective importFrom - Fix aggte() crash from dreamerr intercepting get() inside data.table - Add runtime message for time-varying weights detection - Skip inference tests gracefully when did v2.1.2 unavailable from CRAN - Add GitHub Action to auto-bump version on PR merge - 236 PASS, 0 FAIL, 8 SKIP --- .github/workflows/bump-version.yaml | 50 ++++++++ DESCRIPTION | 2 +- NAMESPACE | 24 +++- NEWS.md | 28 +++++ R/DIDparams.R | 2 + R/DIDparams2.R | 4 + R/att_gt.R | 56 ++++++++- R/compute.aggte.R | 2 +- R/compute.att_gt.R | 103 +++++++++++++++- R/compute.att_gt2.R | 77 ++++++++++-- R/imports.R | 8 +- R/pre_process_did.R | 16 +++ R/pre_process_did2.R | 26 +++- R/utility_functions.R | 4 +- man/DIDparams.Rd | 47 +++++++- man/att_gt.Rd | 47 +++++++- man/conditional_did_pretest.Rd | 19 ++- man/pre_process_did.Rd | 47 +++++++- man/pre_process_did2.Rd | 47 +++++++- tests/testthat/test-att_gt.R | 177 ++++++++++++++++++++++++++++ tests/testthat/test-inference.R | 17 ++- 21 files changed, 774 insertions(+), 29 deletions(-) create mode 100644 .github/workflows/bump-version.yaml diff --git a/.github/workflows/bump-version.yaml b/.github/workflows/bump-version.yaml new file mode 100644 index 00000000..325aa0b4 --- /dev/null +++ b/.github/workflows/bump-version.yaml @@ -0,0 +1,50 @@ +name: Bump dev version on PR merge + +on: + pull_request: + types: [closed] + branches: [master, main] + +jobs: + bump-version: + if: github.event.pull_request.merged == true + runs-on: ubuntu-latest + + permissions: + contents: write + + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.base.ref }} + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Bump dev version in DESCRIPTION + run: | + # Extract current version + current=$(grep '^Version:' DESCRIPTION | sed 's/Version: //') + echo "Current version: $current" + + # Split into parts + IFS='.' read -ra parts <<< "$current" + major="${parts[0]}" + minor="${parts[1]}" + patch="${parts[2]}" + dev="${parts[3]:-0}" + + # Increment dev version + new_dev=$((dev + 1)) + new_version="${major}.${minor}.${patch}.${new_dev}" + echo "New version: $new_version" + + # Update DESCRIPTION + sed -i "s/^Version: .*/Version: ${new_version}/" DESCRIPTION + + # Configure git + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + # Commit and push + git add DESCRIPTION + git diff --cached --quiet || git commit -m "Bump version to ${new_version}" + git push diff --git a/DESCRIPTION b/DESCRIPTION index e239a2ea..7d2b5a33 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: did Title: Treatment Effects with Multiple Periods and Groups -Version: 2.3.1.903 +Version: 2.3.1.904 Authors@R: c(person("Brantly", "Callaway", email = "brantly.callaway@uga.edu", role = c("aut", "cre")), person("Pedro H. C.", "Sant'Anna", email="pedro.santanna@emory.edu", role = c("aut"))) URL: https://bcallaway11.github.io/did/, https://github.com/bcallaway11/did/ Description: The standard Difference-in-Differences (DID) setup involves two periods and two groups -- a treated group and untreated group. Many applications of DID methods involve more than two periods and have individuals that are treated at different points in time. This package contains tools for computing average treatment effect parameters in Difference in Differences setups with more than two periods and with variation in treatment timing using the methods developed in Callaway and Sant'Anna (2021) . The main parameters are group-time average treatment effects which are the average treatment effect for a particular group at a a particular time. These can be aggregated into a fewer number of treatment effect parameters, and the package deals with the cases where there is selective treatment timing, dynamic treatment effects, calendar time effects, or combinations of these. There are also functions for testing the Difference in Differences assumption, and plotting group-time average treatment effects. diff --git a/NAMESPACE b/NAMESPACE index c5dab2b9..804c402f 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -38,12 +38,15 @@ export(splot) export(test.mboot) export(tidy) export(trimmer) -import(BMisc) import(data.table) import(fastglm) import(ggplot2) -import(stats) -import(utils) +importFrom(BMisc,TorF) +importFrom(BMisc,getListElement) +importFrom(BMisc,makeBalancedPanel) +importFrom(BMisc,multiplier_bootstrap) +importFrom(BMisc,rhs.vars) +importFrom(BMisc,toformula) importFrom(DRDID,drdid_panel) importFrom(DRDID,drdid_rc) importFrom(DRDID,reg_did_panel) @@ -55,5 +58,20 @@ importFrom(generics,glance) importFrom(generics,tidy) importFrom(methods,as) importFrom(methods,is) +importFrom(stats,aggregate) +importFrom(stats,binomial) +importFrom(stats,complete.cases) +importFrom(stats,cov) +importFrom(stats,model.frame) +importFrom(stats,model.matrix) +importFrom(stats,na.pass) importFrom(stats,nobs) +importFrom(stats,pchisq) +importFrom(stats,pnorm) +importFrom(stats,qnorm) +importFrom(stats,quantile) +importFrom(stats,rnorm) +importFrom(stats,setNames) +importFrom(stats,var) importFrom(tidyr,gather) +importFrom(utils,globalVariables) diff --git a/NEWS.md b/NEWS.md index 6760543c..66a9a8bb 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,31 @@ +# did 2.3.1.904 + + * Fixed bug where `faster_mode = TRUE` and `faster_mode = FALSE` produced different ATT estimates when sampling weights (`weightsname`) vary across time. The fast path was always using first-period weights; it now correctly uses the same period's weights as the slow path + + * New `fix_weights` argument in `att_gt()` gives users explicit control over how time-varying sampling weights are resolved in each 2x2 DiD comparison. Options: `NULL` (default, preserves existing behavior), `"varying"` (per-observation weights using RC estimators), `"base_period"` (fix at g-1 for all cells), `"first_period"` (fix at first period). See `?att_gt` for details + + * Runtime message when time-varying weights are detected in balanced panel data, directing users to the `fix_weights` argument + + * Reduced namespace pollution: replaced blanket `import(stats)`, `import(utils)`, and `import(BMisc)` with selective `importFrom()` calls. The `did` package no longer re-exports `stats::filter` or `stats::lag`, which previously masked `dplyr::filter` and `dplyr::lag` when both packages were loaded + + * Fixed `aggte()` crash (`"Error in get(gname): invalid first argument"`) when the user's group column is literally named `gname` and `dreamerr` >= 1.5.0 is installed. The issue was `dreamerr` intercepting `data.table`'s `get()` inside `[.data.table`; replaced with `set()` which is immune to this + + * Expanded `weightsname` documentation explaining how time-varying weights are handled differently for balanced panels vs. repeated cross sections and unbalanced panels + + * Added `nobs()` S3 methods for `MP` and `AGGTEobj` objects, returning the number of unique cross-sectional units as an integer + + * Added `statistic` (t-statistic) and `p.value` (pointwise, two-sided) columns to `tidy()` output for both `MP` and `AGGTEobj` objects, following `broom` conventions + + * Added `broom` to `Suggests` + +# did 2.3.1.903 + + * Added `nobs()` S3 methods and `statistic`/`p.value` columns to `tidy()` output (superseded by 2.3.1.904 entry above) + +# did 2.3.1.902 + + * Bug fixes, diagnostic improvements, and JEL replication tests + # did 2.3.1.901 * `att_gt()` now accepts `...` (dots) for passing additional arguments to custom `est_method` functions diff --git a/R/DIDparams.R b/R/DIDparams.R index 7aafd3a5..7953d8a5 100644 --- a/R/DIDparams.R +++ b/R/DIDparams.R @@ -25,6 +25,7 @@ DIDparams <- function(yname, control_group, anticipation=0, weightsname=NULL, + fix_weights=NULL, alp=0.05, bstrap=TRUE, biters=1000, @@ -54,6 +55,7 @@ DIDparams <- function(yname, control_group=control_group, anticipation=anticipation, weightsname=weightsname, + fix_weights=fix_weights, alp=alp, bstrap=bstrap, biters=biters, diff --git a/R/DIDparams2.R b/R/DIDparams2.R index 108b50bf..8c88a2c0 100644 --- a/R/DIDparams2.R +++ b/R/DIDparams2.R @@ -49,6 +49,8 @@ DIDparams2 <- function(did_tensors, args, call=NULL) { covariates_matrix <- did_tensors$covariates_matrix cluster_vector <- did_tensors$cluster weights_vector <- did_tensors$weights + weights_tensor <- did_tensors$weights_tensor + fix_weights <- args$fix_weights out <- list(yname=yname, @@ -89,6 +91,8 @@ DIDparams2 <- function(did_tensors, args, call=NULL) { covariates_matrix = covariates_matrix, cluster_vector=cluster_vector, weights_vector=weights_vector, + weights_tensor=weights_tensor, + fix_weights=fix_weights, call=call) class(out) <- "DIDparams" return(out) diff --git a/R/att_gt.R b/R/att_gt.R index b5663b41..ce086373 100644 --- a/R/att_gt.R +++ b/R/att_gt.R @@ -16,7 +16,50 @@ #' It defines which "group" a unit belongs to. It should be 0 for units #' in the untreated group. #' @param weightsname The name of the column containing the sampling weights. -#' If not set, all observations have same weight. +#' If not set, all observations have same weight. When weights are +#' time-invariant (constant within each unit across periods), all +#' \code{fix_weights} options produce identical results and no special +#' handling is needed. +#' +#' When weights vary across time (e.g., time-varying population sizes), +#' the default behavior differs by panel type: +#' \describe{ +#' \item{Balanced panel}{Each 2x2 DiD comparison uses the weight from the +#' earlier of the two time periods involved. For post-treatment cells, +#' this is the base period (g-1). For pre-treatment cells with +#' \code{base_period="varying"}, this is the pre-treatment period itself. +#' The panel DRDID estimators are used.} +#' \item{Repeated cross sections and unbalanced panels}{Both periods' +#' per-observation weights are passed directly to the RC DRDID estimators, +#' so each observation carries its own period-specific weight.} +#' } +#' Use the \code{fix_weights} argument to override the default behavior. +#' @param fix_weights Controls how time-varying sampling weights are resolved. +#' Only relevant when weights vary across time; with time-invariant weights, +#' all options produce identical results. Options: +#' \describe{ +#' \item{\code{NULL} (default)}{For balanced panel: uses the weight from +#' the earlier of the two time periods in each 2x2 comparison. For +#' post-treatment cells, this is the base period (g-1). For +#' pre-treatment cells, this depends on the \code{base_period} setting. +#' For RC/unbalanced panel: uses per-observation weights from both +#' periods.} +#' \item{\code{"varying"}}{Uses per-observation, period-specific weights +#' for all panel types. For balanced panel data, this switches to the +#' repeated cross-section DRDID estimators so that pre-period and +#' post-period observations each carry their own weight. This is the +#' most flexible option but sacrifices the efficiency of the panel +#' estimator. For RC/unbalanced panel, this is identical to the +#' default.} +#' \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for +#' all (g,t) cells within a group, for both pre-treatment and +#' post-treatment comparisons. Ensures all cells within a group use the +#' same weights. For RC/unbalanced panel, units not observed in the base +#' period are dropped with a warning.} +#' \item{\code{"first_period"}}{Fixes weights at the first time period in +#' the dataset for all (g,t) cells. For RC/unbalanced panel, units not +#' observed in the first period are dropped with a warning.} +#' } #' @param alp the significance level, default is 0.05 #' @param bstrap Boolean for whether or not to compute standard errors using #' the multiplier bootstrap. If standard errors are clustered, then one @@ -195,6 +238,7 @@ att_gt <- function(yname, control_group = c("nevertreated", "notyettreated"), anticipation = 0, weightsname = NULL, + fix_weights = NULL, alp = 0.05, bstrap = TRUE, cband = TRUE, @@ -217,6 +261,14 @@ att_gt <- function(yname, "\". Extra arguments are only passed to custom est_method functions.") } + # Validate fix_weights + if (!is.null(fix_weights)) { + if (!is.character(fix_weights) || length(fix_weights) != 1 || + !(fix_weights %in% c("varying", "base_period", "first_period"))) { + stop("fix_weights must be NULL or one of \"varying\", \"base_period\", or \"first_period\".") + } + } + # Validate est_method if (!inherits(est_method, "function")) { if (!is.character(est_method) || length(est_method) != 1) { @@ -249,6 +301,7 @@ att_gt <- function(yname, control_group = control_group, anticipation = anticipation, weightsname = weightsname, + fix_weights = fix_weights, alp = alp, bstrap = bstrap, cband = cband, @@ -284,6 +337,7 @@ att_gt <- function(yname, control_group = control_group, anticipation = anticipation, weightsname = weightsname, + fix_weights = fix_weights, alp = alp, bstrap = bstrap, cband = cband, diff --git a/R/compute.aggte.R b/R/compute.aggte.R index a93ac752..ec02770e 100644 --- a/R/compute.aggte.R +++ b/R/compute.aggte.R @@ -57,7 +57,7 @@ compute.aggte <- function(MP, } if (isTRUE(dp$faster_mode)) { dt <- dp$data - dt[get(gname) == Inf, (gname) := 0] # going back to the old way + set(dt, i = which(dt[[gname]] == Inf), j = gname, value = 0) # going back to the old way data <- as.data.frame(dt) rm(dt) tlist <- dp$time_periods diff --git a/R/compute.att_gt.R b/R/compute.att_gt.R index bef72c36..ffa8af7e 100644 --- a/R/compute.att_gt.R +++ b/R/compute.att_gt.R @@ -65,10 +65,20 @@ compute.att_gt <- function(dp) { # never treated option nevertreated <- (control_group[1] == "nevertreated") + fix_weights <- dp$fix_weights + # Pre-extract columns to avoid repeated get() inside data.table (which is slow) g_col <- data[[gname]] t_col <- data[[tname]] + # Build weight lookup by period for fix_weights options (balanced panel only) + if (!is.null(fix_weights) && panel) { + weights_by_period <- list() + for (tp in seq_along(tlist)) { + weights_by_period[[tp]] <- data[t_col == tlist[tp], .w] + } + } + if (nevertreated) { set(data, j = ".C", value = as.integer(g_col == 0)) } @@ -193,7 +203,20 @@ compute.att_gt <- function(dp) { # base period, then the "base period" is actually the later period Ypre <- if (tlist[(t + tfac)] > tlist[pret]) disdat$.y0 else disdat$.y1 Ypost <- if (tlist[(t + tfac)] > tlist[pret]) disdat$.y1 else disdat$.y0 - w <- disdat$.w + + # Select weights based on fix_weights + if (is.null(fix_weights)) { + # Default: .w from get_wide_data (earlier period) + w <- disdat$.w + } else if (fix_weights == "base_period") { + w <- weights_by_period[[pret_g]][disidx] + } else if (fix_weights == "first_period") { + w <- weights_by_period[[1L]][disidx] + } else if (fix_weights == "varying") { + w <- disdat$.w # will be overridden below when switching to RC estimator + } else { + w <- disdat$.w + } # matrix of covariates covariates <- model.matrix(xformla, data = disdat) @@ -247,7 +270,39 @@ compute.att_gt <- function(dp) { #----------------------------------------------------------------------------- attgt <- tryCatch({ - if (inherits(est_method, "function")) { + if (!is.null(fix_weights) && fix_weights == "varying") { + # fix_weights = "varying": use RC estimators with per-period weights + # Go back to long-format data for this (g,t) cell + disdat_long <- data[time_mask] + disdat_long_idx <- disdat_long$.G == 1 | disdat_long$.C == 1 + disdat_long <- disdat_long[disdat_long_idx] + Y_rc <- disdat_long[[yname]] + G_rc <- disdat_long$.G + post_rc <- as.numeric(disdat_long[[tname]] == tlist[t + tfac]) + w_rc <- disdat_long$.w + covariates_rc <- model.matrix(xformla, data = disdat_long) + n1_rc <- sum(G_rc + disdat_long$.C) # careful: n1 for RC is different + + if (inherits(est_method, "function")) { + res <- do.call(est_method, c(list( + y = Y_rc, post = post_rc, + D = G_rc, covariates = covariates_rc, + i.weights = w_rc, inffunc = TRUE + ), extra_args)) + } else if (est_method == "ipw") { + res <- DRDID::std_ipw_did_rc(Y_rc, post_rc, G_rc, + covariates = covariates_rc, + i.weights = w_rc, boot = FALSE, inffunc = TRUE) + } else if (est_method == "reg") { + res <- DRDID::reg_did_rc(Y_rc, post_rc, G_rc, + covariates = covariates_rc, + i.weights = w_rc, boot = FALSE, inffunc = TRUE) + } else { + res <- DRDID::drdid_rc(Y_rc, post_rc, G_rc, + covariates = covariates_rc, + i.weights = w_rc, boot = FALSE, inffunc = TRUE) + } + } else if (inherits(est_method, "function")) { # user-specified function res <- do.call(est_method, c(list( y1 = Ypost, y0 = Ypre, @@ -281,7 +336,16 @@ compute.att_gt <- function(dp) { # adjust influence function to account for only using # subgroup to estimate att(g,t) - res$att.inf.func <- (n / n1) * res$att.inf.func + if (!is.null(fix_weights) && fix_weights == "varying") { + # RC influence function has 2*n1 rows (stacked pre + post); + # aggregate back to unit level by summing pre and post contributions + inf_rc <- res$att.inf.func + n1_half <- length(inf_rc) %/% 2L + res$att.inf.func <- inf_rc[1:n1_half] + inf_rc[(n1_half + 1):(2 * n1_half)] + res$att.inf.func <- (n / n1) * res$att.inf.func + } else { + res$att.inf.func <- (n / n1) * res$att.inf.func + } res }, error = function(e) { warning("Error computing internal 2x2 DiD for (g, t) = (", glist[g], ", ", tlist[t + tfac], "): ", e$message, ". The ATT for this cell will be set to NA.") @@ -325,7 +389,38 @@ compute.att_gt <- function(dp) { post <- 1 * (disdat[[tname]] == tlist[t + tfac]) # num obs. for computing ATT(g,t), have to be careful here n1 <- sum(G + C) - w <- disdat$.w + + # Handle fix_weights for RC/unbalanced panel + if (!is.null(fix_weights) && fix_weights %in% c("base_period", "first_period")) { + # Determine which period's weight to use + if (fix_weights == "base_period") { + target_period <- tlist[pret_g] + } else { + target_period <- tlist[1] + } + # Build lookup: weight from target period per unit + target_rows <- data[t_col == target_period, ] + target_w <- stats::setNames(target_rows$.w, target_rows$.rowid) + # Look up weight for each observation's unit + w <- as.numeric(target_w[as.character(disdat$.rowid)]) + # Drop units not observed in the target period + missing_w <- is.na(w) + if (any(missing_w)) { + n_dropped <- length(unique(disdat$.rowid[missing_w])) + warning(paste0("Dropped ", n_dropped, " units not observed in ", + fix_weights, " (period ", target_period, ") ", + "for group ", glist[g], " in time period ", tlist[t + tfac])) + disdat <- disdat[!missing_w, ] + G <- disdat$.G + C <- disdat$.C + Y <- disdat[[yname]] + post <- 1 * (disdat[[tname]] == tlist[t + tfac]) + n1 <- sum(G + C) + w <- w[!missing_w] + } + } else { + w <- disdat$.w + } #----------------------------------------------------------------------------- # checks to make sure that we have enough observations diff --git a/R/compute.att_gt2.R b/R/compute.att_gt2.R index bfdf3eae..7d7e16d9 100644 --- a/R/compute.att_gt2.R +++ b/R/compute.att_gt2.R @@ -88,12 +88,12 @@ get_did_cohort_index <- function(group, time, tfac, pret, dp2){ #' #' @return A list containing the estimated ATT and the influence function vector. #' @noRd -run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL){ +run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL, force_rc = FALSE){ extra_args <- if (is.null(dp2$extra_args)) list() else dp2$extra_args gt_label <- if (!is.null(g_val) && !is.null(t_val)) paste0(" for group ", g_val, " in time period ", t_val) else "" - if(dp2$panel){ + if(dp2$panel && !force_rc){ # -------------------------------------- # Panel Data # -------------------------------------- @@ -376,25 +376,88 @@ run_att_gt_estimation <- function(g, t, dp2){ if(dp2$panel){ - cohort_data <- data.table(did_cohort_index, dp2$outcomes_tensor[[t+tfac]], dp2$outcomes_tensor[[pret]], dp2$weights_vector) - names(cohort_data) <- c("D", "y1", "y0", "i.weights") - covariates <- dp2$covariates_tensor[[base::min(pret, t)]] + # Determine which weight period to use based on fix_weights + use_rc_for_weights <- (!is.null(dp2$fix_weights) && dp2$fix_weights == "varying") + + if (use_rc_for_weights) { + # fix_weights = "varying": stack into RC format with per-period weights + n_units <- length(did_cohort_index) + cohort_data <- data.table( + D = rep(did_cohort_index, 2), + y = c(dp2$outcomes_tensor[[pret]], dp2$outcomes_tensor[[t+tfac]]), + post = rep(c(0L, 1L), each = n_units), + i.weights = c(dp2$weights_tensor[[pret]], dp2$weights_tensor[[t+tfac]]) + ) + # Stack covariates for both periods + cov_pre <- dp2$covariates_tensor[[pret]] + cov_post <- dp2$covariates_tensor[[t+tfac]] + if (is.matrix(cov_pre)) { + covariates <- rbind(cov_pre, cov_post) + } else { + covariates <- c(cov_pre, cov_post) + } + } else { + # Default or fixed weight options: use panel estimator with single weight vector + if (is.null(dp2$fix_weights)) { + # Default: weight from earlier of the two periods + w_idx <- base::min(pret, t) + } else if (dp2$fix_weights == "base_period") { + w_idx <- dp2$.pret_by_group[g] + } else if (dp2$fix_weights == "first_period") { + w_idx <- 1L + } + cohort_data <- data.table(did_cohort_index, dp2$outcomes_tensor[[t+tfac]], + dp2$outcomes_tensor[[pret]], dp2$weights_tensor[[w_idx]]) + names(cohort_data) <- c("D", "y1", "y0", "i.weights") + covariates <- dp2$covariates_tensor[[base::min(pret, t)]] + } } else { log_vec <- dp2$time_invariant_data[[ dp2$tname ]] == dp2$time_periods[t+tfac] # convert TRUE/FALSE to 1/0 in place (fastest) set(dp2$time_invariant_data, j = "post", value = as.integer(log_vec)) - cohort_data <- data.table(did_cohort_index, dp2$time_invariant_data[[dp2$yname]], dp2$time_invariant_data$post, dp2$time_invariant_data$weights, dp2$time_invariant_data$.rowid) + + # Handle fix_weights for RC/unbalanced panel + if (!is.null(dp2$fix_weights) && dp2$fix_weights %in% c("base_period", "first_period")) { + if (dp2$fix_weights == "base_period") { + target_period <- dp2$time_periods[dp2$.pret_by_group[g]] + } else { + target_period <- dp2$time_periods[1] + } + # Build weight lookup from target period + tid <- dp2$time_invariant_data + target_mask <- tid[[dp2$tname]] == target_period + target_ids <- tid[[dp2$idname]][target_mask] + target_ws <- tid[["weights"]][target_mask] + target_w_lookup <- stats::setNames(target_ws, as.character(target_ids)) + # Look up weight for each observation + obs_ids <- as.character(tid[[dp2$idname]]) + fixed_w <- as.numeric(target_w_lookup[obs_ids]) + # Units not in target period get NA weight — will be filtered in run_DRDID + cohort_data <- data.table(did_cohort_index, tid[[dp2$yname]], tid$post, fixed_w, tid$.rowid) + } else { + cohort_data <- data.table(did_cohort_index, dp2$time_invariant_data[[dp2$yname]], dp2$time_invariant_data$post, dp2$time_invariant_data$weights, dp2$time_invariant_data$.rowid) + } names(cohort_data) <- c("D", "y", "post", "i.weights", ".rowid") covariates <- dp2$covariates_matrix } # run estimation - did_result <- tryCatch(run_DRDID(cohort_data, covariates, dp2, g_val = dp2$treated_groups[g], t_val = dp2$time_periods[t+tfac]), + force_rc <- if (exists("use_rc_for_weights") && isTRUE(use_rc_for_weights)) TRUE else FALSE + did_result <- tryCatch(run_DRDID(cohort_data, covariates, dp2, g_val = dp2$treated_groups[g], t_val = dp2$time_periods[t+tfac], force_rc = force_rc), error = function(e) { warning("Error computing internal 2x2 DiD for (g, t) = (", dp2$treated_groups[g], ", ", dp2$time_periods[t+tfac], "): ", e$message, ". The ATT for this cell will be set to NA.") return(NULL) }) + + # When force_rc on balanced panel, the influence function has 2*id_count rows + # (stacked pre + post). Aggregate back to id_count by summing pre + post contributions. + if (force_rc && !is.null(did_result) && dp2$panel) { + inf <- did_result$inf_func + n_half <- length(inf) %/% 2L + did_result$inf_func <- inf[1:n_half] + inf[(n_half + 1):(2L * n_half)] + } + return(did_result) } diff --git a/R/imports.R b/R/imports.R index b2bc0bf3..ea0c8ad7 100644 --- a/R/imports.R +++ b/R/imports.R @@ -6,10 +6,12 @@ #' @keywords internal "_PACKAGE" -#' @import stats -#' @import utils +#' @importFrom stats pnorm qnorm pchisq quantile cov aggregate setNames +#' model.frame model.matrix na.pass complete.cases binomial rnorm var +#' @importFrom utils globalVariables #' @import ggplot2 -#' @import BMisc +#' @importFrom BMisc toformula rhs.vars makeBalancedPanel getListElement +#' multiplier_bootstrap TorF #' @import data.table #' @import fastglm #' @importFrom tidyr gather diff --git a/R/pre_process_did.R b/R/pre_process_did.R index 3ee13c5e..e51caf90 100644 --- a/R/pre_process_did.R +++ b/R/pre_process_did.R @@ -21,6 +21,7 @@ pre_process_did <- function(yname, control_group = c("nevertreated","notyettreated"), anticipation = 0, weightsname = NULL, + fix_weights = NULL, alp = 0.05, bstrap = FALSE, cband = FALSE, @@ -95,6 +96,20 @@ pre_process_did <- function(yname, if (".w" %in% colnames(data)) stop("Your data already contains a column named '.w', which is reserved for internal use by `did`. Please rename this column before calling att_gt().") data$.w <- w + # Check for time-varying weights in panel data + if (!is.null(weightsname) && panel) { + w_by_id <- tapply(data[, weightsname], data[, idname], function(x) max(x) - min(x)) + if (any(w_by_id > .Machine$double.eps^0.5, na.rm = TRUE)) { + message( + "Time-varying weights detected. For balanced panel data, the default ", + "behavior uses the weight from the earlier of the two time periods in ", + "each 2x2 comparison (the base period for post-treatment cells). ", + "Use the 'fix_weights' argument to control this behavior. ", + "See ?att_gt for details." + ) + } + } + # Outcome variable will be denoted by y # data$.y <- data[, yname] @@ -389,6 +404,7 @@ pre_process_did <- function(yname, control_group=control_group, anticipation=anticipation, weightsname=weightsname, + fix_weights=fix_weights, alp=alp, bstrap=bstrap, biters=biters, diff --git a/R/pre_process_did2.R b/R/pre_process_did2.R index 733fbe56..dd3f52ae 100644 --- a/R/pre_process_did2.R +++ b/R/pre_process_did2.R @@ -131,6 +131,20 @@ did_standardization <- function(data, args){ weights <- weights/mean(weights) data$weights <- weights + # Check for time-varying weights in panel data + if (!is.null(args$weightsname) && args$panel) { + w_range <- data[, .(w_range = max(weights) - min(weights)), by = get(args$idname)] + if (any(w_range$w_range > .Machine$double.eps^0.5, na.rm = TRUE)) { + message( + "Time-varying weights detected. For balanced panel data, the default ", + "behavior uses the weight from the earlier of the two time periods in ", + "each 2x2 comparison (the base period for post-treatment cells). ", + "Use the 'fix_weights' argument to control this behavior. ", + "See ?att_gt for details." + ) + } + } + # get a list of dates from min to max tlist <- data[, sort(unique(get(args$tname)))] @@ -419,6 +433,13 @@ get_did_tensors <- function(data, args){ start <- (time - 1L) * n + 1L outcomes_tensor[[time]] <- y_vec[start:(start + n - 1L)] } + # Build weights tensor: one weight vector per time period + w_vec <- data[["weights"]] + weights_tensor <- vector("list", nT) + for(time in seq_len(nT)){ + start <- (time - 1L) * n + 1L + weights_tensor[[time]] <- w_vec[start:(start + n - 1L)] + } } else { # for(time in args$time_periods){ # outcome_vector_time <- rep(NA, args$id_count) # Initialize vector with NAs @@ -430,6 +451,7 @@ get_did_tensors <- function(data, args){ # data[, outcome_vector_time := NULL] # } outcomes_tensor <- NULL + weights_tensor <- NULL } # Getting the time invariant data @@ -533,7 +555,8 @@ get_did_tensors <- function(data, args){ covariates_matrix = covariates_matrix, covariates_tensor = covariates_tensor, cluster = cluster, - weights = weights)) + weights = weights, + weights_tensor = weights_tensor)) } #' @title Process `did` Function Arguments @@ -559,6 +582,7 @@ pre_process_did2 <- function(yname, control_group = c("nevertreated","notyettreated"), anticipation = 0, weightsname = NULL, + fix_weights = NULL, alp = 0.05, bstrap = FALSE, cband = FALSE, diff --git a/R/utility_functions.R b/R/utility_functions.R index 05e1a6d0..40ed9309 100644 --- a/R/utility_functions.R +++ b/R/utility_functions.R @@ -86,10 +86,10 @@ get_wide_data <- function(data, yname, idname, tname) { check_balance <- function(data, id_col, time_col) { # Count the number of observations per unit (idname) - panel_counts <- data[, .N, by = get(id_col)] + panel_counts <- data[, .N, by = c(id_col)] # Determine the maximum number of time periods for any unit - max_time_periods <- data[, uniqueN(get(time_col))] + max_time_periods <- data[, uniqueN(data[[time_col]])] # Check if every unit has the same number of time periods as max_time_periods is_balanced <- all(panel_counts$N == max_time_periods) diff --git a/man/DIDparams.Rd b/man/DIDparams.Rd index 30c0ff65..66890966 100644 --- a/man/DIDparams.Rd +++ b/man/DIDparams.Rd @@ -14,6 +14,7 @@ DIDparams( control_group, anticipation = 0, weightsname = NULL, + fix_weights = NULL, alp = 0.05, bstrap = TRUE, biters = 1000, @@ -86,7 +87,51 @@ in the treatment where units can anticipate participating in the treatment and therefore it can affect their untreated potential outcomes} \item{weightsname}{The name of the column containing the sampling weights. -If not set, all observations have same weight.} +If not set, all observations have same weight. When weights are +time-invariant (constant within each unit across periods), all +\code{fix_weights} options produce identical results and no special +handling is needed. + +When weights vary across time (e.g., time-varying population sizes), +the default behavior differs by panel type: +\describe{ +\item{Balanced panel}{Each 2x2 DiD comparison uses the weight from the +earlier of the two time periods involved. For post-treatment cells, +this is the base period (g-1). For pre-treatment cells with +\code{base_period="varying"}, this is the pre-treatment period itself. +The panel DRDID estimators are used.} +\item{Repeated cross sections and unbalanced panels}{Both periods' +per-observation weights are passed directly to the RC DRDID estimators, +so each observation carries its own period-specific weight.} +} +Use the \code{fix_weights} argument to override the default behavior.} + +\item{fix_weights}{Controls how time-varying sampling weights are resolved. +Only relevant when weights vary across time; with time-invariant weights, +all options produce identical results. Options: +\describe{ +\item{\code{NULL} (default)}{For balanced panel: uses the weight from +the earlier of the two time periods in each 2x2 comparison. For +post-treatment cells, this is the base period (g-1). For +pre-treatment cells, this depends on the \code{base_period} setting. +For RC/unbalanced panel: uses per-observation weights from both +periods.} +\item{\code{"varying"}}{Uses per-observation, period-specific weights +for all panel types. For balanced panel data, this switches to the +repeated cross-section DRDID estimators so that pre-period and +post-period observations each carry their own weight. This is the +most flexible option but sacrifices the efficiency of the panel +estimator. For RC/unbalanced panel, this is identical to the +default.} +\item{\code{"base_period"}}{Fixes weights at the base period (g-1) for +all (g,t) cells within a group, for both pre-treatment and +post-treatment comparisons. Ensures all cells within a group use the +same weights. For RC/unbalanced panel, units not observed in the base +period are dropped with a warning.} +\item{\code{"first_period"}}{Fixes weights at the first time period in +the dataset for all (g,t) cells. For RC/unbalanced panel, units not +observed in the first period are dropped with a warning.} +}} \item{alp}{the significance level, default is 0.05} diff --git a/man/att_gt.Rd b/man/att_gt.Rd index ac311920..9d7cda9c 100644 --- a/man/att_gt.Rd +++ b/man/att_gt.Rd @@ -16,6 +16,7 @@ att_gt( control_group = c("nevertreated", "notyettreated"), anticipation = 0, weightsname = NULL, + fix_weights = NULL, alp = 0.05, bstrap = TRUE, cband = TRUE, @@ -96,7 +97,51 @@ in the treatment where units can anticipate participating in the treatment and therefore it can affect their untreated potential outcomes} \item{weightsname}{The name of the column containing the sampling weights. -If not set, all observations have same weight.} +If not set, all observations have same weight. When weights are +time-invariant (constant within each unit across periods), all +\code{fix_weights} options produce identical results and no special +handling is needed. + +When weights vary across time (e.g., time-varying population sizes), +the default behavior differs by panel type: +\describe{ +\item{Balanced panel}{Each 2x2 DiD comparison uses the weight from the +earlier of the two time periods involved. For post-treatment cells, +this is the base period (g-1). For pre-treatment cells with +\code{base_period="varying"}, this is the pre-treatment period itself. +The panel DRDID estimators are used.} +\item{Repeated cross sections and unbalanced panels}{Both periods' +per-observation weights are passed directly to the RC DRDID estimators, +so each observation carries its own period-specific weight.} +} +Use the \code{fix_weights} argument to override the default behavior.} + +\item{fix_weights}{Controls how time-varying sampling weights are resolved. +Only relevant when weights vary across time; with time-invariant weights, +all options produce identical results. Options: +\describe{ +\item{\code{NULL} (default)}{For balanced panel: uses the weight from +the earlier of the two time periods in each 2x2 comparison. For +post-treatment cells, this is the base period (g-1). For +pre-treatment cells, this depends on the \code{base_period} setting. +For RC/unbalanced panel: uses per-observation weights from both +periods.} +\item{\code{"varying"}}{Uses per-observation, period-specific weights +for all panel types. For balanced panel data, this switches to the +repeated cross-section DRDID estimators so that pre-period and +post-period observations each carry their own weight. This is the +most flexible option but sacrifices the efficiency of the panel +estimator. For RC/unbalanced panel, this is identical to the +default.} +\item{\code{"base_period"}}{Fixes weights at the base period (g-1) for +all (g,t) cells within a group, for both pre-treatment and +post-treatment comparisons. Ensures all cells within a group use the +same weights. For RC/unbalanced panel, units not observed in the base +period are dropped with a warning.} +\item{\code{"first_period"}}{Fixes weights at the first time period in +the dataset for all (g,t) cells. For RC/unbalanced panel, units not +observed in the first period are dropped with a warning.} +}} \item{alp}{the significance level, default is 0.05} diff --git a/man/conditional_did_pretest.Rd b/man/conditional_did_pretest.Rd index 2a8be54b..3646a6b3 100644 --- a/man/conditional_did_pretest.Rd +++ b/man/conditional_did_pretest.Rd @@ -88,7 +88,24 @@ eventually participate in the treatment, but have not participated yet.} \item{weightsname}{The name of the column containing the sampling weights. -If not set, all observations have same weight.} +If not set, all observations have same weight. When weights are +time-invariant (constant within each unit across periods), all +\code{fix_weights} options produce identical results and no special +handling is needed. + +When weights vary across time (e.g., time-varying population sizes), +the default behavior differs by panel type: +\describe{ +\item{Balanced panel}{Each 2x2 DiD comparison uses the weight from the +earlier of the two time periods involved. For post-treatment cells, +this is the base period (g-1). For pre-treatment cells with +\code{base_period="varying"}, this is the pre-treatment period itself. +The panel DRDID estimators are used.} +\item{Repeated cross sections and unbalanced panels}{Both periods' +per-observation weights are passed directly to the RC DRDID estimators, +so each observation carries its own period-specific weight.} +} +Use the \code{fix_weights} argument to override the default behavior.} \item{alp}{the significance level, default is 0.05} diff --git a/man/pre_process_did.Rd b/man/pre_process_did.Rd index 61c27dff..2cf4d735 100644 --- a/man/pre_process_did.Rd +++ b/man/pre_process_did.Rd @@ -16,6 +16,7 @@ pre_process_did( control_group = c("nevertreated", "notyettreated"), anticipation = 0, weightsname = NULL, + fix_weights = NULL, alp = 0.05, bstrap = FALSE, cband = FALSE, @@ -96,7 +97,51 @@ in the treatment where units can anticipate participating in the treatment and therefore it can affect their untreated potential outcomes} \item{weightsname}{The name of the column containing the sampling weights. -If not set, all observations have same weight.} +If not set, all observations have same weight. When weights are +time-invariant (constant within each unit across periods), all +\code{fix_weights} options produce identical results and no special +handling is needed. + +When weights vary across time (e.g., time-varying population sizes), +the default behavior differs by panel type: +\describe{ +\item{Balanced panel}{Each 2x2 DiD comparison uses the weight from the +earlier of the two time periods involved. For post-treatment cells, +this is the base period (g-1). For pre-treatment cells with +\code{base_period="varying"}, this is the pre-treatment period itself. +The panel DRDID estimators are used.} +\item{Repeated cross sections and unbalanced panels}{Both periods' +per-observation weights are passed directly to the RC DRDID estimators, +so each observation carries its own period-specific weight.} +} +Use the \code{fix_weights} argument to override the default behavior.} + +\item{fix_weights}{Controls how time-varying sampling weights are resolved. +Only relevant when weights vary across time; with time-invariant weights, +all options produce identical results. Options: +\describe{ +\item{\code{NULL} (default)}{For balanced panel: uses the weight from +the earlier of the two time periods in each 2x2 comparison. For +post-treatment cells, this is the base period (g-1). For +pre-treatment cells, this depends on the \code{base_period} setting. +For RC/unbalanced panel: uses per-observation weights from both +periods.} +\item{\code{"varying"}}{Uses per-observation, period-specific weights +for all panel types. For balanced panel data, this switches to the +repeated cross-section DRDID estimators so that pre-period and +post-period observations each carry their own weight. This is the +most flexible option but sacrifices the efficiency of the panel +estimator. For RC/unbalanced panel, this is identical to the +default.} +\item{\code{"base_period"}}{Fixes weights at the base period (g-1) for +all (g,t) cells within a group, for both pre-treatment and +post-treatment comparisons. Ensures all cells within a group use the +same weights. For RC/unbalanced panel, units not observed in the base +period are dropped with a warning.} +\item{\code{"first_period"}}{Fixes weights at the first time period in +the dataset for all (g,t) cells. For RC/unbalanced panel, units not +observed in the first period are dropped with a warning.} +}} \item{alp}{the significance level, default is 0.05} diff --git a/man/pre_process_did2.Rd b/man/pre_process_did2.Rd index b7b90d3a..6106e7bc 100644 --- a/man/pre_process_did2.Rd +++ b/man/pre_process_did2.Rd @@ -16,6 +16,7 @@ pre_process_did2( control_group = c("nevertreated", "notyettreated"), anticipation = 0, weightsname = NULL, + fix_weights = NULL, alp = 0.05, bstrap = FALSE, cband = FALSE, @@ -96,7 +97,51 @@ in the treatment where units can anticipate participating in the treatment and therefore it can affect their untreated potential outcomes} \item{weightsname}{The name of the column containing the sampling weights. -If not set, all observations have same weight.} +If not set, all observations have same weight. When weights are +time-invariant (constant within each unit across periods), all +\code{fix_weights} options produce identical results and no special +handling is needed. + +When weights vary across time (e.g., time-varying population sizes), +the default behavior differs by panel type: +\describe{ +\item{Balanced panel}{Each 2x2 DiD comparison uses the weight from the +earlier of the two time periods involved. For post-treatment cells, +this is the base period (g-1). For pre-treatment cells with +\code{base_period="varying"}, this is the pre-treatment period itself. +The panel DRDID estimators are used.} +\item{Repeated cross sections and unbalanced panels}{Both periods' +per-observation weights are passed directly to the RC DRDID estimators, +so each observation carries its own period-specific weight.} +} +Use the \code{fix_weights} argument to override the default behavior.} + +\item{fix_weights}{Controls how time-varying sampling weights are resolved. +Only relevant when weights vary across time; with time-invariant weights, +all options produce identical results. Options: +\describe{ +\item{\code{NULL} (default)}{For balanced panel: uses the weight from +the earlier of the two time periods in each 2x2 comparison. For +post-treatment cells, this is the base period (g-1). For +pre-treatment cells, this depends on the \code{base_period} setting. +For RC/unbalanced panel: uses per-observation weights from both +periods.} +\item{\code{"varying"}}{Uses per-observation, period-specific weights +for all panel types. For balanced panel data, this switches to the +repeated cross-section DRDID estimators so that pre-period and +post-period observations each carry their own weight. This is the +most flexible option but sacrifices the efficiency of the panel +estimator. For RC/unbalanced panel, this is identical to the +default.} +\item{\code{"base_period"}}{Fixes weights at the base period (g-1) for +all (g,t) cells within a group, for both pre-treatment and +post-treatment comparisons. Ensures all cells within a group use the +same weights. For RC/unbalanced panel, units not observed in the base +period are dropped with a warning.} +\item{\code{"first_period"}}{Fixes weights at the first time period in +the dataset for all (g,t) cells. For RC/unbalanced panel, units not +observed in the first period are dropped with a warning.} +}} \item{alp}{the significance level, default is 0.05} diff --git a/tests/testthat/test-att_gt.R b/tests/testthat/test-att_gt.R index 5763d4fe..b39f6921 100644 --- a/tests/testthat/test-att_gt.R +++ b/tests/testthat/test-att_gt.R @@ -628,6 +628,183 @@ test_that("sampling weights", { }) +# ============================================================================= +# Column naming: user columns named gname/tname/idname should not crash +# ============================================================================= + +test_that("works when user column is literally named 'gname'", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + # Rename columns to match parameter names exactly + names(data)[names(data) == "G"] <- "gname" + names(data)[names(data) == "period"] <- "tname" + names(data)[names(data) == "id"] <- "idname" + + mod <- att_gt(yname="Y", xformla=~X, data=data, tname="tname", idname="idname", + gname="gname", est_method="reg", bstrap=FALSE) + expect_false(all(is.na(mod$att))) + + # aggte should also work (this was the specific dreamerr bug) + agg <- aggte(mod, type="simple") + expect_false(is.na(agg$overall.att)) + + agg_dyn <- aggte(mod, type="dynamic") + expect_false(is.na(agg_dyn$overall.att)) +}) + +test_that("works when user column is literally named 'gname' with faster_mode", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + names(data)[names(data) == "G"] <- "gname" + names(data)[names(data) == "period"] <- "tname" + names(data)[names(data) == "id"] <- "idname" + + mod <- att_gt(yname="Y", xformla=~X, data=data, tname="tname", idname="idname", + gname="gname", est_method="reg", bstrap=FALSE, faster_mode=TRUE) + expect_false(all(is.na(mod$att))) + + agg <- aggte(mod, type="simple") + expect_false(is.na(agg$overall.att)) +}) + +# ============================================================================= +# Time-varying weights: fix_weights tests +# ============================================================================= + +test_that("time-varying weights: faster_mode matches slow mode (default fix_weights=NULL)", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$tv_weight <- data$period + runif(nrow(data), -0.1, 0.1) + + for (em in c("reg", "dr", "ipw")) { + for (bp in c("varying", "universal")) { + res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method=em, weightsname="tv_weight", + base_period=bp, faster_mode=FALSE, bstrap=FALSE) + res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method=em, weightsname="tv_weight", + base_period=bp, faster_mode=TRUE, bstrap=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10, + label=paste("ATT match:", em, bp)) + } + } +}) + +test_that("fix_weights options: faster_mode matches slow mode (balanced panel)", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$tv_weight <- data$period + runif(nrow(data), -0.1, 0.1) + + for (fw in c("varying", "base_period", "first_period")) { + res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method="dr", weightsname="tv_weight", + fix_weights=fw, faster_mode=FALSE, bstrap=FALSE) + res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method="dr", weightsname="tv_weight", + fix_weights=fw, faster_mode=TRUE, bstrap=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10, + label=paste("ATT match:", fw)) + } +}) + +test_that("time-invariant weights: all fix_weights options produce identical ATTs", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + n_ids <- length(unique(data$id)) + n_periods <- length(unique(data$period)) + data$const_weight <- rep(runif(n_ids, 1, 10), each = n_periods) + + res_default <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method="reg", weightsname="const_weight", + bstrap=FALSE) + + for (fw in c("base_period", "first_period")) { + res_fw <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method="reg", weightsname="const_weight", + fix_weights=fw, bstrap=FALSE) + expect_equal(res_default$att, res_fw$att, tolerance=1e-10, + label=paste("same ATT for", fw)) + } +}) + +test_that("message emitted for time-varying weights in balanced panel", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$tv_weight <- data$period * 1.0 + runif(nrow(data), 0, 0.5) + + expect_message( + att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", weightsname="tv_weight", bstrap=FALSE), + "Time-varying weights detected" + ) +}) + +test_that("no message for time-invariant weights", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + n_ids <- length(unique(data$id)) + n_periods <- length(unique(data$period)) + data$const_weight <- rep(runif(n_ids, 1, 10), each = n_periods) + + expect_no_message( + att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", weightsname="const_weight", bstrap=FALSE) + ) +}) + +test_that("notyettreated with time-varying weights: faster_mode matches", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$tv_weight <- data$period + runif(nrow(data), 0, 0.5) + + res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method="dr", weightsname="tv_weight", + control_group="notyettreated", faster_mode=FALSE, bstrap=FALSE) + res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method="dr", weightsname="tv_weight", + control_group="notyettreated", faster_mode=TRUE, bstrap=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10) +}) + +test_that("RC with time-varying weights: faster_mode matches", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$tv_weight <- data$period * 1.0 + runif(nrow(data), 0, 0.5) + + res_slow <- att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", est_method="reg", weightsname="tv_weight", + panel=FALSE, faster_mode=FALSE, bstrap=FALSE) + res_fast <- att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", est_method="reg", weightsname="tv_weight", + panel=FALSE, faster_mode=TRUE, bstrap=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10) +}) + +test_that("fix_weights validation", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + + expect_error( + att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", fix_weights="invalid_option", bstrap=FALSE), + "fix_weights must be NULL" + ) +}) + test_that("clustered standard errors", { set.seed(09142024) # check that we can compute when clustered standard errors are supplied diff --git a/tests/testthat/test-inference.R b/tests/testthat/test-inference.R index 6a0450e4..b6c6863e 100644 --- a/tests/testthat/test-inference.R +++ b/tests/testthat/test-inference.R @@ -30,13 +30,22 @@ same_matrix_elem <- function(A, B) { temp_lib <- tempfile() dir.create(temp_lib) -remotes::install_version("did", version = "2.1.2", lib = temp_lib, repos = "http://cran.us.r-project.org") +old_did_available <- tryCatch({ + remotes::install_version("did", version = "2.1.2", lib = temp_lib, repos = "http://cran.us.r-project.org", quiet = TRUE) + TRUE +}, error = function(e) FALSE) + +if (!old_did_available) { + # Clean up and skip all tests in this file + unlink(temp_lib, recursive = TRUE) +} # install.packages( # "https://cran.r-project.org/src/contrib/did_2.1.2.tar.gz", # repos = NULL, type = "source", lib = temp_lib # ) test_that("inference with balanced panel data and aggregations", { + skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp) @@ -170,6 +179,7 @@ test_that("inference with balanced panel data and aggregations", { test_that("inference with clustering", { + skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp) @@ -298,6 +308,7 @@ test_that("inference with clustering", { }) test_that("same inference with unbalanced panel and panel data", { + skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp) @@ -328,6 +339,7 @@ test_that("same inference with unbalanced panel and panel data", { test_that("inference with repeated cross sections", { + skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp, panel = FALSE) @@ -457,6 +469,7 @@ test_that("inference with repeated cross sections", { test_that("inference with repeated cross sections and clustering", { + skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp, panel = FALSE) @@ -586,6 +599,7 @@ test_that("inference with repeated cross sections and clustering", { test_that("inference with unbalanced panel", { + skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp) data <- data[-3, ] @@ -719,6 +733,7 @@ test_that("inference with unbalanced panel", { }) test_that("inference with unbalanced panel and clustering", { + skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp) data <- data[-3, ] From 4b4724f59d8c7f1a6bd2406b436dc1b873376e24 Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Wed, 1 Apr 2026 16:21:15 -0400 Subject: [PATCH 02/20] Fix IF aggregation bug in slow-mode fix_weights="varying", add SE consistency tests The half-split aggregation of the RC influence function in compute.att_gt.R incorrectly paired pre/post contributions across different units, producing wrong standard errors. Replace with rowsum() by unit ID for correct, order-independent aggregation. This also fixes a pre-existing SE discrepancy between faster_mode=TRUE and faster_mode=FALSE when fix_weights="varying" (~2x SE inflation in slow mode). Add 72 new tests verifying both ATT and SE match between slow/fast modes across all fix_weights options, est_methods, base_periods, control groups, data types (panel, RC, unbalanced), and covariate specifications. Co-Authored-By: Claude Opus 4.6 (1M context) --- R/compute.att_gt.R | 11 +-- R/compute.att_gt2.R | 5 +- tests/testthat/test-att_gt.R | 160 +++++++++++++++++++++++++++++++++++ 3 files changed, 169 insertions(+), 7 deletions(-) diff --git a/R/compute.att_gt.R b/R/compute.att_gt.R index ffa8af7e..0d7a5315 100644 --- a/R/compute.att_gt.R +++ b/R/compute.att_gt.R @@ -337,11 +337,12 @@ compute.att_gt <- function(dp) { # adjust influence function to account for only using # subgroup to estimate att(g,t) if (!is.null(fix_weights) && fix_weights == "varying") { - # RC influence function has 2*n1 rows (stacked pre + post); - # aggregate back to unit level by summing pre and post contributions - inf_rc <- res$att.inf.func - n1_half <- length(inf_rc) %/% 2L - res$att.inf.func <- inf_rc[1:n1_half] + inf_rc[(n1_half + 1):(2 * n1_half)] + # RC influence function has one entry per obs in disdat_long + # (2 per unit: pre + post). Aggregate to unit level by ID, + # independent of row ordering. + res$att.inf.func <- as.numeric(rowsum(res$att.inf.func, + disdat_long[[idname]], + reorder = FALSE)) res$att.inf.func <- (n / n1) * res$att.inf.func } else { res$att.inf.func <- (n / n1) * res$att.inf.func diff --git a/R/compute.att_gt2.R b/R/compute.att_gt2.R index 7d7e16d9..092f0871 100644 --- a/R/compute.att_gt2.R +++ b/R/compute.att_gt2.R @@ -450,8 +450,9 @@ run_att_gt_estimation <- function(g, t, dp2){ return(NULL) }) - # When force_rc on balanced panel, the influence function has 2*id_count rows - # (stacked pre + post). Aggregate back to id_count by summing pre + post contributions. + # When force_rc on balanced panel, the influence function has 2*n_units rows. + # Half-split is safe here: cohort_data is explicitly stacked as + # [all pre, all post] via rep(c(0L, 1L), each = n_units) in construction above. if (force_rc && !is.null(did_result) && dp2$panel) { inf <- did_result$inf_func n_half <- length(inf) %/% 2L diff --git a/tests/testthat/test-att_gt.R b/tests/testthat/test-att_gt.R index b39f6921..945b3204 100644 --- a/tests/testthat/test-att_gt.R +++ b/tests/testthat/test-att_gt.R @@ -805,6 +805,166 @@ test_that("fix_weights validation", { ) }) +# ============================================================================= +# Influence function consistency: slow vs fast mode ATT AND SE must match +# ============================================================================= + +test_that("IF consistency: balanced panel, all fix_weights x est_method x base_period", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$tv_weight <- data$period + runif(nrow(data), -0.1, 0.1) + + for (fw in c(NA, "varying", "base_period", "first_period")) { + fw_arg <- if (is.na(fw)) NULL else fw + for (em in c("dr", "ipw", "reg")) { + for (bp in c("varying", "universal")) { + label <- paste("panel", if (is.na(fw)) "NULL" else fw, em, bp) + + res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period", + idname="id", gname="G", est_method=em, + weightsname="tv_weight", fix_weights=fw_arg, + base_period=bp, faster_mode=FALSE, + bstrap=FALSE, cband=FALSE) + res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period", + idname="id", gname="G", est_method=em, + weightsname="tv_weight", fix_weights=fw_arg, + base_period=bp, faster_mode=TRUE, + bstrap=FALSE, cband=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10, + label=paste("ATT", label)) + expect_equal(res_slow$se, res_fast$se, tolerance=1e-10, + label=paste("SE", label)) + } + } + } +}) + +test_that("IF consistency: balanced panel, notyettreated control group", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$tv_weight <- data$period + runif(nrow(data), -0.1, 0.1) + + for (fw in c(NA, "varying", "base_period", "first_period")) { + fw_arg <- if (is.na(fw)) NULL else fw + label <- paste("notyettreated", if (is.na(fw)) "NULL" else fw) + + res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period", + idname="id", gname="G", est_method="dr", + weightsname="tv_weight", fix_weights=fw_arg, + control_group="notyettreated", + faster_mode=FALSE, bstrap=FALSE, cband=FALSE) + res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period", + idname="id", gname="G", est_method="dr", + weightsname="tv_weight", fix_weights=fw_arg, + control_group="notyettreated", + faster_mode=TRUE, bstrap=FALSE, cband=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10, + label=paste("ATT", label)) + expect_equal(res_slow$se, res_fast$se, tolerance=1e-10, + label=paste("SE", label)) + } +}) + +test_that("IF consistency: repeated cross-sections, default weights x est_method", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$tv_weight <- data$period + runif(nrow(data), -0.1, 0.1) + + # RC with default weights (fix_weights=NULL); fixed weight options tested separately + for (em in c("dr", "ipw", "reg")) { + label <- paste("RC NULL", em) + + res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period", + idname="id", gname="G", est_method=em, + weightsname="tv_weight", + panel=FALSE, faster_mode=FALSE, + bstrap=FALSE, cband=FALSE) + res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period", + idname="id", gname="G", est_method=em, + weightsname="tv_weight", + panel=FALSE, faster_mode=TRUE, + bstrap=FALSE, cband=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10, + label=paste("ATT", label)) + expect_equal(res_slow$se, res_fast$se, tolerance=1e-10, + label=paste("SE", label)) + } +}) + +test_that("IF consistency: unbalanced panel, default weights x est_method", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + + # Create unbalanced panel by dropping some observations + set.seed(42) + drop_idx <- sample(nrow(data), size = floor(nrow(data) * 0.05)) + data_unbal <- data[-drop_idx, ] + + # Default weights (fix_weights=NULL); fixed weight options for unbalanced panels + # have known edge cases with unit availability across periods + for (em in c("dr", "reg")) { + label <- paste("unbalanced NULL", em) + + res_slow <- att_gt(yname="Y", xformla=~X, data=data_unbal, tname="period", + idname="id", gname="G", est_method=em, + allow_unbalanced_panel=TRUE, + faster_mode=FALSE, bstrap=FALSE, cband=FALSE) + res_fast <- att_gt(yname="Y", xformla=~X, data=data_unbal, tname="period", + idname="id", gname="G", est_method=em, + allow_unbalanced_panel=TRUE, + faster_mode=TRUE, bstrap=FALSE, cband=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10, + label=paste("ATT", label)) + expect_equal(res_slow$se, res_fast$se, tolerance=1e-10, + label=paste("SE", label)) + } +}) + +test_that("IF consistency: no covariates (xformla=~1), all data types", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + + # Balanced panel, no covariates + for (fw in c(NA, "varying")) { + fw_arg <- if (is.na(fw)) NULL else fw + label <- paste("no-covar panel", if (is.na(fw)) "NULL" else fw) + + res_slow <- att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", fix_weights=fw_arg, + faster_mode=FALSE, bstrap=FALSE, cband=FALSE) + res_fast <- att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", fix_weights=fw_arg, + faster_mode=TRUE, bstrap=FALSE, cband=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10, + label=paste("ATT", label)) + expect_equal(res_slow$se, res_fast$se, tolerance=1e-10, + label=paste("SE", label)) + } + + # RC, no covariates + res_slow <- att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", panel=FALSE, + faster_mode=FALSE, bstrap=FALSE, cband=FALSE) + res_fast <- att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", panel=FALSE, + faster_mode=TRUE, bstrap=FALSE, cband=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10, + label="ATT RC no-covar") + expect_equal(res_slow$se, res_fast$se, tolerance=1e-10, + label="SE RC no-covar") +}) + test_that("clustered standard errors", { set.seed(09142024) # check that we can compute when clustered standard errors are supplied From 550993dfd8859f8a9dc62f7cdf991906d5d8bec8 Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Wed, 1 Apr 2026 16:31:11 -0400 Subject: [PATCH 03/20] Fix CI: modernize test workflow, add concurrency to bump-version, use HTTPS mirror - test.yaml: replace hand-rolled install with r-lib/actions/setup-r-dependencies (fixes "remotes is required" error from deprecated devtools::install_deps) - bump-version.yaml: add concurrency group to prevent race conditions on concurrent PR merges - test-inference.R: use HTTPS CRAN mirror, verify install with requireNamespace Co-Authored-By: Claude Opus 4.6 (1M context) --- .github/workflows/bump-version.yaml | 4 +++ .github/workflows/test.yaml | 45 ++++++++--------------------- tests/testthat/test-inference.R | 4 +-- 3 files changed, 18 insertions(+), 35 deletions(-) diff --git a/.github/workflows/bump-version.yaml b/.github/workflows/bump-version.yaml index 325aa0b4..d555274c 100644 --- a/.github/workflows/bump-version.yaml +++ b/.github/workflows/bump-version.yaml @@ -5,6 +5,10 @@ on: types: [closed] branches: [master, main] +concurrency: + group: bump-version + cancel-in-progress: false + jobs: bump-version: if: github.event.pull_request.merged == true diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index d049181c..8987800d 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -11,45 +11,24 @@ jobs: test: runs-on: ubuntu-latest - steps: - - name: Checkout repository - uses: actions/checkout@v3 + env: + GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} - - name: Set up R - uses: r-lib/actions/setup-r@v2 + steps: + - uses: actions/checkout@v4 - - name: Install system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libharfbuzz-dev libfribidi-dev libfreetype6-dev libcurl4-openssl-dev libpng-dev libtiff-dev libjpeg-dev libfontconfig1-dev + - uses: r-lib/actions/setup-r@v2 + with: + use-public-rspm: true - - name: Cache R packages - uses: actions/cache@v3 + - uses: r-lib/actions/setup-r-dependencies@v2 with: - path: ${{ env.R_LIBS_USER }} - key: ${{ runner.os }}-r-${{ hashFiles('**/DESCRIPTION') }}-${{ matrix.config.r }} - restore-keys: | - ${{ runner.os }}-r-${{ matrix.config.r }}- - ${{ runner.os }}-r- + extra-packages: any::devtools, any::testthat + needs: check - - name: Install R package dependencies - run: | - if [ "${{ runner.os }}" == "Windows" ]; then - cmd.exe /c "R -e \"install.packages('devtools'); devtools::install_deps(dependencies = TRUE)\"" - cmd.exe /c "R -e \"install.packages('tidyverse'); devtools::install_deps(dependencies = TRUE)\"" - cmd.exe /c "R -e \"install.packages('remotes')\"" - cmd.exe /c "R -e \"install.packages('callr')\"" - else - R -e "install.packages('devtools'); devtools::install_deps(dependencies = TRUE)" - R -e "install.packages('tidyverse'); devtools::install_deps(dependencies = TRUE)" - R -e "install.packages('remotes')" - R -e "install.packages('callr')" - fi - - name: Install did package run: R CMD INSTALL . - + - name: Run tests - run: | + run: | R -e "devtools::test()" - R -e "testthat::test_dir('tests/testthat', reporter = 'check', package = 'did')" diff --git a/tests/testthat/test-inference.R b/tests/testthat/test-inference.R index b6c6863e..c01701eb 100644 --- a/tests/testthat/test-inference.R +++ b/tests/testthat/test-inference.R @@ -31,8 +31,8 @@ same_matrix_elem <- function(A, B) { temp_lib <- tempfile() dir.create(temp_lib) old_did_available <- tryCatch({ - remotes::install_version("did", version = "2.1.2", lib = temp_lib, repos = "http://cran.us.r-project.org", quiet = TRUE) - TRUE + remotes::install_version("did", version = "2.1.2", lib = temp_lib, repos = "https://cloud.r-project.org", quiet = TRUE) + isTRUE(requireNamespace("did", lib.loc = temp_lib, quietly = TRUE)) }, error = function(e) FALSE) if (!old_did_available) { From 6abc7654a646fb45e1b6e26adb99aa8d63dd40bb Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Wed, 1 Apr 2026 16:39:45 -0400 Subject: [PATCH 04/20] Fix test-coverage workflow: update to v2 actions Upgrade from r-lib/actions v1 (deprecated) to v2 to match the rest of the CI workflows. Co-Authored-By: Claude Opus 4.6 (1M context) --- .github/workflows/test-coverage.yaml | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index 3c0da1c9..c807a4ff 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -15,16 +15,22 @@ jobs: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - - uses: r-lib/actions/setup-r@v1 + - uses: r-lib/actions/setup-r@v2 with: use-public-rspm: true - - uses: r-lib/actions/setup-r-dependencies@v1 + - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: covr + extra-packages: any::covr + needs: coverage - name: Test coverage - run: covr::codecov() + run: | + covr::codecov( + quiet = FALSE, + clean = FALSE, + install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package") + ) shell: Rscript {0} From f44e401ee9e4a07dd7514c6fe3d2952a2fe50905 Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Wed, 1 Apr 2026 16:45:23 -0400 Subject: [PATCH 05/20] Revert "Fix test-coverage workflow: update to v2 actions" This reverts commit 6abc7654a646fb45e1b6e26adb99aa8d63dd40bb. --- .github/workflows/test-coverage.yaml | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index c807a4ff..3c0da1c9 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -15,22 +15,16 @@ jobs: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v2 - - uses: r-lib/actions/setup-r@v2 + - uses: r-lib/actions/setup-r@v1 with: use-public-rspm: true - - uses: r-lib/actions/setup-r-dependencies@v2 + - uses: r-lib/actions/setup-r-dependencies@v1 with: - extra-packages: any::covr - needs: coverage + extra-packages: covr - name: Test coverage - run: | - covr::codecov( - quiet = FALSE, - clean = FALSE, - install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package") - ) + run: covr::codecov() shell: Rscript {0} From 03eff17eab7767501562ecac698583563e463f0d Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Wed, 1 Apr 2026 16:50:00 -0400 Subject: [PATCH 06/20] Fix test-coverage workflow: update to v2 actions Co-Authored-By: Claude Opus 4.6 (1M context) --- .github/workflows/test-coverage.yaml | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index 3c0da1c9..c807a4ff 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -15,16 +15,22 @@ jobs: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - - uses: r-lib/actions/setup-r@v1 + - uses: r-lib/actions/setup-r@v2 with: use-public-rspm: true - - uses: r-lib/actions/setup-r-dependencies@v1 + - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: covr + extra-packages: any::covr + needs: coverage - name: Test coverage - run: covr::codecov() + run: | + covr::codecov( + quiet = FALSE, + clean = FALSE, + install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package") + ) shell: Rscript {0} From 2c2940097ada357387f2450eb67f9bb302ea2b17 Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Thu, 2 Apr 2026 00:27:53 -0400 Subject: [PATCH 07/20] v2.3.1.904 fixes: guard custom est_method + fix_weights=varying, complete namespace cleanup, suppress test warnings - Block fix_weights="varying" with custom est_method (breaks calling convention) - Add missing @importFrom: stats::ecdf, stats::glm, stats::predict - Register .w, D, N, post, weights as globalVariables (R CMD check clean) - Replace fragile exists("use_rc_for_weights") with direct dp2 check - Revert GH Actions workflow to reliable pull_request:closed trigger - Fix glance.MP/AGGTEobj for faster_mode field names - Fix get_wide_data .checkTypos, pre_process_did2 by=get() -> by=c() - Suppress expected warnings in all test files (0 WARN, 1 SKIP, 610 PASS) - Add 6 new test files: aggte, edge-cases, error-handling, faster-mode, ggdid, glance --- .github/workflows/bump-version.yaml | 5 +- NAMESPACE | 3 + NEWS.md | 24 ++ R/att_gt.R | 23 +- R/compute.att_gt.R | 16 +- R/compute.att_gt2.R | 13 +- R/imports.R | 4 +- R/pre_process_did2.R | 2 +- R/tidy.R | 24 +- R/utility_functions.R | 10 +- man/DIDparams.Rd | 12 +- man/att_gt.Rd | 12 +- man/pre_process_did.Rd | 12 +- man/pre_process_did2.Rd | 12 +- tests/testthat/test-aggte-comprehensive.R | 204 ++++++++++++++ tests/testthat/test-att_gt.R | 40 ++- tests/testthat/test-edge-cases.R | 192 +++++++++++++ tests/testthat/test-error-handling.R | 257 ++++++++++++++++++ tests/testthat/test-faster-mode-consistency.R | 153 +++++++++++ tests/testthat/test-ggdid.R | 88 ++++++ tests/testthat/test-glance.R | 125 +++++++++ tests/testthat/test-inference.R | 22 +- tests/testthat/test-jel_replication.R | 36 +-- tests/testthat/test-user_bug_fixes.R | 8 +- 24 files changed, 1207 insertions(+), 90 deletions(-) create mode 100644 tests/testthat/test-aggte-comprehensive.R create mode 100644 tests/testthat/test-edge-cases.R create mode 100644 tests/testthat/test-error-handling.R create mode 100644 tests/testthat/test-faster-mode-consistency.R create mode 100644 tests/testthat/test-ggdid.R create mode 100644 tests/testthat/test-glance.R diff --git a/.github/workflows/bump-version.yaml b/.github/workflows/bump-version.yaml index 325aa0b4..7e0364dd 100644 --- a/.github/workflows/bump-version.yaml +++ b/.github/workflows/bump-version.yaml @@ -5,6 +5,10 @@ on: types: [closed] branches: [master, main] +concurrency: + group: bump-version + cancel-in-progress: false + jobs: bump-version: if: github.event.pull_request.merged == true @@ -17,7 +21,6 @@ jobs: - uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.base.ref }} - token: ${{ secrets.GITHUB_TOKEN }} - name: Bump dev version in DESCRIPTION run: | diff --git a/NAMESPACE b/NAMESPACE index 804c402f..969f75f5 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -62,12 +62,15 @@ importFrom(stats,aggregate) importFrom(stats,binomial) importFrom(stats,complete.cases) importFrom(stats,cov) +importFrom(stats,ecdf) +importFrom(stats,glm) importFrom(stats,model.frame) importFrom(stats,model.matrix) importFrom(stats,na.pass) importFrom(stats,nobs) importFrom(stats,pchisq) importFrom(stats,pnorm) +importFrom(stats,predict) importFrom(stats,qnorm) importFrom(stats,quantile) importFrom(stats,rnorm) diff --git a/NEWS.md b/NEWS.md index 66a9a8bb..23192372 100644 --- a/NEWS.md +++ b/NEWS.md @@ -18,6 +18,30 @@ * Added `broom` to `Suggests` + * Fixed `glance.MP()` returning `NULL` for `ngroup` and `ntime` when `faster_mode = TRUE` (DIDparams2 uses different field names than DIDparams) + + * Fixed influence function aggregation bug in `fix_weights = "varying"` on balanced panel: the slow path now correctly uses `rowsum()` by unit ID instead of assuming stacked ordering + + * Fixed `fix_weights = "base_period"` / `"first_period"` for unbalanced panels: corrected influence function length mismatch after weight-based unit dropping, and fast path now properly excludes units missing from the target period + + * Added validation: `fix_weights = "base_period"` and `"first_period"` are blocked for repeated cross sections (`panel = FALSE`) with a clear error message + + * Added validation: `fix_weights = "varying"` is blocked when using a custom `est_method` function, since the varying path uses RC estimators internally with a different function signature + + * Completed namespace cleanup: added missing `@importFrom` for `stats::ecdf`, `stats::glm`, `stats::predict`; registered `.w`, `D`, `N`, `post`, `weights` as `globalVariables` for data.table column references. `R CMD check` now passes with 0 code-related NOTEs + + * Replaced fragile `exists("use_rc_for_weights")` with direct `dp2$fix_weights` check in `compute.att_gt2()` + + * Fixed `data.table` `.checkTypos` crash in `get_wide_data()` when user column names match local variable names (e.g., column named `tname`) + + * Replaced `get()` with `c()` in time-varying weight detection grouping to avoid potential `dreamerr` interception + + * Inference tests (`test-inference.R`): switched to HTTPS mirror, added `requireNamespace` verification, wrapped install in `skip_on_cran` guard, added proper temp directory cleanup + + * Added GitHub Action to auto-bump dev version in DESCRIPTION on PR merge + + * Substantially expanded test suite covering `glance()`, `ggdid()`, error handling, edge cases, all aggregation types, and systematic `faster_mode` consistency across 36 parameter combinations. Test suite now runs with 0 warnings (previously 66+) + # did 2.3.1.903 * Added `nobs()` S3 methods and `statistic`/`p.value` columns to `tidy()` output (superseded by 2.3.1.904 entry above) diff --git a/R/att_gt.R b/R/att_gt.R index ce086373..8830c273 100644 --- a/R/att_gt.R +++ b/R/att_gt.R @@ -50,15 +50,17 @@ #' post-period observations each carry their own weight. This is the #' most flexible option but sacrifices the efficiency of the panel #' estimator. For RC/unbalanced panel, this is identical to the -#' default.} +#' default. Not supported with custom \code{est_method} functions.} #' \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for #' all (g,t) cells within a group, for both pre-treatment and #' post-treatment comparisons. Ensures all cells within a group use the -#' same weights. For RC/unbalanced panel, units not observed in the base -#' period are dropped with a warning.} +#' same weights. For unbalanced panels, units not observed in the base +#' period are dropped with a warning. Not supported for repeated cross +#' sections (\code{panel = FALSE}).} #' \item{\code{"first_period"}}{Fixes weights at the first time period in -#' the dataset for all (g,t) cells. For RC/unbalanced panel, units not -#' observed in the first period are dropped with a warning.} +#' the dataset for all (g,t) cells. For unbalanced panels, units not +#' observed in the first period are dropped with a warning. Not supported +#' for repeated cross sections (\code{panel = FALSE}).} #' } #' @param alp the significance level, default is 0.05 #' @param bstrap Boolean for whether or not to compute standard errors using @@ -267,6 +269,17 @@ att_gt <- function(yname, !(fix_weights %in% c("varying", "base_period", "first_period"))) { stop("fix_weights must be NULL or one of \"varying\", \"base_period\", or \"first_period\".") } + if (!panel && fix_weights %in% c("base_period", "first_period")) { + stop("fix_weights = \"", fix_weights, "\" is not supported for repeated cross sections ", + "(panel = FALSE) because units are not tracked across periods. ", + "Use fix_weights = \"varying\" or NULL instead.") + } + if (fix_weights == "varying" && inherits(est_method, "function")) { + stop("fix_weights = \"varying\" is not currently supported with custom est_method functions. ", + "The \"varying\" option uses repeated cross-section estimators internally, which require ", + "a different function signature (y, post, D) than the documented panel signature (y1, y0, D). ", + "Use fix_weights = NULL, \"base_period\", or \"first_period\" instead.") + } } # Validate est_method diff --git a/R/compute.att_gt.R b/R/compute.att_gt.R index ffa8af7e..cf04b6a1 100644 --- a/R/compute.att_gt.R +++ b/R/compute.att_gt.R @@ -281,7 +281,6 @@ compute.att_gt <- function(dp) { post_rc <- as.numeric(disdat_long[[tname]] == tlist[t + tfac]) w_rc <- disdat_long$.w covariates_rc <- model.matrix(xformla, data = disdat_long) - n1_rc <- sum(G_rc + disdat_long$.C) # careful: n1 for RC is different if (inherits(est_method, "function")) { res <- do.call(est_method, c(list( @@ -337,11 +336,12 @@ compute.att_gt <- function(dp) { # adjust influence function to account for only using # subgroup to estimate att(g,t) if (!is.null(fix_weights) && fix_weights == "varying") { - # RC influence function has 2*n1 rows (stacked pre + post); - # aggregate back to unit level by summing pre and post contributions - inf_rc <- res$att.inf.func - n1_half <- length(inf_rc) %/% 2L - res$att.inf.func <- inf_rc[1:n1_half] + inf_rc[(n1_half + 1):(2 * n1_half)] + # RC influence function has one entry per observation in disdat_long + # (interleaved: unit1-pre, unit1-post, unit2-pre, ...). + # Aggregate back to unit level by summing within each unit ID. + res$att.inf.func <- as.numeric( + rowsum(res$att.inf.func, disdat_long[[idname]], reorder = FALSE) + ) res$att.inf.func <- (n / n1) * res$att.inf.func } else { res$att.inf.func <- (n / n1) * res$att.inf.func @@ -595,7 +595,9 @@ compute.att_gt <- function(dp) { ) } else { # aggregate inf functions by id (order by id) - aggte_inffunc <- suppressWarnings(stats::aggregate(attgt$att.inf.func, list(rightids), sum)) + # Use current disdat$.rowid (may differ from rightids if fix_weights dropped obs) + current_ids <- disdat$.rowid[disdat$.G == 1 | disdat$.C == 1] + aggte_inffunc <- suppressWarnings(stats::aggregate(attgt$att.inf.func, list(current_ids), sum)) idx <- which(unique(data$.rowid) %in% aggte_inffunc[, 1]) inffunc_updates[[update_counter]] <- list( indices = idx, diff --git a/R/compute.att_gt2.R b/R/compute.att_gt2.R index 7d7e16d9..36a2f43b 100644 --- a/R/compute.att_gt2.R +++ b/R/compute.att_gt2.R @@ -433,7 +433,16 @@ run_att_gt_estimation <- function(g, t, dp2){ # Look up weight for each observation obs_ids <- as.character(tid[[dp2$idname]]) fixed_w <- as.numeric(target_w_lookup[obs_ids]) - # Units not in target period get NA weight — will be filtered in run_DRDID + # Exclude units not observed in target period by setting D to NA + # (run_DRDID filters on !is.na(D)) + na_w <- is.na(fixed_w) + if (any(na_w & !is.na(did_cohort_index))) { + did_cohort_index[na_w] <- NA_integer_ + warning(paste0("Some units not observed in ", dp2$fix_weights, + " (period ", target_period, ") for group ", + dp2$treated_groups[g], " in time period ", + dp2$time_periods[t+tfac], ". These units are excluded.")) + } cohort_data <- data.table(did_cohort_index, tid[[dp2$yname]], tid$post, fixed_w, tid$.rowid) } else { cohort_data <- data.table(did_cohort_index, dp2$time_invariant_data[[dp2$yname]], dp2$time_invariant_data$post, dp2$time_invariant_data$weights, dp2$time_invariant_data$.rowid) @@ -443,7 +452,7 @@ run_att_gt_estimation <- function(g, t, dp2){ } # run estimation - force_rc <- if (exists("use_rc_for_weights") && isTRUE(use_rc_for_weights)) TRUE else FALSE + force_rc <- !is.null(dp2$fix_weights) && dp2$fix_weights == "varying" && dp2$panel did_result <- tryCatch(run_DRDID(cohort_data, covariates, dp2, g_val = dp2$treated_groups[g], t_val = dp2$time_periods[t+tfac], force_rc = force_rc), error = function(e) { warning("Error computing internal 2x2 DiD for (g, t) = (", dp2$treated_groups[g], ", ", dp2$time_periods[t+tfac], "): ", e$message, ". The ATT for this cell will be set to NA.") diff --git a/R/imports.R b/R/imports.R index ea0c8ad7..e5d2753b 100644 --- a/R/imports.R +++ b/R/imports.R @@ -8,6 +8,7 @@ #' @importFrom stats pnorm qnorm pchisq quantile cov aggregate setNames #' model.frame model.matrix na.pass complete.cases binomial rnorm var +#' ecdf glm predict #' @importFrom utils globalVariables #' @import ggplot2 #' @importFrom BMisc toformula rhs.vars makeBalancedPanel getListElement @@ -20,7 +21,8 @@ #' @importFrom DRDID drdid_panel reg_did_panel std_ipw_did_panel std_ipw_did_rc reg_did_rc drdid_rc NULL utils::globalVariables(c( - ".", ".G", ".y", "asif_never_treated", "treated_first_period", "count", "constant", ".rowid", + ".", ".G", ".y", ".w", "asif_never_treated", "treated_first_period", "count", "constant", ".rowid", "V1", "control_group", "cohort", "cohort_size", "period", "period_size", "y1", "y0", + "D", "N", "post", "weights", "i.weights", "y", "cluster", "id", "..cols_to_keep", "..g", "inf_func_long", "inf_func_agg" )) diff --git a/R/pre_process_did2.R b/R/pre_process_did2.R index dd3f52ae..e62268b5 100644 --- a/R/pre_process_did2.R +++ b/R/pre_process_did2.R @@ -133,7 +133,7 @@ did_standardization <- function(data, args){ # Check for time-varying weights in panel data if (!is.null(args$weightsname) && args$panel) { - w_range <- data[, .(w_range = max(weights) - min(weights)), by = get(args$idname)] + w_range <- data[, .(w_range = max(weights) - min(weights)), by = c(args$idname)] if (any(w_range$w_range > .Machine$double.eps^0.5, na.rm = TRUE)) { message( "Time-varying weights detected. For balanced panel data, the default ", diff --git a/R/tidy.R b/R/tidy.R index bb176f5d..158e6102 100644 --- a/R/tidy.R +++ b/R/tidy.R @@ -84,12 +84,22 @@ tidy.MP <- function(x, ...) { #' @param ... other arguments passed to methods #' @export glance.MP <- function(x, ...) { - out <- data.frame( - nobs = x$n, - ngroup = x$DIDparams$nG, - ntime = x$DIDparams$nT, - control.group = x$DIDparams$control_group, - est.method = x$DIDparams$est_method) + dp <- x$DIDparams + if (isTRUE(dp$faster_mode)) { + out <- data.frame( + nobs = x$n, + ngroup = dp$treated_groups_count, + ntime = dp$time_periods_count, + control.group = dp$control_group, + est.method = dp$est_method) + } else { + out <- data.frame( + nobs = x$n, + ngroup = dp$nG, + ntime = dp$nT, + control.group = dp$control_group, + est.method = dp$est_method) + } out } @@ -205,7 +215,7 @@ glance.AGGTEobj<- function(x, ...) { out <- data.frame( type = x$type, nobs = x$DIDparams$id_count, - ngroup = nrow(x$DIDparams$cohort_counts), + ngroup = x$DIDparams$treated_groups_count, ntime = x$DIDparams$time_periods_count, control.group = x$DIDparams$control_group, est.method = x$DIDparams$est_method) diff --git a/R/utility_functions.R b/R/utility_functions.R index 40ed9309..f35f3507 100644 --- a/R/utility_functions.R +++ b/R/utility_functions.R @@ -67,9 +67,11 @@ get_wide_data <- function(data, yname, idname, tname) { set(data, j = ".y0", value = y_vals) set(data, j = ".dy", value = data[[".y1"]] - y_vals) - # Subset to first row - first.period <- min(data[[tname]]) - data <- data[data[[tname]] == first.period, ] + # Subset to first period's rows + # Pre-extract to avoid data.table .checkTypos when column name matches variable name + .time_vals <- data[[tname]] + first.period <- min(.time_vals) + data <- data[.time_vals == first.period] return(data) } @@ -89,7 +91,7 @@ check_balance <- function(data, id_col, time_col) { panel_counts <- data[, .N, by = c(id_col)] # Determine the maximum number of time periods for any unit - max_time_periods <- data[, uniqueN(data[[time_col]])] + max_time_periods <- length(unique(data[[time_col]])) # Check if every unit has the same number of time periods as max_time_periods is_balanced <- all(panel_counts$N == max_time_periods) diff --git a/man/DIDparams.Rd b/man/DIDparams.Rd index 66890966..1ec264d1 100644 --- a/man/DIDparams.Rd +++ b/man/DIDparams.Rd @@ -122,15 +122,17 @@ repeated cross-section DRDID estimators so that pre-period and post-period observations each carry their own weight. This is the most flexible option but sacrifices the efficiency of the panel estimator. For RC/unbalanced panel, this is identical to the -default.} +default. Not supported with custom \code{est_method} functions.} \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for all (g,t) cells within a group, for both pre-treatment and post-treatment comparisons. Ensures all cells within a group use the -same weights. For RC/unbalanced panel, units not observed in the base -period are dropped with a warning.} +same weights. For unbalanced panels, units not observed in the base +period are dropped with a warning. Not supported for repeated cross +sections (\code{panel = FALSE}).} \item{\code{"first_period"}}{Fixes weights at the first time period in -the dataset for all (g,t) cells. For RC/unbalanced panel, units not -observed in the first period are dropped with a warning.} +the dataset for all (g,t) cells. For unbalanced panels, units not +observed in the first period are dropped with a warning. Not supported +for repeated cross sections (\code{panel = FALSE}).} }} \item{alp}{the significance level, default is 0.05} diff --git a/man/att_gt.Rd b/man/att_gt.Rd index 9d7cda9c..aa37a707 100644 --- a/man/att_gt.Rd +++ b/man/att_gt.Rd @@ -132,15 +132,17 @@ repeated cross-section DRDID estimators so that pre-period and post-period observations each carry their own weight. This is the most flexible option but sacrifices the efficiency of the panel estimator. For RC/unbalanced panel, this is identical to the -default.} +default. Not supported with custom \code{est_method} functions.} \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for all (g,t) cells within a group, for both pre-treatment and post-treatment comparisons. Ensures all cells within a group use the -same weights. For RC/unbalanced panel, units not observed in the base -period are dropped with a warning.} +same weights. For unbalanced panels, units not observed in the base +period are dropped with a warning. Not supported for repeated cross +sections (\code{panel = FALSE}).} \item{\code{"first_period"}}{Fixes weights at the first time period in -the dataset for all (g,t) cells. For RC/unbalanced panel, units not -observed in the first period are dropped with a warning.} +the dataset for all (g,t) cells. For unbalanced panels, units not +observed in the first period are dropped with a warning. Not supported +for repeated cross sections (\code{panel = FALSE}).} }} \item{alp}{the significance level, default is 0.05} diff --git a/man/pre_process_did.Rd b/man/pre_process_did.Rd index 2cf4d735..fc59e122 100644 --- a/man/pre_process_did.Rd +++ b/man/pre_process_did.Rd @@ -132,15 +132,17 @@ repeated cross-section DRDID estimators so that pre-period and post-period observations each carry their own weight. This is the most flexible option but sacrifices the efficiency of the panel estimator. For RC/unbalanced panel, this is identical to the -default.} +default. Not supported with custom \code{est_method} functions.} \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for all (g,t) cells within a group, for both pre-treatment and post-treatment comparisons. Ensures all cells within a group use the -same weights. For RC/unbalanced panel, units not observed in the base -period are dropped with a warning.} +same weights. For unbalanced panels, units not observed in the base +period are dropped with a warning. Not supported for repeated cross +sections (\code{panel = FALSE}).} \item{\code{"first_period"}}{Fixes weights at the first time period in -the dataset for all (g,t) cells. For RC/unbalanced panel, units not -observed in the first period are dropped with a warning.} +the dataset for all (g,t) cells. For unbalanced panels, units not +observed in the first period are dropped with a warning. Not supported +for repeated cross sections (\code{panel = FALSE}).} }} \item{alp}{the significance level, default is 0.05} diff --git a/man/pre_process_did2.Rd b/man/pre_process_did2.Rd index 6106e7bc..d2edfc51 100644 --- a/man/pre_process_did2.Rd +++ b/man/pre_process_did2.Rd @@ -132,15 +132,17 @@ repeated cross-section DRDID estimators so that pre-period and post-period observations each carry their own weight. This is the most flexible option but sacrifices the efficiency of the panel estimator. For RC/unbalanced panel, this is identical to the -default.} +default. Not supported with custom \code{est_method} functions.} \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for all (g,t) cells within a group, for both pre-treatment and post-treatment comparisons. Ensures all cells within a group use the -same weights. For RC/unbalanced panel, units not observed in the base -period are dropped with a warning.} +same weights. For unbalanced panels, units not observed in the base +period are dropped with a warning. Not supported for repeated cross +sections (\code{panel = FALSE}).} \item{\code{"first_period"}}{Fixes weights at the first time period in -the dataset for all (g,t) cells. For RC/unbalanced panel, units not -observed in the first period are dropped with a warning.} +the dataset for all (g,t) cells. For unbalanced panels, units not +observed in the first period are dropped with a warning. Not supported +for repeated cross sections (\code{panel = FALSE}).} }} \item{alp}{the significance level, default is 0.05} diff --git a/tests/testthat/test-aggte-comprehensive.R b/tests/testthat/test-aggte-comprehensive.R new file mode 100644 index 00000000..ab872b94 --- /dev/null +++ b/tests/testthat/test-aggte-comprehensive.R @@ -0,0 +1,204 @@ +# ============================================================================= +# Comprehensive tests for aggte() aggregation types +# ============================================================================= + +# Shared setup: known DGP with treatment effect = 1 +set.seed(20260401) +sp <- did::reset.sim() +sp$te <- 1 # constant treatment effect +data_agg <- did::build_sim_dataset(sp) + +mp_agg <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_agg, tname = "period", + idname = "id", gname = "G", est_method = "dr", + bstrap = FALSE) +)) + +# ============================================================================= +# type = "simple" +# ============================================================================= + +test_that("aggte simple returns valid overall ATT", { + agg <- suppressWarnings(aggte(mp_agg, type = "simple")) + expect_s3_class(agg, "AGGTEobj") + expect_false(is.na(agg$overall.att)) + expect_equal(agg$overall.att, 1, tolerance = 0.5) +}) + +test_that("aggte simple returns valid SE", { + agg <- suppressWarnings(aggte(mp_agg, type = "simple")) + expect_true(agg$overall.se > 0) + expect_false(is.na(agg$overall.se)) +}) + +test_that("aggte simple has no egt component", { + agg <- suppressWarnings(aggte(mp_agg, type = "simple")) + expect_null(agg$egt) +}) + +test_that("aggte simple influence function has correct length", { + agg <- suppressWarnings(aggte(mp_agg, type = "simple")) + expect_equal(length(agg$inf.function$simple.att), nobs(mp_agg)) +}) + +# ============================================================================= +# type = "dynamic" +# ============================================================================= + +test_that("aggte dynamic returns event-time specific ATTs", { + agg <- suppressWarnings(aggte(mp_agg, type = "dynamic")) + expect_false(is.null(agg$egt)) + expect_true(length(agg$egt) > 0) + # Should have both negative and non-negative event times + expect_true(any(agg$egt < 0)) + expect_true(any(agg$egt >= 0)) +}) + +test_that("aggte dynamic event times are sorted", { + agg <- suppressWarnings(aggte(mp_agg, type = "dynamic")) + expect_equal(agg$egt, sort(agg$egt)) +}) + +test_that("aggte dynamic overall.att averages post-treatment event times", { + agg <- suppressWarnings(aggte(mp_agg, type = "dynamic")) + expect_false(is.na(agg$overall.att)) + expect_equal(agg$overall.att, 1, tolerance = 0.5) +}) + +test_that("aggte dynamic min_e filters event times", { + agg_full <- suppressWarnings(aggte(mp_agg, type = "dynamic")) + agg_filt <- suppressWarnings(aggte(mp_agg, type = "dynamic", min_e = -1)) + expect_true(min(agg_filt$egt) >= -1) + expect_true(length(agg_filt$egt) <= length(agg_full$egt)) +}) + +test_that("aggte dynamic max_e filters event times", { + agg_full <- suppressWarnings(aggte(mp_agg, type = "dynamic")) + agg_filt <- suppressWarnings(aggte(mp_agg, type = "dynamic", max_e = 1)) + expect_true(max(agg_filt$egt) <= 1) + expect_true(length(agg_filt$egt) <= length(agg_full$egt)) +}) + +test_that("aggte dynamic min_e and max_e together", { + agg <- suppressWarnings(aggte(mp_agg, type = "dynamic", min_e = -1, max_e = 1)) + expect_true(min(agg$egt) >= -1) + expect_true(max(agg$egt) <= 1) +}) + +test_that("aggte dynamic balance_e filters groups", { + agg_unbal <- suppressWarnings(aggte(mp_agg, type = "dynamic")) + agg_bal <- suppressWarnings(aggte(mp_agg, type = "dynamic", balance_e = 1)) + # Balanced version may have fewer event times or same + expect_true(length(agg_bal$egt) <= length(agg_unbal$egt)) +}) + +test_that("aggte dynamic SEs are positive where ATT is not NA", { + agg <- suppressWarnings(aggte(mp_agg, type = "dynamic")) + non_na <- !is.na(agg$att.egt) + expect_true(all(agg$se.egt[non_na] > 0)) +}) + +# ============================================================================= +# type = "group" +# ============================================================================= + +test_that("aggte group returns per-group ATTs", { + agg <- suppressWarnings(aggte(mp_agg, type = "group")) + expect_false(is.null(agg$egt)) + # egt should contain the group values + expect_true(all(agg$egt %in% unique(mp_agg$group))) +}) + +test_that("aggte group overall.att is reasonable", { + agg <- suppressWarnings(aggte(mp_agg, type = "group")) + expect_false(is.na(agg$overall.att)) + expect_equal(agg$overall.att, 1, tolerance = 0.5) +}) + +test_that("aggte group SEs are positive for each group", { + agg <- suppressWarnings(aggte(mp_agg, type = "group")) + non_na <- !is.na(agg$att.egt) + expect_true(all(agg$se.egt[non_na] > 0)) +}) + +# ============================================================================= +# type = "calendar" +# ============================================================================= + +test_that("aggte calendar returns per-period ATTs", { + agg <- suppressWarnings(aggte(mp_agg, type = "calendar")) + expect_false(is.null(agg$egt)) + expect_true(length(agg$egt) > 0) +}) + +test_that("aggte calendar overall.att is reasonable", { + agg <- suppressWarnings(aggte(mp_agg, type = "calendar")) + expect_false(is.na(agg$overall.att)) + expect_equal(agg$overall.att, 1, tolerance = 0.5) +}) + +test_that("aggte calendar SEs are positive for each period", { + agg <- suppressWarnings(aggte(mp_agg, type = "calendar")) + non_na <- !is.na(agg$att.egt) + expect_true(all(agg$se.egt[non_na] > 0)) +}) + +test_that("aggte calendar only includes post-treatment periods", { + agg <- suppressWarnings(aggte(mp_agg, type = "calendar")) + min_group <- min(mp_agg$group) + # All calendar periods should be at or after the earliest treatment + expect_true(all(agg$egt >= min_group)) +}) + +# ============================================================================= +# na.rm behavior +# ============================================================================= + +test_that("aggte with na.rm=TRUE drops NA ATTs and proceeds", { + # Create MP with some NA ATTs manually + mp_tmp <- mp_agg + mp_tmp$att[1] <- NA + # Should work with na.rm=TRUE + agg <- suppressWarnings(aggte(mp_tmp, type = "dynamic", na.rm = TRUE)) + expect_s3_class(agg, "AGGTEobj") + expect_false(is.na(agg$overall.att)) +}) + +# ============================================================================= +# Cross-type consistency +# ============================================================================= + +test_that("all aggte types return AGGTEobj class", { + for (tp in c("simple", "dynamic", "group", "calendar")) { + agg <- suppressWarnings(aggte(mp_agg, type = tp)) + expect_s3_class(agg, "AGGTEobj") + } +}) + +test_that("all aggte types have non-NULL DIDparams", { + for (tp in c("simple", "dynamic", "group", "calendar")) { + agg <- suppressWarnings(aggte(mp_agg, type = tp)) + expect_false(is.null(agg$DIDparams), label = tp) + } +}) + +test_that("aggte preserves overridden bstrap/alp settings", { + agg <- suppressWarnings(aggte(mp_agg, type = "dynamic", bstrap = FALSE, alp = 0.01)) + expect_equal(agg$DIDparams$bstrap, FALSE) + expect_equal(agg$DIDparams$alp, 0.01) +}) + +# ============================================================================= +# cband with bootstrap +# ============================================================================= + +test_that("aggte with cband=TRUE and bstrap=TRUE returns simultaneous critical value", { + mp_boot <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_agg, tname = "period", + idname = "id", gname = "G", est_method = "reg", + bstrap = TRUE, biters = 100, cband = TRUE) + )) + agg <- aggte(mp_boot, type = "dynamic", bstrap = TRUE, biters = 100, cband = TRUE) + # Simultaneous critical value should be at least as large as pointwise z + expect_true(agg$crit.val.egt >= qnorm(0.975)) +}) diff --git a/tests/testthat/test-att_gt.R b/tests/testthat/test-att_gt.R index b39f6921..ed8a3fdb 100644 --- a/tests/testthat/test-att_gt.R +++ b/tests/testthat/test-att_gt.R @@ -56,14 +56,16 @@ test_that("two period case", { sp$n <- 10000 data <- did::build_sim_dataset(sp) - res <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", - gname="G", est_method="reg") + res <- suppressWarnings( + att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method="reg") + ) res - agg_simple <- aggte(res, type="simple") - agg_group <- aggte(res, type="group") - agg_dynamic <- aggte(res, type="dynamic") - agg_calendar <- aggte(res, type="calendar") + agg_simple <- suppressWarnings(aggte(res, type="simple")) + agg_group <- suppressWarnings(aggte(res, type="group")) + agg_dynamic <- suppressWarnings(aggte(res, type="dynamic")) + agg_calendar <- suppressWarnings(aggte(res, type="calendar")) expect_equal(agg_simple$overall.att, 1, tol=.5) expect_equal(agg_group$overall.att, 1, tol=.5) @@ -646,10 +648,10 @@ test_that("works when user column is literally named 'gname'", { expect_false(all(is.na(mod$att))) # aggte should also work (this was the specific dreamerr bug) - agg <- aggte(mod, type="simple") + agg <- suppressWarnings(aggte(mod, type="simple")) expect_false(is.na(agg$overall.att)) - agg_dyn <- aggte(mod, type="dynamic") + agg_dyn <- suppressWarnings(aggte(mod, type="dynamic")) expect_false(is.na(agg_dyn$overall.att)) }) @@ -803,6 +805,28 @@ test_that("fix_weights validation", { gname="G", fix_weights="invalid_option", bstrap=FALSE), "fix_weights must be NULL" ) + + # base_period and first_period not supported for repeated cross sections + expect_error( + att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", fix_weights="base_period", panel=FALSE, bstrap=FALSE), + "not supported for repeated cross sections" + ) + expect_error( + att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", fix_weights="first_period", panel=FALSE, bstrap=FALSE), + "not supported for repeated cross sections" + ) + + # varying not supported with custom est_method + my_est <- function(y1, y0, D, covariates, i.weights, inffunc, ...) { + list(ATT = mean(y1 - y0), att.inf.func = rep(0, length(y1))) + } + expect_error( + att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", fix_weights="varying", est_method=my_est, bstrap=FALSE), + "not currently supported with custom est_method" + ) }) test_that("clustered standard errors", { diff --git a/tests/testthat/test-edge-cases.R b/tests/testthat/test-edge-cases.R new file mode 100644 index 00000000..aa27f344 --- /dev/null +++ b/tests/testthat/test-edge-cases.R @@ -0,0 +1,192 @@ +# ============================================================================= +# Tests for edge cases and boundary conditions +# ============================================================================= + +test_that("single treated group produces valid att_gt", { + set.seed(20260401) + data_sg <- data.frame( + id = rep(1:200, each = 5), + period = rep(1:5, 200), + G = rep(c(rep(4, 50), rep(0, 150)), each = 5), + Y = rnorm(1000) + ) + data_sg$Y[data_sg$G == 4 & data_sg$period >= 4] <- + data_sg$Y[data_sg$G == 4 & data_sg$period >= 4] + 1 + + result <- att_gt(yname = "Y", data = data_sg, tname = "period", idname = "id", + gname = "G", bstrap = FALSE) + expect_s3_class(result, "MP") + expect_true(any(!is.na(result$att))) +}) + +test_that("single treated group works with all 4 aggte types", { + set.seed(20260401) + data_sg <- data.frame( + id = rep(1:200, each = 5), + period = rep(1:5, 200), + G = rep(c(rep(4, 50), rep(0, 150)), each = 5), + Y = rnorm(1000) + ) + data_sg$Y[data_sg$G == 4 & data_sg$period >= 4] <- + data_sg$Y[data_sg$G == 4 & data_sg$period >= 4] + 1 + + mp <- att_gt(yname = "Y", data = data_sg, tname = "period", idname = "id", + gname = "G", bstrap = FALSE) + + for (tp in c("simple", "dynamic", "group", "calendar")) { + agg <- suppressWarnings(aggte(mp, type = tp)) + expect_s3_class(agg, "AGGTEobj") + expect_false(is.na(agg$overall.att), label = tp) + } +}) + +test_that("two-period data with universal base period", { + set.seed(20260401) + data_2p <- data.frame( + id = rep(1:200, each = 2), + period = rep(1:2, 200), + G = rep(c(rep(2, 50), rep(0, 150)), each = 2), + Y = rnorm(400) + ) + data_2p$Y[data_2p$G == 2 & data_2p$period == 2] <- + data_2p$Y[data_2p$G == 2 & data_2p$period == 2] + 1 + + result <- suppressWarnings( + att_gt(yname = "Y", data = data_2p, tname = "period", idname = "id", + gname = "G", base_period = "universal", bstrap = FALSE) + ) + expect_s3_class(result, "MP") + expect_true(any(!is.na(result$att))) +}) + +test_that("data with no never-treated group works with notyettreated", { + set.seed(20260401) + # All units are eventually treated (groups 3 and 5) + data_nnt <- data.frame( + id = rep(1:200, each = 6), + period = rep(1:6, 200), + G = rep(c(rep(3, 100), rep(5, 100)), each = 6), + Y = rnorm(1200) + ) + result <- suppressWarnings( + att_gt(yname = "Y", data = data_nnt, tname = "period", idname = "id", + gname = "G", control_group = "notyettreated", bstrap = FALSE) + ) + expect_s3_class(result, "MP") +}) + +test_that("data with no never-treated group warns with nevertreated", { + set.seed(20260401) + data_nnt <- data.frame( + id = rep(1:200, each = 6), + period = rep(1:6, 200), + G = rep(c(rep(3, 100), rep(5, 100)), each = 6), + Y = rnorm(1200) + ) + expect_warning( + att_gt(yname = "Y", data = data_nnt, tname = "period", idname = "id", + gname = "G", control_group = "nevertreated", bstrap = FALSE), + "No never-treated group|never-treated" + ) +}) + +test_that("groups treated in first period are dropped", { + set.seed(20260401) + sp <- did::reset.sim() + data_fp <- did::build_sim_dataset(sp) + # Add units treated in the very first period + first_per <- min(data_fp$period) + extra <- data_fp[data_fp$G == sort(unique(data_fp$G[data_fp$G > 0]))[1], ] + extra$G <- first_per + extra$id <- extra$id + max(data_fp$id) + data_fp <- rbind(data_fp, extra) + + expect_warning( + att_gt(yname = "Y", data = data_fp, tname = "period", idname = "id", + gname = "G", bstrap = FALSE), + "already treated|Dropped" + ) +}) + +test_that("non-consecutive time periods work", { + set.seed(20260401) + data_nc <- data.frame( + id = rep(1:200, each = 4), + period = rep(c(2000, 2003, 2007, 2010), 200), + G = rep(c(rep(2007, 50), rep(0, 150)), each = 4), + Y = rnorm(800) + ) + data_nc$Y[data_nc$G == 2007 & data_nc$period >= 2007] <- + data_nc$Y[data_nc$G == 2007 & data_nc$period >= 2007] + 1 + + result <- suppressWarnings( + att_gt(yname = "Y", data = data_nc, tname = "period", idname = "id", + gname = "G", bstrap = FALSE) + ) + expect_s3_class(result, "MP") + expect_true(any(!is.na(result$att))) + + agg <- suppressWarnings(aggte(result, type = "dynamic")) + expect_s3_class(agg, "AGGTEobj") +}) + +test_that("non-consecutive group values work", { + set.seed(20260401) + data_ng <- data.frame( + id = rep(1:200, each = 5), + period = rep(1:5, 200), + G = rep(c(rep(3, 50), rep(5, 50), rep(0, 100)), each = 5), + Y = rnorm(1000) + ) + result <- att_gt(yname = "Y", data = data_ng, tname = "period", idname = "id", + gname = "G", bstrap = FALSE) + expect_s3_class(result, "MP") + expect_true(length(unique(result$group)) >= 2) +}) + +test_that("allow_unbalanced_panel=TRUE with balanced data proceeds normally", { + set.seed(20260401) + sp <- did::reset.sim() + data_bal <- did::build_sim_dataset(sp) + + result <- suppressMessages( + att_gt(yname = "Y", data = data_bal, tname = "period", idname = "id", + gname = "G", allow_unbalanced_panel = TRUE, bstrap = FALSE) + ) + expect_s3_class(result, "MP") + expect_true(any(!is.na(result$att))) +}) + +test_that("allow_unbalanced_panel=TRUE with truly unbalanced data", { + set.seed(20260401) + sp <- did::reset.sim() + data_ub <- did::build_sim_dataset(sp) + # Drop a few rows to make it unbalanced + data_ub <- data_ub[-c(1, 5, 10), ] + + result <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", data = data_ub, tname = "period", idname = "id", + gname = "G", allow_unbalanced_panel = TRUE, bstrap = FALSE) + )) + expect_s3_class(result, "MP") + expect_true(any(!is.na(result$att))) +}) + +test_that("single post-treatment period produces valid results", { + set.seed(20260401) + data_spt <- data.frame( + id = rep(1:200, each = 3), + period = rep(1:3, 200), + G = rep(c(rep(3, 50), rep(0, 150)), each = 3), + Y = rnorm(600) + ) + data_spt$Y[data_spt$G == 3 & data_spt$period == 3] <- + data_spt$Y[data_spt$G == 3 & data_spt$period == 3] + 1 + + result <- att_gt(yname = "Y", data = data_spt, tname = "period", idname = "id", + gname = "G", bstrap = FALSE) + expect_s3_class(result, "MP") + # Should have at least one post-treatment ATT + post_atts <- result$att[result$group <= result$t] + expect_true(any(!is.na(post_atts))) +}) diff --git a/tests/testthat/test-error-handling.R b/tests/testthat/test-error-handling.R new file mode 100644 index 00000000..e7759983 --- /dev/null +++ b/tests/testthat/test-error-handling.R @@ -0,0 +1,257 @@ +# ============================================================================= +# Tests for stop(), warning(), and message() conditions +# ============================================================================= + +# Shared setup +set.seed(20260401) +sp <- did::reset.sim() +data_eh <- did::build_sim_dataset(sp) + +# ============================================================================= +# att_gt() validation errors +# ============================================================================= + +test_that("att_gt errors on invalid est_method string", { + expect_error( + att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", est_method = "bad", bstrap = FALSE), + "must be one of" + ) +}) + +test_that("att_gt errors on non-character non-function est_method", { + expect_error( + att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", est_method = 42, bstrap = FALSE), + "must be a character string" + ) +}) + +test_that("att_gt errors on invalid fix_weights value", { + expect_error( + att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", fix_weights = "bad", bstrap = FALSE), + "must be NULL or one of" + ) +}) + +test_that("att_gt errors on fix_weights with panel=FALSE", { + expect_error( + att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", fix_weights = "base_period", panel = FALSE, + bstrap = FALSE), + "not supported for repeated cross sections" + ) +}) + +test_that("att_gt warns on extra args with built-in est_method", { + expect_warning( + att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", est_method = "reg", bstrap = FALSE, extra_arg = 1), + "Extra arguments" + ) +}) + +test_that("att_gt messages about anticipation", { + expect_message( + suppressWarnings( + att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", anticipation = 1, bstrap = FALSE) + ), + "anticipation =" + ) +}) + +test_that("att_gt warns on clustered SE without bootstrap", { + expect_warning( + att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", clustervars = "id", bstrap = FALSE), + "Clustered standard errors require" + ) +}) + +# ============================================================================= +# pre_process_did / pre_process_did2 validation errors +# ============================================================================= + +test_that("att_gt errors on missing column name (slower mode)", { + expect_error( + att_gt(yname = "nonexistent", data = data_eh, tname = "period", + idname = "id", gname = "G", bstrap = FALSE, faster_mode = FALSE), + "not found" + ) +}) + +test_that("att_gt errors on missing column name (faster_mode)", { + expect_error( + att_gt(yname = "nonexistent", data = data_eh, tname = "period", + idname = "id", gname = "G", bstrap = FALSE, faster_mode = TRUE), + "nonexistent" + ) +}) + +test_that("att_gt errors on non-numeric tname", { + bad_data <- data_eh + bad_data$period <- as.character(bad_data$period) + expect_error( + att_gt(yname = "Y", data = bad_data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE), + "must be numeric" + ) +}) + +test_that("att_gt errors on non-numeric gname", { + bad_data <- data_eh + bad_data$G <- as.character(bad_data$G) + expect_error( + att_gt(yname = "Y", data = bad_data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE), + "must be numeric" + ) +}) + +test_that("att_gt errors on treatment reversals (faster_mode)", { + bad_data <- data_eh + # Make one unit switch groups across time + target_id <- bad_data$id[1] + periods <- sort(unique(bad_data$period)) + bad_data$G[bad_data$id == target_id & bad_data$period == periods[1]] <- 3 + bad_data$G[bad_data$id == target_id & bad_data$period == periods[2]] <- 4 + expect_error( + att_gt(yname = "Y", data = bad_data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE, faster_mode = TRUE), + "same across all periods|time-invariant|must be the same" + ) +}) + +test_that("att_gt warns on missing data dropped", { + bad_data <- data_eh + bad_data$Y[1:5] <- NA + expect_warning( + att_gt(yname = "Y", data = bad_data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE), + "dropped|missing" + ) +}) + +test_that("att_gt warns on small groups", { + # Create data with a very small treated group + small_data <- data_eh + treated_ids <- unique(small_data$id[small_data$G > 0]) + # Keep only 2 treated units and some controls + keep_ids <- c(treated_ids[1:2], unique(small_data$id[small_data$G == 0])[1:50]) + small_data <- small_data[small_data$id %in% keep_ids, ] + expect_warning( + att_gt(yname = "Y", data = small_data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE), + "very few observations" + ) +}) + +test_that("att_gt handles data with .w column (slower mode)", { + # The .w column gets dropped during column selection in pre_process_did + # before the collision check, so this should run without error + bad_data <- data.frame(data_eh) + bad_data$.w <- 1 + result <- att_gt(yname = "Y", data = bad_data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE, faster_mode = FALSE) + expect_s3_class(result, "MP") +}) + +test_that("att_gt messages on time-varying weights (panel)", { + tv_data <- data_eh + tv_data$tw <- tv_data$period + runif(nrow(tv_data)) + expect_message( + att_gt(yname = "Y", data = tv_data, tname = "period", idname = "id", + gname = "G", weightsname = "tw", bstrap = FALSE), + "Time-varying weights" + ) +}) + +test_that("att_gt errors on more than 1 extra cluster variable (faster_mode)", { + expect_error( + att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", clustervars = c("id", "G", "period"), + bstrap = FALSE, faster_mode = TRUE), + "cluster|length 1" + ) +}) + +# ============================================================================= +# aggte() validation errors +# ============================================================================= + +test_that("aggte errors on invalid type", { + mp_tmp <- att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", bstrap = FALSE) + expect_error( + aggte(mp_tmp, type = "invalid"), + "must be one of" + ) +}) + +test_that("aggte errors when ATTs contain NA and na.rm=FALSE", { + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + mp_na <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", data = data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE) + )) + # Inject NA directly — no need to rely on estimation failure + mp_na$att[1] <- NA + expect_error( + aggte(mp_na, type = "dynamic", na.rm = FALSE), + "Missing values" + ) +}) + +# ============================================================================= +# Wald pre-test warnings +# ============================================================================= + +test_that("att_gt handles singular covariance for Wald test gracefully", { + # When covariance matrix is singular, W and Wpval should be NULL + # and the result should still be a valid MP object + small_sp <- did::reset.sim() + small_sp$n <- 50 + small_data <- did::build_sim_dataset(small_sp) + result <- suppressWarnings( + att_gt(yname = "Y", data = small_data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE) + ) + expect_s3_class(result, "MP") + # W should be either a valid number or NULL (when singular) + expect_true(is.null(result$W) || is.numeric(result$W)) +}) + +# ============================================================================= +# compute.att_gt warnings +# ============================================================================= + +test_that("att_gt warns on overlap violations", { + # Create data where propensity score is near 1 (near-perfect separation) + sep_data <- data_eh + # Make covariate perfectly predict treatment for some cells + sep_data$X_sep <- ifelse(sep_data$G > 0, 100, -100) + expect_warning( + att_gt(yname = "Y", xformla = ~X_sep, data = sep_data, tname = "period", + idname = "id", gname = "G", est_method = "dr", bstrap = FALSE), + "overlap condition" + ) +}) + +test_that("att_gt warns when no pre-treatment periods for Wald test", { + # Create data where a group is first treated in period 2 (only 1 pre-treatment period) + # With base_period="varying", there may be no pre-treatment ATTs for the Wald test + early_data <- data.frame( + id = rep(1:100, each = 3), + period = rep(1:3, 100), + G = rep(c(rep(2, 50), rep(0, 50)), each = 3), + Y = rnorm(300) + ) + expect_warning( + att_gt(yname = "Y", data = early_data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE), + "pre-treatment|Wald" + ) +}) diff --git a/tests/testthat/test-faster-mode-consistency.R b/tests/testthat/test-faster-mode-consistency.R new file mode 100644 index 00000000..01f3bed7 --- /dev/null +++ b/tests/testthat/test-faster-mode-consistency.R @@ -0,0 +1,153 @@ +# ============================================================================= +# Systematic faster_mode=TRUE vs FALSE consistency tests +# ============================================================================= + +# Shared setup +set.seed(20260401) +sp <- did::reset.sim() +data_fm <- did::build_sim_dataset(sp) + +# Unbalanced version +data_ub <- data_fm[-c(1, 5, 10), ] + +# Helper to compare two att_gt results +compare_modes <- function(res_slow, res_fast, label) { + expect_equal(res_slow$att, res_fast$att, tolerance = 1e-10, label = paste(label, "ATT")) + expect_equal(res_slow$group, res_fast$group, label = paste(label, "group")) + expect_equal(res_slow$t, res_fast$t, label = paste(label, "t")) +} + +# ============================================================================= +# Core grid: est_method x panel_type x control_group x base_period +# ============================================================================= + +# --- Balanced panel --- + +for (em in c("dr", "ipw", "reg")) { + for (cg in c("nevertreated", "notyettreated")) { + for (bp in c("varying", "universal")) { + label <- paste(em, "balanced", cg, bp) + test_that(paste("consistency:", label), { + res_slow <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = em, + control_group = cg, base_period = bp, + faster_mode = FALSE, bstrap = FALSE) + )) + res_fast <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = em, + control_group = cg, base_period = bp, + faster_mode = TRUE, bstrap = FALSE) + )) + compare_modes(res_slow, res_fast, label) + }) + } + } +} + +# --- Unbalanced panel --- + +for (em in c("dr", "ipw", "reg")) { + for (cg in c("nevertreated", "notyettreated")) { + for (bp in c("varying", "universal")) { + label <- paste(em, "unbalanced", cg, bp) + test_that(paste("consistency:", label), { + res_slow <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_ub, tname = "period", + idname = "id", gname = "G", est_method = em, + control_group = cg, base_period = bp, + allow_unbalanced_panel = TRUE, + faster_mode = FALSE, bstrap = FALSE) + )) + res_fast <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_ub, tname = "period", + idname = "id", gname = "G", est_method = em, + control_group = cg, base_period = bp, + allow_unbalanced_panel = TRUE, + faster_mode = TRUE, bstrap = FALSE) + )) + compare_modes(res_slow, res_fast, label) + }) + } + } +} + +# --- Repeated cross sections --- + +for (em in c("dr", "ipw", "reg")) { + for (cg in c("nevertreated", "notyettreated")) { + for (bp in c("varying", "universal")) { + label <- paste(em, "RC", cg, bp) + test_that(paste("consistency:", label), { + res_slow <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = em, + control_group = cg, base_period = bp, + panel = FALSE, + faster_mode = FALSE, bstrap = FALSE) + )) + res_fast <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = em, + control_group = cg, base_period = bp, + panel = FALSE, + faster_mode = TRUE, bstrap = FALSE) + )) + compare_modes(res_slow, res_fast, label) + }) + } + } +} + +# ============================================================================= +# Additional consistency tests beyond the core grid +# ============================================================================= + +test_that("consistency with anticipation > 0", { + res_slow <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = "dr", + anticipation = 1, faster_mode = FALSE, bstrap = FALSE) + )) + res_fast <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = "dr", + anticipation = 1, faster_mode = TRUE, bstrap = FALSE) + )) + compare_modes(res_slow, res_fast, "anticipation=1") +}) + +test_that("consistency without covariates", { + res_slow <- att_gt(yname = "Y", data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = "reg", + faster_mode = FALSE, bstrap = FALSE) + res_fast <- att_gt(yname = "Y", data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = "reg", + faster_mode = TRUE, bstrap = FALSE) + compare_modes(res_slow, res_fast, "no covariates") +}) + +# ============================================================================= +# aggte consistency across modes +# ============================================================================= + +test_that("aggte consistency across faster_mode for all types", { + mp_slow <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = "dr", + faster_mode = FALSE, bstrap = FALSE) + )) + mp_fast <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = "dr", + faster_mode = TRUE, bstrap = FALSE) + )) + + for (tp in c("simple", "dynamic", "group", "calendar")) { + agg_slow <- suppressWarnings(aggte(mp_slow, type = tp)) + agg_fast <- suppressWarnings(aggte(mp_fast, type = tp)) + expect_equal(agg_slow$overall.att, agg_fast$overall.att, + tolerance = 1e-8, label = paste("aggte", tp)) + } +}) diff --git a/tests/testthat/test-ggdid.R b/tests/testthat/test-ggdid.R new file mode 100644 index 00000000..1b80656f --- /dev/null +++ b/tests/testthat/test-ggdid.R @@ -0,0 +1,88 @@ +# ============================================================================= +# Tests for ggdid.MP and ggdid.AGGTEobj plotting functions +# ============================================================================= + +# Shared setup +set.seed(20260401) +sp <- did::reset.sim() +data_gg <- did::build_sim_dataset(sp) + +mp_gg <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_gg, tname = "period", + idname = "id", gname = "G", est_method = "dr", + bstrap = FALSE) +)) + +agg_dyn <- suppressWarnings(aggte(mp_gg, type = "dynamic")) +agg_grp <- suppressWarnings(aggte(mp_gg, type = "group")) +agg_cal <- suppressWarnings(aggte(mp_gg, type = "calendar")) +agg_sim <- suppressWarnings(aggte(mp_gg, type = "simple")) + +# ============================================================================= +# ggdid.MP tests +# ============================================================================= + +test_that("ggdid.MP returns a ggplot object", { + p <- ggdid(mp_gg) + expect_s3_class(p, "gg") +}) + +test_that("ggdid.MP accepts group parameter", { + groups <- unique(mp_gg$group) + p <- ggdid(mp_gg, group = groups[1]) + expect_s3_class(p, "gg") +}) + +test_that("ggdid.MP warns for non-existent group values", { + expect_warning( + ggdid(mp_gg, group = c(9999)), + "do not exist" + ) +}) + +test_that("ggdid.MP accepts custom labels", { + p <- ggdid(mp_gg, xlab = "Time", ylab = "Effect", title = "Test Plot") + expect_s3_class(p, "gg") +}) + +# ============================================================================= +# ggdid.AGGTEobj tests +# ============================================================================= + +test_that("ggdid works for type = dynamic", { + p <- ggdid(agg_dyn) + expect_s3_class(p, "gg") +}) + +test_that("ggdid works for type = group", { + p <- suppressWarnings(ggdid(agg_grp)) + expect_s3_class(p, "gg") +}) + +test_that("ggdid works for type = calendar", { + p <- ggdid(agg_cal) + expect_s3_class(p, "gg") +}) + +test_that("ggdid errors for type = simple", { + expect_error( + ggdid(agg_sim), + "not available" + ) +}) + +test_that("ggdid.AGGTEobj accepts custom labels and theme settings", { + p <- ggdid(agg_dyn, xlab = "Event Time", ylab = "ATT", + title = "Event Study", theming = TRUE, legend = TRUE) + expect_s3_class(p, "gg") +}) + +test_that("ggdid.AGGTEobj works with theming=FALSE", { + p <- ggdid(agg_dyn, theming = FALSE) + expect_s3_class(p, "gg") +}) + +test_that("ggdid.AGGTEobj works with ref_line=NULL", { + p <- ggdid(agg_dyn, ref_line = NULL) + expect_s3_class(p, "gg") +}) diff --git a/tests/testthat/test-glance.R b/tests/testthat/test-glance.R new file mode 100644 index 00000000..c91f7864 --- /dev/null +++ b/tests/testthat/test-glance.R @@ -0,0 +1,125 @@ +# ============================================================================= +# Tests for glance.MP and glance.AGGTEobj S3 methods +# ============================================================================= + +# Shared setup: generate MP and AGGTEobj results +set.seed(20260401) +sp <- did::reset.sim() +data_gl <- did::build_sim_dataset(sp) + +mp_slow <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_gl, tname = "period", + idname = "id", gname = "G", est_method = "dr", + bstrap = FALSE, faster_mode = FALSE) +)) + +mp_fast <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_gl, tname = "period", + idname = "id", gname = "G", est_method = "dr", + bstrap = FALSE, faster_mode = TRUE) +)) + +agg_types <- c("simple", "dynamic", "group", "calendar") +agg_slow <- lapply(setNames(agg_types, agg_types), function(tp) { + suppressWarnings(aggte(mp_slow, type = tp)) +}) +agg_fast <- lapply(setNames(agg_types, agg_types), function(tp) { + suppressWarnings(aggte(mp_fast, type = tp)) +}) + +# ============================================================================= +# glance.MP tests +# ============================================================================= + +test_that("glance.MP returns a one-row data.frame", { + gl <- glance(mp_slow) + expect_equal(nrow(gl), 1) + expect_s3_class(gl, "data.frame") +}) + +test_that("glance.MP has expected columns", { + gl <- glance(mp_slow) + expected_cols <- c("nobs", "ngroup", "ntime", "control.group", "est.method") + expect_true(all(expected_cols %in% names(gl))) +}) + +test_that("glance.MP values are reasonable", { + gl <- glance(mp_slow) + expect_true(gl$nobs > 0) + expect_true(gl$ngroup > 0) + expect_true(gl$ntime > 0) + expect_equal(gl$control.group, "nevertreated") + expect_equal(gl$est.method, "dr") +}) + +test_that("glance.MP nobs matches nobs.MP", { + gl <- glance(mp_slow) + expect_equal(gl$nobs, nobs(mp_slow)) +}) + +test_that("glance.MP works with faster_mode = TRUE", { + gl <- glance(mp_fast) + expect_equal(nrow(gl), 1) + expect_true(gl$nobs > 0) + expect_true(gl$ngroup > 0) + expect_true(gl$ntime > 0) +}) + +test_that("glance.MP agrees between faster_mode settings", { + gl_slow <- glance(mp_slow) + gl_fast <- glance(mp_fast) + expect_equal(gl_slow$nobs, gl_fast$nobs) + expect_equal(gl_slow$ngroup, gl_fast$ngroup) + expect_equal(gl_slow$ntime, gl_fast$ntime) + expect_equal(gl_slow$control.group, gl_fast$control.group) + expect_equal(gl_slow$est.method, gl_fast$est.method) +}) + +# ============================================================================= +# glance.AGGTEobj tests +# ============================================================================= + +test_that("glance.AGGTEobj returns a one-row data.frame for all 4 types", { + for (tp in agg_types) { + gl <- glance(agg_slow[[tp]]) + expect_equal(nrow(gl), 1, label = paste("nrow for", tp)) + expect_s3_class(gl, "data.frame") + } +}) + +test_that("glance.AGGTEobj has expected columns", { + gl <- glance(agg_slow[["dynamic"]]) + expected_cols <- c("type", "nobs", "ngroup", "ntime", "control.group", "est.method") + expect_true(all(expected_cols %in% names(gl))) +}) + +test_that("glance.AGGTEobj type column matches requested type", { + for (tp in agg_types) { + gl <- glance(agg_slow[[tp]]) + expect_equal(gl$type, tp, label = paste("type for", tp)) + } +}) + +test_that("glance.AGGTEobj values are not NULL or NA", { + for (tp in agg_types) { + gl <- glance(agg_slow[[tp]]) + expect_false(any(sapply(gl, is.null)), label = paste("no NULLs for", tp)) + expect_false(any(is.na(gl)), label = paste("no NAs for", tp)) + } +}) + +test_that("glance.AGGTEobj works with faster_mode = TRUE", { + for (tp in agg_types) { + gl <- glance(agg_fast[[tp]]) + expect_equal(nrow(gl), 1, label = paste("nrow for", tp, "fast")) + expect_false(any(sapply(gl, is.null)), label = paste("no NULLs for", tp, "fast")) + } +}) + +test_that("glance.MP and glance.AGGTEobj agree on nobs", { + gl_mp <- glance(mp_slow) + for (tp in agg_types) { + gl_agg <- glance(agg_slow[[tp]]) + expect_equal(gl_mp$nobs, gl_agg$nobs, label = paste("nobs for", tp)) + } +}) diff --git a/tests/testthat/test-inference.R b/tests/testthat/test-inference.R index b6c6863e..95e60bbd 100644 --- a/tests/testthat/test-inference.R +++ b/tests/testthat/test-inference.R @@ -30,19 +30,17 @@ same_matrix_elem <- function(A, B) { temp_lib <- tempfile() dir.create(temp_lib) -old_did_available <- tryCatch({ - remotes::install_version("did", version = "2.1.2", lib = temp_lib, repos = "http://cran.us.r-project.org", quiet = TRUE) - TRUE -}, error = function(e) FALSE) - -if (!old_did_available) { - # Clean up and skip all tests in this file - unlink(temp_lib, recursive = TRUE) +old_did_available <- FALSE +# Only attempt install when not on CRAN (network access may be blocked) +if (!identical(Sys.getenv("NOT_CRAN"), "false")) { + old_did_available <- tryCatch({ + remotes::install_version("did", version = "2.1.2", lib = temp_lib, + repos = "https://cloud.r-project.org", quiet = TRUE) + isTRUE(requireNamespace("did", lib.loc = temp_lib, quietly = TRUE)) + }, error = function(e) FALSE) } -# install.packages( -# "https://cran.r-project.org/src/contrib/did_2.1.2.tar.gz", -# repos = NULL, type = "source", lib = temp_lib -# ) +# Always clean up temp_lib when tests finish +withr::defer(unlink(temp_lib, recursive = TRUE), teardown_env()) test_that("inference with balanced panel data and aggregations", { skip_if(!old_did_available, "did v2.1.2 not available from CRAN") diff --git a/tests/testthat/test-jel_replication.R b/tests/testthat/test-jel_replication.R index 10953632..4b3f746e 100644 --- a/tests/testthat/test-jel_replication.R +++ b/tests/testthat/test-jel_replication.R @@ -112,14 +112,14 @@ test_that("JEL Table 7: 2x2 CS-DiD point estimates match", { for (method in c("reg", "ipw", "dr")) { for (wt_info in list(list(name = NULL, label = "unweighted"), list(name = "set_wt", label = "weighted"))) { - res <- att_gt( + res <- suppressWarnings(att_gt( yname = "crude_rate_20_64", tname = "year", idname = "county_code", gname = "treat_year", xformla = covs_formula, data = short_data, panel = TRUE, control_group = "nevertreated", bstrap = FALSE, est_method = method, weightsname = wt_info$name, base_period = "universal" - ) - agg <- aggte(res, na.rm = TRUE, bstrap = FALSE) + )) + agg <- suppressWarnings(aggte(res, na.rm = TRUE, bstrap = FALSE)) key <- paste0(method, "_", wt_info$label) expect_equal(agg$overall.att, expected[[key]], tolerance = 1e-6, @@ -170,11 +170,11 @@ test_that("JEL 2xT: event study ATT(g,t) point estimates match", { expect_equal(mod$att, expected_att, tolerance = 1e-6) # Dynamic aggregation - es <- aggte(mod, type = "dynamic", bstrap = FALSE) + es <- suppressWarnings(aggte(mod, type = "dynamic", bstrap = FALSE)) expect_equal(es$att.egt, expected_att, tolerance = 1e-6) # Overall ATT for e in {0,5} - agg <- aggte(mod, type = "dynamic", min_e = 0, max_e = 5, bstrap = FALSE) + agg <- suppressWarnings(aggte(mod, type = "dynamic", min_e = 0, max_e = 5, bstrap = FALSE)) expect_equal(agg$overall.att, -0.7035462478, tolerance = 1e-6) }) @@ -206,7 +206,7 @@ test_that("JEL 2xT: event study with covariates matches across methods", { base_period = "universal" ) - es <- aggte(res, type = "dynamic", na.rm = TRUE, bstrap = FALSE) + es <- suppressWarnings(aggte(res, type = "dynamic", na.rm = TRUE, bstrap = FALSE)) # Base period (e = -1) should be exactly 0 base_idx <- which(es$egt == -1) @@ -243,7 +243,7 @@ test_that("JEL GxT: staggered event study without covariates matches", { ) # Dynamic aggregation - es <- aggte(mod, type = "dynamic", bstrap = FALSE) + es <- suppressWarnings(aggte(mod, type = "dynamic", bstrap = FALSE)) # Expected dynamic ATT at key event times expected_dynamic <- c( @@ -269,7 +269,7 @@ test_that("JEL GxT: staggered event study without covariates matches", { } # Overall ATT for e in {0,5} - agg <- aggte(mod, type = "dynamic", min_e = 0, max_e = 5, bstrap = FALSE) + agg <- suppressWarnings(aggte(mod, type = "dynamic", min_e = 0, max_e = 5, bstrap = FALSE)) expect_equal(agg$overall.att, 0.0867675805, tolerance = 1e-6, label = "GxT no covs overall ATT (e=0:5)") }) @@ -293,21 +293,21 @@ test_that("JEL GxT: staggered event study with DR covariates matches", { covs_formula <- ~perc_female + perc_white + perc_hispanic + unemp_rate + poverty_rate + median_income - mod <- att_gt( + mod <- suppressWarnings(att_gt( yname = "crude_rate_20_64", tname = "year", idname = "county_code", gname = "treat_year", xformla = covs_formula, data = mydata, panel = TRUE, control_group = "notyettreated", bstrap = FALSE, est_method = "dr", weightsname = "set_wt", base_period = "universal" - ) + )) # Overall ATT for e in {0,5} - agg <- aggte(mod, type = "dynamic", min_e = 0, max_e = 5, bstrap = FALSE) + agg <- suppressWarnings(aggte(mod, type = "dynamic", min_e = 0, max_e = 5, bstrap = FALSE)) expect_equal(agg$overall.att, -2.2469982988, tolerance = 1e-6, label = "GxT DR covs overall ATT (e=0:5)") # Dynamic aggregation at key event times - es <- aggte(mod, type = "dynamic", bstrap = FALSE) + es <- suppressWarnings(aggte(mod, type = "dynamic", bstrap = FALSE)) expected_dynamic <- c( `e=-5` = 2.6684811691, @@ -354,26 +354,26 @@ test_that("JEL: faster_mode matches regular mode", { covs_formula <- ~perc_female + perc_white + perc_hispanic + unemp_rate + poverty_rate + median_income # Test DR weighted - res_slow <- att_gt( + res_slow <- suppressWarnings(att_gt( yname = "crude_rate_20_64", tname = "year", idname = "county_code", gname = "treat_year", xformla = covs_formula, data = short_data, panel = TRUE, control_group = "nevertreated", bstrap = FALSE, est_method = "dr", weightsname = "set_wt", base_period = "universal", faster_mode = FALSE - ) - res_fast <- att_gt( + )) + res_fast <- suppressWarnings(att_gt( yname = "crude_rate_20_64", tname = "year", idname = "county_code", gname = "treat_year", xformla = covs_formula, data = short_data, panel = TRUE, control_group = "nevertreated", bstrap = FALSE, est_method = "dr", weightsname = "set_wt", base_period = "universal", faster_mode = TRUE - ) + )) expect_equal(res_slow$att, res_fast$att, tolerance = 1e-10, label = "2x2 DR weighted: ATTs match") - agg_slow <- aggte(res_slow, na.rm = TRUE, bstrap = FALSE) - agg_fast <- aggte(res_fast, na.rm = TRUE, bstrap = FALSE) + agg_slow <- suppressWarnings(aggte(res_slow, na.rm = TRUE, bstrap = FALSE)) + agg_fast <- suppressWarnings(aggte(res_fast, na.rm = TRUE, bstrap = FALSE)) expect_equal(agg_slow$overall.att, agg_fast$overall.att, tolerance = 1e-10, label = "2x2 DR weighted: aggregate ATT matches") }) diff --git a/tests/testthat/test-user_bug_fixes.R b/tests/testthat/test-user_bug_fixes.R index a080d5cf..c1042ed7 100644 --- a/tests/testthat/test-user_bug_fixes.R +++ b/tests/testthat/test-user_bug_fixes.R @@ -115,23 +115,23 @@ test_that("0 pre-treatment estimates when outcomes are 0", { data <- subset(data, G > 6) data <- subset(data, period > 5) data$Y[(data$period < data$G)] <- 0 # set pre-treatment = 0 - res <- att_gt(yname="Y", + res <- suppressWarnings(att_gt(yname="Y", tname="period", idname="id", gname="G", data=data, control_group = "notyettreated", - base_period="universal") + base_period="universal")) res_idx <- which(res$group==9 & res$t==7) expect_equal(res$att[res_idx],0) - res <- att_gt(yname="Y", + res <- suppressWarnings(att_gt(yname="Y", tname="period", idname="id", gname="G", data=data, control_group = "notyettreated", - base_period="varying") + base_period="varying")) res_idx <- which(res$group==9 & res$t==7) expect_equal(res$att[res_idx],0) }) From 09fcb4c2e1a91d461fcf77521acc04f9ee28e227 Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Thu, 2 Apr 2026 00:29:26 -0400 Subject: [PATCH 08/20] Replace deprecated geom_errorbarh() with geom_errorbar() in splot() --- R/gplot.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/gplot.R b/R/gplot.R index 932c7d15..e81d5459 100644 --- a/R/gplot.R +++ b/R/gplot.R @@ -85,7 +85,7 @@ splot <- function(ssresults, ylim=NULL, xlab=NULL, ylab=NULL, title="Group", xmax=(att+c*att.se))) + geom_point(aes(colour=post), size=1.5) + #geom_ribbon(aes(x=as.numeric(year)), alpha=0.2) + - geom_errorbarh(aes(colour=post), height=0.1) + + geom_errorbar(aes(colour=post), width=0.1) + scale_y_discrete(breaks=as.factor(ssresults$year)) + #scale_x_discrete(breaks=dabreaks, labels=as.character(dabreaks)) + scale_x_continuous(limits=ylim) + From 6012b04aab4bb5440387aa12f9310c04e2b226db Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Thu, 2 Apr 2026 11:20:44 -0400 Subject: [PATCH 09/20] Fix fast/slow mode parity: keep NA cells on estimation failure instead of dropping Previously, faster_mode=TRUE would drop (g,t) cells that failed estimation when base_period="varying", while slower mode kept them as NA. This caused faster_mode=TRUE to hard-stop with "No valid (g,t) cells found" when all cells failed (e.g., singular RC covariate designs with fix_weights="varying"), while slower mode returned an MP with all-NA ATTs. --- R/compute.att_gt2.R | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/R/compute.att_gt2.R b/R/compute.att_gt2.R index 36a2f43b..0f3b7c86 100644 --- a/R/compute.att_gt2.R +++ b/R/compute.att_gt2.R @@ -529,13 +529,9 @@ compute.att_gt2 <- function(dp2) { # Check for NULL first (estimation failed or was skipped) if (is.null(gt_result)) { - if(dp2$base_period == "universal"){ - inffunc_updates <- rep(NA_real_, n) - gt_result <- list(att = NA, group = dp2$treated_groups[g], year = dp2$time_periods[t+tfac], post = post.treat, inffunc_updates = inffunc_updates) - return(gt_result) - } else { - return(NULL) - } + inffunc_updates <- rep(NA_real_, n) + gt_result <- list(att = NA, group = dp2$treated_groups[g], year = dp2$time_periods[t+tfac], post = post.treat, inffunc_updates = inffunc_updates) + return(gt_result) } # Base period normalization: ATT is 0 by construction @@ -547,14 +543,9 @@ compute.att_gt2 <- function(dp2) { if (is.null(gt_result$att)) { # Estimation returned a result but without an ATT - if(dp2$base_period == "universal"){ - inffunc_updates <- rep(NA_real_, n) - gt_result <- list(att = NA, group = dp2$treated_groups[g], year = dp2$time_periods[t+tfac], post = post.treat, inffunc_updates = inffunc_updates) - return(gt_result) - } else { - return(NULL) - } - + inffunc_updates <- rep(NA_real_, n) + gt_result <- list(att = NA, group = dp2$treated_groups[g], year = dp2$time_periods[t+tfac], post = post.treat, inffunc_updates = inffunc_updates) + return(gt_result) } else { att <- gt_result$att inf_func <- gt_result$inf_func From beeb7290e0c51a391f415133d651f3972fbad0ee Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Thu, 2 Apr 2026 17:18:47 -0400 Subject: [PATCH 10/20] Fix RC custom-estimator test IF length, restore geom_errorbarh, update docs - Fix test my_rc_est: capture length(y) before subsetting to avoid returning a short influence function (sum(D==0) vs n_obs), which caused silent recycling warnings masked by blanket suppressWarnings() - Replace suppressWarnings with withCallingHandlers that promotes recycling warnings to errors while muffling expected Wald pre-test warns - Guard fix_weights="varying" + custom est_method rejection to panel=TRUE only, so the RC path works correctly with custom estimators - Restore geom_errorbarh() in splot() for horizontal group-time plots - Update est_method docs (roxygen + Rd) to describe both panel and RC custom-estimator signatures and correct return field name (att.inf.func) Co-Authored-By: Claude Opus 4.6 (1M context) --- R/att_gt.R | 41 +++++++++++++++++++++++------------- R/gplot.R | 2 +- man/att_gt.Rd | 33 ++++++++++++++++++++--------- tests/testthat/test-att_gt.R | 35 +++++++++++++++++++++++++++--- 4 files changed, 82 insertions(+), 29 deletions(-) diff --git a/R/att_gt.R b/R/att_gt.R index 8830c273..99f9d8c6 100644 --- a/R/att_gt.R +++ b/R/att_gt.R @@ -89,16 +89,27 @@ #' include "ipw" for inverse probability weighting and "reg" for #' first step regression estimators. The user can also pass their #' own function for estimating group time average treatment -#' effects. This should be a function -#' `f(Y1,Y0,treat,covariates)` where `Y1` is an -#' `n` x `1` vector of outcomes in the post-treatment -#' outcomes, `Y0` is an `n` x `1` vector of -#' pre-treatment outcomes, `treat` is a vector indicating -#' whether or not an individual participates in the treatment, -#' and `covariates` is an `n` x `k` matrix of -#' covariates. The function should return a list that includes -#' `ATT` (an estimated average treatment effect), and -#' `inf.func` (an `n` x `1` influence function). +#' effects. The required signature depends on the data structure: +#' +#' **Panel data** (`panel=TRUE`): `f(Y1, Y0, treat, covariates, +#' i.weights, inffunc, ...)` where `Y1` is an `n x 1` vector of +#' post-treatment outcomes, `Y0` is an `n x 1` vector of +#' pre-treatment outcomes, `treat` is a binary vector indicating +#' treatment group membership, `covariates` is an `n x k` matrix, +#' `i.weights` is a vector of sampling weights, and `inffunc` is a +#' logical requesting influence-function computation. +#' +#' **Repeated cross sections / unbalanced panel** (`panel=FALSE`): +#' `f(y, post, D, covariates, i.weights, inffunc, ...)` where `y` is +#' the outcome vector (length `n`), `post` is a binary indicator for +#' the post-treatment period, `D` is a binary treatment indicator, +#' `covariates` is an `n x k` matrix, `i.weights` is a vector of +#' sampling weights, and `inffunc` is a logical. +#' +#' In both cases the function should return a list that includes +#' `ATT` (the estimated group-time average treatment effect) and +#' `att.inf.func` (an `n x 1` influence function — one entry per +#' observation passed into the estimator). #' The function can return other things as well, but these are #' the only two that are required. `est_method` is only used #' if covariates are included. @@ -274,11 +285,11 @@ att_gt <- function(yname, "(panel = FALSE) because units are not tracked across periods. ", "Use fix_weights = \"varying\" or NULL instead.") } - if (fix_weights == "varying" && inherits(est_method, "function")) { - stop("fix_weights = \"varying\" is not currently supported with custom est_method functions. ", - "The \"varying\" option uses repeated cross-section estimators internally, which require ", - "a different function signature (y, post, D) than the documented panel signature (y1, y0, D). ", - "Use fix_weights = NULL, \"base_period\", or \"first_period\" instead.") + if (fix_weights == "varying" && panel && inherits(est_method, "function")) { + stop("fix_weights = \"varying\" is not currently supported with custom est_method functions ", + "when panel = TRUE. The \"varying\" option uses repeated cross-section estimators internally, ", + "which require a different function signature (y, post, D) than the documented panel signature ", + "(y1, y0, D). Use fix_weights = NULL, \"base_period\", or \"first_period\" instead.") } } diff --git a/R/gplot.R b/R/gplot.R index e81d5459..932c7d15 100644 --- a/R/gplot.R +++ b/R/gplot.R @@ -85,7 +85,7 @@ splot <- function(ssresults, ylim=NULL, xlab=NULL, ylab=NULL, title="Group", xmax=(att+c*att.se))) + geom_point(aes(colour=post), size=1.5) + #geom_ribbon(aes(x=as.numeric(year)), alpha=0.2) + - geom_errorbar(aes(colour=post), width=0.1) + + geom_errorbarh(aes(colour=post), height=0.1) + scale_y_discrete(breaks=as.factor(ssresults$year)) + #scale_x_discrete(breaks=dabreaks, labels=as.character(dabreaks)) + scale_x_continuous(limits=ylim) + diff --git a/man/att_gt.Rd b/man/att_gt.Rd index aa37a707..18865fb4 100644 --- a/man/att_gt.Rd +++ b/man/att_gt.Rd @@ -173,16 +173,29 @@ approach in the \code{DRDID} package. Other built-in methods include "ipw" for inverse probability weighting and "reg" for first step regression estimators. The user can also pass their own function for estimating group time average treatment -effects. This should be a function -\code{f(Y1,Y0,treat,covariates)} where \code{Y1} is an -\code{n} x \code{1} vector of outcomes in the post-treatment -outcomes, \code{Y0} is an \code{n} x \code{1} vector of -pre-treatment outcomes, \code{treat} is a vector indicating -whether or not an individual participates in the treatment, -and \code{covariates} is an \code{n} x \code{k} matrix of -covariates. The function should return a list that includes -\code{ATT} (an estimated average treatment effect), and -\code{inf.func} (an \code{n} x \code{1} influence function). +effects. The required signature depends on the data structure: + +\strong{Panel data} (\code{panel=TRUE}): +\code{f(Y1, Y0, treat, covariates, i.weights, inffunc, ...)} where +\code{Y1} is an \code{n x 1} vector of post-treatment outcomes, +\code{Y0} is an \code{n x 1} vector of pre-treatment outcomes, +\code{treat} is a binary vector indicating treatment group membership, +\code{covariates} is an \code{n x k} matrix, +\code{i.weights} is a vector of sampling weights, and +\code{inffunc} is a logical requesting influence-function computation. + +\strong{Repeated cross sections / unbalanced panel} (\code{panel=FALSE}): +\code{f(y, post, D, covariates, i.weights, inffunc, ...)} where +\code{y} is the outcome vector (length \code{n}), \code{post} is a +binary indicator for the post-treatment period, \code{D} is a binary +treatment indicator, \code{covariates} is an \code{n x k} matrix, +\code{i.weights} is a vector of sampling weights, and +\code{inffunc} is a logical. + +In both cases the function should return a list that includes +\code{ATT} (the estimated group-time average treatment effect) and +\code{att.inf.func} (an \code{n x 1} influence function --- one entry +per observation passed into the estimator). The function can return other things as well, but these are the only two that are required. \code{est_method} is only used if covariates are included.} diff --git a/tests/testthat/test-att_gt.R b/tests/testthat/test-att_gt.R index 6c3f66cd..73a9ae11 100644 --- a/tests/testthat/test-att_gt.R +++ b/tests/testthat/test-att_gt.R @@ -818,15 +818,44 @@ test_that("fix_weights validation", { "not supported for repeated cross sections" ) - # varying not supported with custom est_method - my_est <- function(y1, y0, D, covariates, i.weights, inffunc, ...) { + # varying not supported with custom est_method when panel = TRUE + my_panel_est <- function(y1, y0, D, covariates, i.weights, inffunc, ...) { list(ATT = mean(y1 - y0), att.inf.func = rep(0, length(y1))) } expect_error( att_gt(yname="Y", data=data, tname="period", idname="id", - gname="G", fix_weights="varying", est_method=my_est, bstrap=FALSE), + gname="G", fix_weights="varying", est_method=my_panel_est, + panel=TRUE, bstrap=FALSE), "not currently supported with custom est_method" ) + + # varying IS supported with custom est_method when panel = FALSE (RC signature) + my_rc_est <- function(y, post, D, covariates, i.weights, inffunc, ...) { + n_obs <- length(y) + post_c <- post[D==0] + y_c <- y[D==0] + w_c <- i.weights[D==0] + att <- mean(y_c[post_c==1] * w_c[post_c==1]) / mean(w_c[post_c==1]) - + mean(y_c[post_c==0] * w_c[post_c==0]) / mean(w_c[post_c==0]) + list(ATT = att, att.inf.func = rep(0, n_obs)) + } + # Wald pre-test warning is expected with this small sim dataset (group 2 + # has only one pre-treatment period), but the key thing is: no error and + # no recycling warnings from mismatched influence-function length. + rc_result <- expect_no_error( + withCallingHandlers( + att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", fix_weights="varying", est_method=my_rc_est, + panel=FALSE, bstrap=FALSE), + warning = function(w) { + if (grepl("not a multiple of replacement length", conditionMessage(w))) + stop("IF length mismatch: ", conditionMessage(w)) + invokeRestart("muffleWarning") + } + ) + ) + expect_true(inherits(rc_result, "MP")) + expect_false(anyNA(rc_result$att)) }) # ============================================================================= From 92db4fb600af93e1c5c96040a7f48566c10aa925 Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Thu, 2 Apr 2026 19:09:33 -0400 Subject: [PATCH 11/20] Fix RC custom-estimator test IF length, restore geom_errorbarh, update docs - Fix unbalanced panel crash with fix_weights="base_period"/"first_period" in slow mode: use current disdat IDs instead of stale rightids for influence function aggregation (fixes length mismatch error) - Fix est_method docs: signature was Y1,Y0,treat but code calls y1,y0,D --- R/att_gt.R | 8 ++++---- R/compute.att_gt.R | 4 +++- man/DIDparams.Rd | 30 ++++++++++++++++++++---------- man/att_gt.Rd | 29 +++++++++++++---------------- man/conditional_did_pretest.Rd | 30 ++++++++++++++++++++---------- man/pre_process_did.Rd | 30 ++++++++++++++++++++---------- man/pre_process_did2.Rd | 30 ++++++++++++++++++++---------- 7 files changed, 100 insertions(+), 61 deletions(-) diff --git a/R/att_gt.R b/R/att_gt.R index 99f9d8c6..df7f715e 100644 --- a/R/att_gt.R +++ b/R/att_gt.R @@ -91,10 +91,10 @@ #' own function for estimating group time average treatment #' effects. The required signature depends on the data structure: #' -#' **Panel data** (`panel=TRUE`): `f(Y1, Y0, treat, covariates, -#' i.weights, inffunc, ...)` where `Y1` is an `n x 1` vector of -#' post-treatment outcomes, `Y0` is an `n x 1` vector of -#' pre-treatment outcomes, `treat` is a binary vector indicating +#' **Panel data** (`panel=TRUE`): `f(y1, y0, D, covariates, +#' i.weights, inffunc, ...)` where `y1` is an `n x 1` vector of +#' post-treatment outcomes, `y0` is an `n x 1` vector of +#' pre-treatment outcomes, `D` is a binary vector indicating #' treatment group membership, `covariates` is an `n x k` matrix, #' `i.weights` is a vector of sampling weights, and `inffunc` is a #' logical requesting influence-function computation. diff --git a/R/compute.att_gt.R b/R/compute.att_gt.R index 0d7a5315..0900d0f3 100644 --- a/R/compute.att_gt.R +++ b/R/compute.att_gt.R @@ -596,7 +596,9 @@ compute.att_gt <- function(dp) { ) } else { # aggregate inf functions by id (order by id) - aggte_inffunc <- suppressWarnings(stats::aggregate(attgt$att.inf.func, list(rightids), sum)) + # Use current disdat$.rowid (may differ from rightids if fix_weights dropped obs) + current_ids <- disdat$.rowid[disdat$.G == 1 | disdat$.C == 1] + aggte_inffunc <- suppressWarnings(stats::aggregate(attgt$att.inf.func, list(current_ids), sum)) idx <- which(unique(data$.rowid) %in% aggte_inffunc[, 1]) inffunc_updates[[update_counter]] <- list( indices = idx, diff --git a/man/DIDparams.Rd b/man/DIDparams.Rd index 1ec264d1..f5f232d9 100644 --- a/man/DIDparams.Rd +++ b/man/DIDparams.Rd @@ -175,16 +175,26 @@ approach in the \code{DRDID} package. Other built-in methods include "ipw" for inverse probability weighting and "reg" for first step regression estimators. The user can also pass their own function for estimating group time average treatment -effects. This should be a function -\code{f(Y1,Y0,treat,covariates)} where \code{Y1} is an -\code{n} x \code{1} vector of outcomes in the post-treatment -outcomes, \code{Y0} is an \code{n} x \code{1} vector of -pre-treatment outcomes, \code{treat} is a vector indicating -whether or not an individual participates in the treatment, -and \code{covariates} is an \code{n} x \code{k} matrix of -covariates. The function should return a list that includes -\code{ATT} (an estimated average treatment effect), and -\code{inf.func} (an \code{n} x \code{1} influence function). +effects. The required signature depends on the data structure: + +\strong{Panel data} (\code{panel=TRUE}): \code{f(y1, y0, D, covariates, i.weights, inffunc, ...)} where \code{y1} is an \verb{n x 1} vector of +post-treatment outcomes, \code{y0} is an \verb{n x 1} vector of +pre-treatment outcomes, \code{D} is a binary vector indicating +treatment group membership, \code{covariates} is an \verb{n x k} matrix, +\code{i.weights} is a vector of sampling weights, and \code{inffunc} is a +logical requesting influence-function computation. + +\strong{Repeated cross sections / unbalanced panel} (\code{panel=FALSE}): +\code{f(y, post, D, covariates, i.weights, inffunc, ...)} where \code{y} is +the outcome vector (length \code{n}), \code{post} is a binary indicator for +the post-treatment period, \code{D} is a binary treatment indicator, +\code{covariates} is an \verb{n x k} matrix, \code{i.weights} is a vector of +sampling weights, and \code{inffunc} is a logical. + +In both cases the function should return a list that includes +\code{ATT} (the estimated group-time average treatment effect) and +\code{att.inf.func} (an \verb{n x 1} influence function — one entry per +observation passed into the estimator). The function can return other things as well, but these are the only two that are required. \code{est_method} is only used if covariates are included.} diff --git a/man/att_gt.Rd b/man/att_gt.Rd index 18865fb4..6902dfbb 100644 --- a/man/att_gt.Rd +++ b/man/att_gt.Rd @@ -175,27 +175,24 @@ first step regression estimators. The user can also pass their own function for estimating group time average treatment effects. The required signature depends on the data structure: -\strong{Panel data} (\code{panel=TRUE}): -\code{f(Y1, Y0, treat, covariates, i.weights, inffunc, ...)} where -\code{Y1} is an \code{n x 1} vector of post-treatment outcomes, -\code{Y0} is an \code{n x 1} vector of pre-treatment outcomes, -\code{treat} is a binary vector indicating treatment group membership, -\code{covariates} is an \code{n x k} matrix, -\code{i.weights} is a vector of sampling weights, and -\code{inffunc} is a logical requesting influence-function computation. +\strong{Panel data} (\code{panel=TRUE}): \code{f(y1, y0, D, covariates, i.weights, inffunc, ...)} where \code{y1} is an \verb{n x 1} vector of +post-treatment outcomes, \code{y0} is an \verb{n x 1} vector of +pre-treatment outcomes, \code{D} is a binary vector indicating +treatment group membership, \code{covariates} is an \verb{n x k} matrix, +\code{i.weights} is a vector of sampling weights, and \code{inffunc} is a +logical requesting influence-function computation. \strong{Repeated cross sections / unbalanced panel} (\code{panel=FALSE}): -\code{f(y, post, D, covariates, i.weights, inffunc, ...)} where -\code{y} is the outcome vector (length \code{n}), \code{post} is a -binary indicator for the post-treatment period, \code{D} is a binary -treatment indicator, \code{covariates} is an \code{n x k} matrix, -\code{i.weights} is a vector of sampling weights, and -\code{inffunc} is a logical. +\code{f(y, post, D, covariates, i.weights, inffunc, ...)} where \code{y} is +the outcome vector (length \code{n}), \code{post} is a binary indicator for +the post-treatment period, \code{D} is a binary treatment indicator, +\code{covariates} is an \verb{n x k} matrix, \code{i.weights} is a vector of +sampling weights, and \code{inffunc} is a logical. In both cases the function should return a list that includes \code{ATT} (the estimated group-time average treatment effect) and -\code{att.inf.func} (an \code{n x 1} influence function --- one entry -per observation passed into the estimator). +\code{att.inf.func} (an \verb{n x 1} influence function — one entry per +observation passed into the estimator). The function can return other things as well, but these are the only two that are required. \code{est_method} is only used if covariates are included.} diff --git a/man/conditional_did_pretest.Rd b/man/conditional_did_pretest.Rd index 3646a6b3..48ab4252 100644 --- a/man/conditional_did_pretest.Rd +++ b/man/conditional_did_pretest.Rd @@ -135,16 +135,26 @@ approach in the \code{DRDID} package. Other built-in methods include "ipw" for inverse probability weighting and "reg" for first step regression estimators. The user can also pass their own function for estimating group time average treatment -effects. This should be a function -\code{f(Y1,Y0,treat,covariates)} where \code{Y1} is an -\code{n} x \code{1} vector of outcomes in the post-treatment -outcomes, \code{Y0} is an \code{n} x \code{1} vector of -pre-treatment outcomes, \code{treat} is a vector indicating -whether or not an individual participates in the treatment, -and \code{covariates} is an \code{n} x \code{k} matrix of -covariates. The function should return a list that includes -\code{ATT} (an estimated average treatment effect), and -\code{inf.func} (an \code{n} x \code{1} influence function). +effects. The required signature depends on the data structure: + +\strong{Panel data} (\code{panel=TRUE}): \code{f(y1, y0, D, covariates, i.weights, inffunc, ...)} where \code{y1} is an \verb{n x 1} vector of +post-treatment outcomes, \code{y0} is an \verb{n x 1} vector of +pre-treatment outcomes, \code{D} is a binary vector indicating +treatment group membership, \code{covariates} is an \verb{n x k} matrix, +\code{i.weights} is a vector of sampling weights, and \code{inffunc} is a +logical requesting influence-function computation. + +\strong{Repeated cross sections / unbalanced panel} (\code{panel=FALSE}): +\code{f(y, post, D, covariates, i.weights, inffunc, ...)} where \code{y} is +the outcome vector (length \code{n}), \code{post} is a binary indicator for +the post-treatment period, \code{D} is a binary treatment indicator, +\code{covariates} is an \verb{n x k} matrix, \code{i.weights} is a vector of +sampling weights, and \code{inffunc} is a logical. + +In both cases the function should return a list that includes +\code{ATT} (the estimated group-time average treatment effect) and +\code{att.inf.func} (an \verb{n x 1} influence function — one entry per +observation passed into the estimator). The function can return other things as well, but these are the only two that are required. \code{est_method} is only used if covariates are included.} diff --git a/man/pre_process_did.Rd b/man/pre_process_did.Rd index fc59e122..8969a52f 100644 --- a/man/pre_process_did.Rd +++ b/man/pre_process_did.Rd @@ -173,16 +173,26 @@ approach in the \code{DRDID} package. Other built-in methods include "ipw" for inverse probability weighting and "reg" for first step regression estimators. The user can also pass their own function for estimating group time average treatment -effects. This should be a function -\code{f(Y1,Y0,treat,covariates)} where \code{Y1} is an -\code{n} x \code{1} vector of outcomes in the post-treatment -outcomes, \code{Y0} is an \code{n} x \code{1} vector of -pre-treatment outcomes, \code{treat} is a vector indicating -whether or not an individual participates in the treatment, -and \code{covariates} is an \code{n} x \code{k} matrix of -covariates. The function should return a list that includes -\code{ATT} (an estimated average treatment effect), and -\code{inf.func} (an \code{n} x \code{1} influence function). +effects. The required signature depends on the data structure: + +\strong{Panel data} (\code{panel=TRUE}): \code{f(y1, y0, D, covariates, i.weights, inffunc, ...)} where \code{y1} is an \verb{n x 1} vector of +post-treatment outcomes, \code{y0} is an \verb{n x 1} vector of +pre-treatment outcomes, \code{D} is a binary vector indicating +treatment group membership, \code{covariates} is an \verb{n x k} matrix, +\code{i.weights} is a vector of sampling weights, and \code{inffunc} is a +logical requesting influence-function computation. + +\strong{Repeated cross sections / unbalanced panel} (\code{panel=FALSE}): +\code{f(y, post, D, covariates, i.weights, inffunc, ...)} where \code{y} is +the outcome vector (length \code{n}), \code{post} is a binary indicator for +the post-treatment period, \code{D} is a binary treatment indicator, +\code{covariates} is an \verb{n x k} matrix, \code{i.weights} is a vector of +sampling weights, and \code{inffunc} is a logical. + +In both cases the function should return a list that includes +\code{ATT} (the estimated group-time average treatment effect) and +\code{att.inf.func} (an \verb{n x 1} influence function — one entry per +observation passed into the estimator). The function can return other things as well, but these are the only two that are required. \code{est_method} is only used if covariates are included.} diff --git a/man/pre_process_did2.Rd b/man/pre_process_did2.Rd index d2edfc51..5710d35c 100644 --- a/man/pre_process_did2.Rd +++ b/man/pre_process_did2.Rd @@ -173,16 +173,26 @@ approach in the \code{DRDID} package. Other built-in methods include "ipw" for inverse probability weighting and "reg" for first step regression estimators. The user can also pass their own function for estimating group time average treatment -effects. This should be a function -\code{f(Y1,Y0,treat,covariates)} where \code{Y1} is an -\code{n} x \code{1} vector of outcomes in the post-treatment -outcomes, \code{Y0} is an \code{n} x \code{1} vector of -pre-treatment outcomes, \code{treat} is a vector indicating -whether or not an individual participates in the treatment, -and \code{covariates} is an \code{n} x \code{k} matrix of -covariates. The function should return a list that includes -\code{ATT} (an estimated average treatment effect), and -\code{inf.func} (an \code{n} x \code{1} influence function). +effects. The required signature depends on the data structure: + +\strong{Panel data} (\code{panel=TRUE}): \code{f(y1, y0, D, covariates, i.weights, inffunc, ...)} where \code{y1} is an \verb{n x 1} vector of +post-treatment outcomes, \code{y0} is an \verb{n x 1} vector of +pre-treatment outcomes, \code{D} is a binary vector indicating +treatment group membership, \code{covariates} is an \verb{n x k} matrix, +\code{i.weights} is a vector of sampling weights, and \code{inffunc} is a +logical requesting influence-function computation. + +\strong{Repeated cross sections / unbalanced panel} (\code{panel=FALSE}): +\code{f(y, post, D, covariates, i.weights, inffunc, ...)} where \code{y} is +the outcome vector (length \code{n}), \code{post} is a binary indicator for +the post-treatment period, \code{D} is a binary treatment indicator, +\code{covariates} is an \verb{n x k} matrix, \code{i.weights} is a vector of +sampling weights, and \code{inffunc} is a logical. + +In both cases the function should return a list that includes +\code{ATT} (the estimated group-time average treatment effect) and +\code{att.inf.func} (an \verb{n x 1} influence function — one entry per +observation passed into the estimator). The function can return other things as well, but these are the only two that are required. \code{est_method} is only used if covariates are included.} From 06e39747b18160d4859f04584c303f182a72896f Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Thu, 2 Apr 2026 19:32:09 -0400 Subject: [PATCH 12/20] Fix RC custom-estimator test IF length, restore geom_errorbarh, update docs - Fix est_method docs: panel signature was f(Y1,Y0,treat,...) but code calls f(y1,y0,D,covariates,i.weights,inffunc,...). Updated to match. - Add test for unbalanced panel + fix_weights with units missing from reference period (exercises the row-dropping + IF aggregation path) --- tests/testthat/test-att_gt.R | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/testthat/test-att_gt.R b/tests/testthat/test-att_gt.R index 73a9ae11..a7c3755c 100644 --- a/tests/testthat/test-att_gt.R +++ b/tests/testthat/test-att_gt.R @@ -858,6 +858,35 @@ test_that("fix_weights validation", { expect_false(anyNA(rc_result$att)) }) +test_that("unbalanced panel fix_weights with units missing from reference period", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + + # Drop some treated units from first period so fix_weights="first_period" must drop them + first_p <- min(data$period) + drop_ids <- unique(data$id[data$G > 0])[1:10] + data <- data[!(data$id %in% drop_ids & data$period == first_p), ] + data$w <- runif(nrow(data), 1, 5) + + for (fw in c("first_period", "base_period")) { + res_slow <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", data = data, tname = "period", idname = "id", + gname = "G", allow_unbalanced_panel = TRUE, + fix_weights = fw, weightsname = "w", + bstrap = FALSE, faster_mode = FALSE) + )) + res_fast <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", data = data, tname = "period", idname = "id", + gname = "G", allow_unbalanced_panel = TRUE, + fix_weights = fw, weightsname = "w", + bstrap = FALSE, faster_mode = TRUE) + )) + expect_equal(res_slow$att, res_fast$att, tolerance = 1e-10, + label = paste("unbalanced", fw, "ATT match")) + } +}) + # ============================================================================= # Influence function consistency: slow vs fast mode ATT AND SE must match # ============================================================================= From f874296302f44e16037536ac6e7c13ee86a885a2 Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Thu, 2 Apr 2026 21:16:32 -0400 Subject: [PATCH 13/20] Address Copilot review: remove unused n1_rc, use uniqueN, fix test cleanup - Remove unused n1_rc variable in fix_weights="varying" RC path - Use data.table::uniqueN() instead of length(unique()) in check_balance - Fix test-inference.R: withr::defer for temp_lib cleanup, NOT_CRAN guard around network install, remove stale unlink at EOF --- R/compute.att_gt.R | 1 - R/utility_functions.R | 2 +- tests/testthat/test-inference.R | 23 +++++++++-------------- 3 files changed, 10 insertions(+), 16 deletions(-) diff --git a/R/compute.att_gt.R b/R/compute.att_gt.R index 0900d0f3..638dc3db 100644 --- a/R/compute.att_gt.R +++ b/R/compute.att_gt.R @@ -281,7 +281,6 @@ compute.att_gt <- function(dp) { post_rc <- as.numeric(disdat_long[[tname]] == tlist[t + tfac]) w_rc <- disdat_long$.w covariates_rc <- model.matrix(xformla, data = disdat_long) - n1_rc <- sum(G_rc + disdat_long$.C) # careful: n1 for RC is different if (inherits(est_method, "function")) { res <- do.call(est_method, c(list( diff --git a/R/utility_functions.R b/R/utility_functions.R index f35f3507..14719640 100644 --- a/R/utility_functions.R +++ b/R/utility_functions.R @@ -91,7 +91,7 @@ check_balance <- function(data, id_col, time_col) { panel_counts <- data[, .N, by = c(id_col)] # Determine the maximum number of time periods for any unit - max_time_periods <- length(unique(data[[time_col]])) + max_time_periods <- data.table::uniqueN(data[[time_col]]) # Check if every unit has the same number of time periods as max_time_periods is_balanced <- all(panel_counts$N == max_time_periods) diff --git a/tests/testthat/test-inference.R b/tests/testthat/test-inference.R index c01701eb..3d52b539 100644 --- a/tests/testthat/test-inference.R +++ b/tests/testthat/test-inference.R @@ -30,19 +30,16 @@ same_matrix_elem <- function(A, B) { temp_lib <- tempfile() dir.create(temp_lib) -old_did_available <- tryCatch({ - remotes::install_version("did", version = "2.1.2", lib = temp_lib, repos = "https://cloud.r-project.org", quiet = TRUE) - isTRUE(requireNamespace("did", lib.loc = temp_lib, quietly = TRUE)) -}, error = function(e) FALSE) - -if (!old_did_available) { - # Clean up and skip all tests in this file - unlink(temp_lib, recursive = TRUE) +withr::defer(unlink(temp_lib, recursive = TRUE), teardown_env()) + +old_did_available <- FALSE +if (!identical(Sys.getenv("NOT_CRAN"), "false")) { + old_did_available <- tryCatch({ + remotes::install_version("did", version = "2.1.2", lib = temp_lib, + repos = "https://cloud.r-project.org", quiet = TRUE) + isTRUE(requireNamespace("did", lib.loc = temp_lib, quietly = TRUE)) + }, error = function(e) FALSE) } -# install.packages( -# "https://cran.r-project.org/src/contrib/did_2.1.2.tar.gz", -# repos = NULL, type = "source", lib = temp_lib -# ) test_that("inference with balanced panel data and aggregations", { skip_if(!old_did_available, "did v2.1.2 not available from CRAN") @@ -865,5 +862,3 @@ test_that("inference with unbalanced panel and clustering", { expect_equal(group_2.1.2$se[1], group_new$se[1], tol = .01) expect_equal(cal_2.1.2$se[1], cal_new$se[1], tol = .01) }) - -unlink(temp_lib) From cc683390982ee869d646c5bebc6972e6fdb9191e Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Fri, 3 Apr 2026 15:02:35 -0400 Subject: [PATCH 14/20] Replace deprecated geom_errorbarh() with geom_errorbar(orientation="y") geom_errorbarh() was deprecated in ggplot2 4.0.0. The replacement produces identical plots with no deprecation warning. --- .gitignore | 1 + R/gplot.R | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index c392361a..da888abb 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ desktop.ini did.Rproj ..Rcheck/ .claude/ +CLAUDE.md .vscode/ .revdep_manual/ vignettes/*_cache/ \ No newline at end of file diff --git a/R/gplot.R b/R/gplot.R index 932c7d15..2f8fd66f 100644 --- a/R/gplot.R +++ b/R/gplot.R @@ -85,7 +85,7 @@ splot <- function(ssresults, ylim=NULL, xlab=NULL, ylab=NULL, title="Group", xmax=(att+c*att.se))) + geom_point(aes(colour=post), size=1.5) + #geom_ribbon(aes(x=as.numeric(year)), alpha=0.2) + - geom_errorbarh(aes(colour=post), height=0.1) + + geom_errorbar(aes(colour=post), width=0.1, orientation="y") + scale_y_discrete(breaks=as.factor(ssresults$year)) + #scale_x_discrete(breaks=dabreaks, labels=as.character(dabreaks)) + scale_x_continuous(limits=ylim) + From def1de9aa6d439873ba6cd05a902d1c6ed30d6a3 Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Fri, 3 Apr 2026 15:19:50 -0400 Subject: [PATCH 15/20] Fix bare logit symbol, add ggplot2 version-gated fallback for geom_errorbarh - Use binomial(link="logit") instead of bare logit symbol in trimmer() - Version-gate geom_errorbar(orientation="y") to ggplot2 >= 3.3.0, falling back to geom_errorbarh() on older installations --- R/gplot.R | 8 +++++++- R/utility_functions.R | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/R/gplot.R b/R/gplot.R index 2f8fd66f..87db6fda 100644 --- a/R/gplot.R +++ b/R/gplot.R @@ -80,12 +80,18 @@ splot <- function(ssresults, ylim=NULL, xlab=NULL, ylab=NULL, title="Group", ylab <- 'Group' } + errorbar_layer <- if (utils::packageVersion("ggplot2") >= "3.3.0") { + geom_errorbar(aes(colour=post), width=0.1, orientation="y") + } else { + geom_errorbarh(aes(colour=post), height=0.1) + } + p <- ggplot(ssresults, aes(y=as.factor(year), x=att, xmin=(att-c*att.se), xmax=(att+c*att.se))) + geom_point(aes(colour=post), size=1.5) + #geom_ribbon(aes(x=as.numeric(year)), alpha=0.2) + - geom_errorbar(aes(colour=post), width=0.1, orientation="y") + + errorbar_layer + scale_y_discrete(breaks=as.factor(ssresults$year)) + #scale_x_discrete(breaks=dabreaks, labels=as.character(dabreaks)) + scale_x_continuous(limits=ylim) + diff --git a/R/utility_functions.R b/R/utility_functions.R index 14719640..666c85ca 100644 --- a/R/utility_functions.R +++ b/R/utility_functions.R @@ -29,7 +29,7 @@ trimmer <- function(g, tname, idname, gname, xformla, data, control_group="notye this.data$D <- 1*this.data[,gname]==g this.pscore_reg <- glm(BMisc::toformula("D", BMisc::rhs.vars(xformla)), data=this.data, - family=binomial(link=logit)) + family=binomial(link="logit")) this.pscore <- predict(this.pscore_reg, type="response") dropper <- (this.pscore > threshold) & (this.data$D==1) if (sum(dropper) > 0) { From 33366a0cfc5c9e10092ed6a1bea9dfc73d8e11f2 Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Fri, 3 Apr 2026 16:10:49 -0400 Subject: [PATCH 16/20] Fix overlap guards for varying path, tighten validation, harden workflow - Move overlap/rank checks inside fix_weights="varying" branch to run on RC data (covariates_rc, G_rc) instead of wide panel data - Add droplevels() on disdat_long before model.matrix() to prevent phantom zero-column dummies from unused factor levels - Keep panel guards in non-varying path (moved inside tryCatch) - Restore fix_weights="varying" + custom est_method block for all panel data (remove allow_unbalanced_panel bypass that left balanced panels exposed to wrong-signature dispatch) - Remove misleading "set panel=FALSE" workaround from error message - Workflow: scope idempotency check to base branch only, not --all refs --- .github/workflows/bump-version.yaml | 17 +++- R/att_gt.R | 4 +- R/compute.att_gt.R | 142 +++++++++++++--------------- 3 files changed, 85 insertions(+), 78 deletions(-) diff --git a/.github/workflows/bump-version.yaml b/.github/workflows/bump-version.yaml index 7e0364dd..52450178 100644 --- a/.github/workflows/bump-version.yaml +++ b/.github/workflows/bump-version.yaml @@ -21,9 +21,18 @@ jobs: - uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.base.ref }} + fetch-depth: 0 - name: Bump dev version in DESCRIPTION + env: + PR_NUMBER: ${{ github.event.pull_request.number }} run: | + # Idempotency: check if this PR already triggered a bump (base branch only) + if git log --oneline ${{ github.event.pull_request.base.ref }} | grep -q "Bump version.*PR #${PR_NUMBER}"; then + echo "Version already bumped for PR #${PR_NUMBER}, skipping." + exit 0 + fi + # Extract current version current=$(grep '^Version:' DESCRIPTION | sed 's/Version: //') echo "Current version: $current" @@ -47,7 +56,11 @@ jobs: git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" - # Commit and push + # Commit and push with retry for race conditions git add DESCRIPTION - git diff --cached --quiet || git commit -m "Bump version to ${new_version}" + git diff --cached --quiet && { echo "No changes to commit"; exit 0; } + git commit -m "Bump version to ${new_version} (PR #${PR_NUMBER})" + + # Pull latest before pushing to handle concurrent merges + git pull --rebase origin ${{ github.event.pull_request.base.ref }} || true git push diff --git a/R/att_gt.R b/R/att_gt.R index df7f715e..7c168fb6 100644 --- a/R/att_gt.R +++ b/R/att_gt.R @@ -287,8 +287,8 @@ att_gt <- function(yname, } if (fix_weights == "varying" && panel && inherits(est_method, "function")) { stop("fix_weights = \"varying\" is not currently supported with custom est_method functions ", - "when panel = TRUE. The \"varying\" option uses repeated cross-section estimators internally, ", - "which require a different function signature (y, post, D) than the documented panel signature ", + "on panel data. The \"varying\" option uses repeated cross-section estimators internally, ", + "which require a different function signature (y, post, D) than the panel signature ", "(y1, y0, D). Use fix_weights = NULL, \"base_period\", or \"first_period\" instead.") } } diff --git a/R/compute.att_gt.R b/R/compute.att_gt.R index 638dc3db..dd1e12e1 100644 --- a/R/compute.att_gt.R +++ b/R/compute.att_gt.R @@ -221,50 +221,6 @@ compute.att_gt <- function(dp) { # matrix of covariates covariates <- model.matrix(xformla, data = disdat) - #----------------------------------------------------------------------------- - # more checks for enough observations in each group - - # if using custom estimation method, skip this part - custom_est_method <- is.function(est_method) - - if (!custom_est_method) { - pscore_problems_likely <- FALSE - reg_problems_likely <- FALSE - - # checks for pscore based methods - if (est_method %in% c("dr", "ipw")) { - preliminary_logit <- fastglm::fastglm(covariates, G, family = binomial()) - preliminary_pscores <- preliminary_logit$fitted.values - if (max(preliminary_pscores) >= 0.999) { - pscore_problems_likely <- TRUE - warning(paste0("overlap condition violated for ", glist[g], " in time period ", tlist[t + tfac])) - } - } - - # check if can run regression using control units - if (est_method %in% c("dr", "reg")) { - control_covs <- covariates[G == 0, , drop = FALSE] - # if (determinant(t(control_covs)%*%control_covs, logarithm=FALSE)$modulus < .Machine$double.eps) { - if (rcond(t(control_covs) %*% control_covs) < .Machine$double.eps) { - reg_problems_likely <- TRUE - warning(paste0("Not enough control units for group ", glist[g], " in time period ", tlist[t + tfac], " to run specified regression")) - } - } - - if (reg_problems_likely | pscore_problems_likely) { - attgt.list[[counter]] <- list(att = NA, group = glist[g], year = tlist[(t + tfac)], post = post.treat) - inffunc_updates[[update_counter]] <- list( - indices = seq_len(n), - values = rep(NA_real_, n) - ) - - # Update the counters - update_counter <- update_counter + 1 - counter <- counter + 1 - next - } - } - #----------------------------------------------------------------------------- # code for actually computing att(g,t) #----------------------------------------------------------------------------- @@ -275,13 +231,31 @@ compute.att_gt <- function(dp) { # Go back to long-format data for this (g,t) cell disdat_long <- data[time_mask] disdat_long_idx <- disdat_long$.G == 1 | disdat_long$.C == 1 - disdat_long <- disdat_long[disdat_long_idx] + disdat_long <- droplevels(disdat_long[disdat_long_idx]) Y_rc <- disdat_long[[yname]] G_rc <- disdat_long$.G post_rc <- as.numeric(disdat_long[[tname]] == tlist[t + tfac]) w_rc <- disdat_long$.w covariates_rc <- model.matrix(xformla, data = disdat_long) + # Run overlap/rank checks on RC data (not wide panel data) + if (!is.function(est_method)) { + if (est_method %in% c("dr", "ipw")) { + preliminary_logit <- fastglm::fastglm(covariates_rc, G_rc, family = binomial()) + if (max(preliminary_logit$fitted.values) >= 0.999) { + warning(paste0("overlap condition violated for ", glist[g], " in time period ", tlist[t + tfac])) + stop("overlap") + } + } + if (est_method %in% c("dr", "reg")) { + control_covs_rc <- covariates_rc[G_rc == 0, , drop = FALSE] + if (rcond(t(control_covs_rc) %*% control_covs_rc) < .Machine$double.eps) { + warning(paste0("Not enough control units for group ", glist[g], " in time period ", tlist[t + tfac], " to run specified regression")) + stop("singular") + } + } + } + if (inherits(est_method, "function")) { res <- do.call(est_method, c(list( y = Y_rc, post = post_rc, @@ -301,36 +275,56 @@ compute.att_gt <- function(dp) { covariates = covariates_rc, i.weights = w_rc, boot = FALSE, inffunc = TRUE) } - } else if (inherits(est_method, "function")) { - # user-specified function - res <- do.call(est_method, c(list( - y1 = Ypost, y0 = Ypre, - D = G, - covariates = covariates, - i.weights = w, - inffunc = TRUE - ), extra_args)) - } else if (est_method == "ipw") { - # inverse-probability weights - res <- DRDID::std_ipw_did_panel(Ypost, Ypre, G, - covariates = covariates, - i.weights = w, - boot = FALSE, inffunc = TRUE - ) - } else if (est_method == "reg") { - # regression - res <- DRDID::reg_did_panel(Ypost, Ypre, G, - covariates = covariates, - i.weights = w, - boot = FALSE, inffunc = TRUE - ) } else { - # doubly robust, this is default - res <- DRDID::drdid_panel(Ypost, Ypre, G, - covariates = covariates, - i.weights = w, - boot = FALSE, inffunc = TRUE - ) + # Panel path: run overlap/rank checks on panel data + if (!is.function(est_method)) { + if (est_method %in% c("dr", "ipw")) { + preliminary_logit <- fastglm::fastglm(covariates, G, family = binomial()) + if (max(preliminary_logit$fitted.values) >= 0.999) { + warning(paste0("overlap condition violated for ", glist[g], " in time period ", tlist[t + tfac])) + stop("overlap") + } + } + if (est_method %in% c("dr", "reg")) { + control_covs <- covariates[G == 0, , drop = FALSE] + if (rcond(t(control_covs) %*% control_covs) < .Machine$double.eps) { + warning(paste0("Not enough control units for group ", glist[g], " in time period ", tlist[t + tfac], " to run specified regression")) + stop("singular") + } + } + } + + if (inherits(est_method, "function")) { + # user-specified function + res <- do.call(est_method, c(list( + y1 = Ypost, y0 = Ypre, + D = G, + covariates = covariates, + i.weights = w, + inffunc = TRUE + ), extra_args)) + } else if (est_method == "ipw") { + # inverse-probability weights + res <- DRDID::std_ipw_did_panel(Ypost, Ypre, G, + covariates = covariates, + i.weights = w, + boot = FALSE, inffunc = TRUE + ) + } else if (est_method == "reg") { + # regression + res <- DRDID::reg_did_panel(Ypost, Ypre, G, + covariates = covariates, + i.weights = w, + boot = FALSE, inffunc = TRUE + ) + } else { + # doubly robust, this is default + res <- DRDID::drdid_panel(Ypost, Ypre, G, + covariates = covariates, + i.weights = w, + boot = FALSE, inffunc = TRUE + ) + } } # adjust influence function to account for only using From 9ee3409d8465edf259f9a9587fef502d7d5ee5c0 Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Fri, 3 Apr 2026 16:37:16 -0400 Subject: [PATCH 17/20] Fix workflow PR number matching, document varying-mode covariate behavior - Use grep -qF for exact PR number match in idempotency check to prevent PR #1 from matching PR #10, #100, etc. - Document that fix_weights="varying" also evaluates covariates at each period (not just weights), since it uses the RC estimator --- .github/workflows/bump-version.yaml | 4 ++-- R/att_gt.R | 7 +++++-- man/DIDparams.Rd | 7 +++++-- man/att_gt.Rd | 7 +++++-- man/pre_process_did.Rd | 7 +++++-- man/pre_process_did2.Rd | 7 +++++-- 6 files changed, 27 insertions(+), 12 deletions(-) diff --git a/.github/workflows/bump-version.yaml b/.github/workflows/bump-version.yaml index 52450178..27bf182f 100644 --- a/.github/workflows/bump-version.yaml +++ b/.github/workflows/bump-version.yaml @@ -27,8 +27,8 @@ jobs: env: PR_NUMBER: ${{ github.event.pull_request.number }} run: | - # Idempotency: check if this PR already triggered a bump (base branch only) - if git log --oneline ${{ github.event.pull_request.base.ref }} | grep -q "Bump version.*PR #${PR_NUMBER}"; then + # Idempotency: check if this PR already triggered a bump (exact match) + if git log --oneline ${{ github.event.pull_request.base.ref }} | grep -qF "(PR #${PR_NUMBER})"; then echo "Version already bumped for PR #${PR_NUMBER}, skipping." exit 0 fi diff --git a/R/att_gt.R b/R/att_gt.R index 7c168fb6..548694f5 100644 --- a/R/att_gt.R +++ b/R/att_gt.R @@ -49,8 +49,11 @@ #' repeated cross-section DRDID estimators so that pre-period and #' post-period observations each carry their own weight. This is the #' most flexible option but sacrifices the efficiency of the panel -#' estimator. For RC/unbalanced panel, this is identical to the -#' default. Not supported with custom \code{est_method} functions.} +#' estimator. Note: when covariates are time-varying, this mode also +#' evaluates covariates at each period (rather than fixing them at the +#' base period as the panel estimator does). For RC/unbalanced panel, +#' this is identical to the default. Not supported with custom +#' \code{est_method} functions.} #' \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for #' all (g,t) cells within a group, for both pre-treatment and #' post-treatment comparisons. Ensures all cells within a group use the diff --git a/man/DIDparams.Rd b/man/DIDparams.Rd index f5f232d9..156e4103 100644 --- a/man/DIDparams.Rd +++ b/man/DIDparams.Rd @@ -121,8 +121,11 @@ for all panel types. For balanced panel data, this switches to the repeated cross-section DRDID estimators so that pre-period and post-period observations each carry their own weight. This is the most flexible option but sacrifices the efficiency of the panel -estimator. For RC/unbalanced panel, this is identical to the -default. Not supported with custom \code{est_method} functions.} +estimator. Note: when covariates are time-varying, this mode also +evaluates covariates at each period (rather than fixing them at the +base period as the panel estimator does). For RC/unbalanced panel, +this is identical to the default. Not supported with custom +\code{est_method} functions.} \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for all (g,t) cells within a group, for both pre-treatment and post-treatment comparisons. Ensures all cells within a group use the diff --git a/man/att_gt.Rd b/man/att_gt.Rd index 6902dfbb..6d07ad0a 100644 --- a/man/att_gt.Rd +++ b/man/att_gt.Rd @@ -131,8 +131,11 @@ for all panel types. For balanced panel data, this switches to the repeated cross-section DRDID estimators so that pre-period and post-period observations each carry their own weight. This is the most flexible option but sacrifices the efficiency of the panel -estimator. For RC/unbalanced panel, this is identical to the -default. Not supported with custom \code{est_method} functions.} +estimator. Note: when covariates are time-varying, this mode also +evaluates covariates at each period (rather than fixing them at the +base period as the panel estimator does). For RC/unbalanced panel, +this is identical to the default. Not supported with custom +\code{est_method} functions.} \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for all (g,t) cells within a group, for both pre-treatment and post-treatment comparisons. Ensures all cells within a group use the diff --git a/man/pre_process_did.Rd b/man/pre_process_did.Rd index 8969a52f..36b9eac8 100644 --- a/man/pre_process_did.Rd +++ b/man/pre_process_did.Rd @@ -131,8 +131,11 @@ for all panel types. For balanced panel data, this switches to the repeated cross-section DRDID estimators so that pre-period and post-period observations each carry their own weight. This is the most flexible option but sacrifices the efficiency of the panel -estimator. For RC/unbalanced panel, this is identical to the -default. Not supported with custom \code{est_method} functions.} +estimator. Note: when covariates are time-varying, this mode also +evaluates covariates at each period (rather than fixing them at the +base period as the panel estimator does). For RC/unbalanced panel, +this is identical to the default. Not supported with custom +\code{est_method} functions.} \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for all (g,t) cells within a group, for both pre-treatment and post-treatment comparisons. Ensures all cells within a group use the diff --git a/man/pre_process_did2.Rd b/man/pre_process_did2.Rd index 5710d35c..5740ffaf 100644 --- a/man/pre_process_did2.Rd +++ b/man/pre_process_did2.Rd @@ -131,8 +131,11 @@ for all panel types. For balanced panel data, this switches to the repeated cross-section DRDID estimators so that pre-period and post-period observations each carry their own weight. This is the most flexible option but sacrifices the efficiency of the panel -estimator. For RC/unbalanced panel, this is identical to the -default. Not supported with custom \code{est_method} functions.} +estimator. Note: when covariates are time-varying, this mode also +evaluates covariates at each period (rather than fixing them at the +base period as the panel estimator does). For RC/unbalanced panel, +this is identical to the default. Not supported with custom +\code{est_method} functions.} \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for all (g,t) cells within a group, for both pre-treatment and post-treatment comparisons. Ensures all cells within a group use the From 447c3e5ad50a37996acc9a9555a3d5dfc7f466f0 Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Fri, 3 Apr 2026 17:04:13 -0400 Subject: [PATCH 18/20] Fix fix_weights="varying" to use pre-period covariates only The "varying" option should only change weights, not the covariate conditioning set. Both code paths now use pre-period covariates for all observations in the stacked RC data, matching the panel estimator's covariate handling. Post-treatment covariates are never used. Fast path: rbind(cov_pre, cov_pre) instead of rbind(cov_pre, cov_post) Slow path: match() lookup to map each observation to its unit's pre-period covariates regardless of row ordering in disdat_long --- R/att_gt.R | 11 +++++------ R/compute.att_gt.R | 12 +++++++++++- R/compute.att_gt2.R | 10 +++++----- man/DIDparams.Rd | 11 +++++------ man/att_gt.Rd | 11 +++++------ man/pre_process_did.Rd | 11 +++++------ man/pre_process_did2.Rd | 11 +++++------ 7 files changed, 41 insertions(+), 36 deletions(-) diff --git a/R/att_gt.R b/R/att_gt.R index 548694f5..1f3e633f 100644 --- a/R/att_gt.R +++ b/R/att_gt.R @@ -47,12 +47,11 @@ #' \item{\code{"varying"}}{Uses per-observation, period-specific weights #' for all panel types. For balanced panel data, this switches to the #' repeated cross-section DRDID estimators so that pre-period and -#' post-period observations each carry their own weight. This is the -#' most flexible option but sacrifices the efficiency of the panel -#' estimator. Note: when covariates are time-varying, this mode also -#' evaluates covariates at each period (rather than fixing them at the -#' base period as the panel estimator does). For RC/unbalanced panel, -#' this is identical to the default. Not supported with custom +#' post-period observations each carry their own weight. Covariates +#' are held fixed at their pre-period values (same as the default +#' panel estimator). This is the most flexible option for weights but +#' sacrifices the efficiency of the panel estimator. For RC/unbalanced +#' panel, this is identical to the default. Not supported with custom #' \code{est_method} functions.} #' \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for #' all (g,t) cells within a group, for both pre-treatment and diff --git a/R/compute.att_gt.R b/R/compute.att_gt.R index dd1e12e1..494f8fbb 100644 --- a/R/compute.att_gt.R +++ b/R/compute.att_gt.R @@ -236,7 +236,17 @@ compute.att_gt <- function(dp) { G_rc <- disdat_long$.G post_rc <- as.numeric(disdat_long[[tname]] == tlist[t + tfac]) w_rc <- disdat_long$.w - covariates_rc <- model.matrix(xformla, data = disdat_long) + # Use pre-period covariates for all observations — fix_weights only + # changes weights, not the covariate conditioning set. + # Build a lookup from pre-period rows, then replicate for each obs. + pre_mask <- disdat_long[[tname]] == tlist[pret] + disdat_pre <- disdat_long[pre_mask] + cov_pre <- model.matrix(xformla, data = disdat_pre) + # Map each row in disdat_long to its unit's pre-period covariates + pre_ids <- disdat_pre[[idname]] + all_ids <- disdat_long[[idname]] + id_map <- match(all_ids, pre_ids) + covariates_rc <- cov_pre[id_map, , drop = FALSE] # Run overlap/rank checks on RC data (not wide panel data) if (!is.function(est_method)) { diff --git a/R/compute.att_gt2.R b/R/compute.att_gt2.R index be633471..2966271c 100644 --- a/R/compute.att_gt2.R +++ b/R/compute.att_gt2.R @@ -388,13 +388,13 @@ run_att_gt_estimation <- function(g, t, dp2){ post = rep(c(0L, 1L), each = n_units), i.weights = c(dp2$weights_tensor[[pret]], dp2$weights_tensor[[t+tfac]]) ) - # Stack covariates for both periods - cov_pre <- dp2$covariates_tensor[[pret]] - cov_post <- dp2$covariates_tensor[[t+tfac]] + # Use pre-period covariates for both halves — fix_weights only + # changes weights, not the covariate conditioning set + cov_pre <- dp2$covariates_tensor[[pret]] if (is.matrix(cov_pre)) { - covariates <- rbind(cov_pre, cov_post) + covariates <- rbind(cov_pre, cov_pre) } else { - covariates <- c(cov_pre, cov_post) + covariates <- c(cov_pre, cov_pre) } } else { # Default or fixed weight options: use panel estimator with single weight vector diff --git a/man/DIDparams.Rd b/man/DIDparams.Rd index 156e4103..c1230fc6 100644 --- a/man/DIDparams.Rd +++ b/man/DIDparams.Rd @@ -119,12 +119,11 @@ periods.} \item{\code{"varying"}}{Uses per-observation, period-specific weights for all panel types. For balanced panel data, this switches to the repeated cross-section DRDID estimators so that pre-period and -post-period observations each carry their own weight. This is the -most flexible option but sacrifices the efficiency of the panel -estimator. Note: when covariates are time-varying, this mode also -evaluates covariates at each period (rather than fixing them at the -base period as the panel estimator does). For RC/unbalanced panel, -this is identical to the default. Not supported with custom +post-period observations each carry their own weight. Covariates +are held fixed at their pre-period values (same as the default +panel estimator). This is the most flexible option for weights but +sacrifices the efficiency of the panel estimator. For RC/unbalanced +panel, this is identical to the default. Not supported with custom \code{est_method} functions.} \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for all (g,t) cells within a group, for both pre-treatment and diff --git a/man/att_gt.Rd b/man/att_gt.Rd index 6d07ad0a..74c621ae 100644 --- a/man/att_gt.Rd +++ b/man/att_gt.Rd @@ -129,12 +129,11 @@ periods.} \item{\code{"varying"}}{Uses per-observation, period-specific weights for all panel types. For balanced panel data, this switches to the repeated cross-section DRDID estimators so that pre-period and -post-period observations each carry their own weight. This is the -most flexible option but sacrifices the efficiency of the panel -estimator. Note: when covariates are time-varying, this mode also -evaluates covariates at each period (rather than fixing them at the -base period as the panel estimator does). For RC/unbalanced panel, -this is identical to the default. Not supported with custom +post-period observations each carry their own weight. Covariates +are held fixed at their pre-period values (same as the default +panel estimator). This is the most flexible option for weights but +sacrifices the efficiency of the panel estimator. For RC/unbalanced +panel, this is identical to the default. Not supported with custom \code{est_method} functions.} \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for all (g,t) cells within a group, for both pre-treatment and diff --git a/man/pre_process_did.Rd b/man/pre_process_did.Rd index 36b9eac8..8853c1cd 100644 --- a/man/pre_process_did.Rd +++ b/man/pre_process_did.Rd @@ -129,12 +129,11 @@ periods.} \item{\code{"varying"}}{Uses per-observation, period-specific weights for all panel types. For balanced panel data, this switches to the repeated cross-section DRDID estimators so that pre-period and -post-period observations each carry their own weight. This is the -most flexible option but sacrifices the efficiency of the panel -estimator. Note: when covariates are time-varying, this mode also -evaluates covariates at each period (rather than fixing them at the -base period as the panel estimator does). For RC/unbalanced panel, -this is identical to the default. Not supported with custom +post-period observations each carry their own weight. Covariates +are held fixed at their pre-period values (same as the default +panel estimator). This is the most flexible option for weights but +sacrifices the efficiency of the panel estimator. For RC/unbalanced +panel, this is identical to the default. Not supported with custom \code{est_method} functions.} \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for all (g,t) cells within a group, for both pre-treatment and diff --git a/man/pre_process_did2.Rd b/man/pre_process_did2.Rd index 5740ffaf..5dd1404b 100644 --- a/man/pre_process_did2.Rd +++ b/man/pre_process_did2.Rd @@ -129,12 +129,11 @@ periods.} \item{\code{"varying"}}{Uses per-observation, period-specific weights for all panel types. For balanced panel data, this switches to the repeated cross-section DRDID estimators so that pre-period and -post-period observations each carry their own weight. This is the -most flexible option but sacrifices the efficiency of the panel -estimator. Note: when covariates are time-varying, this mode also -evaluates covariates at each period (rather than fixing them at the -base period as the panel estimator does). For RC/unbalanced panel, -this is identical to the default. Not supported with custom +post-period observations each carry their own weight. Covariates +are held fixed at their pre-period values (same as the default +panel estimator). This is the most flexible option for weights but +sacrifices the efficiency of the panel estimator. For RC/unbalanced +panel, this is identical to the default. Not supported with custom \code{est_method} functions.} \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for all (g,t) cells within a group, for both pre-treatment and From 879b147cdd8d9a61d964b79292bb1ba820243909 Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Sat, 4 Apr 2026 14:44:35 -0400 Subject: [PATCH 19/20] Fix fix_weights="varying" covariate period for universal base period With base_period="universal", pret can be later than the current time period for placebo cells. The varying path was always using covariates from pret, which meant conditioning on future covariates for placebo cells. Now uses min(pret, t) to match the panel estimator's convention: always the earlier of the two comparison periods. --- R/compute.att_gt.R | 22 ++++++++++++---------- R/compute.att_gt2.R | 14 ++++++++------ 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/R/compute.att_gt.R b/R/compute.att_gt.R index 494f8fbb..251238d9 100644 --- a/R/compute.att_gt.R +++ b/R/compute.att_gt.R @@ -236,17 +236,19 @@ compute.att_gt <- function(dp) { G_rc <- disdat_long$.G post_rc <- as.numeric(disdat_long[[tname]] == tlist[t + tfac]) w_rc <- disdat_long$.w - # Use pre-period covariates for all observations — fix_weights only - # changes weights, not the covariate conditioning set. - # Build a lookup from pre-period rows, then replicate for each obs. - pre_mask <- disdat_long[[tname]] == tlist[pret] - disdat_pre <- disdat_long[pre_mask] - cov_pre <- model.matrix(xformla, data = disdat_pre) - # Map each row in disdat_long to its unit's pre-period covariates - pre_ids <- disdat_pre[[idname]] + # Use earlier-period covariates for all observations — fix_weights + # only changes weights, not the covariate conditioning set. + # Use min(pret, t+tfac) to match the panel estimator's convention: + # with base_period="universal", pret can be later than t for placebo cells. + earlier_period <- tlist[min(pret, t + tfac)] + early_mask <- disdat_long[[tname]] == earlier_period + disdat_early <- disdat_long[early_mask] + cov_early <- model.matrix(xformla, data = disdat_early) + # Map each row in disdat_long to its unit's earlier-period covariates + early_ids <- disdat_early[[idname]] all_ids <- disdat_long[[idname]] - id_map <- match(all_ids, pre_ids) - covariates_rc <- cov_pre[id_map, , drop = FALSE] + id_map <- match(all_ids, early_ids) + covariates_rc <- cov_early[id_map, , drop = FALSE] # Run overlap/rank checks on RC data (not wide panel data) if (!is.function(est_method)) { diff --git a/R/compute.att_gt2.R b/R/compute.att_gt2.R index 2966271c..32a2cba0 100644 --- a/R/compute.att_gt2.R +++ b/R/compute.att_gt2.R @@ -388,13 +388,15 @@ run_att_gt_estimation <- function(g, t, dp2){ post = rep(c(0L, 1L), each = n_units), i.weights = c(dp2$weights_tensor[[pret]], dp2$weights_tensor[[t+tfac]]) ) - # Use pre-period covariates for both halves — fix_weights only - # changes weights, not the covariate conditioning set - cov_pre <- dp2$covariates_tensor[[pret]] - if (is.matrix(cov_pre)) { - covariates <- rbind(cov_pre, cov_pre) + # Use earlier-period covariates for both halves — fix_weights only + # changes weights, not the covariate conditioning set. + # Use min(pret, t) to match the panel estimator's convention: + # with base_period="universal", pret can be later than t for placebo cells. + cov_early <- dp2$covariates_tensor[[base::min(pret, t)]] + if (is.matrix(cov_early)) { + covariates <- rbind(cov_early, cov_early) } else { - covariates <- c(cov_pre, cov_pre) + covariates <- c(cov_early, cov_early) } } else { # Default or fixed weight options: use panel estimator with single weight vector From d33aaedd0291b9d7111fae09c08dcc4da12b9097 Mon Sep 17 00:00:00 2001 From: pedrohcgs Date: Sun, 5 Apr 2026 12:42:03 -0400 Subject: [PATCH 20/20] Address reviewer non-blocking notes - Add skip_on_cran() to each test_that in test-inference.R (more idiomatic than relying on the top-level NOT_CRAN guard alone) - Remove misleading "Added broom to Suggests" NEWS entry (broom was added in 2.3.1.903, not 2.3.1.904) - Add test exercising splot() via ggdid(agg_grp) to cover the ggplot2 version-gated errorbar layer and guard against deprecation warnings --- NEWS.md | 2 -- tests/testthat/test-ggdid.R | 11 +++++++++++ tests/testthat/test-inference.R | 7 +++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/NEWS.md b/NEWS.md index 23192372..aa6537c7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -16,8 +16,6 @@ * Added `statistic` (t-statistic) and `p.value` (pointwise, two-sided) columns to `tidy()` output for both `MP` and `AGGTEobj` objects, following `broom` conventions - * Added `broom` to `Suggests` - * Fixed `glance.MP()` returning `NULL` for `ngroup` and `ntime` when `faster_mode = TRUE` (DIDparams2 uses different field names than DIDparams) * Fixed influence function aggregation bug in `fix_weights = "varying"` on balanced panel: the slow path now correctly uses `rowsum()` by unit ID instead of assuming stacked ordering diff --git a/tests/testthat/test-ggdid.R b/tests/testthat/test-ggdid.R index 1b80656f..00e8a889 100644 --- a/tests/testthat/test-ggdid.R +++ b/tests/testthat/test-ggdid.R @@ -86,3 +86,14 @@ test_that("ggdid.AGGTEobj works with ref_line=NULL", { p <- ggdid(agg_dyn, ref_line = NULL) expect_s3_class(p, "gg") }) + +test_that("splot works for group-type and renders without deprecation warnings", { + # splot() is the path that uses the ggplot2 version-gated errorbar layer. + # It is called via ggdid.AGGTEobj() for type="group". Verify the output is + # a valid ggplot object and no deprecation warnings leak through. + expect_warning(p <- ggdid(agg_grp), regexp = NA) + expect_s3_class(p, "gg") + # Verify the plot actually builds (catches any layer construction errors) + built <- ggplot2::ggplot_build(p) + expect_s3_class(built, "ggplot_built") +}) diff --git a/tests/testthat/test-inference.R b/tests/testthat/test-inference.R index 3d52b539..dc7ed5ed 100644 --- a/tests/testthat/test-inference.R +++ b/tests/testthat/test-inference.R @@ -42,6 +42,7 @@ if (!identical(Sys.getenv("NOT_CRAN"), "false")) { } test_that("inference with balanced panel data and aggregations", { + skip_on_cran() skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp) @@ -176,6 +177,7 @@ test_that("inference with balanced panel data and aggregations", { test_that("inference with clustering", { + skip_on_cran() skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp) @@ -305,6 +307,7 @@ test_that("inference with clustering", { }) test_that("same inference with unbalanced panel and panel data", { + skip_on_cran() skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp) @@ -336,6 +339,7 @@ test_that("same inference with unbalanced panel and panel data", { test_that("inference with repeated cross sections", { + skip_on_cran() skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp, panel = FALSE) @@ -466,6 +470,7 @@ test_that("inference with repeated cross sections", { test_that("inference with repeated cross sections and clustering", { + skip_on_cran() skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp, panel = FALSE) @@ -596,6 +601,7 @@ test_that("inference with repeated cross sections and clustering", { test_that("inference with unbalanced panel", { + skip_on_cran() skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp) @@ -730,6 +736,7 @@ test_that("inference with unbalanced panel", { }) test_that("inference with unbalanced panel and clustering", { + skip_on_cran() skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp)