From b79ade97bda956df356531ecef9010d33d8a2aa2 Mon Sep 17 00:00:00 2001 From: Ali Date: Mon, 23 Jun 2025 17:23:42 +0200 Subject: [PATCH 01/10] comment out legacy code to avoid overlap of function definition --- R/case_cohort_backup.R | 392 ++++++++++++++++++++--------------------- 1 file changed, 196 insertions(+), 196 deletions(-) diff --git a/R/case_cohort_backup.R b/R/case_cohort_backup.R index d1966da..7e699a4 100644 --- a/R/case_cohort_backup.R +++ b/R/case_cohort_backup.R @@ -1,203 +1,203 @@ -## DO NOT USE: This is a legacy version of case_cohort functions, not compatible。 -## prepare calculation for case-cohort design at ONE site -# the purpose of this function is to "pre-calculate" the weight before calculating the log-likelihood -# this would accelerate the subsequent calculation of log-likelihood -# currently, we only provide Prentice weight; more options will be provided later -prepare_case_cohort <- function(ipdata, method, full_cohort_size){ - # for each site, pre-calculate the failure time points, the risk sets, and the respective weights - - covariate <- as.matrix(ipdata[, -c(1:3)]) - # find over which position lies the failure times - failure_position <- which(ipdata$status == 1) - # find failure times - failure_times <- ipdata$time[which(ipdata$status == 1)] - # the number of failures - failure_num <- length(failure_times) - # full_cohort_size = nrow(ipdata) # ? - Barlow_IPW <- full_cohort_size / sum(ipdata$subcohort) - - risk_size <- 0 - risk_sets <- as.list(rep(NA, failure_num )) - risk_set_weights <- as.list(rep(NA, failure_num )) - for(j in 1:failure_num){ - my_risk_set1 <- which((ipdata$subcohort == 1) & (ipdata$time >= failure_times[j])) - risk_size <- risk_size + length(my_risk_set1) - if(method == "Prentice"){ - my_weight1 <- rep(1, length(my_risk_set1)) - } - if(method == "Barlow"){ - my_weight1 <- rep(Barlow_IPW, length(my_risk_set1)) - } - if(ipdata$subcohort[which(ipdata$time == failure_times[j])] == 0){ - my_risk_set2 <- which(ipdata$time == failure_times[j]) - my_weight2 <- 1 - }else{ - my_risk_set2 <- c() - my_weight2 <- c() - if(method == "Barlow"){ - my_weight1[which(ipdata$time[my_risk_set1] == failure_times[j])] <- 1 - } - } - risk_sets[[j]] <- c(my_risk_set1, my_risk_set2) - risk_set_weights[[j]] <- c(my_weight1, my_weight2) - } - - return(list(# ipdata = ipdata, - full_cohort_size = full_cohort_size, - covariate = covariate, - failure_position = failure_position, - failure_num = failure_num, - risk_sets = risk_sets, - risk_set_weights = risk_set_weights )) -} - -# this function calculate the log pseudo-likelihood for ONE site -# cc_prep is the output of prepare_case_cohort() -log_plk <- function(beta, cc_prep){ - X = cc_prep$covariate # ipdata[,-c(1:2)] - failure_position = cc_prep$failure_position - failure_num = cc_prep$failure_num - risk_sets = cc_prep$risk_sets - risk_set_weights = cc_prep$risk_set_weights - - numerator_terms <- c(X[failure_position, ] %*% beta) - res <- sum(numerator_terms) - for(j in 1:failure_num){ - temp_term <- sum(c(exp(X[risk_sets[[j]], ] %*% beta )) * risk_set_weights[[j]]) - if(temp_term > 0){ - res <- res - log(temp_term) - }else{ - res <- res - numerator_terms[j] - } - } - return(res) -} +# ## DO NOT USE: This is a legacy version of case_cohort functions, not compatible。 +# ## prepare calculation for case-cohort design at ONE site +# # the purpose of this function is to "pre-calculate" the weight before calculating the log-likelihood +# # this would accelerate the subsequent calculation of log-likelihood +# # currently, we only provide Prentice weight; more options will be provided later +# prepare_case_cohort <- function(ipdata, method, full_cohort_size){ +# # for each site, pre-calculate the failure time points, the risk sets, and the respective weights + +# covariate <- as.matrix(ipdata[, -c(1:3)]) +# # find over which position lies the failure times +# failure_position <- which(ipdata$status == 1) +# # find failure times +# failure_times <- ipdata$time[which(ipdata$status == 1)] +# # the number of failures +# failure_num <- length(failure_times) +# # full_cohort_size = nrow(ipdata) # ? +# Barlow_IPW <- full_cohort_size / sum(ipdata$subcohort) + +# risk_size <- 0 +# risk_sets <- as.list(rep(NA, failure_num )) +# risk_set_weights <- as.list(rep(NA, failure_num )) +# for(j in 1:failure_num){ +# my_risk_set1 <- which((ipdata$subcohort == 1) & (ipdata$time >= failure_times[j])) +# risk_size <- risk_size + length(my_risk_set1) +# if(method == "Prentice"){ +# my_weight1 <- rep(1, length(my_risk_set1)) +# } +# if(method == "Barlow"){ +# my_weight1 <- rep(Barlow_IPW, length(my_risk_set1)) +# } +# if(ipdata$subcohort[which(ipdata$time == failure_times[j])] == 0){ +# my_risk_set2 <- which(ipdata$time == failure_times[j]) +# my_weight2 <- 1 +# }else{ +# my_risk_set2 <- c() +# my_weight2 <- c() +# if(method == "Barlow"){ +# my_weight1[which(ipdata$time[my_risk_set1] == failure_times[j])] <- 1 +# } +# } +# risk_sets[[j]] <- c(my_risk_set1, my_risk_set2) +# risk_set_weights[[j]] <- c(my_weight1, my_weight2) +# } + +# return(list(# ipdata = ipdata, +# full_cohort_size = full_cohort_size, +# covariate = covariate, +# failure_position = failure_position, +# failure_num = failure_num, +# risk_sets = risk_sets, +# risk_set_weights = risk_set_weights )) +# } + +# # this function calculate the log pseudo-likelihood for ONE site +# # cc_prep is the output of prepare_case_cohort() +# log_plk <- function(beta, cc_prep){ +# X = cc_prep$covariate # ipdata[,-c(1:2)] +# failure_position = cc_prep$failure_position +# failure_num = cc_prep$failure_num +# risk_sets = cc_prep$risk_sets +# risk_set_weights = cc_prep$risk_set_weights + +# numerator_terms <- c(X[failure_position, ] %*% beta) +# res <- sum(numerator_terms) +# for(j in 1:failure_num){ +# temp_term <- sum(c(exp(X[risk_sets[[j]], ] %*% beta )) * risk_set_weights[[j]]) +# if(temp_term > 0){ +# res <- res - log(temp_term) +# }else{ +# res <- res - numerator_terms[j] +# } +# } +# return(res) +# } -# this function calculate the gradient of log pseudo-likelihood for ONE site -# cc_prep is the output of prepare_case_cohort() -grad_plk <- function(beta, cc_prep){ - X = cc_prep$covariate # ipdata[,-c(1:2)] - failure_position = cc_prep$failure_position - failure_num = cc_prep$failure_num - risk_sets = cc_prep$risk_sets - risk_set_weights = cc_prep$risk_set_weights - - res = colSums(X[failure_position,]) - for(j in 1:failure_num){ - if(length(risk_sets[[j]]) > 1){ - temp_scalar <- c(exp(X[risk_sets[[j]], ] %*% beta )) * risk_set_weights[[j]] - res <- res - apply(sweep(X[risk_sets[[j]], ], 1, temp_scalar, "*"), 2, sum) / sum(temp_scalar) - }else{ - res <- res - X[risk_sets[[j]], ] * risk_set_weights[[j]] - } - } - return(res) -} +# # this function calculate the gradient of log pseudo-likelihood for ONE site +# # cc_prep is the output of prepare_case_cohort() +# grad_plk <- function(beta, cc_prep){ +# X = cc_prep$covariate # ipdata[,-c(1:2)] +# failure_position = cc_prep$failure_position +# failure_num = cc_prep$failure_num +# risk_sets = cc_prep$risk_sets +# risk_set_weights = cc_prep$risk_set_weights + +# res = colSums(X[failure_position,]) +# for(j in 1:failure_num){ +# if(length(risk_sets[[j]]) > 1){ +# temp_scalar <- c(exp(X[risk_sets[[j]], ] %*% beta )) * risk_set_weights[[j]] +# res <- res - apply(sweep(X[risk_sets[[j]], ], 1, temp_scalar, "*"), 2, sum) / sum(temp_scalar) +# }else{ +# res <- res - X[risk_sets[[j]], ] * risk_set_weights[[j]] +# } +# } +# return(res) +# } -# this function calculate the Hessian of log pseudo-likelihood for ONE site -# cc_prep is the output of prepare_case_cohort -hess_plk <- function(beta, cc_prep){ - X = cc_prep$covariate # ipdata[,-c(1:2)] - failure_position = cc_prep$failure_position - failure_num = cc_prep$failure_num - risk_sets = cc_prep$risk_sets - risk_set_weights = cc_prep$risk_set_weights - - d <- length(beta) - if(length(risk_sets[[1]]) > 1){ - temp_scalar <- c(exp(X[risk_sets[[1]], ] %*% beta )) * risk_set_weights[[1]] - temp_vec <- apply(sweep(X[risk_sets[[1]], ], 1, temp_scalar, "*"), 2, sum) - temp_mat <- sweep(X[risk_sets[[1]], ], 1, sqrt(temp_scalar), "*") - res <- temp_vec %*% t(temp_vec) / (sum(temp_scalar))^2 - crossprod(temp_mat) / sum(temp_scalar) - }else{ - res <- matrix(0, d, d) - } - if(failure_num > 1){ - for(j in 2:failure_num){ - if(length(risk_sets[[j]]) > 1){ - temp_scalar <- c(exp(X[risk_sets[[j]], ] %*% beta )) * risk_set_weights[[j]] - temp_vec <- apply(sweep(X[risk_sets[[j]], ], 1, temp_scalar, "*"), 2, sum) - temp_mat <- sweep(X[risk_sets[[j]], ], 1, sqrt(temp_scalar), "*") - res <- res - crossprod(temp_mat) / sum(temp_scalar) + temp_vec %*% t(temp_vec) / (sum(temp_scalar))^2 - } - } - } - - return(res) -} +# # this function calculate the Hessian of log pseudo-likelihood for ONE site +# # cc_prep is the output of prepare_case_cohort +# hess_plk <- function(beta, cc_prep){ +# X = cc_prep$covariate # ipdata[,-c(1:2)] +# failure_position = cc_prep$failure_position +# failure_num = cc_prep$failure_num +# risk_sets = cc_prep$risk_sets +# risk_set_weights = cc_prep$risk_set_weights + +# d <- length(beta) +# if(length(risk_sets[[1]]) > 1){ +# temp_scalar <- c(exp(X[risk_sets[[1]], ] %*% beta )) * risk_set_weights[[1]] +# temp_vec <- apply(sweep(X[risk_sets[[1]], ], 1, temp_scalar, "*"), 2, sum) +# temp_mat <- sweep(X[risk_sets[[1]], ], 1, sqrt(temp_scalar), "*") +# res <- temp_vec %*% t(temp_vec) / (sum(temp_scalar))^2 - crossprod(temp_mat) / sum(temp_scalar) +# }else{ +# res <- matrix(0, d, d) +# } +# if(failure_num > 1){ +# for(j in 2:failure_num){ +# if(length(risk_sets[[j]]) > 1){ +# temp_scalar <- c(exp(X[risk_sets[[j]], ] %*% beta )) * risk_set_weights[[j]] +# temp_vec <- apply(sweep(X[risk_sets[[j]], ], 1, temp_scalar, "*"), 2, sum) +# temp_mat <- sweep(X[risk_sets[[j]], ], 1, sqrt(temp_scalar), "*") +# res <- res - crossprod(temp_mat) / sum(temp_scalar) + temp_vec %*% t(temp_vec) / (sum(temp_scalar))^2 +# } +# } +# } + +# return(res) +# } -# this function fits Cox PH to case-cohort (survival::cch) with the pooled multi-site data -# notice this assumes varying baseline hazard functions across sites -#' @export -cch_pooled <- function(formula, data, subcoh='subcohort', site='site', variables_lev, - full_cohort_size, method = "Prentice", optim_method = "BFGS", - var_sandwich=T){ - n = nrow(data) - site_uniq = unique(data[,site]) - mf <- model.frame(formula, data, xlev=variables_lev) - - ipdata = data.table::data.table(site=data[,site], - time=as.numeric(model.response(mf))[1:n], - status=as.numeric(model.response(mf))[-c(1:n)], - subcohort = data[,subcoh], - model.matrix(formula, mf)[,-1]) - ipdata = data.table(data.frame(ipdata)) - risk_factor = colnames(ipdata)[-c(1:4)] - - # notice here we allow data degeneration (e.g. missing categories in some site) - px = ncol(ipdata)-4 - initial_beta = rep(0, px) - names(initial_beta) = names(ipdata)[-c(1:4)] - pool_fun <- function(beta) sum(sapply(site_uniq, function(site_id) - log_plk(beta, prepare_case_cohort(ipdata[site==site_id,-'site'], method, full_cohort_size[site_id])))) - - result <- optim(par = initial_beta, fn = pool_fun, - control = list(fnscale = -1), method = optim_method, hessian = T) - b_pooled = result$par - - # calculate sandwich var estimate, degenerated data columns are given 0 coefs - if(var_sandwich==T){ - block1 <- result$hessian - block2 <- NULL - # data_split <- split(data, data$site) - data_split <- split(ipdata, ipdata$site) - - for(i in 1:length(site_uniq)){ - site_id <- site_uniq[i] - ipdata_i = data_split[[i]] - col_deg = apply(ipdata_i[,-c(1:4)],2,var)==0 # degenerated X columns... - ipdata_i = ipdata_i[,-(which(col_deg)+4),with=F] - # use coxph(Surv(time_in, time, status)~.) to do cch... - precision <- min(diff(sort(ipdata_i$time))) / 2 # - ipdata_i$time_in = 0 - ipdata_i[ipdata_i$subcohort == 0, "time_in"] <- ipdata_i[ipdata_i$subcohort == 0, "time"] - precision - - # formula_i <- as.formula(paste("Surv(time, status) ~", paste(risk_factor[!col_deg], collapse = "+"))) - # cch_i <- survival::cch(formula_i, data = cbind(ID=1:nrow(ipdata_i), ipdata_i), - # subcoh = ~subcohort, id = ~ID, - # cohort.size = full_cohort_size[site_id], method = method) - formula_i <- as.formula(paste("Surv(time_in, time, status) ~", paste(risk_factor[!col_deg], collapse = "+"), '+ cluster(ID)')) - # cch_i <- tryCatch(coxph(formula_i, data=cbind(ID=1:nrow(ipdata_i), ipdata_i), robust=T), error=function(e) NULL) - cch_i <- tryCatch(coxph(formula_i, data=cbind(ID=1:nrow(ipdata_i), ipdata_i), init=b_pooled[!col_deg], iter=0), error=function(e) NULL) - score_resid <- resid(cch_i, type = "score") # n x p matrix - S_i = matrix(0, px, px) # this is the meat in sandwich var - S_i[!col_deg, !col_deg] <- crossprod(score_resid) - - # local_hess <- hess_plk(b_pooled[!col_deg], # cch_i$coefficients, - # prepare_case_cohort(ipdata_i[,-c('site','time_in')], - # method, full_cohort_size[site_id])) - # tmp = matrix(0, px, px) - # tmp[!col_deg,!col_deg] <- local_hess %*% cch_i$var %*% local_hess - block2[[i]] <- S_i - } - - var <- solve(block1) %*% Reduce("+", block2) %*% solve(block1) - result$var <- var # this is the output for variance estimates - } - - return(result) -} +# # this function fits Cox PH to case-cohort (survival::cch) with the pooled multi-site data +# # notice this assumes varying baseline hazard functions across sites +# #' @export +# cch_pooled <- function(formula, data, subcoh='subcohort', site='site', variables_lev, +# full_cohort_size, method = "Prentice", optim_method = "BFGS", +# var_sandwich=T){ +# n = nrow(data) +# site_uniq = unique(data[,site]) +# mf <- model.frame(formula, data, xlev=variables_lev) + +# ipdata = data.table::data.table(site=data[,site], +# time=as.numeric(model.response(mf))[1:n], +# status=as.numeric(model.response(mf))[-c(1:n)], +# subcohort = data[,subcoh], +# model.matrix(formula, mf)[,-1]) +# ipdata = data.table(data.frame(ipdata)) +# risk_factor = colnames(ipdata)[-c(1:4)] + +# # notice here we allow data degeneration (e.g. missing categories in some site) +# px = ncol(ipdata)-4 +# initial_beta = rep(0, px) +# names(initial_beta) = names(ipdata)[-c(1:4)] +# pool_fun <- function(beta) sum(sapply(site_uniq, function(site_id) +# log_plk(beta, prepare_case_cohort(ipdata[site==site_id,-'site'], method, full_cohort_size[site_id])))) + +# result <- optim(par = initial_beta, fn = pool_fun, +# control = list(fnscale = -1), method = optim_method, hessian = T) +# b_pooled = result$par + +# # calculate sandwich var estimate, degenerated data columns are given 0 coefs +# if(var_sandwich==T){ +# block1 <- result$hessian +# block2 <- NULL +# # data_split <- split(data, data$site) +# data_split <- split(ipdata, ipdata$site) + +# for(i in 1:length(site_uniq)){ +# site_id <- site_uniq[i] +# ipdata_i = data_split[[i]] +# col_deg = apply(ipdata_i[,-c(1:4)],2,var)==0 # degenerated X columns... +# ipdata_i = ipdata_i[,-(which(col_deg)+4),with=F] +# # use coxph(Surv(time_in, time, status)~.) to do cch... +# precision <- min(diff(sort(ipdata_i$time))) / 2 # +# ipdata_i$time_in = 0 +# ipdata_i[ipdata_i$subcohort == 0, "time_in"] <- ipdata_i[ipdata_i$subcohort == 0, "time"] - precision + +# # formula_i <- as.formula(paste("Surv(time, status) ~", paste(risk_factor[!col_deg], collapse = "+"))) +# # cch_i <- survival::cch(formula_i, data = cbind(ID=1:nrow(ipdata_i), ipdata_i), +# # subcoh = ~subcohort, id = ~ID, +# # cohort.size = full_cohort_size[site_id], method = method) +# formula_i <- as.formula(paste("Surv(time_in, time, status) ~", paste(risk_factor[!col_deg], collapse = "+"), '+ cluster(ID)')) +# # cch_i <- tryCatch(coxph(formula_i, data=cbind(ID=1:nrow(ipdata_i), ipdata_i), robust=T), error=function(e) NULL) +# cch_i <- tryCatch(coxph(formula_i, data=cbind(ID=1:nrow(ipdata_i), ipdata_i), init=b_pooled[!col_deg], iter=0), error=function(e) NULL) +# score_resid <- resid(cch_i, type = "score") # n x p matrix +# S_i = matrix(0, px, px) # this is the meat in sandwich var +# S_i[!col_deg, !col_deg] <- crossprod(score_resid) + +# # local_hess <- hess_plk(b_pooled[!col_deg], # cch_i$coefficients, +# # prepare_case_cohort(ipdata_i[,-c('site','time_in')], +# # method, full_cohort_size[site_id])) +# # tmp = matrix(0, px, px) +# # tmp[!col_deg,!col_deg] <- local_hess %*% cch_i$var %*% local_hess +# block2[[i]] <- S_i +# } + +# var <- solve(block1) %*% Reduce("+", block2) %*% solve(block1) +# result$var <- var # this is the output for variance estimates +# } + +# return(result) +# } \ No newline at end of file From 8e93c09795fdb0912dcfe8e1c0891a91008c8039 Mon Sep 17 00:00:00 2001 From: Ali Date: Mon, 23 Jun 2025 17:24:04 +0200 Subject: [PATCH 02/10] export cpp functions --- R/RcppExports.R | 16 ++++++++++ src/RcppExports.cpp | 77 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/R/RcppExports.R b/R/RcppExports.R index cfddead..a21c20c 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -29,3 +29,19 @@ rcpp_aggregate <- function(x, indices, simplify = TRUE, cumulative = FALSE, reve .Call('_pda_rcpp_aggregate', PACKAGE = 'pda', x, indices, simplify, cumulative, reversely) } +rcpp_cc_log_plk <- function(beta, covariate_list, failure_position, failure_num, risk_sets, risk_set_weights, site_num) { + .Call('_pda_rcpp_cc_log_plk', PACKAGE = 'pda', beta, covariate_list, failure_position, failure_num, risk_sets, risk_set_weights, site_num) +} + +rcpp_cc_pool_fun <- function(beta, covariate_list, failure_position, failure_num, risk_sets, risk_set_weights, K) { + .Call('_pda_rcpp_cc_pool_fun', PACKAGE = 'pda', beta, covariate_list, failure_position, failure_num, risk_sets, risk_set_weights, K) +} + +rcpp_cc_grad_plk <- function(beta, covariate_list, failure_position, failure_num, risk_sets, risk_set_weights, site_num) { + .Call('_pda_rcpp_cc_grad_plk', PACKAGE = 'pda', beta, covariate_list, failure_position, failure_num, risk_sets, risk_set_weights, site_num) +} + +rcpp_cc_hess_plk <- function(beta, covariate_list, failure_position, failure_num, risk_sets, risk_set_weights, site_num) { + .Call('_pda_rcpp_cc_hess_plk', PACKAGE = 'pda', beta, covariate_list, failure_position, failure_num, risk_sets, risk_set_weights, site_num) +} + diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 83ac372..16e5548 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -6,6 +6,11 @@ using namespace Rcpp; +#ifdef RCPP_USE_GLOBAL_ROSTREAM +Rcpp::Rostream& Rcpp::Rcout = Rcpp::Rcpp_cout_get(); +Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); +#endif + // rcpp_coxph_logL double rcpp_coxph_logL(const arma::vec& beta, const arma::vec& time, const arma::vec& event, const arma::mat& z); RcppExport SEXP _pda_rcpp_coxph_logL(SEXP betaSEXP, SEXP timeSEXP, SEXP eventSEXP, SEXP zSEXP) { @@ -108,6 +113,74 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// rcpp_cc_log_plk +double rcpp_cc_log_plk(NumericVector beta, List covariate_list, List failure_position, IntegerVector failure_num, List risk_sets, List risk_set_weights, int site_num); +RcppExport SEXP _pda_rcpp_cc_log_plk(SEXP betaSEXP, SEXP covariate_listSEXP, SEXP failure_positionSEXP, SEXP failure_numSEXP, SEXP risk_setsSEXP, SEXP risk_set_weightsSEXP, SEXP site_numSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< NumericVector >::type beta(betaSEXP); + Rcpp::traits::input_parameter< List >::type covariate_list(covariate_listSEXP); + Rcpp::traits::input_parameter< List >::type failure_position(failure_positionSEXP); + Rcpp::traits::input_parameter< IntegerVector >::type failure_num(failure_numSEXP); + Rcpp::traits::input_parameter< List >::type risk_sets(risk_setsSEXP); + Rcpp::traits::input_parameter< List >::type risk_set_weights(risk_set_weightsSEXP); + Rcpp::traits::input_parameter< int >::type site_num(site_numSEXP); + rcpp_result_gen = Rcpp::wrap(rcpp_cc_log_plk(beta, covariate_list, failure_position, failure_num, risk_sets, risk_set_weights, site_num)); + return rcpp_result_gen; +END_RCPP +} +// rcpp_cc_pool_fun +double rcpp_cc_pool_fun(NumericVector beta, List covariate_list, List failure_position, IntegerVector failure_num, List risk_sets, List risk_set_weights, int K); +RcppExport SEXP _pda_rcpp_cc_pool_fun(SEXP betaSEXP, SEXP covariate_listSEXP, SEXP failure_positionSEXP, SEXP failure_numSEXP, SEXP risk_setsSEXP, SEXP risk_set_weightsSEXP, SEXP KSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< NumericVector >::type beta(betaSEXP); + Rcpp::traits::input_parameter< List >::type covariate_list(covariate_listSEXP); + Rcpp::traits::input_parameter< List >::type failure_position(failure_positionSEXP); + Rcpp::traits::input_parameter< IntegerVector >::type failure_num(failure_numSEXP); + Rcpp::traits::input_parameter< List >::type risk_sets(risk_setsSEXP); + Rcpp::traits::input_parameter< List >::type risk_set_weights(risk_set_weightsSEXP); + Rcpp::traits::input_parameter< int >::type K(KSEXP); + rcpp_result_gen = Rcpp::wrap(rcpp_cc_pool_fun(beta, covariate_list, failure_position, failure_num, risk_sets, risk_set_weights, K)); + return rcpp_result_gen; +END_RCPP +} +// rcpp_cc_grad_plk +NumericVector rcpp_cc_grad_plk(NumericVector beta, List covariate_list, List failure_position, IntegerVector failure_num, List risk_sets, List risk_set_weights, int site_num); +RcppExport SEXP _pda_rcpp_cc_grad_plk(SEXP betaSEXP, SEXP covariate_listSEXP, SEXP failure_positionSEXP, SEXP failure_numSEXP, SEXP risk_setsSEXP, SEXP risk_set_weightsSEXP, SEXP site_numSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< NumericVector >::type beta(betaSEXP); + Rcpp::traits::input_parameter< List >::type covariate_list(covariate_listSEXP); + Rcpp::traits::input_parameter< List >::type failure_position(failure_positionSEXP); + Rcpp::traits::input_parameter< IntegerVector >::type failure_num(failure_numSEXP); + Rcpp::traits::input_parameter< List >::type risk_sets(risk_setsSEXP); + Rcpp::traits::input_parameter< List >::type risk_set_weights(risk_set_weightsSEXP); + Rcpp::traits::input_parameter< int >::type site_num(site_numSEXP); + rcpp_result_gen = Rcpp::wrap(rcpp_cc_grad_plk(beta, covariate_list, failure_position, failure_num, risk_sets, risk_set_weights, site_num)); + return rcpp_result_gen; +END_RCPP +} +// rcpp_cc_hess_plk +NumericMatrix rcpp_cc_hess_plk(NumericVector beta, List covariate_list, List failure_position, IntegerVector failure_num, List risk_sets, List risk_set_weights, int site_num); +RcppExport SEXP _pda_rcpp_cc_hess_plk(SEXP betaSEXP, SEXP covariate_listSEXP, SEXP failure_positionSEXP, SEXP failure_numSEXP, SEXP risk_setsSEXP, SEXP risk_set_weightsSEXP, SEXP site_numSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< NumericVector >::type beta(betaSEXP); + Rcpp::traits::input_parameter< List >::type covariate_list(covariate_listSEXP); + Rcpp::traits::input_parameter< List >::type failure_position(failure_positionSEXP); + Rcpp::traits::input_parameter< IntegerVector >::type failure_num(failure_numSEXP); + Rcpp::traits::input_parameter< List >::type risk_sets(risk_setsSEXP); + Rcpp::traits::input_parameter< List >::type risk_set_weights(risk_set_weightsSEXP); + Rcpp::traits::input_parameter< int >::type site_num(site_numSEXP); + rcpp_result_gen = Rcpp::wrap(rcpp_cc_hess_plk(beta, covariate_list, failure_position, failure_num, risk_sets, risk_set_weights, site_num)); + return rcpp_result_gen; +END_RCPP +} static const R_CallMethodDef CallEntries[] = { {"_pda_rcpp_coxph_logL", (DL_FUNC) &_pda_rcpp_coxph_logL, 4}, @@ -117,6 +190,10 @@ static const R_CallMethodDef CallEntries[] = { {"_pda_rcpp_coxph_logL_gradient_efron", (DL_FUNC) &_pda_rcpp_coxph_logL_gradient_efron, 4}, {"_pda_rcpp_coxph_logL_gradient_efron_dist", (DL_FUNC) &_pda_rcpp_coxph_logL_gradient_efron_dist, 7}, {"_pda_rcpp_aggregate", (DL_FUNC) &_pda_rcpp_aggregate, 5}, + {"_pda_rcpp_cc_log_plk", (DL_FUNC) &_pda_rcpp_cc_log_plk, 7}, + {"_pda_rcpp_cc_pool_fun", (DL_FUNC) &_pda_rcpp_cc_pool_fun, 7}, + {"_pda_rcpp_cc_grad_plk", (DL_FUNC) &_pda_rcpp_cc_grad_plk, 7}, + {"_pda_rcpp_cc_hess_plk", (DL_FUNC) &_pda_rcpp_cc_hess_plk, 7}, {NULL, NULL, 0} }; From d40a7c3daad67034f85753426a51f7d23b6b3953 Mon Sep 17 00:00:00 2001 From: Ali Date: Tue, 24 Jun 2025 12:06:50 +0200 Subject: [PATCH 03/10] increase precision to 8 --- R/pda.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pda.R b/R/pda.R index 85b517c..c91c04c 100644 --- a/R/pda.R +++ b/R/pda.R @@ -30,7 +30,7 @@ #' @return NONE #' @seealso \code{pda} #' @export -pdaPut <- function(obj,name,config,upload_without_confirm=F,silent_message=F,digits=4){ +pdaPut <- function(obj,name,config,upload_without_confirm=F,silent_message=F,digits=8){ mymessage <- function(mes, silent=silent_message) if(silent==F) message(mes) obj_Json <- jsonlite::toJSON(obj, digits = digits) # RJSONIO::toJSON(tt) keep vec name? @@ -309,7 +309,7 @@ getCloudConfig <- function(site_id,dir=NULL,uri=NULL,secret=NULL,silent_message= #' @return control #' @export pda <- function(ipdata=NULL,site_id,control=NULL,dir=NULL,uri=NULL,secret=NULL, - upload_without_confirm=F, silent_message=F, digits=4, + upload_without_confirm=F, silent_message=F, digits=8, hosdata=NULL # for dGEM ){ config <- getCloudConfig(site_id,dir,uri,secret,silent_message) From 8ea0a657e8cf60989da4ff72107116dbfef66657 Mon Sep 17 00:00:00 2001 From: Ali Date: Mon, 30 Jun 2025 11:14:25 +0200 Subject: [PATCH 04/10] add survival namespace --- R/ODACH_CC.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/ODACH_CC.R b/R/ODACH_CC.R index 998ba01..9605ce2 100644 --- a/R/ODACH_CC.R +++ b/R/ODACH_CC.R @@ -69,7 +69,7 @@ ODACH_CC.initialize <- function(ipdata,control,config){ ipdata_i$time_in = 0 ipdata_i[ipdata_i$subcohort == 0, "time_in"] <- ipdata_i[ipdata_i$subcohort == 0, "time"] - precision formula_i <- as.formula(paste("Surv(time_in, time, status) ~", paste(control$risk_factor[!col_deg], collapse = "+"), '+ cluster(ID)')) - fit_i <- tryCatch(coxph(formula_i, data=ipdata_i, robust=T), error=function(e) NULL) + fit_i <- tryCatch(survival::blog()coxph(formula_i, data=ipdata_i, robust=T), error=function(e) NULL) if(!is.null(fit_i)){ # for degenerated X, coef=0, var=Inf From 2b9ff3e3b7e84299ec7d1e8498d9207e7602267c Mon Sep 17 00:00:00 2001 From: Ali Date: Mon, 30 Jun 2025 11:21:35 +0200 Subject: [PATCH 05/10] correct error in namespace --- R/ODACH_CC.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/ODACH_CC.R b/R/ODACH_CC.R index 9605ce2..da9eb62 100644 --- a/R/ODACH_CC.R +++ b/R/ODACH_CC.R @@ -69,7 +69,7 @@ ODACH_CC.initialize <- function(ipdata,control,config){ ipdata_i$time_in = 0 ipdata_i[ipdata_i$subcohort == 0, "time_in"] <- ipdata_i[ipdata_i$subcohort == 0, "time"] - precision formula_i <- as.formula(paste("Surv(time_in, time, status) ~", paste(control$risk_factor[!col_deg], collapse = "+"), '+ cluster(ID)')) - fit_i <- tryCatch(survival::blog()coxph(formula_i, data=ipdata_i, robust=T), error=function(e) NULL) + fit_i <- tryCatch(survival::coxph(formula_i, data=ipdata_i, robust=T), error=function(e) NULL) if(!is.null(fit_i)){ # for degenerated X, coef=0, var=Inf From b2c7a114e2da2bbb56d2f48d0f303a71b0719065 Mon Sep 17 00:00:00 2001 From: Ali Date: Tue, 11 Nov 2025 14:00:02 +0100 Subject: [PATCH 06/10] increase results precision to 16 dgits when writing outputs --- R/pda.R | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/R/pda.R b/R/pda.R index c91c04c..b7f83de 100644 --- a/R/pda.R +++ b/R/pda.R @@ -30,18 +30,18 @@ #' @return NONE #' @seealso \code{pda} #' @export -pdaPut <- function(obj,name,config,upload_without_confirm=F,silent_message=F,digits=8){ +pdaPut <- function(obj,name,config,upload_without_confirm=F,silent_message=F,digits=16){ mymessage <- function(mes, silent=silent_message) if(silent==F) message(mes) obj_Json <- jsonlite::toJSON(obj, digits = digits) # RJSONIO::toJSON(tt) keep vec name? file_name <- paste0(name, '.json') - if(!is.null(config$uri)){ - mymessage(paste("Put",file_name,"on public cloud:")) - }else{ - mymessage(paste("Put",file_name,"on local directory", config$dir, ':')) - } - mymessage(obj_Json) + # if(!is.null(config$uri)){ + # mymessage(paste("Put",file_name,"on public cloud:")) + # }else{ + # mymessage(paste("Put",file_name,"on local directory", config$dir, ':')) + # } + # mymessage(obj_Json) # if(interactive()) { if(upload_without_confirm==F) { @@ -151,7 +151,7 @@ getCloudConfig <- function(site_id,dir=NULL,uri=NULL,secret=NULL,silent_message= } else if (pda_uri!='') { config$uri = pda_uri } else{ - mymessage('no cloud uri found! ') + # mymessage('no cloud uri found! ') } if(!is.null(dir)) { @@ -309,13 +309,13 @@ getCloudConfig <- function(site_id,dir=NULL,uri=NULL,secret=NULL,silent_message= #' @return control #' @export pda <- function(ipdata=NULL,site_id,control=NULL,dir=NULL,uri=NULL,secret=NULL, - upload_without_confirm=F, silent_message=F, digits=8, + upload_without_confirm=F, silent_message=F, digits=16, hosdata=NULL # for dGEM - ){ + ){ config <- getCloudConfig(site_id,dir,uri,secret,silent_message) mymessage <- function(mes, silent=silent_message) if(silent==F) message(mes) files <- pdaList(config) - mymessage('You are performing Privacy-preserving Distributed Algorithm (PDA, https://github.com/Penncil/pda): ') + # mymessage('You are performing Privacy-preserving Distributed Algorithm (PDA, https://github.com/Penncil/pda): ') mymessage(paste0('your site = ', config$site_id)) # read in control, or lead site add a control file to the cloud if there is none @@ -629,7 +629,7 @@ pda <- function(ipdata=NULL,site_id,control=NULL,dir=NULL,uri=NULL,secret=NULL, #' @return control #' @seealso \code{pda} #' @export -pdaSync <- function(config,upload_without_confirm,silent_message=F, digits=4){ +pdaSync <- function(config,upload_without_confirm,silent_message=F, digits=16){ control = pdaGet('control',config) mymessage <- function(mes, silent=silent_message) if(silent==F) message(mes) @@ -868,7 +868,7 @@ pdaSync <- function(config,upload_without_confirm,silent_message=F, digits=4){ ## estimate for pda init: meta, or median, or lead est?... if(control$init_method == 'meta'){ - binit = apply(bhat/vbhat,2,function(x){sum(x, na.rm = TRUE)})/apply(1/vbhat,2,function(x){sum(x, na.rm = TRUE)}) + binit = apply(as.data.frame(bhat/vbhat),2,function(x){sum(x, na.rm = TRUE)})/apply(as.data.frame(1/vbhat),2,function(x){sum(x, na.rm = TRUE)}) # vinit = 1/apply(1/vbhat,2,function(x){sum(x, na.rm = TRUE)}) mymessage('meta (inv var weighted avg) as initial est:') } else if(control$init_method == 'median'){ From c084a25c4f0c5be406eb895492c66b439a3bb7fe Mon Sep 17 00:00:00 2001 From: Ali Date: Tue, 11 Nov 2025 14:01:00 +0100 Subject: [PATCH 07/10] when Cox fails: init bhat and Vhat with zeros instead of NAs --- R/ODACH_CC.R | 42 +++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/R/ODACH_CC.R b/R/ODACH_CC.R index da9eb62..f5120f2 100644 --- a/R/ODACH_CC.R +++ b/R/ODACH_CC.R @@ -70,14 +70,22 @@ ODACH_CC.initialize <- function(ipdata,control,config){ ipdata_i[ipdata_i$subcohort == 0, "time_in"] <- ipdata_i[ipdata_i$subcohort == 0, "time"] - precision formula_i <- as.formula(paste("Surv(time_in, time, status) ~", paste(control$risk_factor[!col_deg], collapse = "+"), '+ cluster(ID)')) fit_i <- tryCatch(survival::coxph(formula_i, data=ipdata_i, robust=T), error=function(e) NULL) + # if (sum(ipdata_i$time<=ipdata_i$time_in) > 0) { + # print(ipdata_i[time<=time_in, c("ID", "time", "time_in", "status")]) + # } + # print(fit_i) + # print(formula_i) + # w <- warnings() + # print(w) + # stop() if(!is.null(fit_i)){ # for degenerated X, coef=0, var=Inf bhat_i = rep(0,px) Vhat_i = rep(Inf,px) bhat_i[!col_deg] <- fit_i$coef Vhat_i[!col_deg] <- summary(fit_i)$coef[,"se(coef)"]^2 # dont's use robust var diag(fit_i$var) - + # Vhat_i[Vhat_i == 0] <- Inf # for caases when survival::coxph() returns coef=NA and se(coef)=0, which is not handled by summary(fit_i)$coef[,2]^2 init <- list(bhat_i = bhat_i, Vhat_i = Vhat_i, site = config$site_id, @@ -87,15 +95,14 @@ ODACH_CC.initialize <- function(ipdata,control,config){ # init$Vhat_i[init$Vhat_i==0] = NA # 20250106 } else{ warning('survival::coxph() failed!!!') - init <- list(bhat_i = NA, - Vhat_i = NA, + init <- list(bhat_i = rep(0,px), + Vhat_i = rep(Inf,px), S_i = NA, site = config$site_id, site_size = nrow(ipdata), full_cohort_size = full_cohort_size, method = control$method) } - return(init) } @@ -121,15 +128,28 @@ ODACH_CC.derive <- function(ipdata,control,config){ ipdata_i = ipdata[,-(which(col_deg)+3),with=F] ipdata_i$ID = 1:nrow(ipdata_i) # for running coxph/cch... precision <- min(diff(sort(ipdata_i$time))) / 2 # + # print("=======") + # print(precision) + # if (precision < 10^(-6)){ + # ipdata_i[["time"]] <- ipdata_i[["time"]] * 100 + # precision <- min(diff(sort(ipdata_i$time))) / 2 # + # print(precision) + # } ipdata_i$time_in = 0 ipdata_i[ipdata_i$subcohort == 0, "time_in"] <- ipdata_i[ipdata_i$subcohort == 0, "time"] - precision - + + # filter_<-ipdata_i[["time_in"]]>=ipdata_i[["time"]] + # print(str(filter_)) + # print(ipdata_i[filter_,c("time_in", "time")]) + # print(min(ipdata_i[["time"]]-ipdata_i[["time_in"]])) ## grad and hess bbar = control$beta_init full_cohort_size = control$full_cohort_size[control$sites==config$site_id] cc_prep = prepare_case_cohort(list(ipdata), control$method, full_cohort_size) # logL_D1 <- grad_plk(bbar, cc_prep) # logL_D2 <- hess_plk(bbar, cc_prep) + # if(config$site_id == "site11") print(bbar) + # if(config$site_id == "site11") print(cc_prep) logL_D1 <- rcpp_cc_grad_plk(beta = bbar, site_num = 1, covariate_list = cc_prep$covariate_list, failure_position = cc_prep$failure_position, @@ -142,11 +162,19 @@ ODACH_CC.derive <- function(ipdata,control,config){ failure_num = cc_prep$failure_num, risk_sets = cc_prep$risk_sets, risk_set_weights = cc_prep$risk_set_weights) - + # if(config$site_id == "site11") stop("Me after rcpp") ## get intermediate (sandwich meat) for robust variance est of ODACH_CC # fit_i <- tryCatch(coxph(formula_i, data=ipdata_i, robust=T), error=function(e) NULL) formula_i <- as.formula(paste("Surv(time_in, time, status) ~", paste(control$risk_factor[!col_deg], collapse = "+"), '+ cluster(ID)')) - fit_i <- tryCatch(coxph(formula_i, data=ipdata_i, robust=T, init=bbar[!col_deg], iter=0), error=function(e) NULL) # 20250326: init/iter trick + # print(formula_i) + # print(bbar[!col_deg]) + # fit_i <- survival::coxph(formula_i, data=ipdata_i, robust=T) + # print(fit_i) + # fit_i <- survival::coxph(formula_i, data=ipdata_i, robust=T, init=bbar[!col_deg]) + # print(fit_i) + fit_i <- tryCatch(survival::coxph(formula_i, data=ipdata_i, robust=T, init=bbar[!col_deg], iter=0), error=function(e) NULL) # 20250326: init/iter trick: coxPH will only calcualte the loglikelihood function for the input init point without iterating (iter == survival::coxph.control(max.iter=0)). Since the default value for `init` is avector of zeros, this will return a vector of zeros unless `init`` is initialised. + # print(fit_i) + score_resid <- resid(fit_i, type = "score") # n x p matrix S_i = matrix(0, px, px) # this is the meat in sandwich var S_i[!col_deg, !col_deg] <- crossprod(score_resid) From e698aba821dfe7a6d10c1f41f9bd15bba1cea57b Mon Sep 17 00:00:00 2001 From: Ali Date: Tue, 11 Nov 2025 18:16:30 +0100 Subject: [PATCH 08/10] add support for left truncated survival data --- R/ODACH_CC.R | 379 +++++++++++++++++++++++++++--------------------- R/case_cohort.R | 199 ++++++++++++++----------- R/pda.R | 21 ++- 3 files changed, 343 insertions(+), 256 deletions(-) diff --git a/R/ODACH_CC.R b/R/ODACH_CC.R index f5120f2..ce75ae5 100644 --- a/R/ODACH_CC.R +++ b/R/ODACH_CC.R @@ -1,13 +1,13 @@ # Copyright 2020 Penn Computing Inference Learning (PennCIL) lab # https://penncil.med.upenn.edu/team/ # This file is part of pda -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -25,9 +25,9 @@ # step = 'initialize', # sites = sites, # heterogeneity = T, -# model = 'ODACH_CC', -# method='Prentice', # -# full_cohort_size=NA, # +# model = 'ODACH_CC', +# method='Prentice', # +# full_cohort_size=NA, # # family = 'cox', # outcome = "status", # variables = c('age', 'sex'), @@ -37,247 +37,297 @@ #' @useDynLib pda #' @title ODACH_CC initialize -#' +#' #' @usage ODACH_CC.initialize(ipdata, control, config) #' @param ipdata individual participant data #' @param control pda control data #' @param config local site configuration -#' +#' #' @references Chongliang Luo, et al. "ODACH: A One-shot Distributed Algorithm for Cox model with Heterogeneous Multi-center Data". #' medRxiv, 2021, https://doi.org/10.1101/2021.04.18.21255694 #' @return list(bhat_i = fit_i$coef, Vhat_i = summary(fit_i)$coef[,2]^2, site=control$mysite, site_size= nrow(ipdata)) #' @keywords internal -ODACH_CC.initialize <- function(ipdata,control,config){ +ODACH_CC.initialize <- function(ipdata, control, config) { # coxph with case-cohort design - full_cohort_size = control$full_cohort_size[control$sites==config$site_id] - px = ncol(ipdata) - 3 - + full_cohort_size <- control$full_cohort_size[control$sites == config$site_id] + call_obj <- str2lang(control$outcome) + surv_cols <- length(as.list(call_obj)) - 1L # -1L for the function name + n_skip_cols <- surv_cols + 1L # +1L for the subcohort column + px <- ncol(ipdata) - n_skip_cols + # handle data degeneration (e.g. missing categories in some site). This could be in pda()? - col_deg = apply(ipdata[,-c(1:3)],2,var)==0 # degenerated X columns... - ipdata_i = ipdata[,-(which(col_deg)+3),with=F] - ipdata_i$ID = 1:nrow(ipdata_i) # for running coxph/cch... - - # formula_i <- as.formula(paste("Surv(time, status) ~", paste(control$risk_factor[!col_deg], collapse = "+"))) - # fit_i <- tryCatch(survival::cch(formula_i, data = ipdata_i, subcoh = ~subcohort, id = ~ID, - # cohort.size = full_cohort_size, method = control$method), error=function(e) NULL) + if (surv_cols == 3) { + col_deg <- apply(ipdata[, -c(1:4)], 2, var) == 0 # degenerated X columns... + } else if (surv_cols == 2) { + col_deg <- apply(ipdata[, -c(1:3)], 2, var) == 0 # degenerated X columns... + } + + ipdata_i <- ipdata[, -(which(col_deg) + n_skip_cols), with = F] + ipdata_i$ID <- 1:nrow(ipdata_i) # for running coxph/cch... + # formula_i <- as.formula(paste("Surv(time, status) ~", paste(control$risk_factor[!col_deg], collapse = "+"))) + # fit_i <- tryCatch(survival::cch(formula_i, data = ipdata_i, subcoh = ~subcohort, id = ~ID, + # cohort.size = full_cohort_size, method = control$method), error=function(e) NULL) + ## 3 ways to do local est: cch, coxph with a tweak of the formula, and cch_pooled # to avoid numerical error using cch() indicated by Ali, we use coxph with a tweak of the formula... # generally cch, coxph and cch_pooled will generate almost identical b and close var (for continuous X, coxph has smaller S.E. than the other two) # but coxph only works for Prentice wt, so will look into it later (and may revert to cch_pooled...) - precision <- min(diff(sort(ipdata_i$time))) / 2 # - ipdata_i$time_in = 0 - ipdata_i[ipdata_i$subcohort == 0, "time_in"] <- ipdata_i[ipdata_i$subcohort == 0, "time"] - precision - formula_i <- as.formula(paste("Surv(time_in, time, status) ~", paste(control$risk_factor[!col_deg], collapse = "+"), '+ cluster(ID)')) - fit_i <- tryCatch(survival::coxph(formula_i, data=ipdata_i, robust=T), error=function(e) NULL) + if (surv_cols == 3) { + times <- sort(unique(c(ipdata$time_in, ipdata$time))) + precision <- min(diff(times)) / 2 + ipdata_i[ipdata_i$subcohort == 0, "time_in"] <- ipdata_i[ipdata_i$subcohort == 0, "time"] - precision + } else if (surv_cols == 2) { + precision <- min(diff(sort(ipdata_i$time))) / 2 # + ipdata_i$time_in <- 0 + ipdata_i[ipdata_i$subcohort == 0, "time_in"] <- ipdata_i[ipdata_i$subcohort == 0, "time"] - precision + } + + formula_i <- as.formula(paste("Surv(time_in, time, status) ~", paste(control$risk_factor[!col_deg], collapse = "+"), "+ cluster(ID)")) + fit_i <- tryCatch(survival::coxph(formula_i, data = ipdata_i, robust = T), error = function(e) NULL) # if (sum(ipdata_i$time<=ipdata_i$time_in) > 0) { # print(ipdata_i[time<=time_in, c("ID", "time", "time_in", "status")]) # } - + # print(fit_i) # print(formula_i) # w <- warnings() # print(w) # stop() - if(!is.null(fit_i)){ + if (!is.null(fit_i)) { # for degenerated X, coef=0, var=Inf - bhat_i = rep(0,px) - Vhat_i = rep(Inf,px) + bhat_i <- rep(0, px) + Vhat_i <- rep(Inf, px) bhat_i[!col_deg] <- fit_i$coef - Vhat_i[!col_deg] <- summary(fit_i)$coef[,"se(coef)"]^2 # dont's use robust var diag(fit_i$var) + Vhat_i[!col_deg] <- summary(fit_i)$coef[, "se(coef)"]^2 # dont's use robust var diag(fit_i$var) # Vhat_i[Vhat_i == 0] <- Inf # for caases when survival::coxph() returns coef=NA and se(coef)=0, which is not handled by summary(fit_i)$coef[,2]^2 - init <- list(bhat_i = bhat_i, - Vhat_i = Vhat_i, - site = config$site_id, - site_size = nrow(ipdata), - full_cohort_size = full_cohort_size, - method = control$method) + init <- list( + bhat_i = bhat_i, + Vhat_i = Vhat_i, + site = config$site_id, + site_size = nrow(ipdata), + full_cohort_size = full_cohort_size, + method = control$method + ) # init$Vhat_i[init$Vhat_i==0] = NA # 20250106 - } else{ - warning('survival::coxph() failed!!!') - init <- list(bhat_i = rep(0,px), - Vhat_i = rep(Inf,px), - S_i = NA, - site = config$site_id, - site_size = nrow(ipdata), - full_cohort_size = full_cohort_size, - method = control$method) + } else { + warning("survival::coxph() failed!!!") + init <- list( + bhat_i = rep(0, px), + Vhat_i = rep(Inf, px), + S_i = NA, + site = config$site_id, + site_size = nrow(ipdata), + full_cohort_size = full_cohort_size, + method = control$method + ) } return(init) } - + #' @useDynLib pda #' @title Generate pda derivatives -#' +#' #' @usage ODACH_CC.derive(ipdata, control, config) #' @param ipdata individual participant data #' @param control pda control data #' @param config local site configuration -#' -#' @details Calculate and broadcast 1st and 2nd order derivative at initial bbar #' -#' @import Rcpp +#' @details Calculate and broadcast 1st and 2nd order derivative at initial bbar +#' +#' @import Rcpp #' @return list(bbar=bbar, site=control$mysite, site_size = nrow(ipdata), logL_D1=logL_D1, logL_D2=logL_D2) #' @keywords internal -ODACH_CC.derive <- function(ipdata,control,config){ - px <- ncol(ipdata) - 3 - +ODACH_CC.derive <- function(ipdata, control, config) { + call_obj <- str2lang(control$outcome) + surv_cols <- length(as.list(call_obj)) - 1L # -1L for the function name + n_skip_cols <- surv_cols + 1L # +1L for the subcohort column + px <- ncol(ipdata) - n_skip_cols + # handle data degeneration (e.g. missing categories in some site). This could be in pda()? - col_deg = apply(ipdata[,-c(1:3)],2,var)==0 # degenerated X columns... - ipdata_i = ipdata[,-(which(col_deg)+3),with=F] - ipdata_i$ID = 1:nrow(ipdata_i) # for running coxph/cch... - precision <- min(diff(sort(ipdata_i$time))) / 2 # - # print("=======") - # print(precision) - # if (precision < 10^(-6)){ - # ipdata_i[["time"]] <- ipdata_i[["time"]] * 100 - # precision <- min(diff(sort(ipdata_i$time))) / 2 # - # print(precision) - # } - ipdata_i$time_in = 0 - ipdata_i[ipdata_i$subcohort == 0, "time_in"] <- ipdata_i[ipdata_i$subcohort == 0, "time"] - precision + if (surv_cols == 3) { + col_deg <- apply(ipdata[, -c(1:4)], 2, var) == 0 # degenerated X columns... + } else if (surv_cols == 2) { + col_deg <- apply(ipdata[, -c(1:3)], 2, var) == 0 # degenerated X columns... + } + - # filter_<-ipdata_i[["time_in"]]>=ipdata_i[["time"]] - # print(str(filter_)) - # print(ipdata_i[filter_,c("time_in", "time")]) - # print(min(ipdata_i[["time"]]-ipdata_i[["time_in"]])) + ipdata_i <- ipdata[, -(which(col_deg) + n_skip_cols), with = F] + ipdata_i$ID <- 1:nrow(ipdata_i) # for running coxph/cch... + if (surv_cols == 3) { + times <- sort(unique(c(ipdata$time_in, ipdata$time))) + precision <- min(diff(times)) / 2 + ipdata_i[ipdata_i$subcohort == 0, "time_in"] <- ipdata_i[ipdata_i$subcohort == 0, "time"] - precision + } else if (surv_cols == 2) { + precision <- min(diff(sort(ipdata_i$time))) / 2 # + ipdata_i$time_in <- 0 + ipdata_i[ipdata_i$subcohort == 0, "time_in"] <- ipdata_i[ipdata_i$subcohort == 0, "time"] - precision + } + ## grad and hess - bbar = control$beta_init - full_cohort_size = control$full_cohort_size[control$sites==config$site_id] - cc_prep = prepare_case_cohort(list(ipdata), control$method, full_cohort_size) + bbar <- control$beta_init + full_cohort_size <- control$full_cohort_size[control$sites == config$site_id] + cc_prep <- prepare_case_cohort(list(ipdata), control, full_cohort_size) + # logL_D1 <- grad_plk(bbar, cc_prep) # logL_D2 <- hess_plk(bbar, cc_prep) # if(config$site_id == "site11") print(bbar) # if(config$site_id == "site11") print(cc_prep) - logL_D1 <- rcpp_cc_grad_plk(beta = bbar, site_num = 1, - covariate_list = cc_prep$covariate_list, - failure_position = cc_prep$failure_position, - failure_num = cc_prep$failure_num, - risk_sets = cc_prep$risk_sets, - risk_set_weights = cc_prep$risk_set_weights) - logL_D2 <- rcpp_cc_hess_plk(beta = bbar, site_num = 1, - covariate_list = cc_prep$covariate_list, - failure_position = cc_prep$failure_position, - failure_num = cc_prep$failure_num, - risk_sets = cc_prep$risk_sets, - risk_set_weights = cc_prep$risk_set_weights) + logL_D1 <- rcpp_cc_grad_plk( + beta = bbar, site_num = 1, + covariate_list = cc_prep$covariate_list, + failure_position = cc_prep$failure_position, + failure_num = cc_prep$failure_num, + risk_sets = cc_prep$risk_sets, + risk_set_weights = cc_prep$risk_set_weights + ) + logL_D2 <- rcpp_cc_hess_plk( + beta = bbar, site_num = 1, + covariate_list = cc_prep$covariate_list, + failure_position = cc_prep$failure_position, + failure_num = cc_prep$failure_num, + risk_sets = cc_prep$risk_sets, + risk_set_weights = cc_prep$risk_set_weights + ) # if(config$site_id == "site11") stop("Me after rcpp") - ## get intermediate (sandwich meat) for robust variance est of ODACH_CC - # fit_i <- tryCatch(coxph(formula_i, data=ipdata_i, robust=T), error=function(e) NULL) - formula_i <- as.formula(paste("Surv(time_in, time, status) ~", paste(control$risk_factor[!col_deg], collapse = "+"), '+ cluster(ID)')) + ## get intermediate (sandwich meat) for robust variance est of ODACH_CC + # fit_i <- tryCatch(coxph(formula_i, data=ipdata_i, robust=T), error=function(e) NULL) + formula_i <- as.formula(paste("Surv(time_in, time, status) ~", paste(control$risk_factor[!col_deg], collapse = "+"), "+ cluster(ID)")) # print(formula_i) # print(bbar[!col_deg]) # fit_i <- survival::coxph(formula_i, data=ipdata_i, robust=T) # print(fit_i) # fit_i <- survival::coxph(formula_i, data=ipdata_i, robust=T, init=bbar[!col_deg]) # print(fit_i) - fit_i <- tryCatch(survival::coxph(formula_i, data=ipdata_i, robust=T, init=bbar[!col_deg], iter=0), error=function(e) NULL) # 20250326: init/iter trick: coxPH will only calcualte the loglikelihood function for the input init point without iterating (iter == survival::coxph.control(max.iter=0)). Since the default value for `init` is avector of zeros, this will return a vector of zeros unless `init`` is initialised. + fit_i <- tryCatch(survival::coxph(formula_i, data = ipdata_i, robust = T, init = bbar[!col_deg], iter = 0), error = function(e) NULL) # 20250326: init/iter trick: coxPH will only calcualte the loglikelihood function for the input init point without iterating (iter == survival::coxph.control(max.iter=0)). Since the default value for `init` is avector of zeros, this will return a vector of zeros unless `init`` is initialised. # print(fit_i) - - score_resid <- resid(fit_i, type = "score") # n x p matrix - S_i = matrix(0, px, px) # this is the meat in sandwich var + + score_resid <- resid(fit_i, type = "score") # n x p matrix + S_i <- matrix(0, px, px) # this is the meat in sandwich var S_i[!col_deg, !col_deg] <- crossprod(score_resid) # S_i[!col_deg, !col_deg] <- logL_D2[!col_deg, !col_deg] %*% fit_i$var %*% logL_D2[!col_deg, !col_deg] # Skhat in Yudong's note... - - derivatives <- list(bbar=bbar, - site=config$site_id, site_size = nrow(ipdata), - full_cohort_size=full_cohort_size, - logL_D1=logL_D1, logL_D2=logL_D2, - S_i = S_i) - + + derivatives <- list( + bbar = bbar, + site = config$site_id, site_size = nrow(ipdata), + full_cohort_size = full_cohort_size, + logL_D1 = logL_D1, logL_D2 = logL_D2, + S_i = S_i + ) + return(derivatives) } #' @useDynLib pda #' @title PDA surrogate estimation -#' +#' #' @usage ODACH_CC.estimate(ipdata, control, config) #' @param ipdata local data in data frame #' @param control pda control #' @param config cloud config #' @import data.table -#' +#' #' @details step-4: construct and solve surrogate logL at the master/lead site -#' @import Rcpp +#' @import Rcpp #' @return list(btilde = sol$par, Htilde = sol$hessian, site=control$mysite, site_size=nrow(ipdata)) #' @keywords internal -ODACH_CC.estimate <- function(ipdata,control,config) { +ODACH_CC.estimate <- function(ipdata, control, config) { # data sanity check ... # time <- ipdata$time # status <- ipdata$status # X <- as.matrix(ipdata[,-c(1:3)]) # n <- length(time) # px <- ncol(X) - px <- ncol(ipdata) - 3 + call_obj <- str2lang(control$outcome) + surv_cols <- length(as.list(call_obj)) - 1L # -1L for the function name + n_skip_cols <- surv_cols + 1L # +1L for the subcohort column + px <- ncol(ipdata) - n_skip_cols + # hasTies <- any(duplicated(ipdata$time)) - + # download derivatives of other sites from the cloud - # calculate 2nd order approx of the total logL + # calculate 2nd order approx of the total logL logL_all_D1 <- rep(0, px) logL_all_D2 <- matrix(0, px, px) N <- 0 - for(site_i in control$sites){ - derivatives_i <- pdaGet(paste0(site_i,'_derive'),config) + for (site_i in control$sites) { + derivatives_i <- pdaGet(paste0(site_i, "_derive"), config) logL_all_D1 <- logL_all_D1 + derivatives_i$logL_D1 logL_all_D2 <- logL_all_D2 + matrix(unlist(derivatives_i$logL_D2), px, px) N <- N + derivatives_i$site_size } - + # initial beta # bbar <- derivatives_i$b_meta bbar <- control$beta_init - full_cohort_size = control$full_cohort_size[control$sites==config$site_id] - cc_prep = prepare_case_cohort(list(ipdata), control$method, full_cohort_size) - + full_cohort_size <- control$full_cohort_size[control$sites == config$site_id] + cc_prep <- prepare_case_cohort(list(ipdata), control, full_cohort_size) + # logL at local site # logL_local <- function(beta) log_plk(beta, cc_prep) # logL_local_D1 <- function(beta) grad_plk(beta, cc_prep) # logL_local_D2 <- function(beta) hess_plk(beta, cc_prep) - logL_local <- function(beta) rcpp_cc_log_plk(beta, site_num = 1, - covariate_list = cc_prep$covariate_list, - failure_position = cc_prep$failure_position, - failure_num = cc_prep$failure_num, - risk_sets = cc_prep$risk_sets, - risk_set_weights = cc_prep$risk_set_weights) - logL_local_D1 <- function(beta) rcpp_cc_grad_plk(beta, site_num = 1, - covariate_list = cc_prep$covariate_list, - failure_position = cc_prep$failure_position, - failure_num = cc_prep$failure_num, - risk_sets = cc_prep$risk_sets, - risk_set_weights = cc_prep$risk_set_weights) - logL_local_D2 <- function(beta) rcpp_cc_hess_plk(beta, site_num = 1, - covariate_list = cc_prep$covariate_list, - failure_position = cc_prep$failure_position, - failure_num = cc_prep$failure_num, - risk_sets = cc_prep$risk_sets, - risk_set_weights = cc_prep$risk_set_weights) - + logL_local <- function(beta) { + rcpp_cc_log_plk(beta, + site_num = 1, + covariate_list = cc_prep$covariate_list, + failure_position = cc_prep$failure_position, + failure_num = cc_prep$failure_num, + risk_sets = cc_prep$risk_sets, + risk_set_weights = cc_prep$risk_set_weights + ) + } + logL_local_D1 <- function(beta) { + rcpp_cc_grad_plk(beta, + site_num = 1, + covariate_list = cc_prep$covariate_list, + failure_position = cc_prep$failure_position, + failure_num = cc_prep$failure_num, + risk_sets = cc_prep$risk_sets, + risk_set_weights = cc_prep$risk_set_weights + ) + } + logL_local_D2 <- function(beta) { + rcpp_cc_hess_plk(beta, + site_num = 1, + covariate_list = cc_prep$covariate_list, + failure_position = cc_prep$failure_position, + failure_num = cc_prep$failure_num, + risk_sets = cc_prep$risk_sets, + risk_set_weights = cc_prep$risk_set_weights + ) + } + # surrogate log-L and its gradient - logL_diff_D1 <- logL_all_D1 - logL_local_D1(bbar) # / N / n - logL_diff_D2 <- logL_all_D2 - logL_local_D2(bbar) # / N / n - logL_tilde <- function(b) -(logL_local(b) + sum(b * logL_diff_D1) + 1/2 * t(b-bbar) %*% logL_diff_D2 %*% (b-bbar)) # / n - # logL_tilde_D1 <- function(b) -(logL_local_D1(b) / n + logL_diff_D1 + logL_diff_D2 %*% (b-bbar)) - - # optimize the surrogate logL - sol <- optim(par = bbar, - fn = logL_tilde, - # gr = logL_tilde_D1, - hessian = TRUE, - method = control$optim_method, - control = list(maxit=control$optim_maxit)) - + logL_diff_D1 <- logL_all_D1 - logL_local_D1(bbar) # / N / n + logL_diff_D2 <- logL_all_D2 - logL_local_D2(bbar) # / N / n + logL_tilde <- function(b) -(logL_local(b) + sum(b * logL_diff_D1) + 1 / 2 * t(b - bbar) %*% logL_diff_D2 %*% (b - bbar)) # / n + # logL_tilde_D1 <- function(b) -(logL_local_D1(b) / n + logL_diff_D1 + logL_diff_D2 %*% (b-bbar)) + + # optimize the surrogate logL + sol <- optim( + par = bbar, + fn = logL_tilde, + # gr = logL_tilde_D1, + hessian = TRUE, + method = control$optim_method, + control = list(maxit = control$optim_maxit) + ) + # robust var estimate: see Yudong's note # setilde = sqrt(diag(solve(sol$hessian))/N) # hess of surrogate log-L at btilde, this is slightly diff than Yudong's, to avoid another iteration... - logL_tilde_D2 = logL_local_D2(bbar) + logL_diff_D2 + logL_tilde_D2 <- logL_local_D2(bbar) + logL_diff_D2 # put together - Stilde = solve(logL_tilde_D2) %*% control$S_i_sum %*% solve(logL_tilde_D2) - setilde = sqrt(diag(Stilde)) - - surr <- list(bbar=bbar, full_cohort_size=full_cohort_size, - btilde = sol$par, setilde=setilde, Htilde = sol$hessian, site=config$site_id, site_size=nrow(ipdata)) + Stilde <- solve(logL_tilde_D2) %*% control$S_i_sum %*% solve(logL_tilde_D2) + setilde <- sqrt(diag(Stilde)) + + surr <- list( + bbar = bbar, full_cohort_size = full_cohort_size, + btilde = sol$par, setilde = setilde, Htilde = sol$hessian, site = config$site_id, site_size = nrow(ipdata) + ) return(surr) } @@ -285,33 +335,38 @@ ODACH_CC.estimate <- function(ipdata,control,config) { #' @useDynLib pda #' @title PDA synthesize surrogate estimates from all sites, optional -#' +#' #' @usage ODACH_CC.synthesize(ipdata, control, config) #' @param ipdata local data in data frame #' @param control pda control #' @param config cloud config -#' +#' #' @details Optional step-4: synthesize all the surrogate est btilde_i from each site, if step-3 from all sites is broadcasted -#' @import Rcpp +#' @import Rcpp #' @return list(btilde=btilde, Vtilde=Vtilde) #' @keywords internal -ODACH_CC.synthesize <- function(ipdata,control,config) { - px <- length(control$risk_factor) +ODACH_CC.synthesize <- function(ipdata, control, config) { + call_obj <- str2lang(control$outcome) + surv_cols <- length(as.list(call_obj)) - 1L # -1L for the function name + n_skip_cols <- surv_cols + 1L # +1L for the subcohort column + px <- ncol(ipdata) - n_skip_cols K <- length(control$sites) btilde_wt_sum <- rep(0, px) - wt_sum <- rep(0, px) # cov matrix? - - for(site_i in control$sites){ - surr_i <- pdaGet(paste0(site_i,'_estimate'),config) + wt_sum <- rep(0, px) # cov matrix? + + for (site_i in control$sites) { + surr_i <- pdaGet(paste0(site_i, "_estimate"), config) btilde_wt_sum <- btilde_wt_sum + surr_i$Htilde %*% surr_i$btilde wt_sum <- wt_sum + surr_i$Htilde } - + # inv-Var weighted average est, and final Var = average Var-tilde btilde <- solve(wt_sum, btilde_wt_sum) Vtilde <- solve(wt_sum) * K - + message("all surrogate estimates synthesized, no need to broadcast! ") - return(list(btilde=btilde, - Vtilde=Vtilde)) + return(list( + btilde = btilde, + Vtilde = Vtilde + )) } diff --git a/R/case_cohort.R b/R/case_cohort.R index ad5616a..0c510de 100644 --- a/R/case_cohort.R +++ b/R/case_cohort.R @@ -1,13 +1,16 @@ - ## prepare calculation for case-cohort design at ONE site # the purpose of this function is to "pre-calculate" the weight before calculating the log-likelihood # this would accelerate the subsequent calculation of log-likelihood # currently, we only provide Prentice weight; more options will be provided later ## this is Yudong's weight_CC() in functions_CC_1.R, can take multi-site data, or single-site as a list of length 1 ## data_list contains list of ipdata, with columns: time, status, subcohort, and covariates -prepare_case_cohort <- function(data_list, method, full_cohort_size){ +prepare_case_cohort <- function(data_list, control, full_cohort_size) { # for each site, pre-calculate the failure time points, the risk sets, and the respective weights # also, remove those sites with zero events + call_obj <- str2lang(control$outcome) + surv_cols <- length(as.list(call_obj)) - 1L # -1L for the function name + + method <- control$method site_to_remove <- c() K <- length(full_cohort_size) failure_num <- rep(NA, K) @@ -16,31 +19,42 @@ prepare_case_cohort <- function(data_list, method, full_cohort_size){ risk_set_weights <- as.list(rep(NA, K)) covariate_list <- as.list(rep(NA, K)) failure_position <- as.list(rep(NA, K)) - for(k in 1:K){ + for (k in 1:K) { # prepare a list for covariates in matrix format so as to speed up computation of log partial likelihood, gradient, and hessian - covariate_list[[k]] <- as.matrix(data_list[[k]][, -c(1:3)]) + if (surv_cols == 3) { + covariate_list[[k]] <- as.matrix(data_list[[k]][, -c(1:4)]) + } else if (surv_cols == 2) { + covariate_list[[k]] <- as.matrix(data_list[[k]][, -c(1:3)]) + } # find over which position lies the failure times failure_position[[k]] <- which(data_list[[k]]$status == 1) # find failure times failure_times[[k]] <- data_list[[k]]$time[which(data_list[[k]]$status == 1)] # the number of failures failure_num[k] <- length(failure_times[[k]]) - - if(failure_num[k] == 0){ + + if (failure_num[k] == 0) { site_to_remove <- c(site_to_remove, k) - }else{ + } else { risk_size <- 0 temp_risk <- as.list(rep(NA, failure_num[k])) temp_weight <- as.list(rep(NA, failure_num[k])) - for(j in 1:failure_num[k]){ - my_risk_set1 <- which((data_list[[k]]$subcohort == 1) & (data_list[[k]]$time >= failure_times[[k]][j])) + for (j in 1:failure_num[k]) { + my_risk_set1 <- which( + (data_list[[k]]$subcohort == 1) & + (data_list[[k]]$time >= failure_times[[k]][j]) & + (data_list[[k]]$time_in <= failure_times[[k]][j]) + ) risk_size <- risk_size + length(my_risk_set1) - if(method == "Prentice"){ + if (method == "Prentice") { my_weight1 <- rep(1, length(my_risk_set1)) - if(data_list[[k]]$subcohort[which(data_list[[k]]$time == failure_times[[k]][j])] == 0){ + if (data_list[[k]]$subcohort[ + which(data_list[[k]]$time == failure_times[[k]][j]) + ] == 0 + ) { my_risk_set2 <- which(data_list[[k]]$time == failure_times[[k]][j]) my_weight2 <- 1 - }else{ + } else { my_risk_set2 <- c() my_weight2 <- c() } @@ -50,13 +64,13 @@ prepare_case_cohort <- function(data_list, method, full_cohort_size){ } risk_sets[[k]] <- temp_risk risk_set_weights[[k]] <- temp_weight - if(risk_size == 0){ + if (risk_size == 0) { site_to_remove <- c(site_to_remove, k) } } } - - if(length(site_to_remove) > 0){ + + if (length(site_to_remove) > 0) { data_list <- data_list[-site_to_remove] full_cohort_size <- full_cohort_size[-site_to_remove] failure_num <- failure_num[-site_to_remove] @@ -67,15 +81,16 @@ prepare_case_cohort <- function(data_list, method, full_cohort_size){ covariate_list <- covariate_list[-site_to_remove] K <- K - length(site_to_remove) } - - return(list(# data_list = data_list, - full_cohort_size = full_cohort_size, - covariate_list = covariate_list, - failure_position = failure_position, - failure_num = failure_num, - risk_sets = risk_sets, - risk_set_weights = risk_set_weights, - site_num=K)) + + return(list( # data_list = data_list, + full_cohort_size = full_cohort_size, + covariate_list = covariate_list, + failure_position = failure_position, + failure_num = failure_num, + risk_sets = risk_sets, + risk_set_weights = risk_set_weights, + site_num = K + )) } ## below only take input single-site ipdata... @@ -90,13 +105,13 @@ prepare_case_cohort <- function(data_list, method, full_cohort_size){ # # find failure times # failure_times <- ipdata$time[which(ipdata$status == 1)] # # the number of failures -# failure_num <- length(failure_times) -# +# failure_num <- length(failure_times) +# # risk_sets <- as.list(rep(NA, failure_num)) # risk_set_weights <- as.list(rep(NA, failure_num)) -# +# # # if have any events -# if(failure_num > 0) { +# if(failure_num > 0) { # for(j in 1:failure_num){ # my_risk_set1 <- which((ipdata$subcohort == 1) & (ipdata$time >= failure_times[j])) # # risk_size <- risk_size + length(my_risk_set1) @@ -114,7 +129,7 @@ prepare_case_cohort <- function(data_list, method, full_cohort_size){ # risk_set_weights[[j]] <- c(my_weight1, my_weight2) # } # } -# +# # return(list(full_cohort_size = full_cohort_size, # covariate = covariate, # failure_position = failure_position, @@ -122,7 +137,7 @@ prepare_case_cohort <- function(data_list, method, full_cohort_size){ # risk_sets = risk_sets, # risk_set_weights = risk_set_weights )) # } - + # this function calculate the log pseudo-likelihood for ONE site # cc_prep is the output of prepare_case_cohort() @@ -130,7 +145,7 @@ log_plk <- function(beta, cc_prep, site_num) { eta <- cc_prep$covariate_list[[site_num]] %*% beta exp_eta <- exp(eta) res <- sum(eta[cc_prep$failure_position[[site_num]]]) - + for (j in 1:cc_prep$failure_num[site_num]) { idx <- cc_prep$risk_sets[[site_num]][[j]] weights <- cc_prep$risk_set_weights[[site_num]][[j]] @@ -145,15 +160,15 @@ grad_plk <- function(beta, cc_prep, site_num) { X <- cc_prep$covariate_list[[site_num]] eta <- X %*% beta exp_eta <- exp(eta) - + grad <- colSums(X[cc_prep$failure_position[[site_num]], , drop = FALSE]) - + for (j in 1:cc_prep$failure_num[site_num]) { idx <- cc_prep$risk_sets[[site_num]][[j]] weights <- cc_prep$risk_set_weights[[site_num]][[j]] temp_w <- exp_eta[idx] * weights denom <- sum(temp_w) - weighted_X <- sweep(X[idx, , drop = FALSE], 1, temp_w, '*') + weighted_X <- sweep(X[idx, , drop = FALSE], 1, temp_w, "*") grad <- grad - colSums(weighted_X) / denom } return(grad) @@ -169,19 +184,19 @@ hess_plk <- function(beta, cc_prep, site_num) { exp_eta <- exp(eta) d <- ncol(X) H <- matrix(0, d, d) - + for (j in 1:cc_prep$failure_num[site_num]) { idx <- cc_prep$risk_sets[[site_num]][[j]] weights <- cc_prep$risk_set_weights[[site_num]][[j]] temp_w <- exp_eta[idx] * weights denom <- sum(temp_w) - + X_sub <- X[idx, , drop = FALSE] - weighted_X <- sweep(X_sub, 1, temp_w, '*') + weighted_X <- sweep(X_sub, 1, temp_w, "*") mean_vec <- colSums(weighted_X) - - sqrt_wX <- sweep(X_sub, 1, sqrt(temp_w), '*') - + + sqrt_wX <- sweep(X_sub, 1, sqrt(temp_w), "*") + H <- H + (tcrossprod(mean_vec) / (denom^2)) - (crossprod(sqrt_wX) / denom) } return(H) @@ -193,72 +208,80 @@ hess_plk <- function(beta, cc_prep, site_num) { # notice this assumes varying baseline hazard functions across sites # cc_prep is the output of prepare_case_cohort() #' @export -cch_pooled <- function(formula, data, subcoh='subcohort', site='site', variables_lev, +cch_pooled <- function(formula, data, subcoh = "subcohort", site = "site", variables_lev, full_cohort_size, method = "Prentice", optim_method = "BFGS", - var_sandwich=T){ - n = nrow(data) - site_uniq = unique(data[,site]) - mf <- model.frame(formula, data, xlev=variables_lev) - - ipdata = data.table::data.table(site=data[,site], - time=as.numeric(model.response(mf))[1:n], - status=as.numeric(model.response(mf))[-c(1:n)], - subcohort = data[,subcoh], - model.matrix(formula, mf)[,-1]) - ipdata = data.table(data.frame(ipdata)) - risk_factor = colnames(ipdata)[-c(1:4)] - + var_sandwich = T) { + n <- nrow(data) + site_uniq <- unique(data[, site]) + mf <- model.frame(formula, data, xlev = variables_lev) + + ipdata <- data.table::data.table( + site = data[, site], + time = as.numeric(model.response(mf))[1:n], + status = as.numeric(model.response(mf))[-c(1:n)], + subcohort = data[, subcoh], + model.matrix(formula, mf)[, -1] + ) + ipdata <- data.table(data.frame(ipdata)) + risk_factor <- colnames(ipdata)[-c(1:4)] + # notice here we allow data degeneration (e.g. missing categories in some site) - px = ncol(ipdata)-4 - initial_beta = rep(0, px) - names(initial_beta) = names(ipdata)[-c(1:4)] + px <- ncol(ipdata) - 4 + initial_beta <- rep(0, px) + names(initial_beta) <- names(ipdata)[-c(1:4)] # pool_fun <- function(beta) sum(sapply(site_uniq, function(site_id) # log_plk(beta, prepare_case_cohort(ipdata[site==site_id,-'site'], method, full_cohort_size[site_id])))) - - data_split <- split(ipdata, by=site, keep.by=F) - cc_prep = prepare_case_cohort(data_split, method, full_cohort_size) - K = cc_prep$site_num - pool_fun <- function(beta) { - sum(vapply(1:K, function(i) rcpp_cc_log_plk(beta, site_num = i, - covariate_list = cc_prep$covariate_list, - failure_position = cc_prep$failure_position, - failure_num = cc_prep$failure_num, - risk_sets = cc_prep$risk_sets, - risk_set_weights = cc_prep$risk_set_weights), numeric(1))) + + data_split <- split(ipdata, by = site, keep.by = F) + cc_prep <- prepare_case_cohort(data_split, method, full_cohort_size) + K <- cc_prep$site_num + pool_fun <- function(beta) { + sum(vapply(1:K, function(i) { + rcpp_cc_log_plk(beta, + site_num = i, + covariate_list = cc_prep$covariate_list, + failure_position = cc_prep$failure_position, + failure_num = cc_prep$failure_num, + risk_sets = cc_prep$risk_sets, + risk_set_weights = cc_prep$risk_set_weights + ) + }, numeric(1))) } - - result <- optim(par = initial_beta, fn = pool_fun, - control = list(fnscale = -1), method = optim_method, hessian = T) - b_pooled = result$par - + + result <- optim( + par = initial_beta, fn = pool_fun, + control = list(fnscale = -1), method = optim_method, hessian = T + ) + b_pooled <- result$par + # calculate sandwich var estimate, degenerated data columns are given 0 coefs - if(var_sandwich==T){ + if (var_sandwich == T) { block1 <- result$hessian block2 <- NULL data_split <- split(ipdata, ipdata$site) - - for(i in 1:length(site_uniq)){ + + for (i in 1:length(site_uniq)) { site_id <- site_uniq[i] - ipdata_i = data_split[[i]] - col_deg = apply(ipdata_i[,-c(1:4)],2,var)==0 # degenerated X columns... - ipdata_i = ipdata_i[,-(which(col_deg)+4),with=F] + ipdata_i <- data_split[[i]] + col_deg <- apply(ipdata_i[, -c(1:4)], 2, var) == 0 # degenerated X columns... + ipdata_i <- ipdata_i[, -(which(col_deg) + 4), with = F] # use coxph(Surv(time_in, time, status)~.) to do cch... precision <- min(diff(sort(ipdata_i$time))) / 2 # - ipdata_i$time_in = 0 + ipdata_i$time_in <- 0 ipdata_i[ipdata_i$subcohort == 0, "time_in"] <- ipdata_i[ipdata_i$subcohort == 0, "time"] - precision - - formula_i <- as.formula(paste("Surv(time_in, time, status) ~", paste(risk_factor[!col_deg], collapse = "+"), '+ cluster(ID)')) - cch_i <- tryCatch(coxph(formula_i, data=cbind(ID=1:nrow(ipdata_i), ipdata_i), init=b_pooled[!col_deg], iter=0), error=function(e) NULL) - score_resid <- resid(cch_i, type = "score") # n x p matrix - S_i = matrix(0, px, px) # this is the meat in sandwich var + + formula_i <- as.formula(paste("Surv(time_in, time, status) ~", paste(risk_factor[!col_deg], collapse = "+"), "+ cluster(ID)")) + cch_i <- tryCatch(coxph(formula_i, data = cbind(ID = 1:nrow(ipdata_i), ipdata_i), init = b_pooled[!col_deg], iter = 0), error = function(e) NULL) + score_resid <- resid(cch_i, type = "score") # n x p matrix + S_i <- matrix(0, px, px) # this is the meat in sandwich var S_i[!col_deg, !col_deg] <- crossprod(score_resid) - + block2[[i]] <- S_i } - + var <- solve(block1) %*% Reduce("+", block2) %*% solve(block1) result$var <- var # this is the output for variance estimates } - + return(result) } diff --git a/R/pda.R b/R/pda.R index b7f83de..3c1cab1 100644 --- a/R/pda.R +++ b/R/pda.R @@ -534,15 +534,24 @@ pda <- function(ipdata=NULL,site_id,control=NULL,dir=NULL,uri=NULL,secret=NULL, } } else if(control$model=='ODACH_CC'){ if (!is.null(ipdata)){ - ipdata = data.table::data.table(time=as.numeric(model.response(mf))[1:n], - status=as.numeric(model.response(mf))[-c(1:n)], - subcohort = ipdata$subcohort, - # sampling_weight = ipdata$sampling_weight, - model.matrix(formula, mf)[,-1]) + if(ncol(model.response(mf))==3){ + ipdata = data.table::data.table(time_in=as.numeric(model.response(mf)[,1]), + time=as.numeric(model.response(mf)[,2]), + status=as.numeric(model.response(mf)[,3]), + subcohort = ipdata$subcohort, + # sampling_weight = ipdata$sampling_weight, + model.matrix(formula, mf)[,-1]) + } else if(ncol(model.response(mf))==2){ + ipdata = data.table::data.table(time=as.numeric(model.response(mf))[1:n], + status=as.numeric(model.response(mf))[-c(1:n)], + subcohort = ipdata$subcohort, + # sampling_weight = ipdata$sampling_weight, + model.matrix(formula, mf)[,-1]) + } # convert irregular risk factor names, e.g. `Group (A,B,C) B` to Group..A.B.C..B # this should (and will) apply to all other models... ipdata = data.table(data.frame(ipdata)) - control$risk_factor = colnames(ipdata)[-c(1:3)] + control$risk_factor = colnames(ipdata)[-c(1:(ncol(model.response(mf))+1))] } } From 558023a29513f8bd7ac37c1a8550303f39705eb5 Mon Sep 17 00:00:00 2001 From: Ali Date: Sat, 22 Nov 2025 11:33:00 +0100 Subject: [PATCH 09/10] remove comments --- R/ODACH_CC.R | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/R/ODACH_CC.R b/R/ODACH_CC.R index ce75ae5..d0d492d 100644 --- a/R/ODACH_CC.R +++ b/R/ODACH_CC.R @@ -85,15 +85,7 @@ ODACH_CC.initialize <- function(ipdata, control, config) { formula_i <- as.formula(paste("Surv(time_in, time, status) ~", paste(control$risk_factor[!col_deg], collapse = "+"), "+ cluster(ID)")) fit_i <- tryCatch(survival::coxph(formula_i, data = ipdata_i, robust = T), error = function(e) NULL) - # if (sum(ipdata_i$time<=ipdata_i$time_in) > 0) { - # print(ipdata_i[time<=time_in, c("ID", "time", "time_in", "status")]) - # } - - # print(fit_i) - # print(formula_i) - # w <- warnings() - # print(w) - # stop() + if (!is.null(fit_i)) { # for degenerated X, coef=0, var=Inf bhat_i <- rep(0, px) From da976b86ba2ff669f080b707a6759dad29fd868d Mon Sep 17 00:00:00 2001 From: Ali Date: Mon, 15 Dec 2025 11:58:30 +0100 Subject: [PATCH 10/10] remove merge headers --- R/case_cohort.R | 10 ---------- R/pda.R | 4 ---- 2 files changed, 14 deletions(-) diff --git a/R/case_cohort.R b/R/case_cohort.R index 47a313a..0fcb439 100644 --- a/R/case_cohort.R +++ b/R/case_cohort.R @@ -4,12 +4,7 @@ # currently, we only provide Prentice weight; more options will be provided later ## this is Yudong's weight_CC() in functions_CC_1.R, can take multi-site data, or single-site as a list of length 1 ## data_list contains list of ipdata, with columns: time, status, subcohort, and covariates -<<<<<<< HEAD -#' @keywords internal -prepare_case_cohort <- function(data_list, method, full_cohort_size){ -======= prepare_case_cohort <- function(data_list, control, full_cohort_size) { ->>>>>>> left_truncation # for each site, pre-calculate the failure time points, the risk sets, and the respective weights # also, remove those sites with zero events call_obj <- str2lang(control$outcome) @@ -215,13 +210,8 @@ hess_plk <- function(beta, cc_prep, site_num) { # this function fits Cox PH to case-cohort (survival::cch) with the pooled multi-site data # notice this assumes varying baseline hazard functions across sites # cc_prep is the output of prepare_case_cohort() -<<<<<<< HEAD #' @keywords internal cch_pooled <- function(formula, data, subcoh='subcohort', site='site', variables_lev, -======= -#' @export -cch_pooled <- function(formula, data, subcoh = "subcohort", site = "site", variables_lev, ->>>>>>> left_truncation full_cohort_size, method = "Prentice", optim_method = "BFGS", var_sandwich = T) { n <- nrow(data) diff --git a/R/pda.R b/R/pda.R index c34d46f..897fb96 100644 --- a/R/pda.R +++ b/R/pda.R @@ -380,11 +380,7 @@ pdaCatalog <- function(task=c('Regression', pda <- function(ipdata=NULL,site_id,control=NULL,dir=NULL,uri=NULL,secret=NULL, upload_without_confirm=F, silent_message=F, digits=16, hosdata=NULL # for dGEM -<<<<<<< HEAD ){ -======= - ){ ->>>>>>> left_truncation config <- getCloudConfig(site_id,dir,uri,secret,silent_message) mymessage <- function(mes, silent=silent_message) if(silent==F) message(mes) files <- pdaList(config)