From 33b247308514874052d9b73922f42325816ebb77 Mon Sep 17 00:00:00 2001 From: Marc Becker <33069354+be-marc@users.noreply.github.com> Date: Sat, 11 Nov 2023 14:29:16 +0100 Subject: [PATCH] fix: hotstart stack (#964) * fix: overwrite train time when hotstarting * fix: disallow learners without model in the hotstart stack * chore: browser * refactor: add debug messages * feat: add debug messages to resample * refactor: debug message * fix: hotstarting works with set cole role * feat: add hotstart threshold * fix: allow null * chore: update news * chore: typo --- NEWS.md | 6 +++++- R/HotstartStack.R | 30 ++++++++++++++++++++++++++--- R/helper_hashes.R | 3 +-- R/resample.R | 7 ++++++- R/worker.R | 9 ++------- man/HotstartStack.Rd | 12 +++++++++--- tests/testthat/test_HotstartStack.R | 28 +++++++++++++++++++++++++++ tests/testthat/test_hotstart.R | 23 ++++++++++++++++++++++ 8 files changed, 101 insertions(+), 17 deletions(-) diff --git a/NEWS.md b/NEWS.md index dde56f9b4..166e10615 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,9 +1,13 @@ # mlr3 0.16.1-9000 +* Learners cannot be added to the `HotstartStack` anymore when the model is missing. +* Learners bellow the `hotstart_threshold` are not added to the `HotstartStack` anymore. +* The `learner$state$train_time` in hotstarted learners is now only the time of the last training. +* Added debug messages to the hotstart stack. +* Fixed bug where the `HotstartStack` did not work with column roles set in the task. * export generic function `col_info` to allow adding new methods for backends * Add `"mlr3.exec_chunk_bins"` option to split the resampling iterations into a number of bins. - # mlr3 0.16.1 * Function `data.table()` is now re-exported. diff --git a/R/HotstartStack.R b/R/HotstartStack.R index 5484837e5..11b845915 100644 --- a/R/HotstartStack.R +++ b/R/HotstartStack.R @@ -54,14 +54,22 @@ HotstartStack = R6Class("HotstartStack", #' Stores hot start learners. stack = NULL, + #' @field hotstart_threshold (named `numeric(1)`)\cr + #' Threshold for storing learners in the stack. + #' If the value of the hotstart parameter is below this threshold, the learner is not added to the stack. + hotstart_threshold = NULL, + #' @description #' Creates a new instance of this [R6][R6::R6Class] class. #' #' @param learners (List of [Learner]s)\cr #' Learners are added to the hotstart stack. If `NULL` (default), empty #' stack is created. - initialize = function(learners = NULL) { + #' @param hotstart_threshold (named `numeric(1)`)\cr + #' Threshold for storing learners in the stack. + initialize = function(learners = NULL, hotstart_threshold = NULL) { self$stack = data.table() + self$hotstart_threshold = assert_numeric(hotstart_threshold, names = "named", null.ok = TRUE) # add learners to stack if (!is.null(learners)) self$add(learners) @@ -77,10 +85,25 @@ HotstartStack = R6Class("HotstartStack", add = function(learners) { learners = assert_learners(as_learners(learners)) + # check for models + if (any(map_lgl(learners, function(learner) is.null(learner$state$model)))) { + stopf("Learners must be trained before adding them to the hotstart stack.") + } + + if (!is.null(self$hotstart_threshold)) { + learners = keep(learners, function(learner) { + hotstart_id = learner$param_set$ids(tags = "hotstart") + learner$param_set$values[[hotstart_id]] >= self$hotstart_threshold[hotstart_id] + }) + if (!length(learners)) return(invisible(self)) + } + # hashes task_hash = map_chr(learners, function(learner) learner$state$task_hash) learner_hash = map_chr(learners, learner_hotstart_hash) + lg$debug("Adding %s learner(s) to the hotstart stack.", length(learners)) + self$stack = rbindlist(list(self$stack, data.table(start_learner = learners, task_hash, learner_hash))) setkeyv(self$stack, c("task_hash", "learner_hash")) @@ -88,8 +111,7 @@ HotstartStack = R6Class("HotstartStack", }, #' @description - #' Calculates the cost for each learner of the stack to hot start the target - #' `learner`. + #' Calculates the cost for each learner of the stack to hot start the target `learner`. #' #' The following cost values can be returned: #' @@ -152,6 +174,8 @@ HotstartStack = R6Class("HotstartStack", ][, "cost" := map_dbl(start_learner, function(l) calculate_cost(l, learner, hotstart_id)) ][which_min(get("cost"), na_rm = TRUE), start_learner] + lg$debug("Found %i start learner(s) in hotstart stack of size %i.", length(start_learner), nrow(self$stack)) + if (!length(start_learner)) return(NULL) learner$state = start_learner[[1]]$state learner diff --git a/R/helper_hashes.R b/R/helper_hashes.R index 91f9e4d4c..f70989e2a 100644 --- a/R/helper_hashes.R +++ b/R/helper_hashes.R @@ -21,7 +21,6 @@ task_hashes = function(task, resampling) { map_chr(seq_len(resampling$iters), function(i) { train_set = resampling$train_set(i) row_roles$use = train_set - calculate_hash(class(task), task$id, task$backend$hash, task$col_info, row_roles, task$col_roles, - task$properties) + calculate_hash(class(task), task$id, task$backend$hash, task$col_info, row_roles, task$col_roles, get_private(task)$.properties) }) } diff --git a/R/resample.R b/R/resample.R index dde558049..9e3467e46 100644 --- a/R/resample.R +++ b/R/resample.R @@ -77,7 +77,10 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe lgr_threshold = map_int(mlr_reflections$loggers, "threshold") grid = if (allow_hotstart) { - hotstart_grid = map_dtr(seq_len(n), function(iteration) { + + lg$debug("Resampling with hotstart enabled.") + + hotstart_grid = map_dtr(seq_len(n), function(iteration) { if (!is.null(learner$hotstart_stack)) { # search for hotstart learner task_hashes = task_hashes(task, resampling) @@ -85,9 +88,11 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe } if (is.null(learner$hotstart_stack) || is.null(start_learner)) { # no hotstart learners stored or no adaptable model found + lg$debug("Resampling with hotstarting not possible. Not start learner found.") mode = "train" } else { # hotstart learner found + lg$debug("Resampling with hotstarting.") start_learner$param_set$values = insert_named(start_learner$param_set$values, learner$param_set$values) learner = start_learner mode = "hotstart" diff --git a/R/worker.R b/R/worker.R index c6ff9dffd..e2d8e312a 100644 --- a/R/worker.R +++ b/R/worker.R @@ -65,13 +65,8 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL .timeout = learner$timeout["train"] ) - if (mode == "train") { - log = append_log(NULL, "train", result$log$class, result$log$msg) - train_time = result$elapsed - } else { # hotstart - log = rbindlist(list(learner$state$log, append_log(NULL, "train", result$log$class, result$log$msg))) - train_time = learner$state$train_time + result$elapsed - } + log = append_log(NULL, "train", result$log$class, result$log$msg) + train_time = result$elapsed proto = task$data(rows = integer()) learner$state = insert_named(learner$state, list( diff --git a/man/HotstartStack.Rd b/man/HotstartStack.Rd index 30eb3724c..3ce607906 100644 --- a/man/HotstartStack.Rd +++ b/man/HotstartStack.Rd @@ -55,6 +55,10 @@ learner$train(task) \describe{ \item{\code{stack}}{\code{\link[data.table:data.table]{data.table::data.table()}}\cr Stores hot start learners.} + +\item{\code{hotstart_threshold}}{(named \code{numeric(1)})\cr +Threshold for storing learners in the stack. +If the value of the hotstart parameter is below this threshold, the learner is not added to the stack.} } \if{html}{\out{}} } @@ -75,7 +79,7 @@ Stores hot start learners.} \subsection{Method \code{new()}}{ Creates a new instance of this \link[R6:R6Class]{R6} class. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{HotstartStack$new(learners = NULL)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{HotstartStack$new(learners = NULL, hotstart_threshold = NULL)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -84,6 +88,9 @@ Creates a new instance of this \link[R6:R6Class]{R6} class. \item{\code{learners}}{(List of \link{Learner}s)\cr Learners are added to the hotstart stack. If \code{NULL} (default), empty stack is created.} + +\item{\code{hotstart_threshold}}{(named \code{numeric(1)})\cr +Threshold for storing learners in the stack.} } \if{html}{\out{}} } @@ -113,8 +120,7 @@ self (invisibly). \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-HotstartStack-start_cost}{}}} \subsection{Method \code{start_cost()}}{ -Calculates the cost for each learner of the stack to hot start the target -\code{learner}. +Calculates the cost for each learner of the stack to hot start the target \code{learner}. The following cost values can be returned: \itemize{ diff --git a/tests/testthat/test_HotstartStack.R b/tests/testthat/test_HotstartStack.R index 4e7906857..4abcfc16d 100644 --- a/tests/testthat/test_HotstartStack.R +++ b/tests/testthat/test_HotstartStack.R @@ -381,3 +381,31 @@ test_that("HotstartStack works without hotstart parameter", { expect_equal(hot$start_cost(target_learner, task$hash), NA_real_) expect_null(get_private(target_learner$hotstart_stack)$.start_learner(target_learner, task$hash)) }) + + +test_that("Learners without a model are not added to the hotstart stack", { + task = tsk("pima") + learner = lrn("classif.debug", iter = 1) + learner$train(task) + learner$state$model = NULL + + expect_error(HotstartStack$new(learner), "Learners must be trained before adding them to the hotstart stack.") +}) + +test_that("HotstartStack threshold works", { + hot = HotstartStack$new(hotstart_threshold = c(iter = 2)) + + task = tsk("pima") + learner_1 = lrn("classif.debug", iter = 1) + learner_1$train(task) + + hot$add(learner_1) + expect_data_table(hot$stack, nrows = 0) + + task = tsk("pima") + learner_2 = lrn("classif.debug", iter = 2) + learner_2$train(task) + + hot$add(learner_2) + expect_data_table(hot$stack, nrows = 1) +}) diff --git a/tests/testthat/test_hotstart.R b/tests/testthat/test_hotstart.R index 4daabd09a..130108406 100644 --- a/tests/testthat/test_hotstart.R +++ b/tests/testthat/test_hotstart.R @@ -235,3 +235,26 @@ test_that("learners are cloned when hotstarting is applied", { expect_equal(bmr$resample_result(1)$learners[[1]]$model$id, hot$stack$start_learner[[1]]$model$id) }) + +test_that("hotstarting works when col role is set in task", { + task = tsk("pima") + task$col_roles$stratum = task$target_names + learner_1 = lrn("classif.debug", iter = 1) + resampling = rsmp("holdout") + resampling$instantiate(task) + + rr = resample(task, learner_1, resampling, store_models = TRUE) + + learner = lrn("classif.debug", iter = 2) + hot = HotstartStack$new(rr$learners) + learner$hotstart_stack = hot + + rr_2 = resample(task, learner, resampling, store_models = TRUE, allow_hotstart = TRUE) + pwalk(list(rr$learners, rr_2$learners), function(l1, l2) { + expect_equal(l2$param_set$values$iter, 2) + expect_class(l2$model, "classif.debug_model") + expect_equal(l2$model$iter, 2) + expect_equal(l1$model$id, l2$model$id) + expect_null(l2$hotstart_stack) + }) +})