diff --git a/CHANGELOG.md b/CHANGELOG.md index fc46891..a6c1c54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -### Version 0.1.0.9004-6 +### Version 0.1.0.9004-7 * Provide backwards compatibility with [legacy mungebits](https://github.com/robertzk/mungebits) diff --git a/DESCRIPTION b/DESCRIPTION index 48369d2..e0197b8 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -6,7 +6,7 @@ Description: A way of thinking about data preparation that online prediction so that both can be described by the same codebase. With mungebits, you can save time on having to re-implement your R code to work in production and instead re-use the same codebase. -Version: 0.1.0.9006 +Version: 0.1.0.9007 Author: Robert Krzyzanowski Maintainer: Robert Krzyzanowski Authors@R: c(person("Robert", "Krzyzanowski", diff --git a/R/munge.R b/R/munge.R index a7d43ef..14d7fd9 100644 --- a/R/munge.R +++ b/R/munge.R @@ -247,7 +247,9 @@ munge <- function(data, mungelist, stagerunner = FALSE, list = FALSE, parse = TRUE) { stopifnot(is.data.frame(data) || (is.environment(data) && - (!identical(stagerunner, FALSE) || any(ls(data) == "data")))) + ## We have to be slightly careful here to ensure that we handle + ## [objectdiff](https://github.com/robertzk/objectdiff) environments correctly. + (!identical(stagerunner, FALSE) || environment_has_data(data)))) if (length(mungelist) == 0L) { return(data) @@ -301,7 +303,11 @@ munge_ <- function(data, mungelist, stagerunner, list_output, parse) { ## by the `mungepiece_stages` helper. stages <- mungepiece_stages(mungelist) if (is.environment(data)) { - context <- data + if (identical(stagerunner, FALSE)) { + context <- normalize_environment(data) + } else { + context <- data + } } else { context <- list2env(list(data = data), parent = emptyenv()) } @@ -376,7 +382,9 @@ mungepiece_stage_body <- function() { ## the trained mungepiece. # Make a fresh copy to avoid shared stage problems. piece <- mungepieces[[mungepiece_index]]$duplicate(private = TRUE) - piece$run(env) + ## We don't do the straightforward `piece$run(env)` to ensure + ## compatibility with [tracked environments](https://github.com/robertzk/objectdiff). + env$data <- piece$run(env$data) newpieces[[mungepiece_index]] <<- piece ## When we are out of mungepieces, that is, when the current index equals @@ -386,6 +394,7 @@ mungepiece_stage_body <- function() { ## the munging actions on new data by passing the dataframe as the second ## argument to the `munge` function. if (mungepiece_index == size) { + names(newpieces) <- names(mungepieces) attr(env$data, "mungepieces") <- append(attr(env$data, "mungepieces"), newpieces) } @@ -418,9 +427,32 @@ legacy_mungepiece_stage_body <- function() { piece$run(env) if (mungepiece_index == size) { + names(newpieces) <- names(mungepieces) attr(env$data, "mungepieces") <- append(attr(env$data, "mungepieces"), newpieces) } }) } +normalize_environment <- function(env) { + ## For compatibility with [objectdiff](https://github.com/robertzk/objectdiff), + ## we use its special-purpose `ls`. + if (is(env, "tracked_environment") && + is.element("objectdiff", loadedNamespaces())) { + getFromNamespace("environment", "objectdiff")(env) + } else { + env + } +} + +environment_has_data <- function(env) { + ## For compatibility with [objectdiff](https://github.com/robertzk/objectdiff), + ## we use its special-purpose `ls`. + if (is(env, "tracked_environment") && + is.element("objectdiff", loadedNamespaces())) { + any(getFromNamespace("ls", "objectdiff")(env) == "data") + } else { + any(ls(env) == "data") + } +} + diff --git a/R/mungepiece.R b/R/mungepiece.R index 787a618..eba6453 100644 --- a/R/mungepiece.R +++ b/R/mungepiece.R @@ -68,7 +68,7 @@ mungepiece <- R6::R6Class("mungepiece", duplicate_mungepiece <- function(piece, ...) { ## To ensure backwards compatibility with ## [legacy mungebits](https://github.com/robertzk/mungebits), - ## we perform nothing is the piece is not an R6 object (and hence + ## we perform nothing if the piece is not an R6 object (and hence ## a new mungepiece in the style of this package). if (is.legacy_mungepiece(piece)) { piece diff --git a/R/parse_mungepiece.R b/R/parse_mungepiece.R index 346cc39..aab658b 100644 --- a/R/parse_mungepiece.R +++ b/R/parse_mungepiece.R @@ -235,7 +235,9 @@ #' # The munge function uses the attached "mungepieces" attribute, a list of #' # trained mungepieces. parse_mungepiece <- function(args) { - if (is.mungepiece(args) || is.mungebit(args)) { args <- list(args) } + if (is.mungepiece(args) || is.mungebit(args) || is.function(args)) { + args <- list(args) + } if (length(args) == 1L && is.mungepiece(args[[1L]])) { ## We duplicate the mungepiece to avoid training it. @@ -243,7 +245,11 @@ parse_mungepiece <- function(args) { } else if (length(args) == 1L && is.mungebit(args[[1L]])) { ## This case is technically handled already in parse_mungepiece_single, ## but we make it explicit here. - mungepiece$new(duplicate_mungebit(args[[1L]])) + if (is.legacy_mungebit(args[[1L]])) { + getFromNamespace("mungepiece", "mungebits")$new(args[[1L]]) + } else { + mungepiece$new(duplicate_mungebit(args[[1L]])) + } ## The third permissible format requires no unnamed arguments, since it ## must be a list consisting of a "train" and "predict" key. } else if (is.list(args) && length(args) > 0L) { @@ -292,9 +298,10 @@ parse_mungepiece_dual <- function(args) { args <- Map(list, parse_mungepiece_dual_chunk(args$train, type = "train"), parse_mungepiece_dual_chunk(args$predict, type = "predict")) - ## This is the format we need to use the `mungebit` and `mungepiece` - ## constructors. - do.call(mungepiece$new, c(list(do.call(mungebit$new, args[[1L]])), args[[2L]])) + ## We use the `create_mungepiece` helper defined below to ensure this + ## construction works for new and [legacy](https://github.com/robertzk/mungebits) + ## mungepieces. + do.call(create_mungepiece, c(args[[1L]], args[[2L]])) } ## We perform [type dispatch](http://adv-r.had.co.nz/OO-essentials.html#s3) to diff --git a/tests/testthat/test-legacy.R b/tests/testthat/test-legacy.R index 4e55a0c..89a4c30 100644 --- a/tests/testthat/test-legacy.R +++ b/tests/testthat/test-legacy.R @@ -80,6 +80,16 @@ describe("Creating legacy mungebits using the munge function", { attr(iris2, "mungepieces") <- NULL expect_equal(iris2, iris[-c(1,2)]) }) + + test_that("it should be able to create a legacy mungebit using the third munge format", { + legacy_fn <- function(df, ...) { + eval.parent(substitute({ df[[1]] <- NULL })) + } + class(legacy_fn) <- "legacy_mungebit_function" + iris2 <- munge(iris, list(list(train = list(legacy_fn, "foo"), predict = list(legacy_fn, "bar")))) + attr(iris2, "mungepieces") <- NULL + expect_equal(iris2, iris[-1L]) + }) }) diff --git a/tests/testthat/test-munge.R b/tests/testthat/test-munge.R index 3004590..f4bf1e4 100644 --- a/tests/testthat/test-munge.R +++ b/tests/testthat/test-munge.R @@ -10,6 +10,22 @@ describe("Invalid inputs", { test_that("when munging against a data.frame it must have a mungepieces attribute", { expect_error(munge(iris, beaver2), "must have a ") }) + + test_that("when passing an environment it contains a data key", { + env <- list2env(list(foo = iris)) + expect_error(munge(env, identity), "is.data.frame") + env <- list2env(list(data = iris)) + munge(env, list(list(identity))) + }) + + test_that("when passing a tracked_environment it contains a data key", { + if (requireNamespace("objectdiff", quietly = TRUE)) { + env <- objectdiff::tracked_environment(list2env(list(foo = iris))) + expect_error(munge(env, identity), "is.data.frame") + env <- objectdiff::tracked_environment(list2env(list(data = iris))) + munge(env, list(list(identity))) + } + }) }) test_that("it does nothing when no mungepieces are passed", { @@ -79,6 +95,29 @@ describe("it can procure the mungepieces list", { }) }) +test_that("mungepiece names are preserved", { + iris2 <- munge(iris, list("Step 1" = list(identity), "Step 2" = list(identity))) + expect_equal(names(attr(iris2, "mungepieces")), c("Step 1", "Step 2")) +}) + +test_that("mungepiece names are preserved for legacy mungebits", { + legacy_function <- function(x) { x } + class(legacy_function) <- c("legacy_mungebit_function", class(legacy_function)) + iris2 <- munge(iris, list("Step 1" = list(legacy_function), "Step 2" = list(legacy_function))) + expect_equal(names(attr(iris2, "mungepieces")), c("Step 1", "Step 2")) +}) + +test_that("munging works with a stagerunner generated for an objectdiff tracked environment", { + if (requireNamespace("objectdiff", quietly = TRUE)) { + env <- objectdiff::tracked_environment(list2env(list(data = iris))) + runner <- munge(env, list("Step 1" = list(identity)), stagerunner = list(remember = TRUE)) + runner$run(1) + result <- runner$context$data + attr(result, "mungepieces") <- NULL + expect_equal(result, iris) + } +}) + describe("using mungepieces with inputs", { simple_imputer <- function(...) { @@ -182,3 +221,4 @@ describe("using mungepieces with inputs", { }) }) + diff --git a/tests/testthat/test-parse_mungepiece.R b/tests/testthat/test-parse_mungepiece.R index dc1cb60..2dc940f 100644 --- a/tests/testthat/test-parse_mungepiece.R +++ b/tests/testthat/test-parse_mungepiece.R @@ -3,7 +3,7 @@ context("parse_mungepiece") describe("Invalid inputs", { test_that("it breaks when it does not receive a list", { expect_error(parse_mungepiece(5)) - expect_error(parse_mungepiece(identity)) + expect_error(parse_mungepiece(NULL)) expect_error(parse_mungepiece(iris)) }) @@ -37,6 +37,19 @@ train_fn <- function(data, by = 2) { predict_fn <- function(data, ...) { data[[1]] <- input$by * data[[1]]; data } +test_that("it can receive a simple function", { + piece <- parse_mungepiece(identity) + piece2 <- mungepiece$new(mungebit$new(identity)) + expect_same_piece(piece, piece2) +}) + +test_that("it can parse a pre-existing legacy mungebit", { + if ("mungebits" %in% row.names(installed.packages())) { + legacy_mungebit <- mungebits:::mungebit$new(function(x) { x }) + parse_mungepiece(list(legacy_mungebit)) + } +}) + describe("First format", { test_that("it correctly creates a mungepiece using the first format with no additional arguments", { piece <- parse_mungepiece(list(train_fn, 2))