Skip to content

Commit

Permalink
feat: allow response different from quantiles (#1178)
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Sep 14, 2024
1 parent 6f032fe commit 461588a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
4 changes: 2 additions & 2 deletions R/PredictionDataRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ check_prediction_data.PredictionDataRegr = function(pdata, ...) { # nolint
stopf("No probs attribute stored in 'quantile'")
}

if (is.null(attr(quantiles, "response"))) {
stopf("No response attribute stored in 'quantile'")
if (is.null(attr(quantiles, "response")) && is.null(pdata$response)) {
stopf("No response attribute stored in 'quantile' or response stored in 'pdata'")
}

if (any(apply(quantiles, 1L, is.unsorted))) {
Expand Down
2 changes: 1 addition & 1 deletion R/PredictionRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
# response is in saved in quantiles matrix
if ("quantiles" %in% predict_types) predict_types = union(predict_types, "response")
self$predict_types = predict_types
private$.quantile_response = attr(quantiles, "response")
if (is.null(pdata$response)) private$.quantile_response = attr(quantiles, "response")
}
),

Expand Down
35 changes: 35 additions & 0 deletions tests/testthat/test_PredictionDataRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,38 @@ test_that("construction of empty PredictionDataRegr", {
expect_numeric(pred$se, len = 0L)
expect_data_table(as.data.table(pred), nrows = 0L, ncols = 4L)
})

test_that("PredictionDataRegr with quantiles", {
n = 100
probs = c(0.1, 0.5, 0.9)
y = runif(n)
task = as_task_regr(data.table(y = y), target = "y")

quantiles = quantile(y, probs = probs)
quantiles = matrix(rep(quantiles, n), nrow = n, byrow = TRUE)
attr(quantiles, "probs") = probs
attr(quantiles, "response") = 0.5

data = list(quantiles = quantiles)
pdata = as_prediction_data(data, task)

pred = as_prediction(pdata)
expect_prediction_regr(pred)
})

test_that("PredictionDataRegr with quantiles and response", {
n = 100
probs = c(0.1, 0.9)
y = runif(n)
task = as_task_regr(data.table(y = y), target = "y")

quantiles = quantile(y, probs = probs)
quantiles = matrix(rep(quantiles, n), nrow = n, byrow = TRUE)
attr(quantiles, "probs") = probs

data = list(quantiles = quantiles, response = rep(0.5, n))
pdata = as_prediction_data(data, task)

pred = as_prediction(pdata)
expect_prediction_regr(pred)
})

0 comments on commit 461588a

Please sign in to comment.