diff --git a/.github/workflows/bump-version.yaml b/.github/workflows/bump-version.yaml new file mode 100644 index 00000000..27bf182f --- /dev/null +++ b/.github/workflows/bump-version.yaml @@ -0,0 +1,66 @@ +name: Bump dev version on PR merge + +on: + pull_request: + types: [closed] + branches: [master, main] + +concurrency: + group: bump-version + cancel-in-progress: false + +jobs: + bump-version: + if: github.event.pull_request.merged == true + runs-on: ubuntu-latest + + permissions: + contents: write + + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.base.ref }} + fetch-depth: 0 + + - name: Bump dev version in DESCRIPTION + env: + PR_NUMBER: ${{ github.event.pull_request.number }} + run: | + # Idempotency: check if this PR already triggered a bump (exact match) + if git log --oneline ${{ github.event.pull_request.base.ref }} | grep -qF "(PR #${PR_NUMBER})"; then + echo "Version already bumped for PR #${PR_NUMBER}, skipping." + exit 0 + fi + + # Extract current version + current=$(grep '^Version:' DESCRIPTION | sed 's/Version: //') + echo "Current version: $current" + + # Split into parts + IFS='.' read -ra parts <<< "$current" + major="${parts[0]}" + minor="${parts[1]}" + patch="${parts[2]}" + dev="${parts[3]:-0}" + + # Increment dev version + new_dev=$((dev + 1)) + new_version="${major}.${minor}.${patch}.${new_dev}" + echo "New version: $new_version" + + # Update DESCRIPTION + sed -i "s/^Version: .*/Version: ${new_version}/" DESCRIPTION + + # Configure git + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + # Commit and push with retry for race conditions + git add DESCRIPTION + git diff --cached --quiet && { echo "No changes to commit"; exit 0; } + git commit -m "Bump version to ${new_version} (PR #${PR_NUMBER})" + + # Pull latest before pushing to handle concurrent merges + git pull --rebase origin ${{ github.event.pull_request.base.ref }} || true + git push diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index 3c0da1c9..c807a4ff 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -15,16 +15,22 @@ jobs: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - - uses: r-lib/actions/setup-r@v1 + - uses: r-lib/actions/setup-r@v2 with: use-public-rspm: true - - uses: r-lib/actions/setup-r-dependencies@v1 + - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: covr + extra-packages: any::covr + needs: coverage - name: Test coverage - run: covr::codecov() + run: | + covr::codecov( + quiet = FALSE, + clean = FALSE, + install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package") + ) shell: Rscript {0} diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index d049181c..8987800d 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -11,45 +11,24 @@ jobs: test: runs-on: ubuntu-latest - steps: - - name: Checkout repository - uses: actions/checkout@v3 + env: + GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} - - name: Set up R - uses: r-lib/actions/setup-r@v2 + steps: + - uses: actions/checkout@v4 - - name: Install system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libharfbuzz-dev libfribidi-dev libfreetype6-dev libcurl4-openssl-dev libpng-dev libtiff-dev libjpeg-dev libfontconfig1-dev + - uses: r-lib/actions/setup-r@v2 + with: + use-public-rspm: true - - name: Cache R packages - uses: actions/cache@v3 + - uses: r-lib/actions/setup-r-dependencies@v2 with: - path: ${{ env.R_LIBS_USER }} - key: ${{ runner.os }}-r-${{ hashFiles('**/DESCRIPTION') }}-${{ matrix.config.r }} - restore-keys: | - ${{ runner.os }}-r-${{ matrix.config.r }}- - ${{ runner.os }}-r- + extra-packages: any::devtools, any::testthat + needs: check - - name: Install R package dependencies - run: | - if [ "${{ runner.os }}" == "Windows" ]; then - cmd.exe /c "R -e \"install.packages('devtools'); devtools::install_deps(dependencies = TRUE)\"" - cmd.exe /c "R -e \"install.packages('tidyverse'); devtools::install_deps(dependencies = TRUE)\"" - cmd.exe /c "R -e \"install.packages('remotes')\"" - cmd.exe /c "R -e \"install.packages('callr')\"" - else - R -e "install.packages('devtools'); devtools::install_deps(dependencies = TRUE)" - R -e "install.packages('tidyverse'); devtools::install_deps(dependencies = TRUE)" - R -e "install.packages('remotes')" - R -e "install.packages('callr')" - fi - - name: Install did package run: R CMD INSTALL . - + - name: Run tests - run: | + run: | R -e "devtools::test()" - R -e "testthat::test_dir('tests/testthat', reporter = 'check', package = 'did')" diff --git a/.gitignore b/.gitignore index c392361a..da888abb 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ desktop.ini did.Rproj ..Rcheck/ .claude/ +CLAUDE.md .vscode/ .revdep_manual/ vignettes/*_cache/ \ No newline at end of file diff --git a/DESCRIPTION b/DESCRIPTION index e239a2ea..7d2b5a33 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: did Title: Treatment Effects with Multiple Periods and Groups -Version: 2.3.1.903 +Version: 2.3.1.904 Authors@R: c(person("Brantly", "Callaway", email = "brantly.callaway@uga.edu", role = c("aut", "cre")), person("Pedro H. C.", "Sant'Anna", email="pedro.santanna@emory.edu", role = c("aut"))) URL: https://bcallaway11.github.io/did/, https://github.com/bcallaway11/did/ Description: The standard Difference-in-Differences (DID) setup involves two periods and two groups -- a treated group and untreated group. Many applications of DID methods involve more than two periods and have individuals that are treated at different points in time. This package contains tools for computing average treatment effect parameters in Difference in Differences setups with more than two periods and with variation in treatment timing using the methods developed in Callaway and Sant'Anna (2021) . The main parameters are group-time average treatment effects which are the average treatment effect for a particular group at a a particular time. These can be aggregated into a fewer number of treatment effect parameters, and the package deals with the cases where there is selective treatment timing, dynamic treatment effects, calendar time effects, or combinations of these. There are also functions for testing the Difference in Differences assumption, and plotting group-time average treatment effects. diff --git a/NAMESPACE b/NAMESPACE index c5dab2b9..969f75f5 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -38,12 +38,15 @@ export(splot) export(test.mboot) export(tidy) export(trimmer) -import(BMisc) import(data.table) import(fastglm) import(ggplot2) -import(stats) -import(utils) +importFrom(BMisc,TorF) +importFrom(BMisc,getListElement) +importFrom(BMisc,makeBalancedPanel) +importFrom(BMisc,multiplier_bootstrap) +importFrom(BMisc,rhs.vars) +importFrom(BMisc,toformula) importFrom(DRDID,drdid_panel) importFrom(DRDID,drdid_rc) importFrom(DRDID,reg_did_panel) @@ -55,5 +58,23 @@ importFrom(generics,glance) importFrom(generics,tidy) importFrom(methods,as) importFrom(methods,is) +importFrom(stats,aggregate) +importFrom(stats,binomial) +importFrom(stats,complete.cases) +importFrom(stats,cov) +importFrom(stats,ecdf) +importFrom(stats,glm) +importFrom(stats,model.frame) +importFrom(stats,model.matrix) +importFrom(stats,na.pass) importFrom(stats,nobs) +importFrom(stats,pchisq) +importFrom(stats,pnorm) +importFrom(stats,predict) +importFrom(stats,qnorm) +importFrom(stats,quantile) +importFrom(stats,rnorm) +importFrom(stats,setNames) +importFrom(stats,var) importFrom(tidyr,gather) +importFrom(utils,globalVariables) diff --git a/NEWS.md b/NEWS.md index 6760543c..aa6537c7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,53 @@ +# did 2.3.1.904 + + * Fixed bug where `faster_mode = TRUE` and `faster_mode = FALSE` produced different ATT estimates when sampling weights (`weightsname`) vary across time. The fast path was always using first-period weights; it now correctly uses the same period's weights as the slow path + + * New `fix_weights` argument in `att_gt()` gives users explicit control over how time-varying sampling weights are resolved in each 2x2 DiD comparison. Options: `NULL` (default, preserves existing behavior), `"varying"` (per-observation weights using RC estimators), `"base_period"` (fix at g-1 for all cells), `"first_period"` (fix at first period). See `?att_gt` for details + + * Runtime message when time-varying weights are detected in balanced panel data, directing users to the `fix_weights` argument + + * Reduced namespace pollution: replaced blanket `import(stats)`, `import(utils)`, and `import(BMisc)` with selective `importFrom()` calls. The `did` package no longer re-exports `stats::filter` or `stats::lag`, which previously masked `dplyr::filter` and `dplyr::lag` when both packages were loaded + + * Fixed `aggte()` crash (`"Error in get(gname): invalid first argument"`) when the user's group column is literally named `gname` and `dreamerr` >= 1.5.0 is installed. The issue was `dreamerr` intercepting `data.table`'s `get()` inside `[.data.table`; replaced with `set()` which is immune to this + + * Expanded `weightsname` documentation explaining how time-varying weights are handled differently for balanced panels vs. repeated cross sections and unbalanced panels + + * Added `nobs()` S3 methods for `MP` and `AGGTEobj` objects, returning the number of unique cross-sectional units as an integer + + * Added `statistic` (t-statistic) and `p.value` (pointwise, two-sided) columns to `tidy()` output for both `MP` and `AGGTEobj` objects, following `broom` conventions + + * Fixed `glance.MP()` returning `NULL` for `ngroup` and `ntime` when `faster_mode = TRUE` (DIDparams2 uses different field names than DIDparams) + + * Fixed influence function aggregation bug in `fix_weights = "varying"` on balanced panel: the slow path now correctly uses `rowsum()` by unit ID instead of assuming stacked ordering + + * Fixed `fix_weights = "base_period"` / `"first_period"` for unbalanced panels: corrected influence function length mismatch after weight-based unit dropping, and fast path now properly excludes units missing from the target period + + * Added validation: `fix_weights = "base_period"` and `"first_period"` are blocked for repeated cross sections (`panel = FALSE`) with a clear error message + + * Added validation: `fix_weights = "varying"` is blocked when using a custom `est_method` function, since the varying path uses RC estimators internally with a different function signature + + * Completed namespace cleanup: added missing `@importFrom` for `stats::ecdf`, `stats::glm`, `stats::predict`; registered `.w`, `D`, `N`, `post`, `weights` as `globalVariables` for data.table column references. `R CMD check` now passes with 0 code-related NOTEs + + * Replaced fragile `exists("use_rc_for_weights")` with direct `dp2$fix_weights` check in `compute.att_gt2()` + + * Fixed `data.table` `.checkTypos` crash in `get_wide_data()` when user column names match local variable names (e.g., column named `tname`) + + * Replaced `get()` with `c()` in time-varying weight detection grouping to avoid potential `dreamerr` interception + + * Inference tests (`test-inference.R`): switched to HTTPS mirror, added `requireNamespace` verification, wrapped install in `skip_on_cran` guard, added proper temp directory cleanup + + * Added GitHub Action to auto-bump dev version in DESCRIPTION on PR merge + + * Substantially expanded test suite covering `glance()`, `ggdid()`, error handling, edge cases, all aggregation types, and systematic `faster_mode` consistency across 36 parameter combinations. Test suite now runs with 0 warnings (previously 66+) + +# did 2.3.1.903 + + * Added `nobs()` S3 methods and `statistic`/`p.value` columns to `tidy()` output (superseded by 2.3.1.904 entry above) + +# did 2.3.1.902 + + * Bug fixes, diagnostic improvements, and JEL replication tests + # did 2.3.1.901 * `att_gt()` now accepts `...` (dots) for passing additional arguments to custom `est_method` functions diff --git a/R/DIDparams.R b/R/DIDparams.R index 7aafd3a5..7953d8a5 100644 --- a/R/DIDparams.R +++ b/R/DIDparams.R @@ -25,6 +25,7 @@ DIDparams <- function(yname, control_group, anticipation=0, weightsname=NULL, + fix_weights=NULL, alp=0.05, bstrap=TRUE, biters=1000, @@ -54,6 +55,7 @@ DIDparams <- function(yname, control_group=control_group, anticipation=anticipation, weightsname=weightsname, + fix_weights=fix_weights, alp=alp, bstrap=bstrap, biters=biters, diff --git a/R/DIDparams2.R b/R/DIDparams2.R index 108b50bf..8c88a2c0 100644 --- a/R/DIDparams2.R +++ b/R/DIDparams2.R @@ -49,6 +49,8 @@ DIDparams2 <- function(did_tensors, args, call=NULL) { covariates_matrix <- did_tensors$covariates_matrix cluster_vector <- did_tensors$cluster weights_vector <- did_tensors$weights + weights_tensor <- did_tensors$weights_tensor + fix_weights <- args$fix_weights out <- list(yname=yname, @@ -89,6 +91,8 @@ DIDparams2 <- function(did_tensors, args, call=NULL) { covariates_matrix = covariates_matrix, cluster_vector=cluster_vector, weights_vector=weights_vector, + weights_tensor=weights_tensor, + fix_weights=fix_weights, call=call) class(out) <- "DIDparams" return(out) diff --git a/R/att_gt.R b/R/att_gt.R index b5663b41..1f3e633f 100644 --- a/R/att_gt.R +++ b/R/att_gt.R @@ -16,7 +16,54 @@ #' It defines which "group" a unit belongs to. It should be 0 for units #' in the untreated group. #' @param weightsname The name of the column containing the sampling weights. -#' If not set, all observations have same weight. +#' If not set, all observations have same weight. When weights are +#' time-invariant (constant within each unit across periods), all +#' \code{fix_weights} options produce identical results and no special +#' handling is needed. +#' +#' When weights vary across time (e.g., time-varying population sizes), +#' the default behavior differs by panel type: +#' \describe{ +#' \item{Balanced panel}{Each 2x2 DiD comparison uses the weight from the +#' earlier of the two time periods involved. For post-treatment cells, +#' this is the base period (g-1). For pre-treatment cells with +#' \code{base_period="varying"}, this is the pre-treatment period itself. +#' The panel DRDID estimators are used.} +#' \item{Repeated cross sections and unbalanced panels}{Both periods' +#' per-observation weights are passed directly to the RC DRDID estimators, +#' so each observation carries its own period-specific weight.} +#' } +#' Use the \code{fix_weights} argument to override the default behavior. +#' @param fix_weights Controls how time-varying sampling weights are resolved. +#' Only relevant when weights vary across time; with time-invariant weights, +#' all options produce identical results. Options: +#' \describe{ +#' \item{\code{NULL} (default)}{For balanced panel: uses the weight from +#' the earlier of the two time periods in each 2x2 comparison. For +#' post-treatment cells, this is the base period (g-1). For +#' pre-treatment cells, this depends on the \code{base_period} setting. +#' For RC/unbalanced panel: uses per-observation weights from both +#' periods.} +#' \item{\code{"varying"}}{Uses per-observation, period-specific weights +#' for all panel types. For balanced panel data, this switches to the +#' repeated cross-section DRDID estimators so that pre-period and +#' post-period observations each carry their own weight. Covariates +#' are held fixed at their pre-period values (same as the default +#' panel estimator). This is the most flexible option for weights but +#' sacrifices the efficiency of the panel estimator. For RC/unbalanced +#' panel, this is identical to the default. Not supported with custom +#' \code{est_method} functions.} +#' \item{\code{"base_period"}}{Fixes weights at the base period (g-1) for +#' all (g,t) cells within a group, for both pre-treatment and +#' post-treatment comparisons. Ensures all cells within a group use the +#' same weights. For unbalanced panels, units not observed in the base +#' period are dropped with a warning. Not supported for repeated cross +#' sections (\code{panel = FALSE}).} +#' \item{\code{"first_period"}}{Fixes weights at the first time period in +#' the dataset for all (g,t) cells. For unbalanced panels, units not +#' observed in the first period are dropped with a warning. Not supported +#' for repeated cross sections (\code{panel = FALSE}).} +#' } #' @param alp the significance level, default is 0.05 #' @param bstrap Boolean for whether or not to compute standard errors using #' the multiplier bootstrap. If standard errors are clustered, then one @@ -44,16 +91,27 @@ #' include "ipw" for inverse probability weighting and "reg" for #' first step regression estimators. The user can also pass their #' own function for estimating group time average treatment -#' effects. This should be a function -#' `f(Y1,Y0,treat,covariates)` where `Y1` is an -#' `n` x `1` vector of outcomes in the post-treatment -#' outcomes, `Y0` is an `n` x `1` vector of -#' pre-treatment outcomes, `treat` is a vector indicating -#' whether or not an individual participates in the treatment, -#' and `covariates` is an `n` x `k` matrix of -#' covariates. The function should return a list that includes -#' `ATT` (an estimated average treatment effect), and -#' `inf.func` (an `n` x `1` influence function). +#' effects. The required signature depends on the data structure: +#' +#' **Panel data** (`panel=TRUE`): `f(y1, y0, D, covariates, +#' i.weights, inffunc, ...)` where `y1` is an `n x 1` vector of +#' post-treatment outcomes, `y0` is an `n x 1` vector of +#' pre-treatment outcomes, `D` is a binary vector indicating +#' treatment group membership, `covariates` is an `n x k` matrix, +#' `i.weights` is a vector of sampling weights, and `inffunc` is a +#' logical requesting influence-function computation. +#' +#' **Repeated cross sections / unbalanced panel** (`panel=FALSE`): +#' `f(y, post, D, covariates, i.weights, inffunc, ...)` where `y` is +#' the outcome vector (length `n`), `post` is a binary indicator for +#' the post-treatment period, `D` is a binary treatment indicator, +#' `covariates` is an `n x k` matrix, `i.weights` is a vector of +#' sampling weights, and `inffunc` is a logical. +#' +#' In both cases the function should return a list that includes +#' `ATT` (the estimated group-time average treatment effect) and +#' `att.inf.func` (an `n x 1` influence function — one entry per +#' observation passed into the estimator). #' The function can return other things as well, but these are #' the only two that are required. `est_method` is only used #' if covariates are included. @@ -195,6 +253,7 @@ att_gt <- function(yname, control_group = c("nevertreated", "notyettreated"), anticipation = 0, weightsname = NULL, + fix_weights = NULL, alp = 0.05, bstrap = TRUE, cband = TRUE, @@ -217,6 +276,25 @@ att_gt <- function(yname, "\". Extra arguments are only passed to custom est_method functions.") } + # Validate fix_weights + if (!is.null(fix_weights)) { + if (!is.character(fix_weights) || length(fix_weights) != 1 || + !(fix_weights %in% c("varying", "base_period", "first_period"))) { + stop("fix_weights must be NULL or one of \"varying\", \"base_period\", or \"first_period\".") + } + if (!panel && fix_weights %in% c("base_period", "first_period")) { + stop("fix_weights = \"", fix_weights, "\" is not supported for repeated cross sections ", + "(panel = FALSE) because units are not tracked across periods. ", + "Use fix_weights = \"varying\" or NULL instead.") + } + if (fix_weights == "varying" && panel && inherits(est_method, "function")) { + stop("fix_weights = \"varying\" is not currently supported with custom est_method functions ", + "on panel data. The \"varying\" option uses repeated cross-section estimators internally, ", + "which require a different function signature (y, post, D) than the panel signature ", + "(y1, y0, D). Use fix_weights = NULL, \"base_period\", or \"first_period\" instead.") + } + } + # Validate est_method if (!inherits(est_method, "function")) { if (!is.character(est_method) || length(est_method) != 1) { @@ -249,6 +327,7 @@ att_gt <- function(yname, control_group = control_group, anticipation = anticipation, weightsname = weightsname, + fix_weights = fix_weights, alp = alp, bstrap = bstrap, cband = cband, @@ -284,6 +363,7 @@ att_gt <- function(yname, control_group = control_group, anticipation = anticipation, weightsname = weightsname, + fix_weights = fix_weights, alp = alp, bstrap = bstrap, cband = cband, diff --git a/R/compute.aggte.R b/R/compute.aggte.R index a93ac752..ec02770e 100644 --- a/R/compute.aggte.R +++ b/R/compute.aggte.R @@ -57,7 +57,7 @@ compute.aggte <- function(MP, } if (isTRUE(dp$faster_mode)) { dt <- dp$data - dt[get(gname) == Inf, (gname) := 0] # going back to the old way + set(dt, i = which(dt[[gname]] == Inf), j = gname, value = 0) # going back to the old way data <- as.data.frame(dt) rm(dt) tlist <- dp$time_periods diff --git a/R/compute.att_gt.R b/R/compute.att_gt.R index bef72c36..251238d9 100644 --- a/R/compute.att_gt.R +++ b/R/compute.att_gt.R @@ -65,10 +65,20 @@ compute.att_gt <- function(dp) { # never treated option nevertreated <- (control_group[1] == "nevertreated") + fix_weights <- dp$fix_weights + # Pre-extract columns to avoid repeated get() inside data.table (which is slow) g_col <- data[[gname]] t_col <- data[[tname]] + # Build weight lookup by period for fix_weights options (balanced panel only) + if (!is.null(fix_weights) && panel) { + weights_by_period <- list() + for (tp in seq_along(tlist)) { + weights_by_period[[tp]] <- data[t_col == tlist[tp], .w] + } + } + if (nevertreated) { set(data, j = ".C", value = as.integer(g_col == 0)) } @@ -193,95 +203,155 @@ compute.att_gt <- function(dp) { # base period, then the "base period" is actually the later period Ypre <- if (tlist[(t + tfac)] > tlist[pret]) disdat$.y0 else disdat$.y1 Ypost <- if (tlist[(t + tfac)] > tlist[pret]) disdat$.y1 else disdat$.y0 - w <- disdat$.w + + # Select weights based on fix_weights + if (is.null(fix_weights)) { + # Default: .w from get_wide_data (earlier period) + w <- disdat$.w + } else if (fix_weights == "base_period") { + w <- weights_by_period[[pret_g]][disidx] + } else if (fix_weights == "first_period") { + w <- weights_by_period[[1L]][disidx] + } else if (fix_weights == "varying") { + w <- disdat$.w # will be overridden below when switching to RC estimator + } else { + w <- disdat$.w + } # matrix of covariates covariates <- model.matrix(xformla, data = disdat) - #----------------------------------------------------------------------------- - # more checks for enough observations in each group - - # if using custom estimation method, skip this part - custom_est_method <- is.function(est_method) - - if (!custom_est_method) { - pscore_problems_likely <- FALSE - reg_problems_likely <- FALSE - - # checks for pscore based methods - if (est_method %in% c("dr", "ipw")) { - preliminary_logit <- fastglm::fastglm(covariates, G, family = binomial()) - preliminary_pscores <- preliminary_logit$fitted.values - if (max(preliminary_pscores) >= 0.999) { - pscore_problems_likely <- TRUE - warning(paste0("overlap condition violated for ", glist[g], " in time period ", tlist[t + tfac])) - } - } - - # check if can run regression using control units - if (est_method %in% c("dr", "reg")) { - control_covs <- covariates[G == 0, , drop = FALSE] - # if (determinant(t(control_covs)%*%control_covs, logarithm=FALSE)$modulus < .Machine$double.eps) { - if (rcond(t(control_covs) %*% control_covs) < .Machine$double.eps) { - reg_problems_likely <- TRUE - warning(paste0("Not enough control units for group ", glist[g], " in time period ", tlist[t + tfac], " to run specified regression")) - } - } - - if (reg_problems_likely | pscore_problems_likely) { - attgt.list[[counter]] <- list(att = NA, group = glist[g], year = tlist[(t + tfac)], post = post.treat) - inffunc_updates[[update_counter]] <- list( - indices = seq_len(n), - values = rep(NA_real_, n) - ) - - # Update the counters - update_counter <- update_counter + 1 - counter <- counter + 1 - next - } - } - #----------------------------------------------------------------------------- # code for actually computing att(g,t) #----------------------------------------------------------------------------- attgt <- tryCatch({ - if (inherits(est_method, "function")) { - # user-specified function - res <- do.call(est_method, c(list( - y1 = Ypost, y0 = Ypre, - D = G, - covariates = covariates, - i.weights = w, - inffunc = TRUE - ), extra_args)) - } else if (est_method == "ipw") { - # inverse-probability weights - res <- DRDID::std_ipw_did_panel(Ypost, Ypre, G, - covariates = covariates, - i.weights = w, - boot = FALSE, inffunc = TRUE - ) - } else if (est_method == "reg") { - # regression - res <- DRDID::reg_did_panel(Ypost, Ypre, G, - covariates = covariates, - i.weights = w, - boot = FALSE, inffunc = TRUE - ) + if (!is.null(fix_weights) && fix_weights == "varying") { + # fix_weights = "varying": use RC estimators with per-period weights + # Go back to long-format data for this (g,t) cell + disdat_long <- data[time_mask] + disdat_long_idx <- disdat_long$.G == 1 | disdat_long$.C == 1 + disdat_long <- droplevels(disdat_long[disdat_long_idx]) + Y_rc <- disdat_long[[yname]] + G_rc <- disdat_long$.G + post_rc <- as.numeric(disdat_long[[tname]] == tlist[t + tfac]) + w_rc <- disdat_long$.w + # Use earlier-period covariates for all observations — fix_weights + # only changes weights, not the covariate conditioning set. + # Use min(pret, t+tfac) to match the panel estimator's convention: + # with base_period="universal", pret can be later than t for placebo cells. + earlier_period <- tlist[min(pret, t + tfac)] + early_mask <- disdat_long[[tname]] == earlier_period + disdat_early <- disdat_long[early_mask] + cov_early <- model.matrix(xformla, data = disdat_early) + # Map each row in disdat_long to its unit's earlier-period covariates + early_ids <- disdat_early[[idname]] + all_ids <- disdat_long[[idname]] + id_map <- match(all_ids, early_ids) + covariates_rc <- cov_early[id_map, , drop = FALSE] + + # Run overlap/rank checks on RC data (not wide panel data) + if (!is.function(est_method)) { + if (est_method %in% c("dr", "ipw")) { + preliminary_logit <- fastglm::fastglm(covariates_rc, G_rc, family = binomial()) + if (max(preliminary_logit$fitted.values) >= 0.999) { + warning(paste0("overlap condition violated for ", glist[g], " in time period ", tlist[t + tfac])) + stop("overlap") + } + } + if (est_method %in% c("dr", "reg")) { + control_covs_rc <- covariates_rc[G_rc == 0, , drop = FALSE] + if (rcond(t(control_covs_rc) %*% control_covs_rc) < .Machine$double.eps) { + warning(paste0("Not enough control units for group ", glist[g], " in time period ", tlist[t + tfac], " to run specified regression")) + stop("singular") + } + } + } + + if (inherits(est_method, "function")) { + res <- do.call(est_method, c(list( + y = Y_rc, post = post_rc, + D = G_rc, covariates = covariates_rc, + i.weights = w_rc, inffunc = TRUE + ), extra_args)) + } else if (est_method == "ipw") { + res <- DRDID::std_ipw_did_rc(Y_rc, post_rc, G_rc, + covariates = covariates_rc, + i.weights = w_rc, boot = FALSE, inffunc = TRUE) + } else if (est_method == "reg") { + res <- DRDID::reg_did_rc(Y_rc, post_rc, G_rc, + covariates = covariates_rc, + i.weights = w_rc, boot = FALSE, inffunc = TRUE) + } else { + res <- DRDID::drdid_rc(Y_rc, post_rc, G_rc, + covariates = covariates_rc, + i.weights = w_rc, boot = FALSE, inffunc = TRUE) + } } else { - # doubly robust, this is default - res <- DRDID::drdid_panel(Ypost, Ypre, G, - covariates = covariates, - i.weights = w, - boot = FALSE, inffunc = TRUE - ) + # Panel path: run overlap/rank checks on panel data + if (!is.function(est_method)) { + if (est_method %in% c("dr", "ipw")) { + preliminary_logit <- fastglm::fastglm(covariates, G, family = binomial()) + if (max(preliminary_logit$fitted.values) >= 0.999) { + warning(paste0("overlap condition violated for ", glist[g], " in time period ", tlist[t + tfac])) + stop("overlap") + } + } + if (est_method %in% c("dr", "reg")) { + control_covs <- covariates[G == 0, , drop = FALSE] + if (rcond(t(control_covs) %*% control_covs) < .Machine$double.eps) { + warning(paste0("Not enough control units for group ", glist[g], " in time period ", tlist[t + tfac], " to run specified regression")) + stop("singular") + } + } + } + + if (inherits(est_method, "function")) { + # user-specified function + res <- do.call(est_method, c(list( + y1 = Ypost, y0 = Ypre, + D = G, + covariates = covariates, + i.weights = w, + inffunc = TRUE + ), extra_args)) + } else if (est_method == "ipw") { + # inverse-probability weights + res <- DRDID::std_ipw_did_panel(Ypost, Ypre, G, + covariates = covariates, + i.weights = w, + boot = FALSE, inffunc = TRUE + ) + } else if (est_method == "reg") { + # regression + res <- DRDID::reg_did_panel(Ypost, Ypre, G, + covariates = covariates, + i.weights = w, + boot = FALSE, inffunc = TRUE + ) + } else { + # doubly robust, this is default + res <- DRDID::drdid_panel(Ypost, Ypre, G, + covariates = covariates, + i.weights = w, + boot = FALSE, inffunc = TRUE + ) + } } # adjust influence function to account for only using # subgroup to estimate att(g,t) - res$att.inf.func <- (n / n1) * res$att.inf.func + if (!is.null(fix_weights) && fix_weights == "varying") { + # RC influence function has one entry per obs in disdat_long + # (2 per unit: pre + post). Aggregate to unit level by ID, + # independent of row ordering. + res$att.inf.func <- as.numeric(rowsum(res$att.inf.func, + disdat_long[[idname]], + reorder = FALSE)) + res$att.inf.func <- (n / n1) * res$att.inf.func + } else { + res$att.inf.func <- (n / n1) * res$att.inf.func + } res }, error = function(e) { warning("Error computing internal 2x2 DiD for (g, t) = (", glist[g], ", ", tlist[t + tfac], "): ", e$message, ". The ATT for this cell will be set to NA.") @@ -325,7 +395,38 @@ compute.att_gt <- function(dp) { post <- 1 * (disdat[[tname]] == tlist[t + tfac]) # num obs. for computing ATT(g,t), have to be careful here n1 <- sum(G + C) - w <- disdat$.w + + # Handle fix_weights for RC/unbalanced panel + if (!is.null(fix_weights) && fix_weights %in% c("base_period", "first_period")) { + # Determine which period's weight to use + if (fix_weights == "base_period") { + target_period <- tlist[pret_g] + } else { + target_period <- tlist[1] + } + # Build lookup: weight from target period per unit + target_rows <- data[t_col == target_period, ] + target_w <- stats::setNames(target_rows$.w, target_rows$.rowid) + # Look up weight for each observation's unit + w <- as.numeric(target_w[as.character(disdat$.rowid)]) + # Drop units not observed in the target period + missing_w <- is.na(w) + if (any(missing_w)) { + n_dropped <- length(unique(disdat$.rowid[missing_w])) + warning(paste0("Dropped ", n_dropped, " units not observed in ", + fix_weights, " (period ", target_period, ") ", + "for group ", glist[g], " in time period ", tlist[t + tfac])) + disdat <- disdat[!missing_w, ] + G <- disdat$.G + C <- disdat$.C + Y <- disdat[[yname]] + post <- 1 * (disdat[[tname]] == tlist[t + tfac]) + n1 <- sum(G + C) + w <- w[!missing_w] + } + } else { + w <- disdat$.w + } #----------------------------------------------------------------------------- # checks to make sure that we have enough observations @@ -500,7 +601,9 @@ compute.att_gt <- function(dp) { ) } else { # aggregate inf functions by id (order by id) - aggte_inffunc <- suppressWarnings(stats::aggregate(attgt$att.inf.func, list(rightids), sum)) + # Use current disdat$.rowid (may differ from rightids if fix_weights dropped obs) + current_ids <- disdat$.rowid[disdat$.G == 1 | disdat$.C == 1] + aggte_inffunc <- suppressWarnings(stats::aggregate(attgt$att.inf.func, list(current_ids), sum)) idx <- which(unique(data$.rowid) %in% aggte_inffunc[, 1]) inffunc_updates[[update_counter]] <- list( indices = idx, diff --git a/R/compute.att_gt2.R b/R/compute.att_gt2.R index bfdf3eae..32a2cba0 100644 --- a/R/compute.att_gt2.R +++ b/R/compute.att_gt2.R @@ -88,12 +88,12 @@ get_did_cohort_index <- function(group, time, tfac, pret, dp2){ #' #' @return A list containing the estimated ATT and the influence function vector. #' @noRd -run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL){ +run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL, force_rc = FALSE){ extra_args <- if (is.null(dp2$extra_args)) list() else dp2$extra_args gt_label <- if (!is.null(g_val) && !is.null(t_val)) paste0(" for group ", g_val, " in time period ", t_val) else "" - if(dp2$panel){ + if(dp2$panel && !force_rc){ # -------------------------------------- # Panel Data # -------------------------------------- @@ -376,25 +376,100 @@ run_att_gt_estimation <- function(g, t, dp2){ if(dp2$panel){ - cohort_data <- data.table(did_cohort_index, dp2$outcomes_tensor[[t+tfac]], dp2$outcomes_tensor[[pret]], dp2$weights_vector) - names(cohort_data) <- c("D", "y1", "y0", "i.weights") - covariates <- dp2$covariates_tensor[[base::min(pret, t)]] + # Determine which weight period to use based on fix_weights + use_rc_for_weights <- (!is.null(dp2$fix_weights) && dp2$fix_weights == "varying") + + if (use_rc_for_weights) { + # fix_weights = "varying": stack into RC format with per-period weights + n_units <- length(did_cohort_index) + cohort_data <- data.table( + D = rep(did_cohort_index, 2), + y = c(dp2$outcomes_tensor[[pret]], dp2$outcomes_tensor[[t+tfac]]), + post = rep(c(0L, 1L), each = n_units), + i.weights = c(dp2$weights_tensor[[pret]], dp2$weights_tensor[[t+tfac]]) + ) + # Use earlier-period covariates for both halves — fix_weights only + # changes weights, not the covariate conditioning set. + # Use min(pret, t) to match the panel estimator's convention: + # with base_period="universal", pret can be later than t for placebo cells. + cov_early <- dp2$covariates_tensor[[base::min(pret, t)]] + if (is.matrix(cov_early)) { + covariates <- rbind(cov_early, cov_early) + } else { + covariates <- c(cov_early, cov_early) + } + } else { + # Default or fixed weight options: use panel estimator with single weight vector + if (is.null(dp2$fix_weights)) { + # Default: weight from earlier of the two periods + w_idx <- base::min(pret, t) + } else if (dp2$fix_weights == "base_period") { + w_idx <- dp2$.pret_by_group[g] + } else if (dp2$fix_weights == "first_period") { + w_idx <- 1L + } + cohort_data <- data.table(did_cohort_index, dp2$outcomes_tensor[[t+tfac]], + dp2$outcomes_tensor[[pret]], dp2$weights_tensor[[w_idx]]) + names(cohort_data) <- c("D", "y1", "y0", "i.weights") + covariates <- dp2$covariates_tensor[[base::min(pret, t)]] + } } else { log_vec <- dp2$time_invariant_data[[ dp2$tname ]] == dp2$time_periods[t+tfac] # convert TRUE/FALSE to 1/0 in place (fastest) set(dp2$time_invariant_data, j = "post", value = as.integer(log_vec)) - cohort_data <- data.table(did_cohort_index, dp2$time_invariant_data[[dp2$yname]], dp2$time_invariant_data$post, dp2$time_invariant_data$weights, dp2$time_invariant_data$.rowid) + + # Handle fix_weights for RC/unbalanced panel + if (!is.null(dp2$fix_weights) && dp2$fix_weights %in% c("base_period", "first_period")) { + if (dp2$fix_weights == "base_period") { + target_period <- dp2$time_periods[dp2$.pret_by_group[g]] + } else { + target_period <- dp2$time_periods[1] + } + # Build weight lookup from target period + tid <- dp2$time_invariant_data + target_mask <- tid[[dp2$tname]] == target_period + target_ids <- tid[[dp2$idname]][target_mask] + target_ws <- tid[["weights"]][target_mask] + target_w_lookup <- stats::setNames(target_ws, as.character(target_ids)) + # Look up weight for each observation + obs_ids <- as.character(tid[[dp2$idname]]) + fixed_w <- as.numeric(target_w_lookup[obs_ids]) + # Exclude units not observed in target period by setting D to NA + # (run_DRDID filters on !is.na(D)) + na_w <- is.na(fixed_w) + if (any(na_w & !is.na(did_cohort_index))) { + did_cohort_index[na_w] <- NA_integer_ + warning(paste0("Some units not observed in ", dp2$fix_weights, + " (period ", target_period, ") for group ", + dp2$treated_groups[g], " in time period ", + dp2$time_periods[t+tfac], ". These units are excluded.")) + } + cohort_data <- data.table(did_cohort_index, tid[[dp2$yname]], tid$post, fixed_w, tid$.rowid) + } else { + cohort_data <- data.table(did_cohort_index, dp2$time_invariant_data[[dp2$yname]], dp2$time_invariant_data$post, dp2$time_invariant_data$weights, dp2$time_invariant_data$.rowid) + } names(cohort_data) <- c("D", "y", "post", "i.weights", ".rowid") covariates <- dp2$covariates_matrix } # run estimation - did_result <- tryCatch(run_DRDID(cohort_data, covariates, dp2, g_val = dp2$treated_groups[g], t_val = dp2$time_periods[t+tfac]), + force_rc <- !is.null(dp2$fix_weights) && dp2$fix_weights == "varying" && dp2$panel + did_result <- tryCatch(run_DRDID(cohort_data, covariates, dp2, g_val = dp2$treated_groups[g], t_val = dp2$time_periods[t+tfac], force_rc = force_rc), error = function(e) { warning("Error computing internal 2x2 DiD for (g, t) = (", dp2$treated_groups[g], ", ", dp2$time_periods[t+tfac], "): ", e$message, ". The ATT for this cell will be set to NA.") return(NULL) }) + + # When force_rc on balanced panel, the influence function has 2*n_units rows. + # Half-split is safe here: cohort_data is explicitly stacked as + # [all pre, all post] via rep(c(0L, 1L), each = n_units) in construction above. + if (force_rc && !is.null(did_result) && dp2$panel) { + inf <- did_result$inf_func + n_half <- length(inf) %/% 2L + did_result$inf_func <- inf[1:n_half] + inf[(n_half + 1):(2L * n_half)] + } + return(did_result) } @@ -457,13 +532,9 @@ compute.att_gt2 <- function(dp2) { # Check for NULL first (estimation failed or was skipped) if (is.null(gt_result)) { - if(dp2$base_period == "universal"){ - inffunc_updates <- rep(NA_real_, n) - gt_result <- list(att = NA, group = dp2$treated_groups[g], year = dp2$time_periods[t+tfac], post = post.treat, inffunc_updates = inffunc_updates) - return(gt_result) - } else { - return(NULL) - } + inffunc_updates <- rep(NA_real_, n) + gt_result <- list(att = NA, group = dp2$treated_groups[g], year = dp2$time_periods[t+tfac], post = post.treat, inffunc_updates = inffunc_updates) + return(gt_result) } # Base period normalization: ATT is 0 by construction @@ -475,14 +546,9 @@ compute.att_gt2 <- function(dp2) { if (is.null(gt_result$att)) { # Estimation returned a result but without an ATT - if(dp2$base_period == "universal"){ - inffunc_updates <- rep(NA_real_, n) - gt_result <- list(att = NA, group = dp2$treated_groups[g], year = dp2$time_periods[t+tfac], post = post.treat, inffunc_updates = inffunc_updates) - return(gt_result) - } else { - return(NULL) - } - + inffunc_updates <- rep(NA_real_, n) + gt_result <- list(att = NA, group = dp2$treated_groups[g], year = dp2$time_periods[t+tfac], post = post.treat, inffunc_updates = inffunc_updates) + return(gt_result) } else { att <- gt_result$att inf_func <- gt_result$inf_func diff --git a/R/gplot.R b/R/gplot.R index 932c7d15..87db6fda 100644 --- a/R/gplot.R +++ b/R/gplot.R @@ -80,12 +80,18 @@ splot <- function(ssresults, ylim=NULL, xlab=NULL, ylab=NULL, title="Group", ylab <- 'Group' } + errorbar_layer <- if (utils::packageVersion("ggplot2") >= "3.3.0") { + geom_errorbar(aes(colour=post), width=0.1, orientation="y") + } else { + geom_errorbarh(aes(colour=post), height=0.1) + } + p <- ggplot(ssresults, aes(y=as.factor(year), x=att, xmin=(att-c*att.se), xmax=(att+c*att.se))) + geom_point(aes(colour=post), size=1.5) + #geom_ribbon(aes(x=as.numeric(year)), alpha=0.2) + - geom_errorbarh(aes(colour=post), height=0.1) + + errorbar_layer + scale_y_discrete(breaks=as.factor(ssresults$year)) + #scale_x_discrete(breaks=dabreaks, labels=as.character(dabreaks)) + scale_x_continuous(limits=ylim) + diff --git a/R/imports.R b/R/imports.R index b2bc0bf3..e5d2753b 100644 --- a/R/imports.R +++ b/R/imports.R @@ -6,10 +6,13 @@ #' @keywords internal "_PACKAGE" -#' @import stats -#' @import utils +#' @importFrom stats pnorm qnorm pchisq quantile cov aggregate setNames +#' model.frame model.matrix na.pass complete.cases binomial rnorm var +#' ecdf glm predict +#' @importFrom utils globalVariables #' @import ggplot2 -#' @import BMisc +#' @importFrom BMisc toformula rhs.vars makeBalancedPanel getListElement +#' multiplier_bootstrap TorF #' @import data.table #' @import fastglm #' @importFrom tidyr gather @@ -18,7 +21,8 @@ #' @importFrom DRDID drdid_panel reg_did_panel std_ipw_did_panel std_ipw_did_rc reg_did_rc drdid_rc NULL utils::globalVariables(c( - ".", ".G", ".y", "asif_never_treated", "treated_first_period", "count", "constant", ".rowid", + ".", ".G", ".y", ".w", "asif_never_treated", "treated_first_period", "count", "constant", ".rowid", "V1", "control_group", "cohort", "cohort_size", "period", "period_size", "y1", "y0", + "D", "N", "post", "weights", "i.weights", "y", "cluster", "id", "..cols_to_keep", "..g", "inf_func_long", "inf_func_agg" )) diff --git a/R/pre_process_did.R b/R/pre_process_did.R index 3ee13c5e..e51caf90 100644 --- a/R/pre_process_did.R +++ b/R/pre_process_did.R @@ -21,6 +21,7 @@ pre_process_did <- function(yname, control_group = c("nevertreated","notyettreated"), anticipation = 0, weightsname = NULL, + fix_weights = NULL, alp = 0.05, bstrap = FALSE, cband = FALSE, @@ -95,6 +96,20 @@ pre_process_did <- function(yname, if (".w" %in% colnames(data)) stop("Your data already contains a column named '.w', which is reserved for internal use by `did`. Please rename this column before calling att_gt().") data$.w <- w + # Check for time-varying weights in panel data + if (!is.null(weightsname) && panel) { + w_by_id <- tapply(data[, weightsname], data[, idname], function(x) max(x) - min(x)) + if (any(w_by_id > .Machine$double.eps^0.5, na.rm = TRUE)) { + message( + "Time-varying weights detected. For balanced panel data, the default ", + "behavior uses the weight from the earlier of the two time periods in ", + "each 2x2 comparison (the base period for post-treatment cells). ", + "Use the 'fix_weights' argument to control this behavior. ", + "See ?att_gt for details." + ) + } + } + # Outcome variable will be denoted by y # data$.y <- data[, yname] @@ -389,6 +404,7 @@ pre_process_did <- function(yname, control_group=control_group, anticipation=anticipation, weightsname=weightsname, + fix_weights=fix_weights, alp=alp, bstrap=bstrap, biters=biters, diff --git a/R/pre_process_did2.R b/R/pre_process_did2.R index 733fbe56..e62268b5 100644 --- a/R/pre_process_did2.R +++ b/R/pre_process_did2.R @@ -131,6 +131,20 @@ did_standardization <- function(data, args){ weights <- weights/mean(weights) data$weights <- weights + # Check for time-varying weights in panel data + if (!is.null(args$weightsname) && args$panel) { + w_range <- data[, .(w_range = max(weights) - min(weights)), by = c(args$idname)] + if (any(w_range$w_range > .Machine$double.eps^0.5, na.rm = TRUE)) { + message( + "Time-varying weights detected. For balanced panel data, the default ", + "behavior uses the weight from the earlier of the two time periods in ", + "each 2x2 comparison (the base period for post-treatment cells). ", + "Use the 'fix_weights' argument to control this behavior. ", + "See ?att_gt for details." + ) + } + } + # get a list of dates from min to max tlist <- data[, sort(unique(get(args$tname)))] @@ -419,6 +433,13 @@ get_did_tensors <- function(data, args){ start <- (time - 1L) * n + 1L outcomes_tensor[[time]] <- y_vec[start:(start + n - 1L)] } + # Build weights tensor: one weight vector per time period + w_vec <- data[["weights"]] + weights_tensor <- vector("list", nT) + for(time in seq_len(nT)){ + start <- (time - 1L) * n + 1L + weights_tensor[[time]] <- w_vec[start:(start + n - 1L)] + } } else { # for(time in args$time_periods){ # outcome_vector_time <- rep(NA, args$id_count) # Initialize vector with NAs @@ -430,6 +451,7 @@ get_did_tensors <- function(data, args){ # data[, outcome_vector_time := NULL] # } outcomes_tensor <- NULL + weights_tensor <- NULL } # Getting the time invariant data @@ -533,7 +555,8 @@ get_did_tensors <- function(data, args){ covariates_matrix = covariates_matrix, covariates_tensor = covariates_tensor, cluster = cluster, - weights = weights)) + weights = weights, + weights_tensor = weights_tensor)) } #' @title Process `did` Function Arguments @@ -559,6 +582,7 @@ pre_process_did2 <- function(yname, control_group = c("nevertreated","notyettreated"), anticipation = 0, weightsname = NULL, + fix_weights = NULL, alp = 0.05, bstrap = FALSE, cband = FALSE, diff --git a/R/tidy.R b/R/tidy.R index bb176f5d..158e6102 100644 --- a/R/tidy.R +++ b/R/tidy.R @@ -84,12 +84,22 @@ tidy.MP <- function(x, ...) { #' @param ... other arguments passed to methods #' @export glance.MP <- function(x, ...) { - out <- data.frame( - nobs = x$n, - ngroup = x$DIDparams$nG, - ntime = x$DIDparams$nT, - control.group = x$DIDparams$control_group, - est.method = x$DIDparams$est_method) + dp <- x$DIDparams + if (isTRUE(dp$faster_mode)) { + out <- data.frame( + nobs = x$n, + ngroup = dp$treated_groups_count, + ntime = dp$time_periods_count, + control.group = dp$control_group, + est.method = dp$est_method) + } else { + out <- data.frame( + nobs = x$n, + ngroup = dp$nG, + ntime = dp$nT, + control.group = dp$control_group, + est.method = dp$est_method) + } out } @@ -205,7 +215,7 @@ glance.AGGTEobj<- function(x, ...) { out <- data.frame( type = x$type, nobs = x$DIDparams$id_count, - ngroup = nrow(x$DIDparams$cohort_counts), + ngroup = x$DIDparams$treated_groups_count, ntime = x$DIDparams$time_periods_count, control.group = x$DIDparams$control_group, est.method = x$DIDparams$est_method) diff --git a/R/utility_functions.R b/R/utility_functions.R index 05e1a6d0..666c85ca 100644 --- a/R/utility_functions.R +++ b/R/utility_functions.R @@ -29,7 +29,7 @@ trimmer <- function(g, tname, idname, gname, xformla, data, control_group="notye this.data$D <- 1*this.data[,gname]==g this.pscore_reg <- glm(BMisc::toformula("D", BMisc::rhs.vars(xformla)), data=this.data, - family=binomial(link=logit)) + family=binomial(link="logit")) this.pscore <- predict(this.pscore_reg, type="response") dropper <- (this.pscore > threshold) & (this.data$D==1) if (sum(dropper) > 0) { @@ -67,9 +67,11 @@ get_wide_data <- function(data, yname, idname, tname) { set(data, j = ".y0", value = y_vals) set(data, j = ".dy", value = data[[".y1"]] - y_vals) - # Subset to first row - first.period <- min(data[[tname]]) - data <- data[data[[tname]] == first.period, ] + # Subset to first period's rows + # Pre-extract to avoid data.table .checkTypos when column name matches variable name + .time_vals <- data[[tname]] + first.period <- min(.time_vals) + data <- data[.time_vals == first.period] return(data) } @@ -86,10 +88,10 @@ get_wide_data <- function(data, yname, idname, tname) { check_balance <- function(data, id_col, time_col) { # Count the number of observations per unit (idname) - panel_counts <- data[, .N, by = get(id_col)] + panel_counts <- data[, .N, by = c(id_col)] # Determine the maximum number of time periods for any unit - max_time_periods <- data[, uniqueN(get(time_col))] + max_time_periods <- data.table::uniqueN(data[[time_col]]) # Check if every unit has the same number of time periods as max_time_periods is_balanced <- all(panel_counts$N == max_time_periods) diff --git a/man/DIDparams.Rd b/man/DIDparams.Rd index 30c0ff65..c1230fc6 100644 --- a/man/DIDparams.Rd +++ b/man/DIDparams.Rd @@ -14,6 +14,7 @@ DIDparams( control_group, anticipation = 0, weightsname = NULL, + fix_weights = NULL, alp = 0.05, bstrap = TRUE, biters = 1000, @@ -86,7 +87,55 @@ in the treatment where units can anticipate participating in the treatment and therefore it can affect their untreated potential outcomes} \item{weightsname}{The name of the column containing the sampling weights. -If not set, all observations have same weight.} +If not set, all observations have same weight. When weights are +time-invariant (constant within each unit across periods), all +\code{fix_weights} options produce identical results and no special +handling is needed. + +When weights vary across time (e.g., time-varying population sizes), +the default behavior differs by panel type: +\describe{ +\item{Balanced panel}{Each 2x2 DiD comparison uses the weight from the +earlier of the two time periods involved. For post-treatment cells, +this is the base period (g-1). For pre-treatment cells with +\code{base_period="varying"}, this is the pre-treatment period itself. +The panel DRDID estimators are used.} +\item{Repeated cross sections and unbalanced panels}{Both periods' +per-observation weights are passed directly to the RC DRDID estimators, +so each observation carries its own period-specific weight.} +} +Use the \code{fix_weights} argument to override the default behavior.} + +\item{fix_weights}{Controls how time-varying sampling weights are resolved. +Only relevant when weights vary across time; with time-invariant weights, +all options produce identical results. Options: +\describe{ +\item{\code{NULL} (default)}{For balanced panel: uses the weight from +the earlier of the two time periods in each 2x2 comparison. For +post-treatment cells, this is the base period (g-1). For +pre-treatment cells, this depends on the \code{base_period} setting. +For RC/unbalanced panel: uses per-observation weights from both +periods.} +\item{\code{"varying"}}{Uses per-observation, period-specific weights +for all panel types. For balanced panel data, this switches to the +repeated cross-section DRDID estimators so that pre-period and +post-period observations each carry their own weight. Covariates +are held fixed at their pre-period values (same as the default +panel estimator). This is the most flexible option for weights but +sacrifices the efficiency of the panel estimator. For RC/unbalanced +panel, this is identical to the default. Not supported with custom +\code{est_method} functions.} +\item{\code{"base_period"}}{Fixes weights at the base period (g-1) for +all (g,t) cells within a group, for both pre-treatment and +post-treatment comparisons. Ensures all cells within a group use the +same weights. For unbalanced panels, units not observed in the base +period are dropped with a warning. Not supported for repeated cross +sections (\code{panel = FALSE}).} +\item{\code{"first_period"}}{Fixes weights at the first time period in +the dataset for all (g,t) cells. For unbalanced panels, units not +observed in the first period are dropped with a warning. Not supported +for repeated cross sections (\code{panel = FALSE}).} +}} \item{alp}{the significance level, default is 0.05} @@ -128,16 +177,26 @@ approach in the \code{DRDID} package. Other built-in methods include "ipw" for inverse probability weighting and "reg" for first step regression estimators. The user can also pass their own function for estimating group time average treatment -effects. This should be a function -\code{f(Y1,Y0,treat,covariates)} where \code{Y1} is an -\code{n} x \code{1} vector of outcomes in the post-treatment -outcomes, \code{Y0} is an \code{n} x \code{1} vector of -pre-treatment outcomes, \code{treat} is a vector indicating -whether or not an individual participates in the treatment, -and \code{covariates} is an \code{n} x \code{k} matrix of -covariates. The function should return a list that includes -\code{ATT} (an estimated average treatment effect), and -\code{inf.func} (an \code{n} x \code{1} influence function). +effects. The required signature depends on the data structure: + +\strong{Panel data} (\code{panel=TRUE}): \code{f(y1, y0, D, covariates, i.weights, inffunc, ...)} where \code{y1} is an \verb{n x 1} vector of +post-treatment outcomes, \code{y0} is an \verb{n x 1} vector of +pre-treatment outcomes, \code{D} is a binary vector indicating +treatment group membership, \code{covariates} is an \verb{n x k} matrix, +\code{i.weights} is a vector of sampling weights, and \code{inffunc} is a +logical requesting influence-function computation. + +\strong{Repeated cross sections / unbalanced panel} (\code{panel=FALSE}): +\code{f(y, post, D, covariates, i.weights, inffunc, ...)} where \code{y} is +the outcome vector (length \code{n}), \code{post} is a binary indicator for +the post-treatment period, \code{D} is a binary treatment indicator, +\code{covariates} is an \verb{n x k} matrix, \code{i.weights} is a vector of +sampling weights, and \code{inffunc} is a logical. + +In both cases the function should return a list that includes +\code{ATT} (the estimated group-time average treatment effect) and +\code{att.inf.func} (an \verb{n x 1} influence function — one entry per +observation passed into the estimator). The function can return other things as well, but these are the only two that are required. \code{est_method} is only used if covariates are included.} diff --git a/man/att_gt.Rd b/man/att_gt.Rd index ac311920..74c621ae 100644 --- a/man/att_gt.Rd +++ b/man/att_gt.Rd @@ -16,6 +16,7 @@ att_gt( control_group = c("nevertreated", "notyettreated"), anticipation = 0, weightsname = NULL, + fix_weights = NULL, alp = 0.05, bstrap = TRUE, cband = TRUE, @@ -96,7 +97,55 @@ in the treatment where units can anticipate participating in the treatment and therefore it can affect their untreated potential outcomes} \item{weightsname}{The name of the column containing the sampling weights. -If not set, all observations have same weight.} +If not set, all observations have same weight. When weights are +time-invariant (constant within each unit across periods), all +\code{fix_weights} options produce identical results and no special +handling is needed. + +When weights vary across time (e.g., time-varying population sizes), +the default behavior differs by panel type: +\describe{ +\item{Balanced panel}{Each 2x2 DiD comparison uses the weight from the +earlier of the two time periods involved. For post-treatment cells, +this is the base period (g-1). For pre-treatment cells with +\code{base_period="varying"}, this is the pre-treatment period itself. +The panel DRDID estimators are used.} +\item{Repeated cross sections and unbalanced panels}{Both periods' +per-observation weights are passed directly to the RC DRDID estimators, +so each observation carries its own period-specific weight.} +} +Use the \code{fix_weights} argument to override the default behavior.} + +\item{fix_weights}{Controls how time-varying sampling weights are resolved. +Only relevant when weights vary across time; with time-invariant weights, +all options produce identical results. Options: +\describe{ +\item{\code{NULL} (default)}{For balanced panel: uses the weight from +the earlier of the two time periods in each 2x2 comparison. For +post-treatment cells, this is the base period (g-1). For +pre-treatment cells, this depends on the \code{base_period} setting. +For RC/unbalanced panel: uses per-observation weights from both +periods.} +\item{\code{"varying"}}{Uses per-observation, period-specific weights +for all panel types. For balanced panel data, this switches to the +repeated cross-section DRDID estimators so that pre-period and +post-period observations each carry their own weight. Covariates +are held fixed at their pre-period values (same as the default +panel estimator). This is the most flexible option for weights but +sacrifices the efficiency of the panel estimator. For RC/unbalanced +panel, this is identical to the default. Not supported with custom +\code{est_method} functions.} +\item{\code{"base_period"}}{Fixes weights at the base period (g-1) for +all (g,t) cells within a group, for both pre-treatment and +post-treatment comparisons. Ensures all cells within a group use the +same weights. For unbalanced panels, units not observed in the base +period are dropped with a warning. Not supported for repeated cross +sections (\code{panel = FALSE}).} +\item{\code{"first_period"}}{Fixes weights at the first time period in +the dataset for all (g,t) cells. For unbalanced panels, units not +observed in the first period are dropped with a warning. Not supported +for repeated cross sections (\code{panel = FALSE}).} +}} \item{alp}{the significance level, default is 0.05} @@ -126,16 +175,26 @@ approach in the \code{DRDID} package. Other built-in methods include "ipw" for inverse probability weighting and "reg" for first step regression estimators. The user can also pass their own function for estimating group time average treatment -effects. This should be a function -\code{f(Y1,Y0,treat,covariates)} where \code{Y1} is an -\code{n} x \code{1} vector of outcomes in the post-treatment -outcomes, \code{Y0} is an \code{n} x \code{1} vector of -pre-treatment outcomes, \code{treat} is a vector indicating -whether or not an individual participates in the treatment, -and \code{covariates} is an \code{n} x \code{k} matrix of -covariates. The function should return a list that includes -\code{ATT} (an estimated average treatment effect), and -\code{inf.func} (an \code{n} x \code{1} influence function). +effects. The required signature depends on the data structure: + +\strong{Panel data} (\code{panel=TRUE}): \code{f(y1, y0, D, covariates, i.weights, inffunc, ...)} where \code{y1} is an \verb{n x 1} vector of +post-treatment outcomes, \code{y0} is an \verb{n x 1} vector of +pre-treatment outcomes, \code{D} is a binary vector indicating +treatment group membership, \code{covariates} is an \verb{n x k} matrix, +\code{i.weights} is a vector of sampling weights, and \code{inffunc} is a +logical requesting influence-function computation. + +\strong{Repeated cross sections / unbalanced panel} (\code{panel=FALSE}): +\code{f(y, post, D, covariates, i.weights, inffunc, ...)} where \code{y} is +the outcome vector (length \code{n}), \code{post} is a binary indicator for +the post-treatment period, \code{D} is a binary treatment indicator, +\code{covariates} is an \verb{n x k} matrix, \code{i.weights} is a vector of +sampling weights, and \code{inffunc} is a logical. + +In both cases the function should return a list that includes +\code{ATT} (the estimated group-time average treatment effect) and +\code{att.inf.func} (an \verb{n x 1} influence function — one entry per +observation passed into the estimator). The function can return other things as well, but these are the only two that are required. \code{est_method} is only used if covariates are included.} diff --git a/man/conditional_did_pretest.Rd b/man/conditional_did_pretest.Rd index 2a8be54b..48ab4252 100644 --- a/man/conditional_did_pretest.Rd +++ b/man/conditional_did_pretest.Rd @@ -88,7 +88,24 @@ eventually participate in the treatment, but have not participated yet.} \item{weightsname}{The name of the column containing the sampling weights. -If not set, all observations have same weight.} +If not set, all observations have same weight. When weights are +time-invariant (constant within each unit across periods), all +\code{fix_weights} options produce identical results and no special +handling is needed. + +When weights vary across time (e.g., time-varying population sizes), +the default behavior differs by panel type: +\describe{ +\item{Balanced panel}{Each 2x2 DiD comparison uses the weight from the +earlier of the two time periods involved. For post-treatment cells, +this is the base period (g-1). For pre-treatment cells with +\code{base_period="varying"}, this is the pre-treatment period itself. +The panel DRDID estimators are used.} +\item{Repeated cross sections and unbalanced panels}{Both periods' +per-observation weights are passed directly to the RC DRDID estimators, +so each observation carries its own period-specific weight.} +} +Use the \code{fix_weights} argument to override the default behavior.} \item{alp}{the significance level, default is 0.05} @@ -118,16 +135,26 @@ approach in the \code{DRDID} package. Other built-in methods include "ipw" for inverse probability weighting and "reg" for first step regression estimators. The user can also pass their own function for estimating group time average treatment -effects. This should be a function -\code{f(Y1,Y0,treat,covariates)} where \code{Y1} is an -\code{n} x \code{1} vector of outcomes in the post-treatment -outcomes, \code{Y0} is an \code{n} x \code{1} vector of -pre-treatment outcomes, \code{treat} is a vector indicating -whether or not an individual participates in the treatment, -and \code{covariates} is an \code{n} x \code{k} matrix of -covariates. The function should return a list that includes -\code{ATT} (an estimated average treatment effect), and -\code{inf.func} (an \code{n} x \code{1} influence function). +effects. The required signature depends on the data structure: + +\strong{Panel data} (\code{panel=TRUE}): \code{f(y1, y0, D, covariates, i.weights, inffunc, ...)} where \code{y1} is an \verb{n x 1} vector of +post-treatment outcomes, \code{y0} is an \verb{n x 1} vector of +pre-treatment outcomes, \code{D} is a binary vector indicating +treatment group membership, \code{covariates} is an \verb{n x k} matrix, +\code{i.weights} is a vector of sampling weights, and \code{inffunc} is a +logical requesting influence-function computation. + +\strong{Repeated cross sections / unbalanced panel} (\code{panel=FALSE}): +\code{f(y, post, D, covariates, i.weights, inffunc, ...)} where \code{y} is +the outcome vector (length \code{n}), \code{post} is a binary indicator for +the post-treatment period, \code{D} is a binary treatment indicator, +\code{covariates} is an \verb{n x k} matrix, \code{i.weights} is a vector of +sampling weights, and \code{inffunc} is a logical. + +In both cases the function should return a list that includes +\code{ATT} (the estimated group-time average treatment effect) and +\code{att.inf.func} (an \verb{n x 1} influence function — one entry per +observation passed into the estimator). The function can return other things as well, but these are the only two that are required. \code{est_method} is only used if covariates are included.} diff --git a/man/pre_process_did.Rd b/man/pre_process_did.Rd index 61c27dff..8853c1cd 100644 --- a/man/pre_process_did.Rd +++ b/man/pre_process_did.Rd @@ -16,6 +16,7 @@ pre_process_did( control_group = c("nevertreated", "notyettreated"), anticipation = 0, weightsname = NULL, + fix_weights = NULL, alp = 0.05, bstrap = FALSE, cband = FALSE, @@ -96,7 +97,55 @@ in the treatment where units can anticipate participating in the treatment and therefore it can affect their untreated potential outcomes} \item{weightsname}{The name of the column containing the sampling weights. -If not set, all observations have same weight.} +If not set, all observations have same weight. When weights are +time-invariant (constant within each unit across periods), all +\code{fix_weights} options produce identical results and no special +handling is needed. + +When weights vary across time (e.g., time-varying population sizes), +the default behavior differs by panel type: +\describe{ +\item{Balanced panel}{Each 2x2 DiD comparison uses the weight from the +earlier of the two time periods involved. For post-treatment cells, +this is the base period (g-1). For pre-treatment cells with +\code{base_period="varying"}, this is the pre-treatment period itself. +The panel DRDID estimators are used.} +\item{Repeated cross sections and unbalanced panels}{Both periods' +per-observation weights are passed directly to the RC DRDID estimators, +so each observation carries its own period-specific weight.} +} +Use the \code{fix_weights} argument to override the default behavior.} + +\item{fix_weights}{Controls how time-varying sampling weights are resolved. +Only relevant when weights vary across time; with time-invariant weights, +all options produce identical results. Options: +\describe{ +\item{\code{NULL} (default)}{For balanced panel: uses the weight from +the earlier of the two time periods in each 2x2 comparison. For +post-treatment cells, this is the base period (g-1). For +pre-treatment cells, this depends on the \code{base_period} setting. +For RC/unbalanced panel: uses per-observation weights from both +periods.} +\item{\code{"varying"}}{Uses per-observation, period-specific weights +for all panel types. For balanced panel data, this switches to the +repeated cross-section DRDID estimators so that pre-period and +post-period observations each carry their own weight. Covariates +are held fixed at their pre-period values (same as the default +panel estimator). This is the most flexible option for weights but +sacrifices the efficiency of the panel estimator. For RC/unbalanced +panel, this is identical to the default. Not supported with custom +\code{est_method} functions.} +\item{\code{"base_period"}}{Fixes weights at the base period (g-1) for +all (g,t) cells within a group, for both pre-treatment and +post-treatment comparisons. Ensures all cells within a group use the +same weights. For unbalanced panels, units not observed in the base +period are dropped with a warning. Not supported for repeated cross +sections (\code{panel = FALSE}).} +\item{\code{"first_period"}}{Fixes weights at the first time period in +the dataset for all (g,t) cells. For unbalanced panels, units not +observed in the first period are dropped with a warning. Not supported +for repeated cross sections (\code{panel = FALSE}).} +}} \item{alp}{the significance level, default is 0.05} @@ -126,16 +175,26 @@ approach in the \code{DRDID} package. Other built-in methods include "ipw" for inverse probability weighting and "reg" for first step regression estimators. The user can also pass their own function for estimating group time average treatment -effects. This should be a function -\code{f(Y1,Y0,treat,covariates)} where \code{Y1} is an -\code{n} x \code{1} vector of outcomes in the post-treatment -outcomes, \code{Y0} is an \code{n} x \code{1} vector of -pre-treatment outcomes, \code{treat} is a vector indicating -whether or not an individual participates in the treatment, -and \code{covariates} is an \code{n} x \code{k} matrix of -covariates. The function should return a list that includes -\code{ATT} (an estimated average treatment effect), and -\code{inf.func} (an \code{n} x \code{1} influence function). +effects. The required signature depends on the data structure: + +\strong{Panel data} (\code{panel=TRUE}): \code{f(y1, y0, D, covariates, i.weights, inffunc, ...)} where \code{y1} is an \verb{n x 1} vector of +post-treatment outcomes, \code{y0} is an \verb{n x 1} vector of +pre-treatment outcomes, \code{D} is a binary vector indicating +treatment group membership, \code{covariates} is an \verb{n x k} matrix, +\code{i.weights} is a vector of sampling weights, and \code{inffunc} is a +logical requesting influence-function computation. + +\strong{Repeated cross sections / unbalanced panel} (\code{panel=FALSE}): +\code{f(y, post, D, covariates, i.weights, inffunc, ...)} where \code{y} is +the outcome vector (length \code{n}), \code{post} is a binary indicator for +the post-treatment period, \code{D} is a binary treatment indicator, +\code{covariates} is an \verb{n x k} matrix, \code{i.weights} is a vector of +sampling weights, and \code{inffunc} is a logical. + +In both cases the function should return a list that includes +\code{ATT} (the estimated group-time average treatment effect) and +\code{att.inf.func} (an \verb{n x 1} influence function — one entry per +observation passed into the estimator). The function can return other things as well, but these are the only two that are required. \code{est_method} is only used if covariates are included.} diff --git a/man/pre_process_did2.Rd b/man/pre_process_did2.Rd index b7b90d3a..5dd1404b 100644 --- a/man/pre_process_did2.Rd +++ b/man/pre_process_did2.Rd @@ -16,6 +16,7 @@ pre_process_did2( control_group = c("nevertreated", "notyettreated"), anticipation = 0, weightsname = NULL, + fix_weights = NULL, alp = 0.05, bstrap = FALSE, cband = FALSE, @@ -96,7 +97,55 @@ in the treatment where units can anticipate participating in the treatment and therefore it can affect their untreated potential outcomes} \item{weightsname}{The name of the column containing the sampling weights. -If not set, all observations have same weight.} +If not set, all observations have same weight. When weights are +time-invariant (constant within each unit across periods), all +\code{fix_weights} options produce identical results and no special +handling is needed. + +When weights vary across time (e.g., time-varying population sizes), +the default behavior differs by panel type: +\describe{ +\item{Balanced panel}{Each 2x2 DiD comparison uses the weight from the +earlier of the two time periods involved. For post-treatment cells, +this is the base period (g-1). For pre-treatment cells with +\code{base_period="varying"}, this is the pre-treatment period itself. +The panel DRDID estimators are used.} +\item{Repeated cross sections and unbalanced panels}{Both periods' +per-observation weights are passed directly to the RC DRDID estimators, +so each observation carries its own period-specific weight.} +} +Use the \code{fix_weights} argument to override the default behavior.} + +\item{fix_weights}{Controls how time-varying sampling weights are resolved. +Only relevant when weights vary across time; with time-invariant weights, +all options produce identical results. Options: +\describe{ +\item{\code{NULL} (default)}{For balanced panel: uses the weight from +the earlier of the two time periods in each 2x2 comparison. For +post-treatment cells, this is the base period (g-1). For +pre-treatment cells, this depends on the \code{base_period} setting. +For RC/unbalanced panel: uses per-observation weights from both +periods.} +\item{\code{"varying"}}{Uses per-observation, period-specific weights +for all panel types. For balanced panel data, this switches to the +repeated cross-section DRDID estimators so that pre-period and +post-period observations each carry their own weight. Covariates +are held fixed at their pre-period values (same as the default +panel estimator). This is the most flexible option for weights but +sacrifices the efficiency of the panel estimator. For RC/unbalanced +panel, this is identical to the default. Not supported with custom +\code{est_method} functions.} +\item{\code{"base_period"}}{Fixes weights at the base period (g-1) for +all (g,t) cells within a group, for both pre-treatment and +post-treatment comparisons. Ensures all cells within a group use the +same weights. For unbalanced panels, units not observed in the base +period are dropped with a warning. Not supported for repeated cross +sections (\code{panel = FALSE}).} +\item{\code{"first_period"}}{Fixes weights at the first time period in +the dataset for all (g,t) cells. For unbalanced panels, units not +observed in the first period are dropped with a warning. Not supported +for repeated cross sections (\code{panel = FALSE}).} +}} \item{alp}{the significance level, default is 0.05} @@ -126,16 +175,26 @@ approach in the \code{DRDID} package. Other built-in methods include "ipw" for inverse probability weighting and "reg" for first step regression estimators. The user can also pass their own function for estimating group time average treatment -effects. This should be a function -\code{f(Y1,Y0,treat,covariates)} where \code{Y1} is an -\code{n} x \code{1} vector of outcomes in the post-treatment -outcomes, \code{Y0} is an \code{n} x \code{1} vector of -pre-treatment outcomes, \code{treat} is a vector indicating -whether or not an individual participates in the treatment, -and \code{covariates} is an \code{n} x \code{k} matrix of -covariates. The function should return a list that includes -\code{ATT} (an estimated average treatment effect), and -\code{inf.func} (an \code{n} x \code{1} influence function). +effects. The required signature depends on the data structure: + +\strong{Panel data} (\code{panel=TRUE}): \code{f(y1, y0, D, covariates, i.weights, inffunc, ...)} where \code{y1} is an \verb{n x 1} vector of +post-treatment outcomes, \code{y0} is an \verb{n x 1} vector of +pre-treatment outcomes, \code{D} is a binary vector indicating +treatment group membership, \code{covariates} is an \verb{n x k} matrix, +\code{i.weights} is a vector of sampling weights, and \code{inffunc} is a +logical requesting influence-function computation. + +\strong{Repeated cross sections / unbalanced panel} (\code{panel=FALSE}): +\code{f(y, post, D, covariates, i.weights, inffunc, ...)} where \code{y} is +the outcome vector (length \code{n}), \code{post} is a binary indicator for +the post-treatment period, \code{D} is a binary treatment indicator, +\code{covariates} is an \verb{n x k} matrix, \code{i.weights} is a vector of +sampling weights, and \code{inffunc} is a logical. + +In both cases the function should return a list that includes +\code{ATT} (the estimated group-time average treatment effect) and +\code{att.inf.func} (an \verb{n x 1} influence function — one entry per +observation passed into the estimator). The function can return other things as well, but these are the only two that are required. \code{est_method} is only used if covariates are included.} diff --git a/tests/testthat/test-aggte-comprehensive.R b/tests/testthat/test-aggte-comprehensive.R new file mode 100644 index 00000000..ab872b94 --- /dev/null +++ b/tests/testthat/test-aggte-comprehensive.R @@ -0,0 +1,204 @@ +# ============================================================================= +# Comprehensive tests for aggte() aggregation types +# ============================================================================= + +# Shared setup: known DGP with treatment effect = 1 +set.seed(20260401) +sp <- did::reset.sim() +sp$te <- 1 # constant treatment effect +data_agg <- did::build_sim_dataset(sp) + +mp_agg <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_agg, tname = "period", + idname = "id", gname = "G", est_method = "dr", + bstrap = FALSE) +)) + +# ============================================================================= +# type = "simple" +# ============================================================================= + +test_that("aggte simple returns valid overall ATT", { + agg <- suppressWarnings(aggte(mp_agg, type = "simple")) + expect_s3_class(agg, "AGGTEobj") + expect_false(is.na(agg$overall.att)) + expect_equal(agg$overall.att, 1, tolerance = 0.5) +}) + +test_that("aggte simple returns valid SE", { + agg <- suppressWarnings(aggte(mp_agg, type = "simple")) + expect_true(agg$overall.se > 0) + expect_false(is.na(agg$overall.se)) +}) + +test_that("aggte simple has no egt component", { + agg <- suppressWarnings(aggte(mp_agg, type = "simple")) + expect_null(agg$egt) +}) + +test_that("aggte simple influence function has correct length", { + agg <- suppressWarnings(aggte(mp_agg, type = "simple")) + expect_equal(length(agg$inf.function$simple.att), nobs(mp_agg)) +}) + +# ============================================================================= +# type = "dynamic" +# ============================================================================= + +test_that("aggte dynamic returns event-time specific ATTs", { + agg <- suppressWarnings(aggte(mp_agg, type = "dynamic")) + expect_false(is.null(agg$egt)) + expect_true(length(agg$egt) > 0) + # Should have both negative and non-negative event times + expect_true(any(agg$egt < 0)) + expect_true(any(agg$egt >= 0)) +}) + +test_that("aggte dynamic event times are sorted", { + agg <- suppressWarnings(aggte(mp_agg, type = "dynamic")) + expect_equal(agg$egt, sort(agg$egt)) +}) + +test_that("aggte dynamic overall.att averages post-treatment event times", { + agg <- suppressWarnings(aggte(mp_agg, type = "dynamic")) + expect_false(is.na(agg$overall.att)) + expect_equal(agg$overall.att, 1, tolerance = 0.5) +}) + +test_that("aggte dynamic min_e filters event times", { + agg_full <- suppressWarnings(aggte(mp_agg, type = "dynamic")) + agg_filt <- suppressWarnings(aggte(mp_agg, type = "dynamic", min_e = -1)) + expect_true(min(agg_filt$egt) >= -1) + expect_true(length(agg_filt$egt) <= length(agg_full$egt)) +}) + +test_that("aggte dynamic max_e filters event times", { + agg_full <- suppressWarnings(aggte(mp_agg, type = "dynamic")) + agg_filt <- suppressWarnings(aggte(mp_agg, type = "dynamic", max_e = 1)) + expect_true(max(agg_filt$egt) <= 1) + expect_true(length(agg_filt$egt) <= length(agg_full$egt)) +}) + +test_that("aggte dynamic min_e and max_e together", { + agg <- suppressWarnings(aggte(mp_agg, type = "dynamic", min_e = -1, max_e = 1)) + expect_true(min(agg$egt) >= -1) + expect_true(max(agg$egt) <= 1) +}) + +test_that("aggte dynamic balance_e filters groups", { + agg_unbal <- suppressWarnings(aggte(mp_agg, type = "dynamic")) + agg_bal <- suppressWarnings(aggte(mp_agg, type = "dynamic", balance_e = 1)) + # Balanced version may have fewer event times or same + expect_true(length(agg_bal$egt) <= length(agg_unbal$egt)) +}) + +test_that("aggte dynamic SEs are positive where ATT is not NA", { + agg <- suppressWarnings(aggte(mp_agg, type = "dynamic")) + non_na <- !is.na(agg$att.egt) + expect_true(all(agg$se.egt[non_na] > 0)) +}) + +# ============================================================================= +# type = "group" +# ============================================================================= + +test_that("aggte group returns per-group ATTs", { + agg <- suppressWarnings(aggte(mp_agg, type = "group")) + expect_false(is.null(agg$egt)) + # egt should contain the group values + expect_true(all(agg$egt %in% unique(mp_agg$group))) +}) + +test_that("aggte group overall.att is reasonable", { + agg <- suppressWarnings(aggte(mp_agg, type = "group")) + expect_false(is.na(agg$overall.att)) + expect_equal(agg$overall.att, 1, tolerance = 0.5) +}) + +test_that("aggte group SEs are positive for each group", { + agg <- suppressWarnings(aggte(mp_agg, type = "group")) + non_na <- !is.na(agg$att.egt) + expect_true(all(agg$se.egt[non_na] > 0)) +}) + +# ============================================================================= +# type = "calendar" +# ============================================================================= + +test_that("aggte calendar returns per-period ATTs", { + agg <- suppressWarnings(aggte(mp_agg, type = "calendar")) + expect_false(is.null(agg$egt)) + expect_true(length(agg$egt) > 0) +}) + +test_that("aggte calendar overall.att is reasonable", { + agg <- suppressWarnings(aggte(mp_agg, type = "calendar")) + expect_false(is.na(agg$overall.att)) + expect_equal(agg$overall.att, 1, tolerance = 0.5) +}) + +test_that("aggte calendar SEs are positive for each period", { + agg <- suppressWarnings(aggte(mp_agg, type = "calendar")) + non_na <- !is.na(agg$att.egt) + expect_true(all(agg$se.egt[non_na] > 0)) +}) + +test_that("aggte calendar only includes post-treatment periods", { + agg <- suppressWarnings(aggte(mp_agg, type = "calendar")) + min_group <- min(mp_agg$group) + # All calendar periods should be at or after the earliest treatment + expect_true(all(agg$egt >= min_group)) +}) + +# ============================================================================= +# na.rm behavior +# ============================================================================= + +test_that("aggte with na.rm=TRUE drops NA ATTs and proceeds", { + # Create MP with some NA ATTs manually + mp_tmp <- mp_agg + mp_tmp$att[1] <- NA + # Should work with na.rm=TRUE + agg <- suppressWarnings(aggte(mp_tmp, type = "dynamic", na.rm = TRUE)) + expect_s3_class(agg, "AGGTEobj") + expect_false(is.na(agg$overall.att)) +}) + +# ============================================================================= +# Cross-type consistency +# ============================================================================= + +test_that("all aggte types return AGGTEobj class", { + for (tp in c("simple", "dynamic", "group", "calendar")) { + agg <- suppressWarnings(aggte(mp_agg, type = tp)) + expect_s3_class(agg, "AGGTEobj") + } +}) + +test_that("all aggte types have non-NULL DIDparams", { + for (tp in c("simple", "dynamic", "group", "calendar")) { + agg <- suppressWarnings(aggte(mp_agg, type = tp)) + expect_false(is.null(agg$DIDparams), label = tp) + } +}) + +test_that("aggte preserves overridden bstrap/alp settings", { + agg <- suppressWarnings(aggte(mp_agg, type = "dynamic", bstrap = FALSE, alp = 0.01)) + expect_equal(agg$DIDparams$bstrap, FALSE) + expect_equal(agg$DIDparams$alp, 0.01) +}) + +# ============================================================================= +# cband with bootstrap +# ============================================================================= + +test_that("aggte with cband=TRUE and bstrap=TRUE returns simultaneous critical value", { + mp_boot <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_agg, tname = "period", + idname = "id", gname = "G", est_method = "reg", + bstrap = TRUE, biters = 100, cband = TRUE) + )) + agg <- aggte(mp_boot, type = "dynamic", bstrap = TRUE, biters = 100, cband = TRUE) + # Simultaneous critical value should be at least as large as pointwise z + expect_true(agg$crit.val.egt >= qnorm(0.975)) +}) diff --git a/tests/testthat/test-att_gt.R b/tests/testthat/test-att_gt.R index 5763d4fe..a7c3755c 100644 --- a/tests/testthat/test-att_gt.R +++ b/tests/testthat/test-att_gt.R @@ -56,14 +56,16 @@ test_that("two period case", { sp$n <- 10000 data <- did::build_sim_dataset(sp) - res <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", - gname="G", est_method="reg") + res <- suppressWarnings( + att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method="reg") + ) res - agg_simple <- aggte(res, type="simple") - agg_group <- aggte(res, type="group") - agg_dynamic <- aggte(res, type="dynamic") - agg_calendar <- aggte(res, type="calendar") + agg_simple <- suppressWarnings(aggte(res, type="simple")) + agg_group <- suppressWarnings(aggte(res, type="group")) + agg_dynamic <- suppressWarnings(aggte(res, type="dynamic")) + agg_calendar <- suppressWarnings(aggte(res, type="calendar")) expect_equal(agg_simple$overall.att, 1, tol=.5) expect_equal(agg_group$overall.att, 1, tol=.5) @@ -628,6 +630,423 @@ test_that("sampling weights", { }) +# ============================================================================= +# Column naming: user columns named gname/tname/idname should not crash +# ============================================================================= + +test_that("works when user column is literally named 'gname'", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + # Rename columns to match parameter names exactly + names(data)[names(data) == "G"] <- "gname" + names(data)[names(data) == "period"] <- "tname" + names(data)[names(data) == "id"] <- "idname" + + mod <- att_gt(yname="Y", xformla=~X, data=data, tname="tname", idname="idname", + gname="gname", est_method="reg", bstrap=FALSE) + expect_false(all(is.na(mod$att))) + + # aggte should also work (this was the specific dreamerr bug) + agg <- suppressWarnings(aggte(mod, type="simple")) + expect_false(is.na(agg$overall.att)) + + agg_dyn <- suppressWarnings(aggte(mod, type="dynamic")) + expect_false(is.na(agg_dyn$overall.att)) +}) + +test_that("works when user column is literally named 'gname' with faster_mode", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + names(data)[names(data) == "G"] <- "gname" + names(data)[names(data) == "period"] <- "tname" + names(data)[names(data) == "id"] <- "idname" + + mod <- att_gt(yname="Y", xformla=~X, data=data, tname="tname", idname="idname", + gname="gname", est_method="reg", bstrap=FALSE, faster_mode=TRUE) + expect_false(all(is.na(mod$att))) + + agg <- aggte(mod, type="simple") + expect_false(is.na(agg$overall.att)) +}) + +# ============================================================================= +# Time-varying weights: fix_weights tests +# ============================================================================= + +test_that("time-varying weights: faster_mode matches slow mode (default fix_weights=NULL)", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$tv_weight <- data$period + runif(nrow(data), -0.1, 0.1) + + for (em in c("reg", "dr", "ipw")) { + for (bp in c("varying", "universal")) { + res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method=em, weightsname="tv_weight", + base_period=bp, faster_mode=FALSE, bstrap=FALSE) + res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method=em, weightsname="tv_weight", + base_period=bp, faster_mode=TRUE, bstrap=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10, + label=paste("ATT match:", em, bp)) + } + } +}) + +test_that("fix_weights options: faster_mode matches slow mode (balanced panel)", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$tv_weight <- data$period + runif(nrow(data), -0.1, 0.1) + + for (fw in c("varying", "base_period", "first_period")) { + res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method="dr", weightsname="tv_weight", + fix_weights=fw, faster_mode=FALSE, bstrap=FALSE) + res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method="dr", weightsname="tv_weight", + fix_weights=fw, faster_mode=TRUE, bstrap=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10, + label=paste("ATT match:", fw)) + } +}) + +test_that("time-invariant weights: all fix_weights options produce identical ATTs", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + n_ids <- length(unique(data$id)) + n_periods <- length(unique(data$period)) + data$const_weight <- rep(runif(n_ids, 1, 10), each = n_periods) + + res_default <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method="reg", weightsname="const_weight", + bstrap=FALSE) + + for (fw in c("base_period", "first_period")) { + res_fw <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method="reg", weightsname="const_weight", + fix_weights=fw, bstrap=FALSE) + expect_equal(res_default$att, res_fw$att, tolerance=1e-10, + label=paste("same ATT for", fw)) + } +}) + +test_that("message emitted for time-varying weights in balanced panel", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$tv_weight <- data$period * 1.0 + runif(nrow(data), 0, 0.5) + + expect_message( + att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", weightsname="tv_weight", bstrap=FALSE), + "Time-varying weights detected" + ) +}) + +test_that("no message for time-invariant weights", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + n_ids <- length(unique(data$id)) + n_periods <- length(unique(data$period)) + data$const_weight <- rep(runif(n_ids, 1, 10), each = n_periods) + + expect_no_message( + att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", weightsname="const_weight", bstrap=FALSE) + ) +}) + +test_that("notyettreated with time-varying weights: faster_mode matches", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$tv_weight <- data$period + runif(nrow(data), 0, 0.5) + + res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method="dr", weightsname="tv_weight", + control_group="notyettreated", faster_mode=FALSE, bstrap=FALSE) + res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", + gname="G", est_method="dr", weightsname="tv_weight", + control_group="notyettreated", faster_mode=TRUE, bstrap=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10) +}) + +test_that("RC with time-varying weights: faster_mode matches", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$tv_weight <- data$period * 1.0 + runif(nrow(data), 0, 0.5) + + res_slow <- att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", est_method="reg", weightsname="tv_weight", + panel=FALSE, faster_mode=FALSE, bstrap=FALSE) + res_fast <- att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", est_method="reg", weightsname="tv_weight", + panel=FALSE, faster_mode=TRUE, bstrap=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10) +}) + +test_that("fix_weights validation", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + + expect_error( + att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", fix_weights="invalid_option", bstrap=FALSE), + "fix_weights must be NULL" + ) + + # base_period and first_period not supported for repeated cross sections + expect_error( + att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", fix_weights="base_period", panel=FALSE, bstrap=FALSE), + "not supported for repeated cross sections" + ) + expect_error( + att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", fix_weights="first_period", panel=FALSE, bstrap=FALSE), + "not supported for repeated cross sections" + ) + + # varying not supported with custom est_method when panel = TRUE + my_panel_est <- function(y1, y0, D, covariates, i.weights, inffunc, ...) { + list(ATT = mean(y1 - y0), att.inf.func = rep(0, length(y1))) + } + expect_error( + att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", fix_weights="varying", est_method=my_panel_est, + panel=TRUE, bstrap=FALSE), + "not currently supported with custom est_method" + ) + + # varying IS supported with custom est_method when panel = FALSE (RC signature) + my_rc_est <- function(y, post, D, covariates, i.weights, inffunc, ...) { + n_obs <- length(y) + post_c <- post[D==0] + y_c <- y[D==0] + w_c <- i.weights[D==0] + att <- mean(y_c[post_c==1] * w_c[post_c==1]) / mean(w_c[post_c==1]) - + mean(y_c[post_c==0] * w_c[post_c==0]) / mean(w_c[post_c==0]) + list(ATT = att, att.inf.func = rep(0, n_obs)) + } + # Wald pre-test warning is expected with this small sim dataset (group 2 + # has only one pre-treatment period), but the key thing is: no error and + # no recycling warnings from mismatched influence-function length. + rc_result <- expect_no_error( + withCallingHandlers( + att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", fix_weights="varying", est_method=my_rc_est, + panel=FALSE, bstrap=FALSE), + warning = function(w) { + if (grepl("not a multiple of replacement length", conditionMessage(w))) + stop("IF length mismatch: ", conditionMessage(w)) + invokeRestart("muffleWarning") + } + ) + ) + expect_true(inherits(rc_result, "MP")) + expect_false(anyNA(rc_result$att)) +}) + +test_that("unbalanced panel fix_weights with units missing from reference period", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + + # Drop some treated units from first period so fix_weights="first_period" must drop them + first_p <- min(data$period) + drop_ids <- unique(data$id[data$G > 0])[1:10] + data <- data[!(data$id %in% drop_ids & data$period == first_p), ] + data$w <- runif(nrow(data), 1, 5) + + for (fw in c("first_period", "base_period")) { + res_slow <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", data = data, tname = "period", idname = "id", + gname = "G", allow_unbalanced_panel = TRUE, + fix_weights = fw, weightsname = "w", + bstrap = FALSE, faster_mode = FALSE) + )) + res_fast <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", data = data, tname = "period", idname = "id", + gname = "G", allow_unbalanced_panel = TRUE, + fix_weights = fw, weightsname = "w", + bstrap = FALSE, faster_mode = TRUE) + )) + expect_equal(res_slow$att, res_fast$att, tolerance = 1e-10, + label = paste("unbalanced", fw, "ATT match")) + } +}) + +# ============================================================================= +# Influence function consistency: slow vs fast mode ATT AND SE must match +# ============================================================================= + +test_that("IF consistency: balanced panel, all fix_weights x est_method x base_period", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$tv_weight <- data$period + runif(nrow(data), -0.1, 0.1) + + for (fw in c(NA, "varying", "base_period", "first_period")) { + fw_arg <- if (is.na(fw)) NULL else fw + for (em in c("dr", "ipw", "reg")) { + for (bp in c("varying", "universal")) { + label <- paste("panel", if (is.na(fw)) "NULL" else fw, em, bp) + + res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period", + idname="id", gname="G", est_method=em, + weightsname="tv_weight", fix_weights=fw_arg, + base_period=bp, faster_mode=FALSE, + bstrap=FALSE, cband=FALSE) + res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period", + idname="id", gname="G", est_method=em, + weightsname="tv_weight", fix_weights=fw_arg, + base_period=bp, faster_mode=TRUE, + bstrap=FALSE, cband=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10, + label=paste("ATT", label)) + expect_equal(res_slow$se, res_fast$se, tolerance=1e-10, + label=paste("SE", label)) + } + } + } +}) + +test_that("IF consistency: balanced panel, notyettreated control group", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$tv_weight <- data$period + runif(nrow(data), -0.1, 0.1) + + for (fw in c(NA, "varying", "base_period", "first_period")) { + fw_arg <- if (is.na(fw)) NULL else fw + label <- paste("notyettreated", if (is.na(fw)) "NULL" else fw) + + res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period", + idname="id", gname="G", est_method="dr", + weightsname="tv_weight", fix_weights=fw_arg, + control_group="notyettreated", + faster_mode=FALSE, bstrap=FALSE, cband=FALSE) + res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period", + idname="id", gname="G", est_method="dr", + weightsname="tv_weight", fix_weights=fw_arg, + control_group="notyettreated", + faster_mode=TRUE, bstrap=FALSE, cband=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10, + label=paste("ATT", label)) + expect_equal(res_slow$se, res_fast$se, tolerance=1e-10, + label=paste("SE", label)) + } +}) + +test_that("IF consistency: repeated cross-sections, default weights x est_method", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$tv_weight <- data$period + runif(nrow(data), -0.1, 0.1) + + # RC with default weights (fix_weights=NULL); fixed weight options tested separately + for (em in c("dr", "ipw", "reg")) { + label <- paste("RC NULL", em) + + res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period", + idname="id", gname="G", est_method=em, + weightsname="tv_weight", + panel=FALSE, faster_mode=FALSE, + bstrap=FALSE, cband=FALSE) + res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period", + idname="id", gname="G", est_method=em, + weightsname="tv_weight", + panel=FALSE, faster_mode=TRUE, + bstrap=FALSE, cband=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10, + label=paste("ATT", label)) + expect_equal(res_slow$se, res_fast$se, tolerance=1e-10, + label=paste("SE", label)) + } +}) + +test_that("IF consistency: unbalanced panel, default weights x est_method", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + + # Create unbalanced panel by dropping some observations + set.seed(42) + drop_idx <- sample(nrow(data), size = floor(nrow(data) * 0.05)) + data_unbal <- data[-drop_idx, ] + + # Default weights (fix_weights=NULL); fixed weight options for unbalanced panels + # have known edge cases with unit availability across periods + for (em in c("dr", "reg")) { + label <- paste("unbalanced NULL", em) + + res_slow <- att_gt(yname="Y", xformla=~X, data=data_unbal, tname="period", + idname="id", gname="G", est_method=em, + allow_unbalanced_panel=TRUE, + faster_mode=FALSE, bstrap=FALSE, cband=FALSE) + res_fast <- att_gt(yname="Y", xformla=~X, data=data_unbal, tname="period", + idname="id", gname="G", est_method=em, + allow_unbalanced_panel=TRUE, + faster_mode=TRUE, bstrap=FALSE, cband=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10, + label=paste("ATT", label)) + expect_equal(res_slow$se, res_fast$se, tolerance=1e-10, + label=paste("SE", label)) + } +}) + +test_that("IF consistency: no covariates (xformla=~1), all data types", { + set.seed(20260401) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + + # Balanced panel, no covariates + for (fw in c(NA, "varying")) { + fw_arg <- if (is.na(fw)) NULL else fw + label <- paste("no-covar panel", if (is.na(fw)) "NULL" else fw) + + res_slow <- att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", fix_weights=fw_arg, + faster_mode=FALSE, bstrap=FALSE, cband=FALSE) + res_fast <- att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", fix_weights=fw_arg, + faster_mode=TRUE, bstrap=FALSE, cband=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10, + label=paste("ATT", label)) + expect_equal(res_slow$se, res_fast$se, tolerance=1e-10, + label=paste("SE", label)) + } + + # RC, no covariates + res_slow <- att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", panel=FALSE, + faster_mode=FALSE, bstrap=FALSE, cband=FALSE) + res_fast <- att_gt(yname="Y", data=data, tname="period", idname="id", + gname="G", panel=FALSE, + faster_mode=TRUE, bstrap=FALSE, cband=FALSE) + + expect_equal(res_slow$att, res_fast$att, tolerance=1e-10, + label="ATT RC no-covar") + expect_equal(res_slow$se, res_fast$se, tolerance=1e-10, + label="SE RC no-covar") +}) + test_that("clustered standard errors", { set.seed(09142024) # check that we can compute when clustered standard errors are supplied diff --git a/tests/testthat/test-edge-cases.R b/tests/testthat/test-edge-cases.R new file mode 100644 index 00000000..aa27f344 --- /dev/null +++ b/tests/testthat/test-edge-cases.R @@ -0,0 +1,192 @@ +# ============================================================================= +# Tests for edge cases and boundary conditions +# ============================================================================= + +test_that("single treated group produces valid att_gt", { + set.seed(20260401) + data_sg <- data.frame( + id = rep(1:200, each = 5), + period = rep(1:5, 200), + G = rep(c(rep(4, 50), rep(0, 150)), each = 5), + Y = rnorm(1000) + ) + data_sg$Y[data_sg$G == 4 & data_sg$period >= 4] <- + data_sg$Y[data_sg$G == 4 & data_sg$period >= 4] + 1 + + result <- att_gt(yname = "Y", data = data_sg, tname = "period", idname = "id", + gname = "G", bstrap = FALSE) + expect_s3_class(result, "MP") + expect_true(any(!is.na(result$att))) +}) + +test_that("single treated group works with all 4 aggte types", { + set.seed(20260401) + data_sg <- data.frame( + id = rep(1:200, each = 5), + period = rep(1:5, 200), + G = rep(c(rep(4, 50), rep(0, 150)), each = 5), + Y = rnorm(1000) + ) + data_sg$Y[data_sg$G == 4 & data_sg$period >= 4] <- + data_sg$Y[data_sg$G == 4 & data_sg$period >= 4] + 1 + + mp <- att_gt(yname = "Y", data = data_sg, tname = "period", idname = "id", + gname = "G", bstrap = FALSE) + + for (tp in c("simple", "dynamic", "group", "calendar")) { + agg <- suppressWarnings(aggte(mp, type = tp)) + expect_s3_class(agg, "AGGTEobj") + expect_false(is.na(agg$overall.att), label = tp) + } +}) + +test_that("two-period data with universal base period", { + set.seed(20260401) + data_2p <- data.frame( + id = rep(1:200, each = 2), + period = rep(1:2, 200), + G = rep(c(rep(2, 50), rep(0, 150)), each = 2), + Y = rnorm(400) + ) + data_2p$Y[data_2p$G == 2 & data_2p$period == 2] <- + data_2p$Y[data_2p$G == 2 & data_2p$period == 2] + 1 + + result <- suppressWarnings( + att_gt(yname = "Y", data = data_2p, tname = "period", idname = "id", + gname = "G", base_period = "universal", bstrap = FALSE) + ) + expect_s3_class(result, "MP") + expect_true(any(!is.na(result$att))) +}) + +test_that("data with no never-treated group works with notyettreated", { + set.seed(20260401) + # All units are eventually treated (groups 3 and 5) + data_nnt <- data.frame( + id = rep(1:200, each = 6), + period = rep(1:6, 200), + G = rep(c(rep(3, 100), rep(5, 100)), each = 6), + Y = rnorm(1200) + ) + result <- suppressWarnings( + att_gt(yname = "Y", data = data_nnt, tname = "period", idname = "id", + gname = "G", control_group = "notyettreated", bstrap = FALSE) + ) + expect_s3_class(result, "MP") +}) + +test_that("data with no never-treated group warns with nevertreated", { + set.seed(20260401) + data_nnt <- data.frame( + id = rep(1:200, each = 6), + period = rep(1:6, 200), + G = rep(c(rep(3, 100), rep(5, 100)), each = 6), + Y = rnorm(1200) + ) + expect_warning( + att_gt(yname = "Y", data = data_nnt, tname = "period", idname = "id", + gname = "G", control_group = "nevertreated", bstrap = FALSE), + "No never-treated group|never-treated" + ) +}) + +test_that("groups treated in first period are dropped", { + set.seed(20260401) + sp <- did::reset.sim() + data_fp <- did::build_sim_dataset(sp) + # Add units treated in the very first period + first_per <- min(data_fp$period) + extra <- data_fp[data_fp$G == sort(unique(data_fp$G[data_fp$G > 0]))[1], ] + extra$G <- first_per + extra$id <- extra$id + max(data_fp$id) + data_fp <- rbind(data_fp, extra) + + expect_warning( + att_gt(yname = "Y", data = data_fp, tname = "period", idname = "id", + gname = "G", bstrap = FALSE), + "already treated|Dropped" + ) +}) + +test_that("non-consecutive time periods work", { + set.seed(20260401) + data_nc <- data.frame( + id = rep(1:200, each = 4), + period = rep(c(2000, 2003, 2007, 2010), 200), + G = rep(c(rep(2007, 50), rep(0, 150)), each = 4), + Y = rnorm(800) + ) + data_nc$Y[data_nc$G == 2007 & data_nc$period >= 2007] <- + data_nc$Y[data_nc$G == 2007 & data_nc$period >= 2007] + 1 + + result <- suppressWarnings( + att_gt(yname = "Y", data = data_nc, tname = "period", idname = "id", + gname = "G", bstrap = FALSE) + ) + expect_s3_class(result, "MP") + expect_true(any(!is.na(result$att))) + + agg <- suppressWarnings(aggte(result, type = "dynamic")) + expect_s3_class(agg, "AGGTEobj") +}) + +test_that("non-consecutive group values work", { + set.seed(20260401) + data_ng <- data.frame( + id = rep(1:200, each = 5), + period = rep(1:5, 200), + G = rep(c(rep(3, 50), rep(5, 50), rep(0, 100)), each = 5), + Y = rnorm(1000) + ) + result <- att_gt(yname = "Y", data = data_ng, tname = "period", idname = "id", + gname = "G", bstrap = FALSE) + expect_s3_class(result, "MP") + expect_true(length(unique(result$group)) >= 2) +}) + +test_that("allow_unbalanced_panel=TRUE with balanced data proceeds normally", { + set.seed(20260401) + sp <- did::reset.sim() + data_bal <- did::build_sim_dataset(sp) + + result <- suppressMessages( + att_gt(yname = "Y", data = data_bal, tname = "period", idname = "id", + gname = "G", allow_unbalanced_panel = TRUE, bstrap = FALSE) + ) + expect_s3_class(result, "MP") + expect_true(any(!is.na(result$att))) +}) + +test_that("allow_unbalanced_panel=TRUE with truly unbalanced data", { + set.seed(20260401) + sp <- did::reset.sim() + data_ub <- did::build_sim_dataset(sp) + # Drop a few rows to make it unbalanced + data_ub <- data_ub[-c(1, 5, 10), ] + + result <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", data = data_ub, tname = "period", idname = "id", + gname = "G", allow_unbalanced_panel = TRUE, bstrap = FALSE) + )) + expect_s3_class(result, "MP") + expect_true(any(!is.na(result$att))) +}) + +test_that("single post-treatment period produces valid results", { + set.seed(20260401) + data_spt <- data.frame( + id = rep(1:200, each = 3), + period = rep(1:3, 200), + G = rep(c(rep(3, 50), rep(0, 150)), each = 3), + Y = rnorm(600) + ) + data_spt$Y[data_spt$G == 3 & data_spt$period == 3] <- + data_spt$Y[data_spt$G == 3 & data_spt$period == 3] + 1 + + result <- att_gt(yname = "Y", data = data_spt, tname = "period", idname = "id", + gname = "G", bstrap = FALSE) + expect_s3_class(result, "MP") + # Should have at least one post-treatment ATT + post_atts <- result$att[result$group <= result$t] + expect_true(any(!is.na(post_atts))) +}) diff --git a/tests/testthat/test-error-handling.R b/tests/testthat/test-error-handling.R new file mode 100644 index 00000000..e7759983 --- /dev/null +++ b/tests/testthat/test-error-handling.R @@ -0,0 +1,257 @@ +# ============================================================================= +# Tests for stop(), warning(), and message() conditions +# ============================================================================= + +# Shared setup +set.seed(20260401) +sp <- did::reset.sim() +data_eh <- did::build_sim_dataset(sp) + +# ============================================================================= +# att_gt() validation errors +# ============================================================================= + +test_that("att_gt errors on invalid est_method string", { + expect_error( + att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", est_method = "bad", bstrap = FALSE), + "must be one of" + ) +}) + +test_that("att_gt errors on non-character non-function est_method", { + expect_error( + att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", est_method = 42, bstrap = FALSE), + "must be a character string" + ) +}) + +test_that("att_gt errors on invalid fix_weights value", { + expect_error( + att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", fix_weights = "bad", bstrap = FALSE), + "must be NULL or one of" + ) +}) + +test_that("att_gt errors on fix_weights with panel=FALSE", { + expect_error( + att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", fix_weights = "base_period", panel = FALSE, + bstrap = FALSE), + "not supported for repeated cross sections" + ) +}) + +test_that("att_gt warns on extra args with built-in est_method", { + expect_warning( + att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", est_method = "reg", bstrap = FALSE, extra_arg = 1), + "Extra arguments" + ) +}) + +test_that("att_gt messages about anticipation", { + expect_message( + suppressWarnings( + att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", anticipation = 1, bstrap = FALSE) + ), + "anticipation =" + ) +}) + +test_that("att_gt warns on clustered SE without bootstrap", { + expect_warning( + att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", clustervars = "id", bstrap = FALSE), + "Clustered standard errors require" + ) +}) + +# ============================================================================= +# pre_process_did / pre_process_did2 validation errors +# ============================================================================= + +test_that("att_gt errors on missing column name (slower mode)", { + expect_error( + att_gt(yname = "nonexistent", data = data_eh, tname = "period", + idname = "id", gname = "G", bstrap = FALSE, faster_mode = FALSE), + "not found" + ) +}) + +test_that("att_gt errors on missing column name (faster_mode)", { + expect_error( + att_gt(yname = "nonexistent", data = data_eh, tname = "period", + idname = "id", gname = "G", bstrap = FALSE, faster_mode = TRUE), + "nonexistent" + ) +}) + +test_that("att_gt errors on non-numeric tname", { + bad_data <- data_eh + bad_data$period <- as.character(bad_data$period) + expect_error( + att_gt(yname = "Y", data = bad_data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE), + "must be numeric" + ) +}) + +test_that("att_gt errors on non-numeric gname", { + bad_data <- data_eh + bad_data$G <- as.character(bad_data$G) + expect_error( + att_gt(yname = "Y", data = bad_data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE), + "must be numeric" + ) +}) + +test_that("att_gt errors on treatment reversals (faster_mode)", { + bad_data <- data_eh + # Make one unit switch groups across time + target_id <- bad_data$id[1] + periods <- sort(unique(bad_data$period)) + bad_data$G[bad_data$id == target_id & bad_data$period == periods[1]] <- 3 + bad_data$G[bad_data$id == target_id & bad_data$period == periods[2]] <- 4 + expect_error( + att_gt(yname = "Y", data = bad_data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE, faster_mode = TRUE), + "same across all periods|time-invariant|must be the same" + ) +}) + +test_that("att_gt warns on missing data dropped", { + bad_data <- data_eh + bad_data$Y[1:5] <- NA + expect_warning( + att_gt(yname = "Y", data = bad_data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE), + "dropped|missing" + ) +}) + +test_that("att_gt warns on small groups", { + # Create data with a very small treated group + small_data <- data_eh + treated_ids <- unique(small_data$id[small_data$G > 0]) + # Keep only 2 treated units and some controls + keep_ids <- c(treated_ids[1:2], unique(small_data$id[small_data$G == 0])[1:50]) + small_data <- small_data[small_data$id %in% keep_ids, ] + expect_warning( + att_gt(yname = "Y", data = small_data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE), + "very few observations" + ) +}) + +test_that("att_gt handles data with .w column (slower mode)", { + # The .w column gets dropped during column selection in pre_process_did + # before the collision check, so this should run without error + bad_data <- data.frame(data_eh) + bad_data$.w <- 1 + result <- att_gt(yname = "Y", data = bad_data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE, faster_mode = FALSE) + expect_s3_class(result, "MP") +}) + +test_that("att_gt messages on time-varying weights (panel)", { + tv_data <- data_eh + tv_data$tw <- tv_data$period + runif(nrow(tv_data)) + expect_message( + att_gt(yname = "Y", data = tv_data, tname = "period", idname = "id", + gname = "G", weightsname = "tw", bstrap = FALSE), + "Time-varying weights" + ) +}) + +test_that("att_gt errors on more than 1 extra cluster variable (faster_mode)", { + expect_error( + att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", clustervars = c("id", "G", "period"), + bstrap = FALSE, faster_mode = TRUE), + "cluster|length 1" + ) +}) + +# ============================================================================= +# aggte() validation errors +# ============================================================================= + +test_that("aggte errors on invalid type", { + mp_tmp <- att_gt(yname = "Y", data = data_eh, tname = "period", idname = "id", + gname = "G", bstrap = FALSE) + expect_error( + aggte(mp_tmp, type = "invalid"), + "must be one of" + ) +}) + +test_that("aggte errors when ATTs contain NA and na.rm=FALSE", { + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + mp_na <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", data = data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE) + )) + # Inject NA directly — no need to rely on estimation failure + mp_na$att[1] <- NA + expect_error( + aggte(mp_na, type = "dynamic", na.rm = FALSE), + "Missing values" + ) +}) + +# ============================================================================= +# Wald pre-test warnings +# ============================================================================= + +test_that("att_gt handles singular covariance for Wald test gracefully", { + # When covariance matrix is singular, W and Wpval should be NULL + # and the result should still be a valid MP object + small_sp <- did::reset.sim() + small_sp$n <- 50 + small_data <- did::build_sim_dataset(small_sp) + result <- suppressWarnings( + att_gt(yname = "Y", data = small_data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE) + ) + expect_s3_class(result, "MP") + # W should be either a valid number or NULL (when singular) + expect_true(is.null(result$W) || is.numeric(result$W)) +}) + +# ============================================================================= +# compute.att_gt warnings +# ============================================================================= + +test_that("att_gt warns on overlap violations", { + # Create data where propensity score is near 1 (near-perfect separation) + sep_data <- data_eh + # Make covariate perfectly predict treatment for some cells + sep_data$X_sep <- ifelse(sep_data$G > 0, 100, -100) + expect_warning( + att_gt(yname = "Y", xformla = ~X_sep, data = sep_data, tname = "period", + idname = "id", gname = "G", est_method = "dr", bstrap = FALSE), + "overlap condition" + ) +}) + +test_that("att_gt warns when no pre-treatment periods for Wald test", { + # Create data where a group is first treated in period 2 (only 1 pre-treatment period) + # With base_period="varying", there may be no pre-treatment ATTs for the Wald test + early_data <- data.frame( + id = rep(1:100, each = 3), + period = rep(1:3, 100), + G = rep(c(rep(2, 50), rep(0, 50)), each = 3), + Y = rnorm(300) + ) + expect_warning( + att_gt(yname = "Y", data = early_data, tname = "period", idname = "id", + gname = "G", bstrap = FALSE), + "pre-treatment|Wald" + ) +}) diff --git a/tests/testthat/test-faster-mode-consistency.R b/tests/testthat/test-faster-mode-consistency.R new file mode 100644 index 00000000..01f3bed7 --- /dev/null +++ b/tests/testthat/test-faster-mode-consistency.R @@ -0,0 +1,153 @@ +# ============================================================================= +# Systematic faster_mode=TRUE vs FALSE consistency tests +# ============================================================================= + +# Shared setup +set.seed(20260401) +sp <- did::reset.sim() +data_fm <- did::build_sim_dataset(sp) + +# Unbalanced version +data_ub <- data_fm[-c(1, 5, 10), ] + +# Helper to compare two att_gt results +compare_modes <- function(res_slow, res_fast, label) { + expect_equal(res_slow$att, res_fast$att, tolerance = 1e-10, label = paste(label, "ATT")) + expect_equal(res_slow$group, res_fast$group, label = paste(label, "group")) + expect_equal(res_slow$t, res_fast$t, label = paste(label, "t")) +} + +# ============================================================================= +# Core grid: est_method x panel_type x control_group x base_period +# ============================================================================= + +# --- Balanced panel --- + +for (em in c("dr", "ipw", "reg")) { + for (cg in c("nevertreated", "notyettreated")) { + for (bp in c("varying", "universal")) { + label <- paste(em, "balanced", cg, bp) + test_that(paste("consistency:", label), { + res_slow <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = em, + control_group = cg, base_period = bp, + faster_mode = FALSE, bstrap = FALSE) + )) + res_fast <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = em, + control_group = cg, base_period = bp, + faster_mode = TRUE, bstrap = FALSE) + )) + compare_modes(res_slow, res_fast, label) + }) + } + } +} + +# --- Unbalanced panel --- + +for (em in c("dr", "ipw", "reg")) { + for (cg in c("nevertreated", "notyettreated")) { + for (bp in c("varying", "universal")) { + label <- paste(em, "unbalanced", cg, bp) + test_that(paste("consistency:", label), { + res_slow <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_ub, tname = "period", + idname = "id", gname = "G", est_method = em, + control_group = cg, base_period = bp, + allow_unbalanced_panel = TRUE, + faster_mode = FALSE, bstrap = FALSE) + )) + res_fast <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_ub, tname = "period", + idname = "id", gname = "G", est_method = em, + control_group = cg, base_period = bp, + allow_unbalanced_panel = TRUE, + faster_mode = TRUE, bstrap = FALSE) + )) + compare_modes(res_slow, res_fast, label) + }) + } + } +} + +# --- Repeated cross sections --- + +for (em in c("dr", "ipw", "reg")) { + for (cg in c("nevertreated", "notyettreated")) { + for (bp in c("varying", "universal")) { + label <- paste(em, "RC", cg, bp) + test_that(paste("consistency:", label), { + res_slow <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = em, + control_group = cg, base_period = bp, + panel = FALSE, + faster_mode = FALSE, bstrap = FALSE) + )) + res_fast <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = em, + control_group = cg, base_period = bp, + panel = FALSE, + faster_mode = TRUE, bstrap = FALSE) + )) + compare_modes(res_slow, res_fast, label) + }) + } + } +} + +# ============================================================================= +# Additional consistency tests beyond the core grid +# ============================================================================= + +test_that("consistency with anticipation > 0", { + res_slow <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = "dr", + anticipation = 1, faster_mode = FALSE, bstrap = FALSE) + )) + res_fast <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = "dr", + anticipation = 1, faster_mode = TRUE, bstrap = FALSE) + )) + compare_modes(res_slow, res_fast, "anticipation=1") +}) + +test_that("consistency without covariates", { + res_slow <- att_gt(yname = "Y", data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = "reg", + faster_mode = FALSE, bstrap = FALSE) + res_fast <- att_gt(yname = "Y", data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = "reg", + faster_mode = TRUE, bstrap = FALSE) + compare_modes(res_slow, res_fast, "no covariates") +}) + +# ============================================================================= +# aggte consistency across modes +# ============================================================================= + +test_that("aggte consistency across faster_mode for all types", { + mp_slow <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = "dr", + faster_mode = FALSE, bstrap = FALSE) + )) + mp_fast <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_fm, tname = "period", + idname = "id", gname = "G", est_method = "dr", + faster_mode = TRUE, bstrap = FALSE) + )) + + for (tp in c("simple", "dynamic", "group", "calendar")) { + agg_slow <- suppressWarnings(aggte(mp_slow, type = tp)) + agg_fast <- suppressWarnings(aggte(mp_fast, type = tp)) + expect_equal(agg_slow$overall.att, agg_fast$overall.att, + tolerance = 1e-8, label = paste("aggte", tp)) + } +}) diff --git a/tests/testthat/test-ggdid.R b/tests/testthat/test-ggdid.R new file mode 100644 index 00000000..00e8a889 --- /dev/null +++ b/tests/testthat/test-ggdid.R @@ -0,0 +1,99 @@ +# ============================================================================= +# Tests for ggdid.MP and ggdid.AGGTEobj plotting functions +# ============================================================================= + +# Shared setup +set.seed(20260401) +sp <- did::reset.sim() +data_gg <- did::build_sim_dataset(sp) + +mp_gg <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_gg, tname = "period", + idname = "id", gname = "G", est_method = "dr", + bstrap = FALSE) +)) + +agg_dyn <- suppressWarnings(aggte(mp_gg, type = "dynamic")) +agg_grp <- suppressWarnings(aggte(mp_gg, type = "group")) +agg_cal <- suppressWarnings(aggte(mp_gg, type = "calendar")) +agg_sim <- suppressWarnings(aggte(mp_gg, type = "simple")) + +# ============================================================================= +# ggdid.MP tests +# ============================================================================= + +test_that("ggdid.MP returns a ggplot object", { + p <- ggdid(mp_gg) + expect_s3_class(p, "gg") +}) + +test_that("ggdid.MP accepts group parameter", { + groups <- unique(mp_gg$group) + p <- ggdid(mp_gg, group = groups[1]) + expect_s3_class(p, "gg") +}) + +test_that("ggdid.MP warns for non-existent group values", { + expect_warning( + ggdid(mp_gg, group = c(9999)), + "do not exist" + ) +}) + +test_that("ggdid.MP accepts custom labels", { + p <- ggdid(mp_gg, xlab = "Time", ylab = "Effect", title = "Test Plot") + expect_s3_class(p, "gg") +}) + +# ============================================================================= +# ggdid.AGGTEobj tests +# ============================================================================= + +test_that("ggdid works for type = dynamic", { + p <- ggdid(agg_dyn) + expect_s3_class(p, "gg") +}) + +test_that("ggdid works for type = group", { + p <- suppressWarnings(ggdid(agg_grp)) + expect_s3_class(p, "gg") +}) + +test_that("ggdid works for type = calendar", { + p <- ggdid(agg_cal) + expect_s3_class(p, "gg") +}) + +test_that("ggdid errors for type = simple", { + expect_error( + ggdid(agg_sim), + "not available" + ) +}) + +test_that("ggdid.AGGTEobj accepts custom labels and theme settings", { + p <- ggdid(agg_dyn, xlab = "Event Time", ylab = "ATT", + title = "Event Study", theming = TRUE, legend = TRUE) + expect_s3_class(p, "gg") +}) + +test_that("ggdid.AGGTEobj works with theming=FALSE", { + p <- ggdid(agg_dyn, theming = FALSE) + expect_s3_class(p, "gg") +}) + +test_that("ggdid.AGGTEobj works with ref_line=NULL", { + p <- ggdid(agg_dyn, ref_line = NULL) + expect_s3_class(p, "gg") +}) + +test_that("splot works for group-type and renders without deprecation warnings", { + # splot() is the path that uses the ggplot2 version-gated errorbar layer. + # It is called via ggdid.AGGTEobj() for type="group". Verify the output is + # a valid ggplot object and no deprecation warnings leak through. + expect_warning(p <- ggdid(agg_grp), regexp = NA) + expect_s3_class(p, "gg") + # Verify the plot actually builds (catches any layer construction errors) + built <- ggplot2::ggplot_build(p) + expect_s3_class(built, "ggplot_built") +}) diff --git a/tests/testthat/test-glance.R b/tests/testthat/test-glance.R new file mode 100644 index 00000000..c91f7864 --- /dev/null +++ b/tests/testthat/test-glance.R @@ -0,0 +1,125 @@ +# ============================================================================= +# Tests for glance.MP and glance.AGGTEobj S3 methods +# ============================================================================= + +# Shared setup: generate MP and AGGTEobj results +set.seed(20260401) +sp <- did::reset.sim() +data_gl <- did::build_sim_dataset(sp) + +mp_slow <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_gl, tname = "period", + idname = "id", gname = "G", est_method = "dr", + bstrap = FALSE, faster_mode = FALSE) +)) + +mp_fast <- suppressWarnings(suppressMessages( + att_gt(yname = "Y", xformla = ~X, data = data_gl, tname = "period", + idname = "id", gname = "G", est_method = "dr", + bstrap = FALSE, faster_mode = TRUE) +)) + +agg_types <- c("simple", "dynamic", "group", "calendar") +agg_slow <- lapply(setNames(agg_types, agg_types), function(tp) { + suppressWarnings(aggte(mp_slow, type = tp)) +}) +agg_fast <- lapply(setNames(agg_types, agg_types), function(tp) { + suppressWarnings(aggte(mp_fast, type = tp)) +}) + +# ============================================================================= +# glance.MP tests +# ============================================================================= + +test_that("glance.MP returns a one-row data.frame", { + gl <- glance(mp_slow) + expect_equal(nrow(gl), 1) + expect_s3_class(gl, "data.frame") +}) + +test_that("glance.MP has expected columns", { + gl <- glance(mp_slow) + expected_cols <- c("nobs", "ngroup", "ntime", "control.group", "est.method") + expect_true(all(expected_cols %in% names(gl))) +}) + +test_that("glance.MP values are reasonable", { + gl <- glance(mp_slow) + expect_true(gl$nobs > 0) + expect_true(gl$ngroup > 0) + expect_true(gl$ntime > 0) + expect_equal(gl$control.group, "nevertreated") + expect_equal(gl$est.method, "dr") +}) + +test_that("glance.MP nobs matches nobs.MP", { + gl <- glance(mp_slow) + expect_equal(gl$nobs, nobs(mp_slow)) +}) + +test_that("glance.MP works with faster_mode = TRUE", { + gl <- glance(mp_fast) + expect_equal(nrow(gl), 1) + expect_true(gl$nobs > 0) + expect_true(gl$ngroup > 0) + expect_true(gl$ntime > 0) +}) + +test_that("glance.MP agrees between faster_mode settings", { + gl_slow <- glance(mp_slow) + gl_fast <- glance(mp_fast) + expect_equal(gl_slow$nobs, gl_fast$nobs) + expect_equal(gl_slow$ngroup, gl_fast$ngroup) + expect_equal(gl_slow$ntime, gl_fast$ntime) + expect_equal(gl_slow$control.group, gl_fast$control.group) + expect_equal(gl_slow$est.method, gl_fast$est.method) +}) + +# ============================================================================= +# glance.AGGTEobj tests +# ============================================================================= + +test_that("glance.AGGTEobj returns a one-row data.frame for all 4 types", { + for (tp in agg_types) { + gl <- glance(agg_slow[[tp]]) + expect_equal(nrow(gl), 1, label = paste("nrow for", tp)) + expect_s3_class(gl, "data.frame") + } +}) + +test_that("glance.AGGTEobj has expected columns", { + gl <- glance(agg_slow[["dynamic"]]) + expected_cols <- c("type", "nobs", "ngroup", "ntime", "control.group", "est.method") + expect_true(all(expected_cols %in% names(gl))) +}) + +test_that("glance.AGGTEobj type column matches requested type", { + for (tp in agg_types) { + gl <- glance(agg_slow[[tp]]) + expect_equal(gl$type, tp, label = paste("type for", tp)) + } +}) + +test_that("glance.AGGTEobj values are not NULL or NA", { + for (tp in agg_types) { + gl <- glance(agg_slow[[tp]]) + expect_false(any(sapply(gl, is.null)), label = paste("no NULLs for", tp)) + expect_false(any(is.na(gl)), label = paste("no NAs for", tp)) + } +}) + +test_that("glance.AGGTEobj works with faster_mode = TRUE", { + for (tp in agg_types) { + gl <- glance(agg_fast[[tp]]) + expect_equal(nrow(gl), 1, label = paste("nrow for", tp, "fast")) + expect_false(any(sapply(gl, is.null)), label = paste("no NULLs for", tp, "fast")) + } +}) + +test_that("glance.MP and glance.AGGTEobj agree on nobs", { + gl_mp <- glance(mp_slow) + for (tp in agg_types) { + gl_agg <- glance(agg_slow[[tp]]) + expect_equal(gl_mp$nobs, gl_agg$nobs, label = paste("nobs for", tp)) + } +}) diff --git a/tests/testthat/test-inference.R b/tests/testthat/test-inference.R index 6a0450e4..dc7ed5ed 100644 --- a/tests/testthat/test-inference.R +++ b/tests/testthat/test-inference.R @@ -30,13 +30,20 @@ same_matrix_elem <- function(A, B) { temp_lib <- tempfile() dir.create(temp_lib) -remotes::install_version("did", version = "2.1.2", lib = temp_lib, repos = "http://cran.us.r-project.org") -# install.packages( -# "https://cran.r-project.org/src/contrib/did_2.1.2.tar.gz", -# repos = NULL, type = "source", lib = temp_lib -# ) +withr::defer(unlink(temp_lib, recursive = TRUE), teardown_env()) + +old_did_available <- FALSE +if (!identical(Sys.getenv("NOT_CRAN"), "false")) { + old_did_available <- tryCatch({ + remotes::install_version("did", version = "2.1.2", lib = temp_lib, + repos = "https://cloud.r-project.org", quiet = TRUE) + isTRUE(requireNamespace("did", lib.loc = temp_lib, quietly = TRUE)) + }, error = function(e) FALSE) +} test_that("inference with balanced panel data and aggregations", { + skip_on_cran() + skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp) @@ -170,6 +177,8 @@ test_that("inference with balanced panel data and aggregations", { test_that("inference with clustering", { + skip_on_cran() + skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp) @@ -298,6 +307,8 @@ test_that("inference with clustering", { }) test_that("same inference with unbalanced panel and panel data", { + skip_on_cran() + skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp) @@ -328,6 +339,8 @@ test_that("same inference with unbalanced panel and panel data", { test_that("inference with repeated cross sections", { + skip_on_cran() + skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp, panel = FALSE) @@ -457,6 +470,8 @@ test_that("inference with repeated cross sections", { test_that("inference with repeated cross sections and clustering", { + skip_on_cran() + skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp, panel = FALSE) @@ -586,6 +601,8 @@ test_that("inference with repeated cross sections and clustering", { test_that("inference with unbalanced panel", { + skip_on_cran() + skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp) data <- data[-3, ] @@ -719,6 +736,8 @@ test_that("inference with unbalanced panel", { }) test_that("inference with unbalanced panel and clustering", { + skip_on_cran() + skip_if(!old_did_available, "did v2.1.2 not available from CRAN") sp <- did::reset.sim() data <- did::build_sim_dataset(sp) data <- data[-3, ] @@ -850,5 +869,3 @@ test_that("inference with unbalanced panel and clustering", { expect_equal(group_2.1.2$se[1], group_new$se[1], tol = .01) expect_equal(cal_2.1.2$se[1], cal_new$se[1], tol = .01) }) - -unlink(temp_lib) diff --git a/tests/testthat/test-jel_replication.R b/tests/testthat/test-jel_replication.R index 10953632..4b3f746e 100644 --- a/tests/testthat/test-jel_replication.R +++ b/tests/testthat/test-jel_replication.R @@ -112,14 +112,14 @@ test_that("JEL Table 7: 2x2 CS-DiD point estimates match", { for (method in c("reg", "ipw", "dr")) { for (wt_info in list(list(name = NULL, label = "unweighted"), list(name = "set_wt", label = "weighted"))) { - res <- att_gt( + res <- suppressWarnings(att_gt( yname = "crude_rate_20_64", tname = "year", idname = "county_code", gname = "treat_year", xformla = covs_formula, data = short_data, panel = TRUE, control_group = "nevertreated", bstrap = FALSE, est_method = method, weightsname = wt_info$name, base_period = "universal" - ) - agg <- aggte(res, na.rm = TRUE, bstrap = FALSE) + )) + agg <- suppressWarnings(aggte(res, na.rm = TRUE, bstrap = FALSE)) key <- paste0(method, "_", wt_info$label) expect_equal(agg$overall.att, expected[[key]], tolerance = 1e-6, @@ -170,11 +170,11 @@ test_that("JEL 2xT: event study ATT(g,t) point estimates match", { expect_equal(mod$att, expected_att, tolerance = 1e-6) # Dynamic aggregation - es <- aggte(mod, type = "dynamic", bstrap = FALSE) + es <- suppressWarnings(aggte(mod, type = "dynamic", bstrap = FALSE)) expect_equal(es$att.egt, expected_att, tolerance = 1e-6) # Overall ATT for e in {0,5} - agg <- aggte(mod, type = "dynamic", min_e = 0, max_e = 5, bstrap = FALSE) + agg <- suppressWarnings(aggte(mod, type = "dynamic", min_e = 0, max_e = 5, bstrap = FALSE)) expect_equal(agg$overall.att, -0.7035462478, tolerance = 1e-6) }) @@ -206,7 +206,7 @@ test_that("JEL 2xT: event study with covariates matches across methods", { base_period = "universal" ) - es <- aggte(res, type = "dynamic", na.rm = TRUE, bstrap = FALSE) + es <- suppressWarnings(aggte(res, type = "dynamic", na.rm = TRUE, bstrap = FALSE)) # Base period (e = -1) should be exactly 0 base_idx <- which(es$egt == -1) @@ -243,7 +243,7 @@ test_that("JEL GxT: staggered event study without covariates matches", { ) # Dynamic aggregation - es <- aggte(mod, type = "dynamic", bstrap = FALSE) + es <- suppressWarnings(aggte(mod, type = "dynamic", bstrap = FALSE)) # Expected dynamic ATT at key event times expected_dynamic <- c( @@ -269,7 +269,7 @@ test_that("JEL GxT: staggered event study without covariates matches", { } # Overall ATT for e in {0,5} - agg <- aggte(mod, type = "dynamic", min_e = 0, max_e = 5, bstrap = FALSE) + agg <- suppressWarnings(aggte(mod, type = "dynamic", min_e = 0, max_e = 5, bstrap = FALSE)) expect_equal(agg$overall.att, 0.0867675805, tolerance = 1e-6, label = "GxT no covs overall ATT (e=0:5)") }) @@ -293,21 +293,21 @@ test_that("JEL GxT: staggered event study with DR covariates matches", { covs_formula <- ~perc_female + perc_white + perc_hispanic + unemp_rate + poverty_rate + median_income - mod <- att_gt( + mod <- suppressWarnings(att_gt( yname = "crude_rate_20_64", tname = "year", idname = "county_code", gname = "treat_year", xformla = covs_formula, data = mydata, panel = TRUE, control_group = "notyettreated", bstrap = FALSE, est_method = "dr", weightsname = "set_wt", base_period = "universal" - ) + )) # Overall ATT for e in {0,5} - agg <- aggte(mod, type = "dynamic", min_e = 0, max_e = 5, bstrap = FALSE) + agg <- suppressWarnings(aggte(mod, type = "dynamic", min_e = 0, max_e = 5, bstrap = FALSE)) expect_equal(agg$overall.att, -2.2469982988, tolerance = 1e-6, label = "GxT DR covs overall ATT (e=0:5)") # Dynamic aggregation at key event times - es <- aggte(mod, type = "dynamic", bstrap = FALSE) + es <- suppressWarnings(aggte(mod, type = "dynamic", bstrap = FALSE)) expected_dynamic <- c( `e=-5` = 2.6684811691, @@ -354,26 +354,26 @@ test_that("JEL: faster_mode matches regular mode", { covs_formula <- ~perc_female + perc_white + perc_hispanic + unemp_rate + poverty_rate + median_income # Test DR weighted - res_slow <- att_gt( + res_slow <- suppressWarnings(att_gt( yname = "crude_rate_20_64", tname = "year", idname = "county_code", gname = "treat_year", xformla = covs_formula, data = short_data, panel = TRUE, control_group = "nevertreated", bstrap = FALSE, est_method = "dr", weightsname = "set_wt", base_period = "universal", faster_mode = FALSE - ) - res_fast <- att_gt( + )) + res_fast <- suppressWarnings(att_gt( yname = "crude_rate_20_64", tname = "year", idname = "county_code", gname = "treat_year", xformla = covs_formula, data = short_data, panel = TRUE, control_group = "nevertreated", bstrap = FALSE, est_method = "dr", weightsname = "set_wt", base_period = "universal", faster_mode = TRUE - ) + )) expect_equal(res_slow$att, res_fast$att, tolerance = 1e-10, label = "2x2 DR weighted: ATTs match") - agg_slow <- aggte(res_slow, na.rm = TRUE, bstrap = FALSE) - agg_fast <- aggte(res_fast, na.rm = TRUE, bstrap = FALSE) + agg_slow <- suppressWarnings(aggte(res_slow, na.rm = TRUE, bstrap = FALSE)) + agg_fast <- suppressWarnings(aggte(res_fast, na.rm = TRUE, bstrap = FALSE)) expect_equal(agg_slow$overall.att, agg_fast$overall.att, tolerance = 1e-10, label = "2x2 DR weighted: aggregate ATT matches") }) diff --git a/tests/testthat/test-user_bug_fixes.R b/tests/testthat/test-user_bug_fixes.R index a080d5cf..c1042ed7 100644 --- a/tests/testthat/test-user_bug_fixes.R +++ b/tests/testthat/test-user_bug_fixes.R @@ -115,23 +115,23 @@ test_that("0 pre-treatment estimates when outcomes are 0", { data <- subset(data, G > 6) data <- subset(data, period > 5) data$Y[(data$period < data$G)] <- 0 # set pre-treatment = 0 - res <- att_gt(yname="Y", + res <- suppressWarnings(att_gt(yname="Y", tname="period", idname="id", gname="G", data=data, control_group = "notyettreated", - base_period="universal") + base_period="universal")) res_idx <- which(res$group==9 & res$t==7) expect_equal(res$att[res_idx],0) - res <- att_gt(yname="Y", + res <- suppressWarnings(att_gt(yname="Y", tname="period", idname="id", gname="G", data=data, control_group = "notyettreated", - base_period="varying") + base_period="varying")) res_idx <- which(res$group==9 & res$t==7) expect_equal(res$att[res_idx],0) })