-
Notifications
You must be signed in to change notification settings - Fork 45
Update add_interaction() function, ranger wrapper, and xgboost wrapper #436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
92b52c4
98cc84f
99e4430
5380db4
9959dfa
39dac15
eb69493
e953d7a
454d2df
71769fb
56c72a9
f5935b9
3df1c23
0dcf613
97a1a7f
29b3364
e7b29fe
e89db12
4deae51
97e76c7
74c53ef
1f7073b
6c85ecb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -80,121 +80,123 @@ Lrnr_xgboost <- R6Class( | |||||||||
| "offset", "importance" | ||||||||||
| ), | ||||||||||
| .train = function(task) { | ||||||||||
| # Safe helper for %||% | ||||||||||
| `%||%` <- function(a, b) if (!is.null(a)) a else b | ||||||||||
|
|
||||||||||
| args <- self$params | ||||||||||
|
|
||||||||||
|
|
||||||||||
| # verbosity | ||||||||||
| verbose <- args$verbose | ||||||||||
| if (is.null(verbose)) { | ||||||||||
| verbose <- getOption("sl3.verbose") | ||||||||||
| } | ||||||||||
| if (is.null(verbose)) verbose <- getOption("sl3.verbose") | ||||||||||
| args$verbose <- as.integer(verbose) | ||||||||||
|
|
||||||||||
| # set up outcome | ||||||||||
| # outcome | ||||||||||
| outcome_type <- self$get_outcome_type(task) | ||||||||||
| Y <- outcome_type$format(task$Y) | ||||||||||
| if (outcome_type$type == "categorical") { | ||||||||||
| Y <- as.numeric(Y) - 1 | ||||||||||
| if (outcome_type$type == "categorical") Y <- as.numeric(Y) - 1L | ||||||||||
|
|
||||||||||
| # raw covariates, keep factors intact | ||||||||||
| Xdf <- task$get_data(columns = task$nodes$covariates, expand_factors = FALSE) | ||||||||||
|
|
||||||||||
| # (optional but recommended) explicit feature types | ||||||||||
| feat_types <- vapply(Xdf, function(z) { | ||||||||||
| if (is.factor(z)) "c" else if (is.integer(z)) "int" | ||||||||||
| else if (is.logical(z)) "i" else "float" | ||||||||||
| }, character(1)) | ||||||||||
|
Comment on lines
+102
to
+105
|
||||||||||
|
|
||||||||||
| # DMatrix | ||||||||||
| dtrain <- try(xgboost::xgb.DMatrix( | ||||||||||
| data = Xdf, label = Y, | ||||||||||
| feature_names = colnames(Xdf), | ||||||||||
| feature_types = feat_types | ||||||||||
| ), silent = TRUE) | ||||||||||
|
|
||||||||||
| if (!inherits(dtrain, "xgb.DMatrix")) { | ||||||||||
| cls <- vapply(Xdf, function(z) paste(class(z), collapse=","), character(1)) | ||||||||||
|
||||||||||
| cls <- vapply(Xdf, function(z) paste(class(z), collapse=","), character(1)) | |
| cls <- vapply(Xdf, function(z) paste(class(z), collapse = ","), character(1)) |
Copilot
AI
Dec 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter extraction from args is problematic. The code extracts nrounds and params separately from args, but this assumes users will pass a nested params argument. However, based on the documentation and initialize method, users pass parameters directly (e.g., nrounds=20, nthread=1, ...). The old code used call_with_args which handled this properly. The new approach should extract nrounds from args$nrounds, but other xgboost parameters should be collected into params from args (excluding nrounds, verbose, and other sl3-specific parameters).
| params <- if (!is.null(args$params)) args$params else list() | |
| # Collect xgboost params from args, excluding sl3-specific ones | |
| sl3_specific <- c("nrounds", "verbose", "params") | |
| params <- args[setdiff(names(args), sl3_specific)] |
Copilot
AI
Dec 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The custom class "sl3_xgb_fit" is added to the fit_object wrapper, but there's no documentation or explanation of why this custom class is needed or how it should be used. If this is meant to be an internal implementation detail, consider documenting it. If external code might need to handle this class, consider adding S3 methods or documentation.
Copilot
AI
Dec 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The importance() method expects the fit object to be an xgb.Booster directly (line 69), but the new code returns a custom list wrapper with the booster nested inside. This will break the importance() method. The args$model should be set to fit_object$booster instead of fit_object.
Copilot
AI
Dec 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment says "relevel to training levels" but no actual releveling is performed. If factor levels in the prediction data differ from training data, this could cause issues with xgboost's categorical feature handling. Consider adding logic to ensure factor levels match those used during training, or update the comment to reflect what the code actually does.
Copilot
AI
Dec 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The prediction DMatrix is constructed without feature_names or feature_types, unlike the training DMatrix. This inconsistency could lead to issues if xgboost expects the same metadata during prediction. Consider adding feature_names and feature_types to ensure consistency with training.
Copilot
AI
Dec 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space after the comma in the parameter definition. Should be "strict_shape = TRUE" instead of "strict_shape=TRUE" to follow R coding conventions.
| predictions <- stats::predict(booster, newdata = xgb_data, strict_shape=TRUE) | |
| predictions <- stats::predict(booster, newdata = xgb_data, strict_shape = TRUE) |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -205,8 +205,23 @@ sl3_Task <- R6Class( | |||||||||||||||||||||
| } else { | ||||||||||||||||||||||
| # match interaction terms to X | ||||||||||||||||||||||
| Xmatch <- lapply(int, function(i) { | ||||||||||||||||||||||
| grep(i, colnames(self$X), value = TRUE) | ||||||||||||||||||||||
| }) | ||||||||||||||||||||||
| cols <- colnames(self$X) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # detect if 'i' is represented by factor dummies in the design | ||||||||||||||||||||||
| has_factor_dummies <- any(startsWith(cols, paste0(i, "."))) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if (has_factor_dummies) { | ||||||||||||||||||||||
| # prefix match for factor dummy columns, anchored | ||||||||||||||||||||||
| grep(paste0("^", i, "\\."), colnames(self$X), value = TRUE) | ||||||||||||||||||||||
|
Comment on lines
+211
to
+215
|
||||||||||||||||||||||
| has_factor_dummies <- any(startsWith(cols, paste0(i, "."))) | |
| if (has_factor_dummies) { | |
| # prefix match for factor dummy columns, anchored | |
| grep(paste0("^", i, "\\."), colnames(self$X), value = TRUE) | |
| pattern <- paste0("^", i, "\\.") | |
| has_factor_dummies <- any(grepl(pattern, cols)) | |
| if (has_factor_dummies) { | |
| grep(pattern, colnames(self$X), value = TRUE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The existing tests in test-xgboost.R compare predictions with the native xgboost library using as.matrix(task$X). However, with expand_factors=FALSE, the new code works with raw data frames containing factors. This will cause the existing tests to fail because the test comparisons still use the matrix-based approach while the wrapper now uses data frames with factors.