Skip to content

Commit

Permalink
feat(prophet): enable confidence intervals and y_hat without forecast (
Browse files Browse the repository at this point in the history
…#17658)

* enable confidence intervals and y_hat without forecast

* fix if statement

Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com>

Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com>
  • Loading branch information
exemplary-citizen and villebro authored Dec 8, 2021
1 parent 418c0b4 commit cd88b8e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
2 changes: 1 addition & 1 deletion superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ class ChartDataProphetOptionsSchema(ChartDataPostProcessingOperationOptionsSchem
)
periods = fields.Integer(
descrption="Time periods (in units of `time_grain`) to predict into the future",
min=1,
min=0,
example=7,
required=True,
)
Expand Down
4 changes: 2 additions & 2 deletions superset/utils/pandas_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,8 +820,8 @@ def prophet( # pylint: disable=too-many-arguments
freq = PROPHET_TIME_GRAIN_MAP[time_grain]
# check type at runtime due to marhsmallow schema not being able to handle
# union types
if not periods or periods < 0 or not isinstance(periods, int):
raise QueryObjectValidationError(_("Periods must be a positive integer value"))
if not isinstance(periods, int) or periods < 0:
raise QueryObjectValidationError(_("Periods must be a whole number"))
if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1:
raise QueryObjectValidationError(
_("Confidence interval must be between 0 and 1 (exclusive)")
Expand Down
24 changes: 23 additions & 1 deletion tests/integration_tests/pandas_postprocessing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,28 @@ def test_prophet_valid(self):
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31)
assert len(df) == 9

def test_prophet_valid_zero_periods(self):
pytest.importorskip("prophet")

df = proc.prophet(
df=prophet_df, time_grain="P1M", periods=0, confidence_interval=0.9
)
columns = {column for column in df.columns}
assert columns == {
DTTM_ALIAS,
"a__yhat",
"a__yhat_upper",
"a__yhat_lower",
"a",
"b__yhat",
"b__yhat_upper",
"b__yhat_lower",
"b",
}
assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2021, 12, 31)
assert len(df) == 4

def test_prophet_import(self):
prophet = find_spec("prophet")
if prophet is None:
Expand Down Expand Up @@ -875,7 +897,7 @@ def test_prophet_incorrect_periods(self):
proc.prophet,
df=prophet_df,
time_grain="P1M",
periods=0,
periods=-1,
confidence_interval=0.8,
)

Expand Down

0 comments on commit cd88b8e

Please sign in to comment.