Skip to content

Commit

Permalink
Test default_progress_fn, skip types.py (#404)
Browse files Browse the repository at this point in the history
* Test `default_progress_fn`, skip `types.py`

* Fix ignore path

* [ci skip] Use glob
  • Loading branch information
michalk8 authored Aug 2, 2023
1 parent 28a2bf8 commit f275dc4
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ source = ["src/"]
omit = [
"*/__init__.py",
"*/_version.py",
"*/types.py",
]

[tool.coverage.report]
Expand Down
3 changes: 2 additions & 1 deletion src/ott/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import numpy as np

__all__ = [
"register_pytree_node", "deprecate", "is_jax_array", "default_progress_fn"
"register_pytree_node", "deprecate", "is_jax_array", "default_prng_key",
"default_progress_fn"
]


Expand Down
17 changes: 16 additions & 1 deletion tests/solvers/linear/sinkhorn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import jax.numpy as jnp
import numpy as np
import pytest
from ott import utils
from ott.geometry import costs, epsilon_scheduler, geometry, grid, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import acceleration, sinkhorn
Expand Down Expand Up @@ -561,8 +562,22 @@ def test_f_potential_is_zero_centered(self, lse_mode: bool):

np.testing.assert_allclose(f_mean, 0., rtol=1e-6, atol=1e-6)

def test_default_progress_fn(self, capsys):
geom = pointcloud.PointCloud(self.x, self.y, epsilon=1e-1)

_ = sinkhorn.solve(
geom,
progress_fn=utils.default_progress_fn,
min_iterations=0,
inner_iterations=7,
max_iterations=13,
)

captured = capsys.readouterr()
assert captured.out.startswith("7 / 13 -- "), captured

@pytest.mark.fast.with_args("num_iterations", [30, 60])
def test_callback_fn(self, num_iterations: int):
def test_custom_callback_fn(self, num_iterations: int):
"""Check that the callback function is actually called."""

def progress_fn(
Expand Down

0 comments on commit f275dc4

Please sign in to comment.