Skip to content

Commit

Permalink
[ENH] sktime API compliance test to loop through tests individually
Browse files Browse the repository at this point in the history
Closes #47
  • Loading branch information
felipeangelimvieira authored Aug 13, 2024
2 parents eb5bad3 + 861f40a commit 3f6c300
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
run: echo "PYTHONPATH=$GITHUB_WORKSPACE/src" >> $GITHUB_ENV

- name: Test with pytest
run: python -m pytest --cov=prophetverse --cov-report=xml -m "not smoke"
run: python -m pytest --cov=prophetverse --cov-report=xml -m "not smoke" --durations=10

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4.5.0
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ dev = [
]

[tool.pytest.ini_options]
log_cli = true
markers = [
"ci: marks tests for Continuous Integration",
"smoke: marks tests for smoke testing",
]


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Expand Down
53 changes: 43 additions & 10 deletions src/prophetverse/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,20 @@ def infer(self, **kwargs):
The updated MAPInferenceEngine object.
"""
self.guide_ = AutoDelta(self.model, init_loc_fn=init_to_mean())
svi_ = SVI(self.model, self.guide_, self.optimizer_factory(), loss=Trace_ELBO())
self.run_results_: SVIRunResult = svi_.run(
rng_key=self.rng_key, num_steps=self.num_steps, **kwargs

def get_result(
rng_key, model, guide, optimizer, num_steps, **kwargs
) -> SVIRunResult:
svi_ = SVI(model, guide, optimizer, loss=Trace_ELBO())
return svi_.run(rng_key=rng_key, num_steps=num_steps, **kwargs)

self.run_results_: SVIRunResult = get_result(
self.rng_key,
self.model,
self.guide_,
self.optimizer_factory(),
self.num_steps,
**kwargs
)

self.raise_error_if_nan_loss(self.run_results_)
Expand Down Expand Up @@ -255,13 +266,36 @@ def infer(self, **kwargs):
self
The MCMCInferenceEngine object.
"""
self.mcmc_ = MCMC(
NUTS(self.model, dense_mass=self.dense_mass, init_strategy=init_to_mean()),

def get_posterior_samples(
rng_key,
model,
dense_mass,
init_strategy,
num_samples,
num_warmup,
num_chains,
**kwargs
) -> MCMC:
mcmc_ = MCMC(
NUTS(model, dense_mass=dense_mass, init_strategy=init_strategy),
num_samples=num_samples,
num_warmup=num_warmup,
num_chains=num_chains,
)
mcmc_.run(rng_key, **kwargs)
return mcmc_.get_samples()

self.posterior_samples_ = get_posterior_samples(
self.rng_key,
self.model,
self.dense_mass,
init_strategy=init_to_mean,
num_samples=self.num_samples,
num_warmup=self.num_warmup,
num_chains=self.num_chains,
**kwargs
)
self.mcmc_.run(self.rng_key, **kwargs)
self.posterior_samples_ = self.mcmc_.get_samples()
return self

def predict(self, **kwargs):
Expand All @@ -282,9 +316,8 @@ def predict(self, **kwargs):
self.model, self.posterior_samples_, num_samples=self.num_samples
)

numpyro.samples_predictive_ = predictive(self.rng_key, **kwargs)
numpyro.samples_ = self.mcmc_.get_samples()
return numpyro.samples_predictive_
self.samples_predictive_ = predictive(self.rng_key, **kwargs)
return self.samples_predictive_


class MAPInferenceEngineError(Exception):
Expand Down
15 changes: 11 additions & 4 deletions src/prophetverse/sktime/multivariate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Contains the implementation of the HierarchicalProphet forecaster."""

from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -386,7 +386,7 @@ def n_series(self):
)

@classmethod
def get_test_params(cls, parameter_set="default") -> List[dict[str, int]]:
def get_test_params(cls, parameter_set="default") -> List[dict[str, Any]]:
"""Params to be used in sktime unit tests.
Parameters
Expand All @@ -401,6 +401,13 @@ def get_test_params(cls, parameter_set="default") -> List[dict[str, int]]:
"""
return [
{
"optimizer_steps": 1_000,
}
"optimizer_steps": 1,
"inference_method": "map",
},
{
"inference_method": "mcmc",
"mcmc_samples": 1,
"mcmc_warmup": 1,
"mcmc_chains": 1,
},
]
21 changes: 13 additions & 8 deletions src/prophetverse/sktime/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,14 +357,19 @@ def get_test_params(cls, parameter_set="default"): # pragma: no cover
List[dict[str, int]]
A list of dictionaries containing the test parameters.
"""
params = []
for likelihood in _LIKELIHOOD_MODEL_MAP.keys():
params.append(
{
"likelihood": likelihood,
"optimizer_steps": 10,
}
)
params = [
{
"optimizer_steps": 1,
"inference_method": "map",
},
{
"inference_method": "mcmc",
"mcmc_samples": 1,
"mcmc_warmup": 1,
"mcmc_chains": 1,
},
]

return params


Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
"""Configure tests and declare global fixtures."""

import warnings
import warnings # noqa: F401

import jax.numpy as jnp
import numpyro
import numpyro # noqa: F401
import pandas as pd
import pytest

from prophetverse.effects.base import BaseAdditiveOrMultiplicativeEffect

warnings.filterwarnings("ignore")
# warnings.filterwarnings("ignore")


def pytest_sessionstart(session):
"""Avoid NaNs in tests."""
numpyro.enable_x64()
# numpyro.enable_x64()


@pytest.fixture(name="effects_sample_data")
Expand Down
19 changes: 8 additions & 11 deletions tests/sktime/test_sktime_check_estimator.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
"""Test the sktime contract for Prophet and HierarchicalProphet."""

import pytest
from sktime.utils.estimator_checks import check_estimator
import gc # noqa: F401

import pytest # noqa: F401
from sktime.utils.estimator_checks import check_estimator, parametrize_with_checks

from prophetverse.sktime import HierarchicalProphet, Prophetverse

PROPHET_MODELS = [
Prophetverse,
HierarchicalProphet,
]
PROPHET_MODELS = [Prophetverse, HierarchicalProphet]


@pytest.mark.skip(reason="Temporarily disabled")
@pytest.mark.parametrize("model", PROPHET_MODELS)
def test_check_estimator(model):
@parametrize_with_checks(PROPHET_MODELS)
def test_sktime_api_compliance(obj, test_name):
"""Test the sktime contract for Prophet and HierarchicalProphet."""

check_estimator(model, raise_exceptions=True)
check_estimator(obj, tests_to_run=test_name, raise_exceptions=True)

0 comments on commit 3f6c300

Please sign in to comment.