From 2bc2ba4e5ec61f5e1dfa1f64c8dcf0876d96e643 Mon Sep 17 00:00:00 2001 From: thomaswiemann Date: Tue, 17 Mar 2026 10:28:54 -0500 Subject: [PATCH 1/5] est_method_vars + extra putput --- .gitignore | 4 +- R/DIDparams.R | 2 + R/DIDparams2.R | 2 + R/att_gt.R | 35 +++++++- R/compute.att_gt.R | 27 ++++-- R/compute.att_gt2.R | 53 ++++++++---- R/pre_process_did.R | 4 +- R/pre_process_did2.R | 3 +- tests/testthat/test-att_gt.R | 155 +++++++++++++++++++++++++++++++++++ 9 files changed, 258 insertions(+), 27 deletions(-) diff --git a/.gitignore b/.gitignore index c392361a..a8372b22 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,6 @@ did.Rproj .claude/ .vscode/ .revdep_manual/ -vignettes/*_cache/ \ No newline at end of file +vignettes/*_cache/ +/.discussions +GEMINI.md diff --git a/R/DIDparams.R b/R/DIDparams.R index 7aafd3a5..6e6a08d0 100644 --- a/R/DIDparams.R +++ b/R/DIDparams.R @@ -43,6 +43,7 @@ DIDparams <- function(yname, nT=NULL, tlist=NULL, glist=NULL, + est_method_vars=NULL, call=NULL) { out <- list(yname=yname, @@ -64,6 +65,7 @@ DIDparams <- function(yname, pl=pl, cores=cores, est_method=est_method, + est_method_vars=est_method_vars, base_period=base_period, panel=panel, true_repeated_cross_sections=true_repeated_cross_sections, diff --git a/R/DIDparams2.R b/R/DIDparams2.R index 108b50bf..5bb778e6 100644 --- a/R/DIDparams2.R +++ b/R/DIDparams2.R @@ -16,6 +16,7 @@ DIDparams2 <- function(did_tensors, args, call=NULL) { xformla <- args$xformla # formula of covariates panel <- args$panel est_method <- args$est_method + est_method_vars <- args$est_method_vars bstrap <- args$bstrap biters <- args$biters cband <- args$cband @@ -58,6 +59,7 @@ DIDparams2 <- function(did_tensors, args, call=NULL) { xformla=xformla, panel=panel, est_method=est_method, + est_method_vars=est_method_vars, bstrap=bstrap, biters=biters, cband=cband, diff --git a/R/att_gt.R b/R/att_gt.R index b5663b41..c58bbf09 100644 --- a/R/att_gt.R +++ b/R/att_gt.R @@ -131,6 +131,11 @@ #' the user allows for anticipation) to be equal to 0, but one #' extra estimate in an earlier period. #' +#' @param est_method_vars Optional character vector of column names from `data` +#' to pass through to a custom `est_method` function. These columns are +#' preserved through preprocessing, subsetted to match each (g,t) partition, +#' and passed to `est_method` as an additional `data` argument (a data.frame). +#' Ignored when using built-in estimation methods. Default is `NULL`. #' @param ... Additional arguments to be passed to a custom `est_method` #' function. These are ignored when using built-in estimation methods #' (`"dr"`, `"ipw"`, `"reg"`). @@ -206,6 +211,7 @@ att_gt <- function(yname, print_details = FALSE, pl = FALSE, cores = 1, + est_method_vars = NULL, ...) { # Capture extra arguments for custom est_method extra_args <- list(...) @@ -217,6 +223,22 @@ att_gt <- function(yname, "\". Extra arguments are only passed to custom est_method functions.") } + # Validate est_method_vars + if (!is.null(est_method_vars)) { + if (!is.character(est_method_vars)) { + stop("est_method_vars must be a character vector of column names from data.") + } + missing_emv <- setdiff(est_method_vars, colnames(data)) + if (length(missing_emv) > 0) { + stop("The following est_method_vars are not found in data: ", + paste(missing_emv, collapse = ", "), ".") + } + if (!inherits(est_method, "function")) { + warning("est_method_vars is specified but est_method is not a custom function. ", + "est_method_vars will be ignored.") + } + } + # Validate est_method if (!inherits(est_method, "function")) { if (!is.character(est_method) || length(est_method) != 1) { @@ -255,6 +277,7 @@ att_gt <- function(yname, biters = biters, clustervars = clustervars, est_method = est_method, + est_method_vars = est_method_vars, base_period = base_period, print_details = print_details, faster_mode = faster_mode, @@ -290,6 +313,7 @@ att_gt <- function(yname, biters = biters, clustervars = clustervars, est_method = est_method, + est_method_vars = est_method_vars, base_period = base_period, print_details = print_details, pl = pl, @@ -485,6 +509,13 @@ att_gt <- function(yname, } - # Return this list - return(MP(group = group, t = tt, att = att, V_analytical = V, se = se, c = cval, inffunc = inffunc, n = n, W = W, Wpval = Wpval, alp = alp, DIDparams = dp)) + # Build the MP object + out <- MP(group = group, t = tt, att = att, V_analytical = V, se = se, c = cval, inffunc = inffunc, n = n, W = W, Wpval = Wpval, alp = alp, DIDparams = dp) + + # Attach per-(g,t) extra outputs from custom est_method (if any) + extras <- lapply(attgt.list, function(x) x$extra) + has_extras <- any(!vapply(extras, is.null, logical(1))) + if (has_extras) out$extra_gt <- extras + + return(out) } diff --git a/R/compute.att_gt.R b/R/compute.att_gt.R index bef72c36..aa8eeac4 100644 --- a/R/compute.att_gt.R +++ b/R/compute.att_gt.R @@ -26,6 +26,7 @@ compute.att_gt <- function(dp) { xformla <- dp$xformla weightsname <- dp$weightsname est_method <- dp$est_method + est_method_vars <- dp$est_method_vars extra_args <- if (is.null(dp$extra_args)) list() else dp$extra_args base_period <- dp$base_period panel <- dp$panel @@ -249,13 +250,18 @@ compute.att_gt <- function(dp) { attgt <- tryCatch({ if (inherits(est_method, "function")) { # user-specified function - res <- do.call(est_method, c(list( + base_args <- list( y1 = Ypost, y0 = Ypre, D = G, covariates = covariates, i.weights = w, inffunc = TRUE - ), extra_args)) + ) + # add passthrough variables if specified + if (!is.null(est_method_vars)) { + base_args$data <- disdat[, est_method_vars, with = FALSE] + } + res <- do.call(est_method, c(base_args, extra_args)) } else if (est_method == "ipw") { # inverse-probability weights res <- DRDID::std_ipw_did_panel(Ypost, Ypre, G, @@ -416,14 +422,19 @@ compute.att_gt <- function(dp) { attgt <- tryCatch({ if (inherits(est_method, "function")) { # user-specified function - res <- do.call(est_method, c(list( + base_args <- list( y = Y, post = post, D = G, covariates = covariates, i.weights = w, inffunc = TRUE - ), extra_args)) + ) + # add passthrough variables if specified + if (!is.null(est_method_vars)) { + base_args$data <- disdat[, est_method_vars, with = FALSE] + } + res <- do.call(est_method, c(base_args, extra_args)) } else if (est_method == "ipw") { # inverse-probability weights res <- DRDID::std_ipw_did_rc( @@ -485,9 +496,15 @@ compute.att_gt <- function(dp) { } # end panel if # save results for this att(g,t) - attgt.list[[counter]] <- list( + attgt_entry <- list( att = attgt$ATT, group = glist[g], year = tlist[(t + tfac)], post = post.treat ) + # preserve extra fields from custom est_method only + if (custom_est_method) { + extra <- attgt[!names(attgt) %in% c("ATT", "att.inf.func")] + if (length(extra) > 0) attgt_entry$extra <- extra + } + attgt.list[[counter]] <- attgt_entry # populate the influence function in the right places diff --git a/R/compute.att_gt2.R b/R/compute.att_gt2.R index bfdf3eae..100ae9b6 100644 --- a/R/compute.att_gt2.R +++ b/R/compute.att_gt2.R @@ -91,6 +91,7 @@ get_did_cohort_index <- function(group, time, tfac, pret, dp2){ run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL){ extra_args <- if (is.null(dp2$extra_args)) list() else dp2$extra_args + est_method_vars <- dp2$est_method_vars gt_label <- if (!is.null(g_val) && !is.null(t_val)) paste0(" for group ", g_val, " in time period ", t_val) else "" if(dp2$panel){ @@ -150,13 +151,18 @@ run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL){ if (inherits(dp2$est_method, "function")) { # user-specified function - attgt <- do.call(dp2$est_method, c(list( - y1=cohort_data[, y1], - y0=cohort_data[, y0], - D=cohort_data[, D], - covariates=covariates, - i.weights=cohort_data[, i.weights], - inffunc=TRUE), extra_args)) + base_args <- list( + y1=cohort_data[, y1], + y0=cohort_data[, y0], + D=cohort_data[, D], + covariates=covariates, + i.weights=cohort_data[, i.weights], + inffunc=TRUE) + # add passthrough variables if specified + if (!is.null(est_method_vars)) { + base_args$data <- dp2$time_invariant_data[valid_obs, est_method_vars, with = FALSE] + } + attgt <- do.call(dp2$est_method, c(base_args, extra_args)) } else if (dp2$est_method == "ipw") { # inverse-probability weights attgt <- std_ipw_did_panel(y1=cohort_data[, y1], @@ -252,13 +258,18 @@ run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL){ if (inherits(dp2$est_method, "function")) { # user-specified function - attgt <- do.call(dp2$est_method, c(list( - y=cohort_data[, y], - post=cohort_data[, post], - D=cohort_data[, D], - covariates=covariates, - i.weights=cohort_data[, i.weights], - inffunc=TRUE), extra_args)) + base_args <- list( + y=cohort_data[, y], + post=cohort_data[, post], + D=cohort_data[, D], + covariates=covariates, + i.weights=cohort_data[, i.weights], + inffunc=TRUE) + # add passthrough variables if specified + if (!is.null(est_method_vars)) { + base_args$data <- dp2$time_invariant_data[valid_obs, est_method_vars, with = FALSE] + } + attgt <- do.call(dp2$est_method, c(base_args, extra_args)) } else if (dp2$est_method == "ipw") { # inverse-probability weights attgt <- std_ipw_did_rc(y=cohort_data[, y], @@ -306,7 +317,13 @@ run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL){ } - return(list(att = attgt$ATT, inf_func = inf_func_vector)) + result <- list(att = attgt$ATT, inf_func = inf_func_vector) + # forward extra fields from custom est_method (if any) + if (custom_est_method) { + extra <- attgt[!names(attgt) %in% c("ATT", "att.inf.func")] + if (length(extra) > 0) result$extra <- extra + } + return(result) } @@ -495,8 +512,10 @@ compute.att_gt2 <- function(dp2) { # Save ATT and influence function inffunc_updates <- inf_func - gt_result <- list(att = att, group = dp2$treated_groups[g], year = dp2$time_periods[t+tfac], post = post.treat, inffunc_updates = inffunc_updates) - return(gt_result) + gt_result_out <- list(att = att, group = dp2$treated_groups[g], year = dp2$time_periods[t+tfac], post = post.treat, inffunc_updates = inffunc_updates) + # preserve extra outputs from custom est_method + if (!is.null(gt_result$extra)) gt_result_out$extra <- gt_result$extra + return(gt_result_out) } } diff --git a/R/pre_process_did.R b/R/pre_process_did.R index 3ee13c5e..e9c0f3a9 100644 --- a/R/pre_process_did.R +++ b/R/pre_process_did.R @@ -32,6 +32,7 @@ pre_process_did <- function(yname, faster_mode = FALSE, pl = FALSE, cores = 1, + est_method_vars = NULL, call = NULL) { #----------------------------------------------------------------------------- # Data pre-processing and error checking @@ -79,7 +80,7 @@ pre_process_did <- function(yname, } # drop irrelevant columns from data - data <- cbind.data.frame(data[,c(idname, tname, yname, gname, weightsname, clustervars)], model.frame(xformla, data=data, na.action=na.pass)) + data <- cbind.data.frame(data[,c(idname, tname, yname, gname, weightsname, clustervars, est_method_vars)], model.frame(xformla, data=data, na.action=na.pass)) # check if any covariates were missing n_orig <- nrow(data) @@ -399,6 +400,7 @@ pre_process_did <- function(yname, pl=pl, cores=cores, est_method=est_method, + est_method_vars=est_method_vars, base_period=base_period, panel=panel, true_repeated_cross_sections=true_repeated_cross_sections, diff --git a/R/pre_process_did2.R b/R/pre_process_did2.R index 733fbe56..b589f760 100644 --- a/R/pre_process_did2.R +++ b/R/pre_process_did2.R @@ -107,7 +107,7 @@ validate_args <- function(args, data){ #' @noRd did_standardization <- function(data, args){ # keep relevant columns in data - cols_to_keep <- c(args$idname, args$tname, args$gname, args$yname, args$weightsname, args$clustervars) + cols_to_keep <- c(args$idname, args$tname, args$gname, args$yname, args$weightsname, args$clustervars, args$est_method_vars) model_frame <- model.frame(args$xformla, data = data, na.action = na.pass) # Subset the dataset to keep only the relevant columns @@ -570,6 +570,7 @@ pre_process_did2 <- function(yname, faster_mode=FALSE, pl = FALSE, cores = 1, + est_method_vars = NULL, call = NULL) { diff --git a/tests/testthat/test-att_gt.R b/tests/testthat/test-att_gt.R index 5763d4fe..ea1f87df 100644 --- a/tests/testthat/test-att_gt.R +++ b/tests/testthat/test-att_gt.R @@ -1275,3 +1275,158 @@ test_that("faster_mode time indexing with universal base period", { expect_equal(res_slow$att, res_fast$att) expect_equal(res_slow$se, as.numeric(res_fast$se)) }) + +test_that("est_method_vars passes through variables to custom est_method", { + set.seed(09142024) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + + # Add a column that we want to pass through + data$fold_id <- sample(1:3, nrow(data), replace = TRUE) + + # Custom est_method that checks for the data argument + my_est <- function(y1, y0, D, covariates, i.weights, inffunc, data) { + # Verify data is a data.frame with the right column + stopifnot(is.data.frame(data)) + stopifnot("fold_id" %in% names(data)) + stopifnot(nrow(data) == length(y1)) + + # Use DRDID to compute the actual estimate + DRDID::drdid_imp_panel(y1 = y1, y0 = y0, D = D, + covariates = covariates, + i.weights = i.weights, + inffunc = inffunc) + } + + # faster_mode = TRUE (panel) + res_fast <- att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = my_est, est_method_vars = c("fold_id"), + bstrap = FALSE, cband = FALSE, faster_mode = TRUE) + expect_equal(res_fast$att[1], 1, tol = .5) + + # faster_mode = FALSE (panel) + res_slow <- att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = my_est, est_method_vars = c("fold_id"), + bstrap = FALSE, cband = FALSE, faster_mode = FALSE) + expect_equal(res_slow$att[1], 1, tol = .5) + + # ATTs should be the same across modes + expect_equal(res_fast$att, res_slow$att) +}) + +test_that("est_method_vars works with multiple variables", { + set.seed(09142024) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$fold_id <- sample(1:3, nrow(data), replace = TRUE) + data$stratum <- sample(letters[1:5], nrow(data), replace = TRUE) + + my_est <- function(y1, y0, D, covariates, i.weights, inffunc, data) { + stopifnot(all(c("fold_id", "stratum") %in% names(data))) + DRDID::drdid_imp_panel(y1 = y1, y0 = y0, D = D, + covariates = covariates, + i.weights = i.weights, inffunc = inffunc) + } + + res <- att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = my_est, + est_method_vars = c("fold_id", "stratum"), + bstrap = FALSE, cband = FALSE) + expect_equal(res$att[1], 1, tol = .5) +}) + +test_that("est_method_vars validation errors", { + set.seed(09142024) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + + # Non-existent column + expect_error( + att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = function(...) NULL, + est_method_vars = c("nonexistent_col")), + "not found in data" + ) + + # Non-character input + expect_error( + att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = function(...) NULL, + est_method_vars = 42), + "character vector" + ) + + # Warning when used with built-in est_method + expect_warning( + att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = "dr", + est_method_vars = c("X")), + "not a custom function" + ) +}) + +test_that("extra_gt captures additional outputs from custom est_method", { + set.seed(09142024) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + + # Custom est_method that returns extra fields + my_est_extra <- function(y1, y0, D, covariates, i.weights, inffunc) { + res <- DRDID::drdid_imp_panel(y1 = y1, y0 = y0, D = D, + covariates = covariates, + i.weights = i.weights, + inffunc = inffunc) + # Add extra fields + res$n_treated <- sum(D) + res$n_control <- sum(1 - D) + res$my_diagnostic <- "ok" + res + } + + # faster_mode = TRUE + res_fast <- att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = my_est_extra, + bstrap = FALSE, cband = FALSE, faster_mode = TRUE) + + # extra_gt should exist and be a list + expect_true(!is.null(res_fast$extra_gt)) + expect_true(is.list(res_fast$extra_gt)) + expect_equal(length(res_fast$extra_gt), length(res_fast$att)) + + # Each entry should have the extra fields + first_extra <- res_fast$extra_gt[[1]] + expect_true("n_treated" %in% names(first_extra)) + expect_true("n_control" %in% names(first_extra)) + expect_true("my_diagnostic" %in% names(first_extra)) + expect_equal(first_extra$my_diagnostic, "ok") + + # faster_mode = FALSE + res_slow <- att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = my_est_extra, + bstrap = FALSE, cband = FALSE, faster_mode = FALSE) + + expect_true(!is.null(res_slow$extra_gt)) + expect_equal(length(res_slow$extra_gt), length(res_slow$att)) + expect_true("n_treated" %in% names(res_slow$extra_gt[[1]])) +}) + +test_that("extra_gt is NULL for built-in est_method", { + set.seed(09142024) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + + res <- att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = "dr", bstrap = FALSE, cband = FALSE) + + # Built-in methods should not produce extra_gt + expect_null(res$extra_gt) +}) From 94b75ea69400247641f63fba29ed44f9dbddeb31 Mon Sep 17 00:00:00 2001 From: thomaswiemann Date: Tue, 17 Mar 2026 13:27:15 -0500 Subject: [PATCH 2/5] pass g.t to est_method + docs --- R/compute.att_gt.R | 12 ++++++++++++ R/compute.att_gt2.R | 12 ++++++++++++ man/DIDparams.Rd | 7 +++++++ man/att_gt.Rd | 7 +++++++ man/pre_process_did.Rd | 7 +++++++ man/pre_process_did2.Rd | 7 +++++++ 6 files changed, 52 insertions(+) diff --git a/R/compute.att_gt.R b/R/compute.att_gt.R index aa8eeac4..15481de3 100644 --- a/R/compute.att_gt.R +++ b/R/compute.att_gt.R @@ -257,6 +257,12 @@ compute.att_gt <- function(dp) { i.weights = w, inffunc = TRUE ) + # forward cell identity if est_method can accept it + fmls <- names(formals(est_method)) + if ("g" %in% fmls) { + base_args$g <- glist[g] + base_args$t <- tlist[(t + tfac)] + } # add passthrough variables if specified if (!is.null(est_method_vars)) { base_args$data <- disdat[, est_method_vars, with = FALSE] @@ -430,6 +436,12 @@ compute.att_gt <- function(dp) { i.weights = w, inffunc = TRUE ) + # forward cell identity if est_method can accept it + fmls <- names(formals(est_method)) + if ("g" %in% fmls) { + base_args$g <- glist[g] + base_args$t <- tlist[(t + tfac)] + } # add passthrough variables if specified if (!is.null(est_method_vars)) { base_args$data <- disdat[, est_method_vars, with = FALSE] diff --git a/R/compute.att_gt2.R b/R/compute.att_gt2.R index 100ae9b6..eeffa1cf 100644 --- a/R/compute.att_gt2.R +++ b/R/compute.att_gt2.R @@ -158,6 +158,12 @@ run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL){ covariates=covariates, i.weights=cohort_data[, i.weights], inffunc=TRUE) + # forward cell identity if est_method can accept it + fmls <- names(formals(dp2$est_method)) + if ("g" %in% fmls) { + base_args$g <- g_val + base_args$t <- t_val + } # add passthrough variables if specified if (!is.null(est_method_vars)) { base_args$data <- dp2$time_invariant_data[valid_obs, est_method_vars, with = FALSE] @@ -265,6 +271,12 @@ run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL){ covariates=covariates, i.weights=cohort_data[, i.weights], inffunc=TRUE) + # forward cell identity if est_method can accept it + fmls <- names(formals(dp2$est_method)) + if ("g" %in% fmls) { + base_args$g <- g_val + base_args$t <- t_val + } # add passthrough variables if specified if (!is.null(est_method_vars)) { base_args$data <- dp2$time_invariant_data[valid_obs, est_method_vars, with = FALSE] diff --git a/man/DIDparams.Rd b/man/DIDparams.Rd index 30c0ff65..c22e4ac4 100644 --- a/man/DIDparams.Rd +++ b/man/DIDparams.Rd @@ -32,6 +32,7 @@ DIDparams( nT = NULL, tlist = NULL, glist = NULL, + est_method_vars = NULL, call = NULL ) } @@ -191,6 +192,12 @@ of rows in a panel dataset).} \item{glist}{a vector containing each group} +\item{est_method_vars}{Optional character vector of column names from \code{data} +to pass through to a custom \code{est_method} function. These columns are +preserved through preprocessing, subsetted to match each (g,t) partition, +and passed to \code{est_method} as an additional \code{data} argument (a data.frame). +Ignored when using built-in estimation methods. Default is \code{NULL}.} + \item{call}{Function call to att_gt} } \description{ diff --git a/man/att_gt.Rd b/man/att_gt.Rd index ac311920..d076c13b 100644 --- a/man/att_gt.Rd +++ b/man/att_gt.Rd @@ -27,6 +27,7 @@ att_gt( print_details = FALSE, pl = FALSE, cores = 1, + est_method_vars = NULL, ... ) } @@ -177,6 +178,12 @@ Default is \code{FALSE}.} \item{cores}{The number of cores to use for parallel processing} +\item{est_method_vars}{Optional character vector of column names from \code{data} +to pass through to a custom \code{est_method} function. These columns are +preserved through preprocessing, subsetted to match each (g,t) partition, +and passed to \code{est_method} as an additional \code{data} argument (a data.frame). +Ignored when using built-in estimation methods. Default is \code{NULL}.} + \item{...}{Additional arguments to be passed to a custom \code{est_method} function. These are ignored when using built-in estimation methods (\code{"dr"}, \code{"ipw"}, \code{"reg"}).} diff --git a/man/pre_process_did.Rd b/man/pre_process_did.Rd index 61c27dff..4934e6b6 100644 --- a/man/pre_process_did.Rd +++ b/man/pre_process_did.Rd @@ -27,6 +27,7 @@ pre_process_did( faster_mode = FALSE, pl = FALSE, cores = 1, + est_method_vars = NULL, call = NULL ) } @@ -177,6 +178,12 @@ it is recommended for use with large datasets.} \item{cores}{The number of cores to use for parallel processing} +\item{est_method_vars}{Optional character vector of column names from \code{data} +to pass through to a custom \code{est_method} function. These columns are +preserved through preprocessing, subsetted to match each (g,t) partition, +and passed to \code{est_method} as an additional \code{data} argument (a data.frame). +Ignored when using built-in estimation methods. Default is \code{NULL}.} + \item{call}{Function call to att_gt} } \value{ diff --git a/man/pre_process_did2.Rd b/man/pre_process_did2.Rd index b7b90d3a..1cac978c 100644 --- a/man/pre_process_did2.Rd +++ b/man/pre_process_did2.Rd @@ -27,6 +27,7 @@ pre_process_did2( faster_mode = FALSE, pl = FALSE, cores = 1, + est_method_vars = NULL, call = NULL ) } @@ -177,6 +178,12 @@ it is recommended for use with large datasets.} \item{cores}{The number of cores to use for parallel processing} +\item{est_method_vars}{Optional character vector of column names from \code{data} +to pass through to a custom \code{est_method} function. These columns are +preserved through preprocessing, subsetted to match each (g,t) partition, +and passed to \code{est_method} as an additional \code{data} argument (a data.frame). +Ignored when using built-in estimation methods. Default is \code{NULL}.} + \item{call}{Function call to att_gt} } \value{ From 9a59579cd65bb86a73dd3cd19d4fef576e09c1ef Mon Sep 17 00:00:00 2001 From: thomaswiemann Date: Tue, 17 Mar 2026 13:38:34 -0500 Subject: [PATCH 3/5] cleanup extra_gt addition --- R/MP.R | 7 ++++--- R/att_gt.R | 14 +++++--------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/R/MP.R b/R/MP.R index ffeb0aee..a953045a 100644 --- a/R/MP.R +++ b/R/MP.R @@ -20,13 +20,14 @@ #' @param alp the significance level, default is 0.05 #' @param DIDparams a [`DIDparams`] object. A way to optionally return the parameters #' of the call to [att_gt()] or [conditional_did_pretest()]. +#' @param ... additional named elements to include in the MP object #' #' @return MP object #' @export -MP <- function(group, t, att, V_analytical, se, c, inffunc, n=NULL, W=NULL, Wpval=NULL, aggte=NULL, alp = 0.05, DIDparams=NULL) { - out <- list(group=group, t=t, att=att, V_analytical=V_analytical, se=se, c=c, +MP <- function(group, t, att, V_analytical, se, c, inffunc, n=NULL, W=NULL, Wpval=NULL, aggte=NULL, alp = 0.05, DIDparams=NULL, ...) { + out <- c(list(group=group, t=t, att=att, V_analytical=V_analytical, se=se, c=c, inffunc=inffunc, n=n, W=W, Wpval=Wpval, aggte=aggte, alp = alp, - DIDparams=DIDparams, call=DIDparams$call) + DIDparams=DIDparams, call=DIDparams$call), list(...)) class(out) <- "MP" out } diff --git a/R/att_gt.R b/R/att_gt.R index c58bbf09..2ed68457 100644 --- a/R/att_gt.R +++ b/R/att_gt.R @@ -509,13 +509,9 @@ att_gt <- function(yname, } - # Build the MP object - out <- MP(group = group, t = tt, att = att, V_analytical = V, se = se, c = cval, inffunc = inffunc, n = n, W = W, Wpval = Wpval, alp = alp, DIDparams = dp) - - # Attach per-(g,t) extra outputs from custom est_method (if any) - extras <- lapply(attgt.list, function(x) x$extra) - has_extras <- any(!vapply(extras, is.null, logical(1))) - if (has_extras) out$extra_gt <- extras - - return(out) + # Build the MP object, append extra_gt results form est_method calls if any + extra_gt <- Filter(Negate(is.null), lapply(attgt.list, function(x) x$extra)) + MP(group = group, t = tt, att = att, V_analytical = V, se = se, c = cval, + inffunc = inffunc, n = n, W = W, Wpval = Wpval, alp = alp, + DIDparams = dp, extra_gt = if (length(extra_gt)) extra_gt) } From 70b8df9a9d7596dba79888e3e659554e0063c157 Mon Sep 17 00:00:00 2001 From: thomaswiemann Date: Tue, 17 Mar 2026 13:57:18 -0500 Subject: [PATCH 4/5] cleanup extra results --- R/compute.att_gt.R | 12 ++++-------- R/compute.att_gt2.R | 6 ++---- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/R/compute.att_gt.R b/R/compute.att_gt.R index 15481de3..4374f498 100644 --- a/R/compute.att_gt.R +++ b/R/compute.att_gt.R @@ -508,15 +508,11 @@ compute.att_gt <- function(dp) { } # end panel if # save results for this att(g,t) - attgt_entry <- list( - att = attgt$ATT, group = glist[g], year = tlist[(t + tfac)], post = post.treat + extra <- if (custom_est_method) attgt[!names(attgt) %in% c("ATT", "att.inf.func")] else NULL + if (!length(extra)) extra <- NULL + attgt.list[[counter]] <- list( + att = attgt$ATT, group = glist[g], year = tlist[(t + tfac)], post = post.treat, extra = extra ) - # preserve extra fields from custom est_method only - if (custom_est_method) { - extra <- attgt[!names(attgt) %in% c("ATT", "att.inf.func")] - if (length(extra) > 0) attgt_entry$extra <- extra - } - attgt.list[[counter]] <- attgt_entry # populate the influence function in the right places diff --git a/R/compute.att_gt2.R b/R/compute.att_gt2.R index eeffa1cf..7f1eb777 100644 --- a/R/compute.att_gt2.R +++ b/R/compute.att_gt2.R @@ -524,10 +524,8 @@ compute.att_gt2 <- function(dp2) { # Save ATT and influence function inffunc_updates <- inf_func - gt_result_out <- list(att = att, group = dp2$treated_groups[g], year = dp2$time_periods[t+tfac], post = post.treat, inffunc_updates = inffunc_updates) - # preserve extra outputs from custom est_method - if (!is.null(gt_result$extra)) gt_result_out$extra <- gt_result$extra - return(gt_result_out) + gt_result <- list(att = att, group = dp2$treated_groups[g], year = dp2$time_periods[t+tfac], post = post.treat, inffunc_updates = inffunc_updates, extra = gt_result$extra) + return(gt_result) } } From 405e4d3054cbddce3c676bd4a4416ed0e0e77898 Mon Sep 17 00:00:00 2001 From: thomaswiemann Date: Tue, 17 Mar 2026 13:58:38 -0500 Subject: [PATCH 5/5] revert .gitignore --- .gitignore | 2 -- 1 file changed, 2 deletions(-) diff --git a/.gitignore b/.gitignore index a8372b22..a1e44ce9 100644 --- a/.gitignore +++ b/.gitignore @@ -17,5 +17,3 @@ did.Rproj .vscode/ .revdep_manual/ vignettes/*_cache/ -/.discussions -GEMINI.md