Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: rtemis
Version: 1.0.0
Version: 1.0.1
Title: Machine Learning and Visualization
Date: 2026-03-14
Date: 2026-04-03
Comment on lines +2 to +4

Copilot AI Apr 5, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR title suggests initializing a VariableImportance class, but the changes here are primarily cleanup (removing redundant check_is_S7() calls), small documentation/comment tweaks, and a version/date bump. Please either update the PR title/description to match the actual changes or include the missing VariableImportance implementation changes in this PR.

Copilot uses AI. Check for mistakes.
Authors@R: person(given = "E.D.", family = "Gennatas", role = c("aut", "cre", "cph"),
email = "gennatas@gmail.com", comment = c(ORCID = "0000-0001-9280-3609"))
Description: Machine learning and visualization package with an 'S7' backend
Expand Down
6 changes: 3 additions & 3 deletions R/00_S7init.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ train_ <- new_generic(
x,
weights = NULL,
dat_validation = NULL,
verbosity = 1L,
...
execution_config = setup_ExecutionConfig(),
verbosity = 1L
) {
S7_dispatch()
}
Expand All @@ -158,7 +158,7 @@ train_ <- new_generic(
predict_super <- new_generic(
"predict_super",
"model",
function(model, newdata, type = NULL, ...) {
function(model, newdata, type = NULL, verbosity = 0L) {
S7_dispatch()
}
) # /rtemis::predict_super
Expand Down
21 changes: 0 additions & 21 deletions R/02_Hyperparameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -2242,27 +2242,6 @@ stopifnot(all(
))


# %% get_ranger_config ----
#' Get Ranger Configuration
#'
#' Get Ranger configuration from RangerHyperparameters object.
#'
#' @param hyperparameters `RangerHyperparameters` object.
#'
#' @return List with Ranger configuration.
#'
#' @author EDG
#'
#' @keywords internal
#' @noRd
get_ranger_config <- function(hyperparameters) {
check_is_S7(hyperparameters, RangerHyperparameters)
hpr <- hyperparameters@hyperparameters
hpr[["ifw"]] <- NULL
hpr
} # /get_ranger_config


# %% list_to_Hyperparameters ----
list_to_Hyperparameters <- function(x) {
fn <- paste0("setup_", x[["algorithm"]])
Expand Down
123 changes: 103 additions & 20 deletions R/07_Supervised.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,82 @@
# https://rconsortium.github.io/S7/articles/classes-objects.html?q=computed#computed-properties
# https://utf8-icons.com/

# %% VariableImportance ----
#' @title VariableImportance
#'
#' @description
#' Class for variable importance objects. Allows for one or more variable importance measures,
#' stored in a data.table with columns "variable", and at least one
#' more column with a descriptive name.
#'
#' @author EDG
#' @noRd
VariableImportance <- new_class(
name = "VariableImportance",
properties = list(
data = class_data.table
),
validator = function(self) {
# Must include at least two columns
if (NCOL(self@data) < 2L) {
cli::cli_abort(
"Variable importance data must include at least two columns: 'variable' and at least one importance measure."
)
}
# Must include column "variable" of type character
if (!"variable" %in% names(self@data)) {
cli::cli_abort(
"Variable importance data must include a 'variable' column."
)
}
if (!is.character(self@data[["variable"]])) {
cli::cli_abort("Column 'variable' must be of type character.")
}
# All other columns must be numeric
other_cols <- setdiff(names(self@data), "variable")
if (!all(self@data[, sapply(.SD, is.numeric), .SDcols = other_cols])) {
cli::cli_abort(
"All columns other than 'variable' must be numeric."
)
}
# Number of rows will be checked by Supervised to be at least as many as
# the number of predictors.
}
) # /rtemis::VariableImportance


# %% repr.VariableImportance ----
method(repr, VariableImportance) <- function(x, pad = 0L, output_type = NULL) {
output_type <- get_output_type(output_type)
# "N variable importance measures for M predictors"
n_m <- NCOL(x@data) - 1L
paste0(
repr_S7name("VariableImportance", pad = pad, output_type = output_type),
strrep(" ", pad),
fmt(n_m, col = highlight_col, bold = TRUE, output_type = output_type),
ngettext(
n_m,
" variable importance measure for ",
" variable importance measures for "
),
fmt(
NROW(x@data),
col = highlight_col,
bold = TRUE,
output_type = output_type
),
ngettext(NROW(x@data), " predictor", " predictors")
)
} # /rtemis::repr.VariableImportance


# %% print.VariableImportance ----
method(print, VariableImportance) <- function(x, output_type = NULL, ...) {
cat(repr(x, output_type = output_type), "\n")
invisible(x)
} # /rtemis::print.VariableImportance


# Plot methods
# Supervised: plot_varimp
# SupervisedRes: plot_varimp, plot_metric
Expand Down Expand Up @@ -45,7 +121,7 @@ Supervised <- new_class(
metrics_validation = Metrics | NULL,
metrics_test = Metrics | NULL,
xnames = class_character,
varimp = class_any,
varimp = VariableImportance | NULL,
question = class_character | NULL,
extra = class_any,
session_info = class_any
Expand Down Expand Up @@ -159,7 +235,7 @@ method(predict, Supervised) <- function(object, newdata, verbosity = 1L, ...) {
model = object@model,
newdata = newdata,
type = object@type,
...
verbosity = verbosity
)
} # /rtemis::predict.Supervised

Expand Down Expand Up @@ -1123,7 +1199,7 @@ SupervisedRes <- new_class(
metrics_training = MetricsRes,
metrics_test = MetricsRes,
xnames = class_character,
varimp = class_any,
varimp = class_list | NULL,
question = class_character | NULL,
extra = class_any,
session_info = class_any
Expand Down Expand Up @@ -1946,6 +2022,7 @@ method(plot_metric, SupervisedRes) <- function(
# %% plot_varimp.Supervised ----
method(plot_varimp, Supervised) <- function(
x,
measure = NULL,
theme = choose_theme(getOption("rtemis_theme")),
filename = NULL,
...
Expand All @@ -1954,13 +2031,20 @@ method(plot_varimp, Supervised) <- function(
msg(highlight2("No variable importance available."))
return(invisible())
}
draw_varimp(x@varimp, theme = theme, filename = filename, ...)
if (is.null(measure)) {
vi <- x@varimp@data[[2L]]
} else {
vi <- x@varimp@data[[measure]]
}
names(vi) <- x@varimp@data[["variable"]]
draw_varimp(vi, theme = theme, filename = filename, ...)
} # /rtemis::plot_varimp.Supervised


# %% plot_varimp.SupervisedRes ----
method(plot_varimp, SupervisedRes) <- function(
x,
measure = NULL,
ylab = NULL,
summarize_fn = "mean",
show_top = 20L,
Expand All @@ -1974,34 +2058,33 @@ method(plot_varimp, SupervisedRes) <- function(
}
check_inherits(summarize_fn, "character")

# ! Variable importance may be returned in different order in each resample !
# Order varimp vectors by variable names
# First, check each varimp vector is named
if (!all(sapply(x@varimp, function(z) !is.null(names(z))))) {
cli::cli_abort("Variable importance elements must be named vectors.")
}
# Not every variable gets a variable importance score necessarily
# Each varimp vector as a one row data.table in order to rbindlist them, filling in NAs as needed.
# x@varimp[[i]] may be named vector or data.frame
varimp_dt <- lapply(x@varimp, function(z) {
as.data.table(as.list(z), keep.rownames = TRUE)
# Extract named numeric vectors from each VariableImportance object.
# Not every variable gets a score in every resample, so rbindlist with fill.
varimp_list <- lapply(x@varimp, function(z) {
vi <- if (is.null(measure)) z@data[[2L]] else z@data[[measure]]
names(vi) <- z@data[["variable"]]
as.data.table(as.list(vi))
})

varimp <- rbindlist(varimp_dt, use.names = TRUE, fill = TRUE)
# Convert NA values to 0
varimp <- rbindlist(varimp_list, use.names = TRUE, fill = TRUE)
# Missing scores (variable absent in a resample) treated as 0
setDF(varimp)
varimp[is.na(varimp)] <- 0
# Summarize variable importance
# Summarize and sort
varimp_summary <- apply(varimp, 2, summarize_fn)
# Sort columns by descending variable importance
varimp_sorted <- varimp_summary[order(-varimp_summary)]
if (length(varimp_sorted) > show_top) {
varimp_sorted <- varimp_sorted[seq_len(show_top)]
}
# ylab
if (is.null(ylab)) {
measure_name <- if (is.null(measure)) {
names(x@varimp[[1L]]@data)[2L]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If x@varimp is an empty list (length 0), accessing x@varimp[[1L]] will result in an error. While the check on line 2055 handles NULL, it does not account for an empty list which is a valid state for a class_list property.

      if (length(x@varimp) > 0) names(x@varimp[[1L]]@data)[2L] else "Importance"

} else {
measure
}
ylab <- paste0(
labelify(paste(summarize_fn, "Variable Importance")),
labelify(paste(summarize_fn, measure_name)),
"\n(across ",
desc(x@outer_resampler),
")"
Expand Down
2 changes: 1 addition & 1 deletion R/draw_dist.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ draw_dist <- function(
),
histfunc = c("count", "sum", "avg", "min", "max"),
hist_n_bins = 20,
barmode = "overlay", # TODO: alternatives
barmode = "overlay", # ?alternatives
ridge_sharex = TRUE,
ridge_y_labs = FALSE,
ridge_order_on_mean = TRUE,
Expand Down
2 changes: 1 addition & 1 deletion R/draw_varimp.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#' A simple `plotly` wrapper to plot horizontal barplots, sorted by value,
#' which can be used to visualize variable importance, model coefficients, etc.
#'
#' @param x Numeric vector: Input.
#' @param x Numeric vector (or coercible to numeric): Input.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pull request title mentions initializing a VariableImportance class, but no such class definition or related logic is present in the diff. This discrepancy should be addressed by either including the missing implementation or updating the PR title to reflect the actual changes.

#' @param names Vector, string: Names of features.
#' @param main Character: Main title.
#' @param type Character: "bar" or "line".
Expand Down
1 change: 1 addition & 0 deletions R/train.R
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ train <- function(
x = x,
weights = weights,
dat_validation = dat_validation_for_training,
execution_config = execution_config, # used by LightRuleFit
verbosity = verbosity
)

Expand Down
12 changes: 10 additions & 2 deletions R/train_CART.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ method(train_, CARTHyperparameters) <- function(
x,
weights = NULL,
dat_validation = NULL,
execution_config = setup_ExecutionConfig(),
verbosity = 1L
) {
# Dependencies ----
Expand Down Expand Up @@ -89,7 +90,8 @@ method(train_, CARTHyperparameters) <- function(
method(predict_super, class_rpart) <- function(
model,
newdata,
type = NULL
type = NULL,
verbosity = 0L
) {
if (type == "Classification") {
# Classification
Expand All @@ -115,5 +117,11 @@ method(predict_super, class_rpart) <- function(
#' @keywords internal
#' @noRd
method(varimp_super, class_rpart) <- function(model) {
model[["variable.importance"]]
vi <- model[["variable.importance"]]
VariableImportance(
data.table(
variable = names(vi),
importance = unname(vi)
)
)
} # /rtemis::varimp_super.rpart
43 changes: 23 additions & 20 deletions R/train_GAM.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@ method(train_, GAMHyperparameters) <- function(
x,
weights = NULL,
dat_validation = NULL,
execution_config = setup_ExecutionConfig(),
verbosity = 1L
) {
# Dependencies ----
check_dependencies("mgcv")

# Checks ----
check_is_S7(hyperparameters, GAMHyperparameters)

# Hyperparameters ----
# Hyperparameters must be either untunable or frozen by `train`.
if (needs_tuning(hyperparameters)) {
Expand Down Expand Up @@ -125,7 +123,8 @@ method(train_, GAMHyperparameters) <- function(
method(predict_super, class_gam) <- function(
model,
newdata,
type = NULL
type = NULL,
verbosity = 0L
) {
out <- predict(object = model, newdata = newdata, type = "response")
if (model[["family"]][["family"]] == "binomial") {
Expand All @@ -137,30 +136,34 @@ method(predict_super, class_gam) <- function(


# %% varimp_super.class_gam ----
#' Get coefficients from GAM model
#' Get variable importance from GAM model
#'
#' Variable importance for GAM is estimated as the variance of each predictor's partial effect,
#' obtained via predict(model, type = "terms"). This measures each smooth term's contribution to
#' the variance of the fitted values. Values are normalized to sum to one, representing each
#' predictor's proportion of total predicted variance. This approach is computationally efficient
#' (no refitting required) and analogous to importance measures in tree-based methods. It assumes
#' approximate uncorrelatedness of partial effects, which penalized smooths tend to satisfy. For
#' models with high concurvity, consider hierarchical partitioning of R² (e.g. via the gam.hp
#' package) as an alternative.
#'
#' @param model mgcv gam model.
#'
#' @keywords internal
#' @noRd
method(varimp_super, class_gam) <- function(
model,
type = c("p-value", "edf", "coefficients")
type = c("partial_effect", "F-test")
) {
type <- match.arg(type)
if (type == "p-value") {
# Get parametric and smooth term p-values
summary_ <- summary(model)
# Exclude intercept
-log10(c(
summary_[["s.table"]][, "p-value"],
summary_[["p.table"]][, ncol(summary_[["p.table"]])][-1]
))
} else if (type == "edf") {
summary(model)[["s.table"]][, "edf"]
} else if (type == "coefficients") {
coef(model)
}
peff <- predict(model, type = "terms")
vi <- apply(peff, 2, var)
npeff <- vi / sum(vi) # normalized importance
VariableImportance(
data.table(
variable = names(npeff),
Partial_Effect_Variance = unname(npeff)
)
)
Comment on lines +161 to +166

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The column name for the importance measure is hardcoded to "Coefficient". However, this method supports multiple types of importance measures including "p-value" and "edf". Using the type argument as the column name would provide more accurate labels in plots and summaries.

  .vi_dt <- data.table(variable = names(.coef), importance = unname(.coef))
  setnames(.vi_dt, "importance", type)
  VariableImportance(.vi_dt)

} # /rtemis::varimp_super.gam


Expand Down
Loading
Loading