Skip to content

Commit

Permalink
Vectorize params in benchmark (#954)
Browse files Browse the repository at this point in the history
* benchmark_grid + assertions

* chore: update

* fix

* cleanup assertion

* cleanups

* restore helper

* ...

* ...

* ...

* ...

* ...

* fix: prohibit cloning of measures

* fix: remove benchmark

* fix: remove clone from as.measure

* fix: prediction

* ...

* cleanup

* docs

* ...

* cleanup

* docs

* ...

* fix: constant values are inserted now

---------

Co-authored-by: be-marc <marcbecker@posteo.de>
Co-authored-by: Marc Becker <33069354+be-marc@users.noreply.github.com>
  • Loading branch information
3 people committed Sep 27, 2023
1 parent 4e12931 commit 9ba6f82
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 30 deletions.
10 changes: 5 additions & 5 deletions R/ResultData.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,14 @@ ResultData = R6Class("ResultData",
self$data = star_init()
} else {
assert_names(names(data),
permutation.of = c("task", "learner", "learner_state", "resampling", "iteration", "prediction", "uhash"))
permutation.of = c("task", "learner", "learner_state", "resampling", "iteration", "param_values", "prediction", "uhash", "learner_hash"))

if (nrow(data) == 0L) {
self$data = star_init()
} else {
fact = data[, c("uhash", "iteration", "learner_state", "prediction", "task", "learner", "resampling"),
fact = data[, c("uhash", "iteration", "learner_state", "prediction", "task", "learner", "resampling", "param_values", "learner_hash"),
with = FALSE]
set(fact, j = "task_hash", value = hashes(fact$task))
set(fact, j = "learner_hash", value = hashes(fact$learner))
set(fact, j = "learner_phash", value = phashes(fact$learner))
set(fact, j = "resampling_hash", value = hashes(fact$resampling))

Expand All @@ -59,12 +58,13 @@ ResultData = R6Class("ResultData",
keyby = "learner_phash"]
resamplings = fact[, list(resampling = .SD$resampling[1L]),
keyby = "resampling_hash"]
learner_components = fact[, list(learner_param_vals = list(.SD$learner[[1L]]$param_set$values)),
learner_components = fact[, list(learner_param_vals = list(.SD$param_values[[1]])),
keyby = "learner_hash"]

set(fact, j = "task", value = NULL)
set(fact, j = "learner", value = NULL)
set(fact, j = "resampling", value = NULL)
set(fact, j = "param_values", value = NULL)
setkeyv(fact, c("uhash", "iteration"))

if (!store_backends) {
Expand Down Expand Up @@ -354,8 +354,8 @@ star_init = function() {
learner_state = list(),
prediction = list(),

task_hash = character(),
learner_hash = character(),
task_hash = character(),
learner_phash = character(),
resampling_hash = character(),

Expand Down
2 changes: 2 additions & 0 deletions R/as_result_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ as_result_data = function(task, learners, resampling, iterations, predictions, l
ResultData$new(data.table(
task = list(task),
learner = learners,
learner_hash = hashes(learners),
learner_state = learner_states,
param_values = map(learners, function(x) x$param_set$values),
resampling = list(resampling),
iteration = iterations,
prediction = predictions,
Expand Down
13 changes: 13 additions & 0 deletions R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,16 @@ assert_row_sums = function(prob) {
}
}
}

assert_param_values = function(x, n_learners = NULL, .var.name = vname(x)) {
assert_list(x, len = n_learners, .var.name = .var.name)

ok = every(x, function(x) {
test_list(x) && every(x, test_list, names = "unique", null.ok = TRUE)
})

if (!ok) {
stopf("'%s' must be a three-time nested list and the most inner list must be named", .var.name)
}
invisible(x)
}
31 changes: 24 additions & 7 deletions R/benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#' Each row defines a resampling by providing a [Task], [Learner] and an instantiated [Resampling] strategy.
#' The helper function [benchmark_grid()] can assist in generating an exhaustive design (see examples) and
#' instantiate the [Resampling]s per [Task].
#' Additionally, you can set the additional column 'param_values', see [benchmark_grid()].
#' @template param_store_models
#' @template param_store_backends
#' @template param_encapsulate
Expand Down Expand Up @@ -79,10 +80,15 @@
benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling")) {
assert_subset(clone, c("task", "learner", "resampling"))
assert_data_frame(design, min.rows = 1L)
assert_names(names(design), permutation.of = c("task", "learner", "resampling"))
assert_names(names(design), must.include = c("task", "learner", "resampling"))
design$task = list(assert_tasks(as_tasks(design$task)))
design$learner = list(assert_learners(as_learners(design$learner)))
design$resampling = list(assert_resamplings(as_resamplings(design$resampling), instantiated = TRUE))
if (is.null(design$param_values)) {
design$param_values = list()
} else {
design$param_values = list(assert_param_values(design$param_values, n_learners = length(design$learner)))
}
assert_flag(store_models)
assert_flag(store_backends)

Expand Down Expand Up @@ -113,15 +119,24 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps
# set encapsulation + fallback
set_encapsulation(design$learner, encapsulate)

# expand the design: add rows for each resampling iteration
grid = pmap_dtr(design, function(task, learner, resampling) {
# expand the design: add rows for each resampling iteration and param_values
grid = pmap_dtr(design, function(task, learner, resampling, param_values) {
# learner = assert_learner(as_learner(learner, clone = TRUE))
assert_learnable(task, learner)

iters = resampling$iters
n_params = max(1L, length(param_values))
# insert constant values
param_values = map(param_values, function(values) insert_named(learner$param_set$values, values))

data.table(
task = list(task), learner = list(learner), resampling = list(resampling),
iteration = seq_len(resampling$iters), uhash = UUIDgenerate()
iteration = rep(seq_len(iters), times = n_params),
param_values = if (is.null(param_values)) list() else rep(param_values, each = iters),
uhash = rep(UUIDgenerate(n = n_params), each = iters)
)
})

n = nrow(grid)
lgr_threshold = map_int(mlr_reflections$loggers, "threshold")

Expand Down Expand Up @@ -167,17 +182,19 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps
}

res = future_map(n, workhorse,
task = grid$task, learner = grid$learner, resampling = grid$resampling, iteration = grid$iteration, mode = grid$mode,
task = grid$task, learner = grid$learner, resampling = grid$resampling, iteration = grid$iteration, param_values = grid$param_values, mode = grid$mode,
MoreArgs = list(store_models = store_models, lgr_threshold = lgr_threshold, pb = pb)
)

grid = insert_named(grid, list(
learner_state = map(res, "learner_state"),
prediction = map(res, "prediction")
prediction = map(res, "prediction"),
param_values = map(res, "param_values"),
learner_hash = map_chr(res, "learner_hash")
))

lg$info("Finished benchmark")

grid$mode = NULL
set(grid, j = "mode", value = NULL)
BenchmarkResult$new(ResultData$new(grid, store_backends = store_backends))
}
28 changes: 16 additions & 12 deletions R/benchmark_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#' @param tasks (list of [Task]).
#' @param learners (list of [Learner]).
#' @param resamplings (list of [Resampling]).
#' @template param_param_values
#' @param paired (`logical(1)`)\cr
#' Set this to `TRUE` if the resamplings are instantiated on the tasks, i.e., the tasks and resamplings are paired.
#' You need to provide the same number of tasks and instantiated resamplings.
Expand Down Expand Up @@ -64,10 +65,13 @@
#' benchmark(grid)
#' }
#'
benchmark_grid = function(tasks, learners, resamplings, paired = FALSE) {
benchmark_grid = function(tasks, learners, resamplings, param_values = NULL, paired = FALSE) {
tasks = assert_tasks(as_tasks(tasks))
learners = assert_learners(as_learners(learners))
resamplings = assert_resamplings(as_resamplings(resamplings))
if (!is.null(param_values)) {
assert_param_values(param_values, n_learners = length(learners))
}

if (assert_flag(paired)) {
if (length(tasks) != length(resamplings)) {
Expand All @@ -86,13 +90,7 @@ benchmark_grid = function(tasks, learners, resamplings, paired = FALSE) {
}

grid = CJ(task = seq_along(tasks), learner = seq_along(learners))
grid$instance = seq_row(grid)

tab = data.table(
task = tasks[grid$task],
learner = learners[grid$learner],
resampling = resamplings[grid$task]
)
tab = data.table(task = tasks[grid$task], learner = learners[grid$learner], resampling = resamplings[grid$task])
} else {
grid = CJ(task = seq_along(tasks), resampling = seq_along(resamplings))
is_instantiated = map_lgl(resamplings, "is_instantiated")
Expand All @@ -116,14 +114,20 @@ benchmark_grid = function(tasks, learners, resamplings, paired = FALSE) {
tab = data.table(task = tasks[grid$task], learner = learners[grid$learner], resampling = instances[grid$instance])
}

if (!is.null(param_values)) {
set(tab, j = "param_values", value = list(param_values[grid$learner]))
}

set_data_table_class(tab, "benchmark_grid")
return(tab)
}

#' @export
print.benchmark_grid = function(x, ...) {
task_ids = map(x$task, "id")
learner_ids = map(x$learner, "id")
resampling_ids = map(x$resampling, "id")
print(data.table(task = task_ids, learner = learner_ids, resampling = resampling_ids))
print(data.table(
task = ids(x$task),
learner = ids(x$learner),
resampling = ids(x$resampling),
param_values = if (is.null(x$param_values)) NULL else sprintf("<%i>", lengths(x$param_values))
))
}
4 changes: 3 additions & 1 deletion R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe
resampling = list(resampling),
iteration = seq_len(n),
prediction = map(res, "prediction"),
uhash = UUIDgenerate()
uhash = UUIDgenerate(),
param_values = map(res, "param_values"),
learner_hash = map_chr(res, "learner_hash")
)

ResampleResult$new(ResultData$new(data, store_backends = store_backends))
Expand Down
12 changes: 9 additions & 3 deletions R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ learner_predict = function(learner, task, row_ids = NULL) {
}


workhorse = function(iteration, task, learner, resampling, lgr_threshold, store_models = FALSE, pb = NULL, mode = "train") {
workhorse = function(iteration, task, learner, resampling, param_values = NULL, lgr_threshold, store_models = FALSE, pb = NULL, mode = "train") {
if (!is.null(pb)) {
pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration))
}
Expand All @@ -246,7 +246,13 @@ workhorse = function(iteration, task, learner, resampling, lgr_threshold, store_
)

# train model
learner = learner_train(learner$clone(), task, sets[["train"]], sets[["test"]], mode = mode)
learner = learner$clone()
if (length(param_values)) {
learner$param_set$values = list()
learner$param_set$set_values(.values = param_values)
}
learner_hash = learner$hash
learner = learner_train(learner, task, sets[["train"]], sets[["test"]], mode = mode)

# predict for each set
sets = sets[learner$predict_sets]
Expand All @@ -261,7 +267,7 @@ workhorse = function(iteration, task, learner, resampling, lgr_threshold, store_
learner$state$model = NULL
}

list(learner_state = learner$state, prediction = pdatas)
list(learner_state = learner$state, prediction = pdatas, param_values = learner$param_set$values, learner_hash = learner_hash)
}

append_log = function(log = NULL, stage = NA_character_, class = NA_character_, msg = character()) {
Expand Down
9 changes: 9 additions & 0 deletions man-roxygen/param_param_values.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#' @param param_values (`list()`)\cr
#' If you want to try many parameter settings for learners, you can pass them through the design
#' which is optimized to be faster than creating learners for each setting.
#'
#' A list of lists of named lists, from outer to inner:
#' 1. One list element for each [Learner].
#' 2. One list element for each hyperparameter configuration to try.
#' 3. Named list of hyperparameter settings to set in the Learner, possibly overwriting
#' already set set hyperparameters in the [Learner].
3 changes: 2 additions & 1 deletion man/benchmark.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 19 additions & 1 deletion man/benchmark_grid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

68 changes: 68 additions & 0 deletions tests/testthat/test_benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,71 @@ test_that("benchmark_grid works if paired = TRUE", {
resamplings = rev(resamplings)
expect_error(benchmark_grid(tasks, learners, resamplings, paired = TRUE))
})

test_that("param_values in benchmark", {
# setup
tasks = tsks("iris")
resamplings = list(rsmp("cv", folds = 3)$instantiate(tasks[[1]]))
learners = lrns("classif.debug")

# single parameter set via manual design
design = data.table(task = tasks, learner = learners, resampling = resamplings, param_values = list(list(list(x = 1))))
bmr = benchmark(design)
expect_benchmark_result(bmr)
expect_equal(bmr$n_resample_results, 1)
expect_equal(nrow(as.data.table(bmr)), 3)
learner = bmr$learners$learner[[1]]
expect_equal(learner$param_set$values$x, 1)
expect_equal(nrow(as.data.table(bmr)), 3)

# multiple parameters set via manual design
design = data.table(task = tasks, learner = learners, resampling = resamplings, param_values = list(list(list(x = 1), list(x = 0.5))))
bmr = benchmark(design)
expect_benchmark_result(bmr)
expect_equal(bmr$n_resample_results, 2)
expect_equal(nrow(as.data.table(bmr)), 6)
learner = bmr$learners$learner[[1]]
expect_equal(learner$param_set$values$x, 1)
learner = bmr$learners$learner[[2]]
expect_equal(learner$param_set$values$x, 0.5)

# benchmark grid does not attach param_values if empty
design = benchmark_grid(tasks, learners, resamplings)
expect_names(names(design), permutation.of = c("task", "learner", "resampling"))

# benchmark grid with param_values
design = benchmark_grid(tasks, learners, resamplings, param_values = list(list(list(x = 1))))
expect_data_table(design, nrows = 1)
expect_names(names(design), permutation.of = c("task", "learner", "resampling", "param_values"))
bmr = benchmark(design)
expect_benchmark_result(bmr)

# benchmark grid with param_values and paired = TRUE
design = benchmark_grid(tasks, learners, resamplings, param_values = list(list(list(x = 1))), paired = TRUE)
expect_data_table(design, nrows = 1)
bmr = benchmark(design)
expect_benchmark_result(bmr)
expect_equal(bmr$n_resample_results, 1)

# benchmark grid with multiple params
design = benchmark_grid(tasks, learners, resamplings, param_values = list(list(list(x = 1), list(x = 0.5))))
expect_data_table(design, nrows = 1)
bmr = benchmark(design)
expect_benchmark_result(bmr)
expect_equal(bmr$n_resample_results, 2)


# benchmark grid with multiple params and multiple learners
design = benchmark_grid(tasks, lrns(c("classif.debug", "classif.debug")), rsmp("holdout"), param_values = list(list(list(x = 1), list(x = 0.5)), list()))
bmr = benchmark(design)
expect_benchmark_result(bmr)
expect_equal(bmr$n_resample_results, 3)

# constant values are inserted
learners = lrns("classif.rpart", minsplit = 12)
design = data.table(task = tasks, learner = learners, resampling = resamplings, param_values = list(list(list(cp = 0.1), list(minbucket = 2))))
bmr = benchmark(design)

expect_equal(bmr$learners$learner[[1]]$param_set$values, list(xval = 0, minsplit = 12, minbucket = 2))
expect_equal(bmr$learners$learner[[2]]$param_set$values, list(xval = 0, minsplit = 12, cp = 0.1))
})

0 comments on commit 9ba6f82

Please sign in to comment.