Skip to content

Commit

Permalink
fix: hotstart stack (#964)
Browse files Browse the repository at this point in the history
* fix: overwrite train time when hotstarting

* fix: disallow learners without model in the hotstart stack

* chore: browser

* refactor: add debug messages

* feat: add debug messages to resample

* refactor: debug message

* fix: hotstarting works with set cole role

* feat: add hotstart threshold

* fix: allow null

* chore: update news

* chore: typo
  • Loading branch information
be-marc committed Nov 11, 2023
1 parent ad0ed92 commit 33b2473
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 17 deletions.
6 changes: 5 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# mlr3 0.16.1-9000

* Learners cannot be added to the `HotstartStack` anymore when the model is missing.
* Learners bellow the `hotstart_threshold` are not added to the `HotstartStack` anymore.
* The `learner$state$train_time` in hotstarted learners is now only the time of the last training.
* Added debug messages to the hotstart stack.
* Fixed bug where the `HotstartStack` did not work with column roles set in the task.
* export generic function `col_info` to allow adding new methods for backends
* Add `"mlr3.exec_chunk_bins"` option to split the resampling iterations into a number of bins.


# mlr3 0.16.1

* Function `data.table()` is now re-exported.
Expand Down
30 changes: 27 additions & 3 deletions R/HotstartStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,22 @@ HotstartStack = R6Class("HotstartStack",
#' Stores hot start learners.
stack = NULL,

#' @field hotstart_threshold (named `numeric(1)`)\cr
#' Threshold for storing learners in the stack.
#' If the value of the hotstart parameter is below this threshold, the learner is not added to the stack.
hotstart_threshold = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @param learners (List of [Learner]s)\cr
#' Learners are added to the hotstart stack. If `NULL` (default), empty
#' stack is created.
initialize = function(learners = NULL) {
#' @param hotstart_threshold (named `numeric(1)`)\cr
#' Threshold for storing learners in the stack.
initialize = function(learners = NULL, hotstart_threshold = NULL) {
self$stack = data.table()
self$hotstart_threshold = assert_numeric(hotstart_threshold, names = "named", null.ok = TRUE)

# add learners to stack
if (!is.null(learners)) self$add(learners)
Expand All @@ -77,19 +85,33 @@ HotstartStack = R6Class("HotstartStack",
add = function(learners) {
learners = assert_learners(as_learners(learners))

# check for models
if (any(map_lgl(learners, function(learner) is.null(learner$state$model)))) {
stopf("Learners must be trained before adding them to the hotstart stack.")
}

if (!is.null(self$hotstart_threshold)) {
learners = keep(learners, function(learner) {
hotstart_id = learner$param_set$ids(tags = "hotstart")
learner$param_set$values[[hotstart_id]] >= self$hotstart_threshold[hotstart_id]
})
if (!length(learners)) return(invisible(self))
}

# hashes
task_hash = map_chr(learners, function(learner) learner$state$task_hash)
learner_hash = map_chr(learners, learner_hotstart_hash)

lg$debug("Adding %s learner(s) to the hotstart stack.", length(learners))

self$stack = rbindlist(list(self$stack, data.table(start_learner = learners, task_hash, learner_hash)))
setkeyv(self$stack, c("task_hash", "learner_hash"))

invisible(self)
},

#' @description
#' Calculates the cost for each learner of the stack to hot start the target
#' `learner`.
#' Calculates the cost for each learner of the stack to hot start the target `learner`.
#'
#' The following cost values can be returned:
#'
Expand Down Expand Up @@ -152,6 +174,8 @@ HotstartStack = R6Class("HotstartStack",
][, "cost" := map_dbl(start_learner, function(l) calculate_cost(l, learner, hotstart_id))
][which_min(get("cost"), na_rm = TRUE), start_learner]

lg$debug("Found %i start learner(s) in hotstart stack of size %i.", length(start_learner), nrow(self$stack))

if (!length(start_learner)) return(NULL)
learner$state = start_learner[[1]]$state
learner
Expand Down
3 changes: 1 addition & 2 deletions R/helper_hashes.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ task_hashes = function(task, resampling) {
map_chr(seq_len(resampling$iters), function(i) {
train_set = resampling$train_set(i)
row_roles$use = train_set
calculate_hash(class(task), task$id, task$backend$hash, task$col_info, row_roles, task$col_roles,
task$properties)
calculate_hash(class(task), task$id, task$backend$hash, task$col_info, row_roles, task$col_roles, get_private(task)$.properties)
})
}
7 changes: 6 additions & 1 deletion R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,22 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe
lgr_threshold = map_int(mlr_reflections$loggers, "threshold")

grid = if (allow_hotstart) {
hotstart_grid = map_dtr(seq_len(n), function(iteration) {

lg$debug("Resampling with hotstart enabled.")

hotstart_grid = map_dtr(seq_len(n), function(iteration) {
if (!is.null(learner$hotstart_stack)) {
# search for hotstart learner
task_hashes = task_hashes(task, resampling)
start_learner = get_private(learner$hotstart_stack)$.start_learner(learner$clone(), task_hashes[iteration])
}
if (is.null(learner$hotstart_stack) || is.null(start_learner)) {
# no hotstart learners stored or no adaptable model found
lg$debug("Resampling with hotstarting not possible. Not start learner found.")
mode = "train"
} else {
# hotstart learner found
lg$debug("Resampling with hotstarting.")
start_learner$param_set$values = insert_named(start_learner$param_set$values, learner$param_set$values)
learner = start_learner
mode = "hotstart"
Expand Down
9 changes: 2 additions & 7 deletions R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,8 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL
.timeout = learner$timeout["train"]
)

if (mode == "train") {
log = append_log(NULL, "train", result$log$class, result$log$msg)
train_time = result$elapsed
} else { # hotstart
log = rbindlist(list(learner$state$log, append_log(NULL, "train", result$log$class, result$log$msg)))
train_time = learner$state$train_time + result$elapsed
}
log = append_log(NULL, "train", result$log$class, result$log$msg)
train_time = result$elapsed

proto = task$data(rows = integer())
learner$state = insert_named(learner$state, list(
Expand Down
12 changes: 9 additions & 3 deletions man/HotstartStack.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions tests/testthat/test_HotstartStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,31 @@ test_that("HotstartStack works without hotstart parameter", {
expect_equal(hot$start_cost(target_learner, task$hash), NA_real_)
expect_null(get_private(target_learner$hotstart_stack)$.start_learner(target_learner, task$hash))
})


test_that("Learners without a model are not added to the hotstart stack", {
task = tsk("pima")
learner = lrn("classif.debug", iter = 1)
learner$train(task)
learner$state$model = NULL

expect_error(HotstartStack$new(learner), "Learners must be trained before adding them to the hotstart stack.")
})

test_that("HotstartStack threshold works", {
hot = HotstartStack$new(hotstart_threshold = c(iter = 2))

task = tsk("pima")
learner_1 = lrn("classif.debug", iter = 1)
learner_1$train(task)

hot$add(learner_1)
expect_data_table(hot$stack, nrows = 0)

task = tsk("pima")
learner_2 = lrn("classif.debug", iter = 2)
learner_2$train(task)

hot$add(learner_2)
expect_data_table(hot$stack, nrows = 1)
})
23 changes: 23 additions & 0 deletions tests/testthat/test_hotstart.R
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,26 @@ test_that("learners are cloned when hotstarting is applied", {
expect_equal(bmr$resample_result(1)$learners[[1]]$model$id,
hot$stack$start_learner[[1]]$model$id)
})

test_that("hotstarting works when col role is set in task", {
task = tsk("pima")
task$col_roles$stratum = task$target_names
learner_1 = lrn("classif.debug", iter = 1)
resampling = rsmp("holdout")
resampling$instantiate(task)

rr = resample(task, learner_1, resampling, store_models = TRUE)

learner = lrn("classif.debug", iter = 2)
hot = HotstartStack$new(rr$learners)
learner$hotstart_stack = hot

rr_2 = resample(task, learner, resampling, store_models = TRUE, allow_hotstart = TRUE)
pwalk(list(rr$learners, rr_2$learners), function(l1, l2) {
expect_equal(l2$param_set$values$iter, 2)
expect_class(l2$model, "classif.debug_model")
expect_equal(l2$model$iter, 2)
expect_equal(l1$model$id, l2$model$id)
expect_null(l2$hotstart_stack)
})
})

0 comments on commit 33b2473

Please sign in to comment.