Update add_interaction() function, ranger wrapper, and xgboost wrapper#436
Update add_interaction() function, ranger wrapper, and xgboost wrapper#436JesseZhou-1 wants to merge 22 commits intotlverse:masterfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR makes three improvements to sl3: (1) fixes the add_interaction() function to prevent incorrect interaction term creation when variable names partially overlap, (2) updates the ranger wrapper to support factor-splitting with newer versions of ranger, and (3) substantially refactors the xgboost wrapper to support direct factor-splitting (requires xgboost development version).
Key changes:
add_interaction()now uses exact/anchored matching instead of partial grep matching- Ranger and xgboost wrappers now use
expand_factors=FALSEto preserve raw factor variables - xgboost wrapper refactored with new DMatrix construction, feature type specification, and fit object structure
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 10 comments.
| File | Description |
|---|---|
| R/sl3_Task.R | Fixed add_interaction() to use exact matching for column names and prefix matching for factor dummies, preventing unintended interactions from partial name matches |
| R/Lrnr_ranger.R | Updated to preserve raw factors using expand_factors=FALSE in both training and prediction, enabling native factor-splitting in newer ranger versions |
| R/Lrnr_xgboost.R | Major refactor to support categorical features: new DMatrix construction with feature types, simplified parameter handling, and new fit object wrapper structure (contains critical bugs that need fixing) |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| # ----- xgboost arguments: use params + evals ----- | ||
| nrounds <- if (!is.null(args$nrounds)) args$nrounds else 20L | ||
| params <- if (!is.null(args$params)) args$params else list() |
There was a problem hiding this comment.
The parameter extraction from args is problematic. The code extracts nrounds and params separately from args, but this assumes users will pass a nested params argument. However, based on the documentation and initialize method, users pass parameters directly (e.g., nrounds=20, nthread=1, ...). The old code used call_with_args which handled this properly. The new approach should extract nrounds from args$nrounds, but other xgboost parameters should be collected into params from args (excluding nrounds, verbose, and other sl3-specific parameters).
| params <- if (!is.null(args$params)) args$params else list() | |
| # Collect xgboost params from args, excluding sl3-specific ones | |
| sl3_specific <- c("nrounds", "verbose", "params") | |
| params <- args[setdiff(names(args), sl3_specific)] |
| # raw covariates, keep factors intact | ||
| Xdf <- task$get_data(columns = task$nodes$covariates, expand_factors = FALSE) |
There was a problem hiding this comment.
The existing tests in test-xgboost.R compare predictions with the native xgboost library using as.matrix(task$X). However, with expand_factors=FALSE, the new code works with raw data frames containing factors. This will cause the existing tests to fail because the test comparisons still use the matrix-based approach while the wrapper now uses data frames with factors.
| predictions <- pack_predictions(predictions) | ||
| } | ||
|
|
||
| predictions <- stats::predict(booster, newdata = xgb_data, strict_shape=TRUE) |
There was a problem hiding this comment.
Missing space after the comma in the parameter definition. Should be "strict_shape = TRUE" instead of "strict_shape=TRUE" to follow R coding conventions.
| predictions <- stats::predict(booster, newdata = xgb_data, strict_shape=TRUE) | |
| predictions <- stats::predict(booster, newdata = xgb_data, strict_shape = TRUE) |
| link_fun = link_fun | ||
| ) | ||
| ) | ||
| class(fit_object) <- c("sl3_xgb_fit", "list") |
There was a problem hiding this comment.
The custom class "sl3_xgb_fit" is added to the fit_object wrapper, but there's no documentation or explanation of why this custom class is needed or how it should be used. If this is meant to be an internal implementation detail, consider documenting it. If external code might need to handle this class, consider adding S3 methods or documentation.
| fit_object <- list( | ||
| booster = fit_booster, | ||
| meta = list( | ||
| training_offset = task$has_node("offset"), | ||
| link_fun = link_fun | ||
| ) | ||
| ) | ||
| class(fit_object) <- c("sl3_xgb_fit", "list") |
There was a problem hiding this comment.
The importance() method expects the fit object to be an xgb.Booster directly (line 69), but the new code returns a custom list wrapper with the booster nested inside. This will break the importance() method. The args$model should be set to fit_object$booster instead of fit_object.
| ), silent = TRUE) | ||
|
|
||
| if (!inherits(dtrain, "xgb.DMatrix")) { | ||
| cls <- vapply(Xdf, function(z) paste(class(z), collapse=","), character(1)) |
There was a problem hiding this comment.
Missing space after the comma in the parameter "collapse=",". Should be "collapse = ","" to follow R coding conventions.
| cls <- vapply(Xdf, function(z) paste(class(z), collapse=","), character(1)) | |
| cls <- vapply(Xdf, function(z) paste(class(z), collapse = ","), character(1)) |
| xgb_data <- try(xgboost::xgb.DMatrix(Xdf), silent = TRUE) | ||
| if (!inherits(xgb_data, "xgb.DMatrix")) stop("Failed to build DMatrix for prediction.") |
There was a problem hiding this comment.
The prediction DMatrix is constructed without feature_names or feature_types, unlike the training DMatrix. This inconsistency could lead to issues if xgboost expects the same metadata during prediction. Consider adding feature_names and feature_types to ensure consistency with training.
| has_factor_dummies <- any(startsWith(cols, paste0(i, "."))) | ||
|
|
||
| if (has_factor_dummies) { | ||
| # prefix match for factor dummy columns, anchored | ||
| grep(paste0("^", i, "\\."), colnames(self$X), value = TRUE) |
There was a problem hiding this comment.
The startsWith function uses paste0(i, ".") for detection but then uses a different pattern paste0("^", i, "\.") for matching with grep. For consistency and to avoid edge cases, both should use the same escaped pattern. The startsWith function should be: startsWith(cols, paste0(i, ".")) is correct, but it would be clearer if the pattern was defined once and reused.
| has_factor_dummies <- any(startsWith(cols, paste0(i, "."))) | |
| if (has_factor_dummies) { | |
| # prefix match for factor dummy columns, anchored | |
| grep(paste0("^", i, "\\."), colnames(self$X), value = TRUE) | |
| pattern <- paste0("^", i, "\\.") | |
| has_factor_dummies <- any(grepl(pattern, cols)) | |
| if (has_factor_dummies) { | |
| grep(pattern, colnames(self$X), value = TRUE) |
| # raw covariates; relevel to training levels | ||
| Xdf <- task$get_data(columns = task$nodes$covariates, expand_factors = FALSE) |
There was a problem hiding this comment.
The comment says "relevel to training levels" but no actual releveling is performed. If factor levels in the prediction data differ from training data, this could cause issues with xgboost's categorical feature handling. Consider adding logic to ensure factor levels match those used during training, or update the comment to reflect what the code actually does.
| feat_types <- vapply(Xdf, function(z) { | ||
| if (is.factor(z)) "c" else if (is.integer(z)) "int" | ||
| else if (is.logical(z)) "i" else "float" | ||
| }, character(1)) |
There was a problem hiding this comment.
The feature_types logic is incomplete. When a column is logical, it returns "i", but there's a missing 'else' clause at the end. This will cause an error because the vapply expects a character(1) result for all branches, but if none of the conditions match, no value is returned.
This PR includes three updates aimed at improving usability and preventing bugs:
1. Fix to add_interaction(): avoid incorrect interaction term creation
The original add_interaction() function (used in Lrnr_define_interaction) could accidentally generate unintended interaction terms when variable names partially overlap. For example, if the dataset includes both age and percentage, requesting interactions with age could also unintentionally include percentage, due to the use of grep() for partial matching.
This has been fixed by switching to explicit name matching to avoid such cases.
2. Update to ranger wrapper: support for factor-splitting
The updated ranger wrapper now supports the newer version of ranger that allows direct splitting on factor predictors, avoiding the need for manual one-hot encoding.
3. Update to xgboost wrapper: support for factor-splitting (requires latest version)
This is a more substantial update, reflecting recent changes in xgboost. The wrapper now supports direct splitting on factor variables as well, but requires the user to install the development version from r-universe. Instructions:
Please review the changes and let me know if you have any questions or suggestions.
Best,
Jesse