-
Notifications
You must be signed in to change notification settings - Fork 20
Init VariableImportance class #129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
09ec7a8
11d5b8b
5ddfaf6
bbfb669
a27242f
8356110
217d354
a95b882
cbe284c
c96bf11
ebfc9b9
4d06765
d702f60
bc7e7da
7fd626a
ee49826
7879e6a
68c6f6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,82 @@ | |
| # https://rconsortium.github.io/S7/articles/classes-objects.html?q=computed#computed-properties | ||
| # https://utf8-icons.com/ | ||
|
|
||
| # %% VariableImportance ---- | ||
| #' @title VariableImportance | ||
| #' | ||
| #' @description | ||
| #' Class for variable importance objects. Allows for one or more variable importance measures, | ||
| #' stored in a data.table with columns "variable", and at least one | ||
| #' more column with a descriptive name. | ||
| #' | ||
| #' @author EDG | ||
| #' @noRd | ||
| VariableImportance <- new_class( | ||
| name = "VariableImportance", | ||
| properties = list( | ||
| data = class_data.table | ||
| ), | ||
| validator = function(self) { | ||
| # Must include at least two columns | ||
| if (NCOL(self@data) < 2L) { | ||
| cli::cli_abort( | ||
| "Variable importance data must include at least two columns: 'variable' and at least one importance measure." | ||
| ) | ||
| } | ||
| # Must include column "variable" of type character | ||
| if (!"variable" %in% names(self@data)) { | ||
| cli::cli_abort( | ||
| "Variable importance data must include a 'variable' column." | ||
| ) | ||
| } | ||
| if (!is.character(self@data[["variable"]])) { | ||
| cli::cli_abort("Column 'variable' must be of type character.") | ||
| } | ||
| # All other columns must be numeric | ||
| other_cols <- setdiff(names(self@data), "variable") | ||
| if (!all(self@data[, sapply(.SD, is.numeric), .SDcols = other_cols])) { | ||
| cli::cli_abort( | ||
| "All columns other than 'variable' must be numeric." | ||
| ) | ||
| } | ||
| # Number of rows will be checked by Supervised to be at least as many as | ||
| # the number of predictors. | ||
| } | ||
| ) # /rtemis::VariableImportance | ||
|
|
||
|
|
||
| # %% repr.VariableImportance ---- | ||
| method(repr, VariableImportance) <- function(x, pad = 0L, output_type = NULL) { | ||
| output_type <- get_output_type(output_type) | ||
| # "N variable importance measures for M predictors" | ||
| n_m <- NCOL(x@data) - 1L | ||
| paste0( | ||
| repr_S7name("VariableImportance", pad = pad, output_type = output_type), | ||
| strrep(" ", pad), | ||
| fmt(n_m, col = highlight_col, bold = TRUE, output_type = output_type), | ||
| ngettext( | ||
| n_m, | ||
| " variable importance measure for ", | ||
| " variable importance measures for " | ||
| ), | ||
| fmt( | ||
| NROW(x@data), | ||
| col = highlight_col, | ||
| bold = TRUE, | ||
| output_type = output_type | ||
| ), | ||
| ngettext(NROW(x@data), " predictor", " predictors") | ||
| ) | ||
| } # /rtemis::repr.VariableImportance | ||
|
|
||
|
|
||
| # %% print.VariableImportance ---- | ||
| method(print, VariableImportance) <- function(x, output_type = NULL, ...) { | ||
| cat(repr(x, output_type = output_type), "\n") | ||
| invisible(x) | ||
| } # /rtemis::print.VariableImportance | ||
|
|
||
|
|
||
| # Plot methods | ||
| # Supervised: plot_varimp | ||
| # SupervisedRes: plot_varimp, plot_metric | ||
|
|
@@ -45,7 +121,7 @@ Supervised <- new_class( | |
| metrics_validation = Metrics | NULL, | ||
| metrics_test = Metrics | NULL, | ||
| xnames = class_character, | ||
| varimp = class_any, | ||
| varimp = VariableImportance | NULL, | ||
| question = class_character | NULL, | ||
| extra = class_any, | ||
| session_info = class_any | ||
|
|
@@ -159,7 +235,7 @@ method(predict, Supervised) <- function(object, newdata, verbosity = 1L, ...) { | |
| model = object@model, | ||
| newdata = newdata, | ||
| type = object@type, | ||
| ... | ||
| verbosity = verbosity | ||
| ) | ||
| } # /rtemis::predict.Supervised | ||
|
|
||
|
|
@@ -1123,7 +1199,7 @@ SupervisedRes <- new_class( | |
| metrics_training = MetricsRes, | ||
| metrics_test = MetricsRes, | ||
| xnames = class_character, | ||
| varimp = class_any, | ||
| varimp = class_list | NULL, | ||
| question = class_character | NULL, | ||
| extra = class_any, | ||
| session_info = class_any | ||
|
|
@@ -1946,6 +2022,7 @@ method(plot_metric, SupervisedRes) <- function( | |
| # %% plot_varimp.Supervised ---- | ||
| method(plot_varimp, Supervised) <- function( | ||
| x, | ||
| measure = NULL, | ||
| theme = choose_theme(getOption("rtemis_theme")), | ||
| filename = NULL, | ||
| ... | ||
|
|
@@ -1954,13 +2031,20 @@ method(plot_varimp, Supervised) <- function( | |
| msg(highlight2("No variable importance available.")) | ||
| return(invisible()) | ||
| } | ||
| draw_varimp(x@varimp, theme = theme, filename = filename, ...) | ||
| if (is.null(measure)) { | ||
| vi <- x@varimp@data[[2L]] | ||
| } else { | ||
| vi <- x@varimp@data[[measure]] | ||
| } | ||
| names(vi) <- x@varimp@data[["variable"]] | ||
| draw_varimp(vi, theme = theme, filename = filename, ...) | ||
| } # /rtemis::plot_varimp.Supervised | ||
|
|
||
|
|
||
| # %% plot_varimp.SupervisedRes ---- | ||
| method(plot_varimp, SupervisedRes) <- function( | ||
| x, | ||
| measure = NULL, | ||
| ylab = NULL, | ||
| summarize_fn = "mean", | ||
| show_top = 20L, | ||
|
|
@@ -1974,34 +2058,33 @@ method(plot_varimp, SupervisedRes) <- function( | |
| } | ||
| check_inherits(summarize_fn, "character") | ||
|
|
||
| # ! Variable importance may be returned in different order in each resample ! | ||
| # Order varimp vectors by variable names | ||
| # First, check each varimp vector is named | ||
| if (!all(sapply(x@varimp, function(z) !is.null(names(z))))) { | ||
| cli::cli_abort("Variable importance elements must be named vectors.") | ||
| } | ||
| # Not every variable gets a variable importance score necessarily | ||
| # Each varimp vector as a one row data.table in order to rbindlist them, filling in NAs as needed. | ||
| # x@varimp[[i]] may be named vector or data.frame | ||
| varimp_dt <- lapply(x@varimp, function(z) { | ||
| as.data.table(as.list(z), keep.rownames = TRUE) | ||
| # Extract named numeric vectors from each VariableImportance object. | ||
| # Not every variable gets a score in every resample, so rbindlist with fill. | ||
| varimp_list <- lapply(x@varimp, function(z) { | ||
| vi <- if (is.null(measure)) z@data[[2L]] else z@data[[measure]] | ||
| names(vi) <- z@data[["variable"]] | ||
| as.data.table(as.list(vi)) | ||
| }) | ||
|
|
||
| varimp <- rbindlist(varimp_dt, use.names = TRUE, fill = TRUE) | ||
| # Convert NA values to 0 | ||
| varimp <- rbindlist(varimp_list, use.names = TRUE, fill = TRUE) | ||
| # Missing scores (variable absent in a resample) treated as 0 | ||
| setDF(varimp) | ||
| varimp[is.na(varimp)] <- 0 | ||
| # Summarize variable importance | ||
| # Summarize and sort | ||
| varimp_summary <- apply(varimp, 2, summarize_fn) | ||
| # Sort columns by descending variable importance | ||
| varimp_sorted <- varimp_summary[order(-varimp_summary)] | ||
| if (length(varimp_sorted) > show_top) { | ||
| varimp_sorted <- varimp_sorted[seq_len(show_top)] | ||
| } | ||
| # ylab | ||
| if (is.null(ylab)) { | ||
| measure_name <- if (is.null(measure)) { | ||
| names(x@varimp[[1L]]@data)[2L] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| } else { | ||
| measure | ||
| } | ||
| ylab <- paste0( | ||
| labelify(paste(summarize_fn, "Variable Importance")), | ||
| labelify(paste(summarize_fn, measure_name)), | ||
| "\n(across ", | ||
| desc(x@outer_resampler), | ||
| ")" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,7 +9,7 @@ | |
| #' A simple `plotly` wrapper to plot horizontal barplots, sorted by value, | ||
| #' which can be used to visualize variable importance, model coefficients, etc. | ||
| #' | ||
| #' @param x Numeric vector: Input. | ||
| #' @param x Numeric vector (or coercible to numeric): Input. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| #' @param names Vector, string: Names of features. | ||
| #' @param main Character: Main title. | ||
| #' @param type Character: "bar" or "line". | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,14 +23,12 @@ method(train_, GAMHyperparameters) <- function( | |
| x, | ||
| weights = NULL, | ||
| dat_validation = NULL, | ||
| execution_config = setup_ExecutionConfig(), | ||
| verbosity = 1L | ||
| ) { | ||
| # Dependencies ---- | ||
| check_dependencies("mgcv") | ||
|
|
||
| # Checks ---- | ||
| check_is_S7(hyperparameters, GAMHyperparameters) | ||
|
|
||
| # Hyperparameters ---- | ||
| # Hyperparameters must be either untunable or frozen by `train`. | ||
| if (needs_tuning(hyperparameters)) { | ||
|
|
@@ -125,7 +123,8 @@ method(train_, GAMHyperparameters) <- function( | |
| method(predict_super, class_gam) <- function( | ||
| model, | ||
| newdata, | ||
| type = NULL | ||
| type = NULL, | ||
| verbosity = 0L | ||
| ) { | ||
| out <- predict(object = model, newdata = newdata, type = "response") | ||
| if (model[["family"]][["family"]] == "binomial") { | ||
|
|
@@ -137,30 +136,34 @@ method(predict_super, class_gam) <- function( | |
|
|
||
|
|
||
| # %% varimp_super.class_gam ---- | ||
| #' Get coefficients from GAM model | ||
| #' Get variable importance from GAM model | ||
| #' | ||
| #' Variable importance for GAM is estimated as the variance of each predictor's partial effect, | ||
| #' obtained via predict(model, type = "terms"). This measures each smooth term's contribution to | ||
| #' the variance of the fitted values. Values are normalized to sum to one, representing each | ||
| #' predictor's proportion of total predicted variance. This approach is computationally efficient | ||
| #' (no refitting required) and analogous to importance measures in tree-based methods. It assumes | ||
| #' approximate uncorrelatedness of partial effects, which penalized smooths tend to satisfy. For | ||
| #' models with high concurvity, consider hierarchical partitioning of R² (e.g. via the gam.hp | ||
| #' package) as an alternative. | ||
| #' | ||
| #' @param model mgcv gam model. | ||
| #' | ||
| #' @keywords internal | ||
| #' @noRd | ||
| method(varimp_super, class_gam) <- function( | ||
| model, | ||
| type = c("p-value", "edf", "coefficients") | ||
| type = c("partial_effect", "F-test") | ||
| ) { | ||
| type <- match.arg(type) | ||
| if (type == "p-value") { | ||
| # Get parametric and smooth term p-values | ||
| summary_ <- summary(model) | ||
| # Exclude intercept | ||
| -log10(c( | ||
| summary_[["s.table"]][, "p-value"], | ||
| summary_[["p.table"]][, ncol(summary_[["p.table"]])][-1] | ||
| )) | ||
| } else if (type == "edf") { | ||
| summary(model)[["s.table"]][, "edf"] | ||
| } else if (type == "coefficients") { | ||
| coef(model) | ||
| } | ||
| peff <- predict(model, type = "terms") | ||
| vi <- apply(peff, 2, var) | ||
| npeff <- vi / sum(vi) # normalized importance | ||
| VariableImportance( | ||
| data.table( | ||
| variable = names(npeff), | ||
| Partial_Effect_Variance = unname(npeff) | ||
| ) | ||
| ) | ||
|
Comment on lines
+161
to
+166
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The column name for the importance measure is hardcoded to "Coefficient". However, this method supports multiple types of importance measures including "p-value" and "edf". Using the .vi_dt <- data.table(variable = names(.coef), importance = unname(.coef))
setnames(.vi_dt, "importance", type)
VariableImportance(.vi_dt) |
||
| } # /rtemis::varimp_super.gam | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR title suggests initializing a
VariableImportanceclass, but the changes here are primarily cleanup (removing redundantcheck_is_S7()calls), small documentation/comment tweaks, and a version/date bump. Please either update the PR title/description to match the actual changes or include the missingVariableImportanceimplementation changes in this PR.