diff --git a/R/DESCRIPTION b/R/DESCRIPTION index 7dc3b7c9b..c6f36a1a2 100644 --- a/R/DESCRIPTION +++ b/R/DESCRIPTION @@ -1,7 +1,7 @@ Package: Robyn Type: Package Title: Semi-Automated Marketing Mix Modeling (MMM) from Meta Marketing Science -Version: 3.12.0.9005 +Version: 3.12.0.9006 Authors@R: c( person("Gufeng", "Zhou", , "gufeng@meta.com", c("cre", "aut")), person("Igor", "Skokan", , "igorskokan@meta.com", c("aut")), diff --git a/R/R/allocator.R b/R/R/allocator.R index 3a0e94d21..1c380dd22 100644 --- a/R/R/allocator.R +++ b/R/R/allocator.R @@ -133,7 +133,8 @@ robyn_allocator <- function(robyn_object = NULL, ## set local variables, sort & prompt # paid_media_spends <- InputCollect$paid_media_spends - paid_media_selected <- InputCollect$paid_media_selected + paid_media_selected <- if ("paid_media_selected" %in% names(InputCollect)) + InputCollect$paid_media_selected else InputCollect$paid_media_spends dep_var_type <- InputCollect$dep_var_type if (is.null(select_model) && length(OutputCollect$allSolutions == 1)) { select_model <- OutputCollect$allSolutions diff --git a/R/R/auxiliary.R b/R/R/auxiliary.R index 95a7cd53f..b42e131ab 100644 --- a/R/R/auxiliary.R +++ b/R/R/auxiliary.R @@ -115,3 +115,29 @@ baseline_vars <- function(InputCollect, baseline_level) { # Calculate MSE .mse_loss <- function(y, y_hat) mean((y - y_hat)^2) + +# next_date(c("2021-01-01", "2021-02-01")) +# next_date(c("2021-01-01", "2021-01-08", "2021-01-15")) +# next_date(c(Sys.Date() - 1, Sys.Date())) +.next_date <- function(dates) { + dates <- as.Date(dates) + diffs <- diff(dates) + if (all(diffs == 1)) { + frequency <- "daily" + } else if (all(diffs == 7)) { + frequency <- "weekly" + } else if (all(format(dates[-length(dates)], "%Y-%m") != format(dates[-1], "%Y-%m"))) { + frequency <- "monthly" + } else { + warning(paste( + "Unable to determine frequency to calculate next logical date.", + "Returning last available date.")) + return(as.Date(tail(dates, 1))) + } + next_date <- switch( + frequency, + "daily" = dates[length(dates)] + 1, + "weekly" = dates[length(dates)] + 7, + "monthly" = seq(dates[length(dates)], by = "1 month", length.out = 2)[2]) + return(as.Date(next_date)) +} diff --git a/R/R/checks.R b/R/R/checks.R index 48c412bf0..fd6388885 100644 --- a/R/R/checks.R +++ b/R/R/checks.R @@ -395,30 +395,31 @@ check_windows <- function(dt_input, date_var, all_media, window_start, window_en refreshAddedStart <- window_start if (is.null(window_end)) { - window_end <- max(dates_vec) + window_end <- .next_date(dates_vec) - 1 } else { window_end <- as.Date(as.character(window_end), "%Y-%m-%d", origin = "1970-01-01") if (is.na(window_end)) { stop(sprintf("Input 'window_end' must have date format, i.e. '%s'", Sys.Date())) - } else if (window_end > max(dates_vec)) { - window_end <- max(dates_vec) + } else if (window_end > .next_date(dates_vec) - 1) { + window_end <- .next_date(dates_vec) - 1 message(paste( - "Input 'window_end' is larger than the latest date in input data.", - "It's automatically set to the latest date:", window_end + "Input 'window_end' is larger than the latest dates available in input data.", + "Automatically set to date:", window_end )) } else if (window_end < window_start) { - window_end <- max(dates_vec) + window_end <- .next_date(dates_vec) - 1 message(paste( "Input 'window_end' must be >= 'window_start.", - "It's automatically set to the latest date:", window_end + "Automatically set to date:", window_end )) } } + # Find closest date contained in input data rollingWindowEndWhich <- which.min(abs(difftime(dates_vec, window_end, units = "days"))) - if (!(window_end %in% dates_vec)) { - window_end <- dt_input[rollingWindowEndWhich, date_var][[1]] - message("Input 'window_end' is adapted to the closest date contained in input data: ", window_end) + if (!window_end %in% c(dates_vec, .next_date(dates_vec) - 1)) { + window_end <- .next_date(dt_input[seq(rollingWindowEndWhich), date_var][[1]]) - 1 + message("Input 'window_end' is adapted to the closest available date from input data: ", window_end) } rollingWindowLength <- rollingWindowEndWhich - rollingWindowStartWhich + 1 diff --git a/R/R/model.R b/R/R/model.R index 29138683c..5bd594fb7 100644 --- a/R/R/model.R +++ b/R/R/model.R @@ -1366,7 +1366,7 @@ init_msgs_run <- function(InputCollect, refresh, lambda_control = NULL, quiet = nrow(InputCollect$dt_mod), InputCollect$intervalType, min(InputCollect$dt_mod$ds), - max(InputCollect$dt_mod$ds) + .next_date(InputCollect$dt_mod$ds) - 1 )) depth <- ifelse( "refreshDepth" %in% names(InputCollect), diff --git a/R/R/plots.R b/R/R/plots.R index 74d827b96..c113ee2e6 100644 --- a/R/R/plots.R +++ b/R/R/plots.R @@ -484,8 +484,6 @@ robyn_onepagers <- function( ## 4. Response curves dt_scurvePlot <- temp[[sid]]$plot4data$dt_scurvePlot dt_scurvePlotMean <- temp[[sid]]$plot4data$dt_scurvePlotMean - paid_media_selected <- if ("paid_media_selected" %in% names(InputCollect)) - InputCollect$paid_media_selected else InputCollect$paid_media_spends trim_rate <- 1.3 # maybe enable as a parameter if (trim_rate > 0) { dt_scurvePlot <- dt_scurvePlot %>% @@ -496,7 +494,7 @@ robyn_onepagers <- function( filter( .data$spend < max(dt_scurvePlotMean$mean_spend_adstocked) * trim_rate, .data$response < max(dt_scurvePlotMean$mean_response) * trim_rate, - .data$channel %in% paid_media_selected + .data$channel %in% InputCollect$paid_media_vars ) %>% left_join( dt_scurvePlotMean[, c("channel", "mean_carryover")], "channel"