Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] sktime API compliance test to loop through tests individually #94

Merged

Conversation

fkiraly
Copy link
Contributor

@fkiraly fkiraly commented Jul 18, 2024

This PR replaces the current logic in the sktime API compliance test, a bulk run of check_estimator, with parametrize_with_checks decoration and a run per test.

This allows to report individual failures of tests, and better diagnose memory leaks, such as those reported in #89.

Also removes the test skip.

@fkiraly fkiraly changed the title [ENH] sktime compliance test to loop through tests individually [ENH] sktime API compliance test to loop through tests individually Jul 18, 2024
Copy link

codecov bot commented Jul 18, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 94.83%. Comparing base (eb5bad3) to head (861f40a).

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #94      +/-   ##
==========================================
+ Coverage   94.64%   94.83%   +0.18%     
==========================================
  Files          26       26              
  Lines        1177     1180       +3     
==========================================
+ Hits         1114     1119       +5     
+ Misses         63       61       -2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@felipeangelimvieira
Copy link
Owner

felipeangelimvieira commented Jul 18, 2024

Ubuntu breaks without logging error message, but tells error code 143.
It seems that test_update_predict_predicted_index breaks Windows. Some weird messages: Windows fatal exception: access violation 🤔

Seems related to jax/numpyro?

@fkiraly
Copy link
Contributor Author

fkiraly commented Jul 18, 2024

well, maybe my first guess is going to be proven accurate?

(from earlier discord discussion)

I have not gone in details, but my prima facie best guess would be:

  • some weird backend interaction around sth like numpyro
  • this causes excessive memory usage given the selected test parameters
  • i.e., sth specific to the estimator

@fkiraly
Copy link
Contributor Author

fkiraly commented Jul 18, 2024

maybe turning on verbose mode in the pytest logs helps further diagnose?

Feel free to take over this PR.

@fkiraly
Copy link
Contributor Author

fkiraly commented Jul 18, 2024

You could also try locally to loop through tests_to_run, and see where in the loop it breaks.

@felipeangelimvieira
Copy link
Owner

well, maybe my first guess is going to be proven accurate?

(from earlier discord discussion)

I have not gone in details, but my prima facie best guess would be:

  • some weird backend interaction around sth like numpyro
  • this causes excessive memory usage given the selected test parameters
  • i.e., sth specific to the estimator

Yes 🎯

maybe turning on verbose mode in the pytest logs helps further diagnose?

Feel free to take over this PR.

Thanks, I will dive deeper into this problem and try to find the source

@fkiraly
Copy link
Contributor Author

fkiraly commented Jul 24, 2024

could the failures perhaps be related to this, similar issue in sktime?
sktime/sktime#6826

Something about the load order of modules - in this specific case, if pandas is loaded before tensorflow, it causes a memory freeze.

FYI @fnhirwa

@fnhirwa
Copy link

fnhirwa commented Jul 27, 2024

could the failures perhaps be related to this, similar issue in sktime? sktime/sktime#6826

Something about the load order of modules - in this specific case, if pandas is loaded before tensorflow, it causes a memory freeze.

FYI @fnhirwa

This is not the same problem here, it is a computational resource issue, I just checked locally and the tests start to hang when testing test_predict_time_index A workaround here would be to use one param in get_test_params of ProphetVerse and test for one likelihood parameter as other models should be being tested through Prophet for normal ProphetGamma for gamma and ProphetNegBinomial for negbinomial

This can prevent the hang.

@@ -1,7 +1,7 @@
"""Test the sktime contract for Prophet and HierarchicalProphet."""

import pytest
from sktime.utils.estimator_checks import check_estimator
from sktime.utils.estimator_checks import check_estimator, parametrize_with_checks

from prophetverse.sktime import HierarchicalProphet, Prophetverse
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you import all the other models, parametrize them as PROPHET_MODELS and add a get_test_params to each would cause the tests to run without the existing issue.

@felipeangelimvieira
Copy link
Owner

Thank you @fnhirwa! Just got back from vacation and will take a look at what you've suggested

@felipeangelimvieira
Copy link
Owner

felipeangelimvieira commented Aug 5, 2024

This is not my area of expertise, but I suspect the issue might be related to garbage collection and JAX. I noticed that test_predict_time_index has around 54 parameterizations, and I suspect that something may be hanging after each iteration. I found this discussion on JAX's GitHub repo jax-ml/jax#15972 where a simple code using device_get causes a memory problem, and NumPyro's SVI class uses device_get during inference.

@felipeangelimvieira
Copy link
Owner

felipeangelimvieira commented Aug 5, 2024

Possibly also related to jax-ml/jax#17432

@fnhirwa
Copy link

fnhirwa commented Aug 5, 2024

This is not my area of expertise, but I suspect the issue might be related to garbage collection and JAX. I noticed that test_predict_time_index has around 54 parameterizations, and I suspect that something may be hanging after each iteration. I found this discussion on JAX's GitHub repo google/jax#15972 where a simple code using device_get causes a memory problem, and NumPyro's SVI class uses device_get during inference.

Sounds strange, will look into this in details 🤞

@felipeangelimvieira
Copy link
Owner

Since the test runs the inference with non-static observations, seems to be related to pyro-ppl/numpyro#1347

Will also try the suggestions that were given in that issue.

@felipeangelimvieira
Copy link
Owner

It seems that it is indeed related to memory leak when using Numpyro/JAX. After wrapping the SVI in a function as suggested in pyro-ppl/numpyro#1347, the tests have passed. Now the problem is how long it takes to run all those tests, which is strange since the models are relatively fast with the given test parameters.

@fkiraly
Copy link
Contributor Author

fkiraly commented Aug 6, 2024

Well, so the tests were useful after all...

The runtimes indeed look strange, these are in the order of tests for all estimators on sktime itself.

@felipeangelimvieira
Copy link
Owner

@fkiraly, I really liked using this decorator. It made it easier to debug where the problem was coming from.

On my MacBook, I detected two tests that are taking more time to finish: test_predict_time_index and test_predict_time_index_with_X. To me, it seems that the cause is a delay numpyro has when starting the first iteration. So, even if we set 1 MCMC sample or 1 optimization step for MAP, it takes about 1 second to execute fit and then some more time to predict, especially with negative binomial and gamma likelihoods. Since these tests have many parameterizations (for example, different FH values and types), these seconds add up, and each parameterization of test_predict_time_index can take about 3 seconds.

To avoid the 5-hour duration of the test, I'm not considering negative binomial and gamma in get_test_params of Prophetverse, since their difference from the normal likelihood is mainly the model attribute, and should not affect API compliance once the normal likelihood is OK.

========================================================================== slowest durations ==========================================================================
553.00s call     tests/sktime/test_sktime_check_estimator.py::test_sktime_api_compliance[ProphetNegBinomial-test_predict_time_index]
313.63s call     tests/sktime/test_sktime_check_estimator.py::test_sktime_api_compliance[ProphetGamma-test_predict_time_index]
291.02s call     tests/sktime/test_sktime_check_estimator.py::test_sktime_api_compliance[ProphetNegBinomial-test_update_predict_predicted_index]
198.98s call     tests/sktime/test_sktime_check_estimator.py::test_sktime_api_compliance[HierarchicalProphet-test_predict_time_index]
153.27s call     tests/sktime/test_sktime_check_estimator.py::test_sktime_api_compliance[ProphetNegBinomial-test_predict_time_index_with_X]
132.41s call     tests/sktime/test_sktime_check_estimator.py::test_sktime_api_compliance[ProphetGamma-test_update_predict_predicted_index]

@felipeangelimvieira felipeangelimvieira merged commit 3f6c300 into felipeangelimvieira:main Aug 13, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants