diff --git a/.codecov.yml b/.codecov.yml index d6db3da47..ce1a68fd0 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -1,5 +1,5 @@ codecov: - require_ci_to_pass: no + require_ci_to_pass: yes strict_yaml_branch: main coverage: diff --git a/.github/workflows/notebook_tests.yml b/.github/workflows/notebook_tests.yml index a3f327e5d..ea47a87a5 100644 --- a/.github/workflows/notebook_tests.yml +++ b/.github/workflows/notebook_tests.yml @@ -31,8 +31,8 @@ jobs: - name: Print versions run: | python -VV - python -c "import jax; print('jax', jax.__version__)" - python -c "import jaxlib; print('jaxlib', jaxlib.__version__)" + python -c "import jax; print('jax==', jax.__version__)" + python -c "import jaxlib; print('jaxlib==', jaxlib.__version__)" - name: Intall Jupyter kernel run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index be1e3dbd5..589cf3095 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,6 +1,8 @@ name: tests on: + schedule: + - cron: 00 00 * * 1 push: branches: [main] pull_request: @@ -39,18 +41,18 @@ jobs: - name: Print versions run: | python -VV - python -c "import jax; print('jax', jax.__version__)" - python -c "import jaxlib; print('jaxlib', jaxlib.__version__)" + python -c "import jax; print('jax==', jax.__version__)" + python -c "import jaxlib; print('jaxlib==', jaxlib.__version__)" - name: Run fast tests if: ${{ matrix.test_mark == 'fast' }} run: | - python -m pytest --cov=ott --cov-append --cov-report=xml --cov-report=term-missing --cov-config=setup.cfg --memray -m fast -n auto + python -m pytest --cov=ott --cov-append --cov-report=xml --cov-report=term-missing --cov-config=pyproject.toml --memray -m fast -n auto - name: Run all tests if: ${{ matrix.test_mark == 'all' }} run: | - python -m pytest --cov=ott --cov-append --cov-report=xml --cov-report=term-missing --cov-config=setup.cfg --memray + python -m pytest --cov=ott --cov-append --cov-report=xml --cov-report=term-missing --cov-config=pyproject.toml --memray - name: Upload coverage uses: codecov/codecov-action@v3 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 21102a7c7..4c354bc85 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: hooks: - id: isort - repo: https://github.com/asottile/yesqa - rev: v1.3.0 + rev: v1.4.0 hooks: - id: yesqa additional_dependencies: @@ -35,7 +35,7 @@ repos: - id: pretty-format-yaml args: [--autofix, --indent, '2'] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: detect-private-key - id: check-ast @@ -45,7 +45,7 @@ repos: - id: trailing-whitespace - id: check-case-conflict - repo: https://github.com/myint/autoflake - rev: v1.4 + rev: v2.0.0 hooks: - id: autoflake args: @@ -54,7 +54,7 @@ repos: - --remove-unused-variable - --ignore-init-module-imports - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + rev: 6.0.0 hooks: - id: flake8 additional_dependencies: @@ -65,7 +65,7 @@ repos: - flake8-blind-except args: [--docstring-convention, google] - repo: https://github.com/asottile/pyupgrade - rev: v2.37.1 + rev: v3.2.2 hooks: - id: pyupgrade args: [--py3-plus, --py37-plus, --keep-runtime-typing] diff --git a/MANIFEST.in b/MANIFEST.in index f7eda1c4c..652d6fd3b 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ prune docs -prune .github +prune examples prune images prune tests +prune .github diff --git a/pyproject.toml b/pyproject.toml index 410371928..dc166f055 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,19 +1,135 @@ [build-system] -requires = [ - "setuptools>=45", - "setuptools-scm[toml]>=6.2", -] +requires = ["setuptools>=61", "setuptools-scm[toml]>=6.2"] build-backend = "setuptools.build_meta" +[project] +name = "ott-jax" +description = "Optimal Transport Tools in JAX." +requires-python = ">=3.7" +dynamic = ["version"] +readme = {file = "README.md", content-type = "text/markdown"} +license = {file = "LICENSE"} +authors = [ + {name = "OTT team", email = "optimal.transport.tools@gmail.com"} +] +dependencies = [ + 'importlib-metadata>=1.0; python_version<"3.8"', + "jax>=0.1.67", + "jaxlib>=0.1.47", + # https://github.com/google/jax/discussions/9951#discussioncomment-3017784 + "numpy>=1.18.4, !=1.23.0", + "matplotlib>=3.0.0", + "flax>=0.5.2", + "optax>=0.1.1", + 'typing_extensions; python_version<"3.8"', + "scipy>=1.7.0", +] +keywords = [ + "optimal transport", + "gromov wasserstein", + "sinkhorn", + "low-rank sinkhorn", + "sinkhorn divergences", + "wasserstein", + "wasserstein barycenter", + "jax", + "autodiff", + "implicit differentiation", +] +classifiers = [ + "Typing :: Typed", + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: Apache Software License", + "Topic :: Scientific/Engineering :: Mathematics", + "Natural Language :: English", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Operating System :: POSIX :: Linux", + "Operating System :: MacOS :: MacOS X", + "Operating System :: Microsoft :: Windows", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", +] + +[project.urls] +"Source Code" = "https://github.com/ott-jax/ott" +Documentation = "https://ott-jax.readthedocs.io" +"Issue Tracker" = "https://github.com/ott-jax/ott/issues" +Changelog = "https://github.com/ott-jax/ott/releases" + +[project.optional-dependencies] +dev = [ + "pre-commit", +] +test = [ + "pytest", + "pytest-xdist", + "pytest-cov", + "coverage[toml]", + "testbook", + "chex", + "networkx>=2.5", + "scikit-learn>=1.0" +] +experimental = [ + "scikit-sparse>=0.4.6", +] +docs = [ + "sphinx>=4.0", + "nbsphinx>=0.8.0", + "recommonmark>=0.7.1", + "ipython>=7.20.0", + "sphinx_autodoc_typehints>=1.12.0", + "sphinx-book-theme>=0.3.3", + "sphinxcontrib-bibtex", +] + +[tool.setuptools] +package-dir = {"" = "src"} +packages = {find = {where = ["src"], namespaces = false}} + [tool.setuptools_scm] +[tool.black] +line-length = 80 +target-version = ["py38"] +include = '\.ipynb$' + [tool.isort] profile = "black" include_trailing_comma = true multi_line_output = 3 skip_glob = ["docs/*"] -[tool.black] -line-length = 80 -target-version = ['py38'] -include = '\.ipynb$' +[tool.pytest.ini_options] +minversion = "6.0" +addopts = '-v -m "not notebook"' +testpaths = [ + "tests", +] +markers = [ + "fast: Mark tests as fast.", + "notebook: Mark tests as notebook related.", +] + +[tool.coverage.run] +branch = true +parallel = true +source = ["src/"] +omit = ["*/__init__.py"] + +[tool.coverage.report] +exclude_lines = [ + '\#.*pragma:\s*no.?cover', + "^if __name__ == .__main__.:$", + '^\s*raise AssertionError\b', + '^\s*raise NotImplementedError\b', + '^\s*return NotImplemented\b', +] +precision = 2 +show_missing = true +skip_empty = true +sort = "Miss" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 2f5826944..000000000 --- a/setup.cfg +++ /dev/null @@ -1,99 +0,0 @@ -[metadata] -name = ott-jax -license = Apache 2.0 -license_files = LICENSE -author_email = optimal.transport.tools@gmail.com -description = OTT: Optimal Transport Tools in Jax. -keywords = - optimal transport - sinkhorn - wasserstein - jax -long_description = file: README.md -long_description_content_type = text/markdown -url = https://github.com/ott-jax/ott -project_urls = - Documentation = https://ott-jax.readthedocs.io - Source Code = https://github.com/ott-jax/ott -classifiers = - Development Status :: 5 - Production/Stable - License :: OSI Approved :: Apache Software License - Topic :: Scientific/Engineering :: Mathematics - Natural Language :: English - Intended Audience :: Developers - Intended Audience :: Science/Research - Operating System :: POSIX :: Linux - Operating System :: MacOS :: MacOS X - Operating System :: Microsoft :: Windows - Programming Language :: Python :: 3 - Programming Language :: Python :: 3.7 - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - Typing :: Typed - -[options] -packages = find: -zip_safe = False -python_requires = >=3.7 -install_requires = - importlib-metadata>=1.0; python_version < "3.8" - absl-py>=0.7.0 - jax>=0.1.67 - jaxlib>=0.1.47 - numpy>=1.18.4, !=1.23.0 # https://github.com/google/jax/discussions/9951#discussioncomment-3017784 - matplotlib>=2.0.1 - flax>=0.3.6 - optax>=0.1.1 - typing_extensions; python_version < "3.8" - PyYAML>=6.0 # https://github.com/google/flax/issues/2190 - scipy>=1.7.0 - scikit-learn>=1.0 - -[options.extras_require] -test = - pytest - pytest-xdist - pytest-cov - testbook - chex - networkx>=2.5 -experimental = - scikit-sparse>=0.4.6 -docs = - sphinx>=4.0 - nbsphinx>=0.8.0 - recommonmark>=0.7.1 - ipython>=7.20.0 - sphinx_autodoc_typehints>=1.12.0 - sphinx-book-theme>=0.3.3 - sphinxcontrib-bibtex -dev = - pre-commit - -[coverage:run] -branch = true -parallel = true -source = ott -omit = */__init__.py - -[coverage:report] -exclude_lines = - \#.*pragma:\s*no.?cover - ^if __name__ == .__main__.:$ - ^\s*raise AssertionError\b - ^\s*raise NotImplementedError\b - ^\s*return NotImplemented\b -precision = 2 -show_missing = True -skip_empty = True -sort = Miss - -[tool:pytest] -minversion = 6.0 -addopts = -v -m "not notebook" -testpaths = - tests -markers = - fast: Mark tests as fast. - notebook: Mark tests as notebook related. diff --git a/ott/__init__.py b/src/ott/__init__.py similarity index 100% rename from ott/__init__.py rename to src/ott/__init__.py diff --git a/ott/_version.py b/src/ott/_version.py similarity index 100% rename from ott/_version.py rename to src/ott/_version.py diff --git a/ott/geometry/__init__.py b/src/ott/geometry/__init__.py similarity index 100% rename from ott/geometry/__init__.py rename to src/ott/geometry/__init__.py diff --git a/ott/geometry/costs.py b/src/ott/geometry/costs.py similarity index 99% rename from ott/geometry/costs.py rename to src/ott/geometry/costs.py index 94c06112a..6311b5e22 100644 --- a/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Several cost/norm functions for relevant vector types.""" import abc import functools diff --git a/ott/geometry/epsilon_scheduler.py b/src/ott/geometry/epsilon_scheduler.py similarity index 99% rename from ott/geometry/epsilon_scheduler.py rename to src/ott/geometry/epsilon_scheduler.py index 284fa29f7..e1dd3ad8f 100644 --- a/ott/geometry/epsilon_scheduler.py +++ b/src/ott/geometry/epsilon_scheduler.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """A class to define a scheduler for the entropic regularization epsilon.""" from typing import Any, Optional diff --git a/ott/geometry/geometry.py b/src/ott/geometry/geometry.py similarity index 99% rename from ott/geometry/geometry.py rename to src/ott/geometry/geometry.py index ac57e1248..93968bc7d 100644 --- a/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """A class describing operations used to instantiate and use a geometry.""" import functools from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union diff --git a/ott/geometry/graph.py b/src/ott/geometry/graph.py similarity index 100% rename from ott/geometry/graph.py rename to src/ott/geometry/graph.py diff --git a/ott/geometry/grid.py b/src/ott/geometry/grid.py similarity index 99% rename from ott/geometry/grid.py rename to src/ott/geometry/grid.py index 729449790..500e51018 100644 --- a/ott/geometry/grid.py +++ b/src/ott/geometry/grid.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Implements a geometry class for points supported on a cartesian product.""" import itertools from typing import Any, List, NoReturn, Optional, Sequence, Tuple diff --git a/ott/geometry/low_rank.py b/src/ott/geometry/low_rank.py similarity index 99% rename from ott/geometry/low_rank.py rename to src/ott/geometry/low_rank.py index 792f2adb0..a4a423d11 100644 --- a/ott/geometry/low_rank.py +++ b/src/ott/geometry/low_rank.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """A class describing low-rank geometries.""" from typing import Any, Callable, Optional, Tuple, Union diff --git a/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py similarity index 99% rename from ott/geometry/pointcloud.py rename to src/ott/geometry/pointcloud.py index 97b9b13dd..c4f1fde34 100644 --- a/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """A geometry defined using 2 point clouds and a cost function between them.""" import math from typing import Any, Callable, Optional, Tuple, Union diff --git a/ott/geometry/segment.py b/src/ott/geometry/segment.py similarity index 100% rename from ott/geometry/segment.py rename to src/ott/geometry/segment.py diff --git a/ott/initializers/__init__.py b/src/ott/initializers/__init__.py similarity index 100% rename from ott/initializers/__init__.py rename to src/ott/initializers/__init__.py diff --git a/ott/initializers/linear/__init__.py b/src/ott/initializers/linear/__init__.py similarity index 100% rename from ott/initializers/linear/__init__.py rename to src/ott/initializers/linear/__init__.py diff --git a/ott/initializers/linear/initializers.py b/src/ott/initializers/linear/initializers.py similarity index 100% rename from ott/initializers/linear/initializers.py rename to src/ott/initializers/linear/initializers.py diff --git a/ott/initializers/linear/initializers_lr.py b/src/ott/initializers/linear/initializers_lr.py similarity index 100% rename from ott/initializers/linear/initializers_lr.py rename to src/ott/initializers/linear/initializers_lr.py diff --git a/ott/initializers/nn/__init__.py b/src/ott/initializers/nn/__init__.py similarity index 100% rename from ott/initializers/nn/__init__.py rename to src/ott/initializers/nn/__init__.py diff --git a/ott/initializers/nn/initializers.py b/src/ott/initializers/nn/initializers.py similarity index 100% rename from ott/initializers/nn/initializers.py rename to src/ott/initializers/nn/initializers.py diff --git a/ott/initializers/quadratic/__init__.py b/src/ott/initializers/quadratic/__init__.py similarity index 100% rename from ott/initializers/quadratic/__init__.py rename to src/ott/initializers/quadratic/__init__.py diff --git a/ott/initializers/quadratic/initializers.py b/src/ott/initializers/quadratic/initializers.py similarity index 100% rename from ott/initializers/quadratic/initializers.py rename to src/ott/initializers/quadratic/initializers.py diff --git a/ott/math/__init__.py b/src/ott/math/__init__.py similarity index 100% rename from ott/math/__init__.py rename to src/ott/math/__init__.py diff --git a/ott/math/decomposition.py b/src/ott/math/decomposition.py similarity index 100% rename from ott/math/decomposition.py rename to src/ott/math/decomposition.py diff --git a/ott/math/fixed_point_loop.py b/src/ott/math/fixed_point_loop.py similarity index 99% rename from ott/math/fixed_point_loop.py rename to src/ott/math/fixed_point_loop.py index 3b22ad20a..653c0bc62 100644 --- a/ott/math/fixed_point_loop.py +++ b/src/ott/math/fixed_point_loop.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """jheek@ backprop-friendly implementation of fixed point loop.""" from typing import Any, Callable diff --git a/ott/math/matrix_square_root.py b/src/ott/math/matrix_square_root.py similarity index 99% rename from ott/math/matrix_square_root.py rename to src/ott/math/matrix_square_root.py index 761e67af2..e2981726e 100644 --- a/ott/math/matrix_square_root.py +++ b/src/ott/math/matrix_square_root.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """A Jax backprop friendly version of Matrix square root.""" import functools diff --git a/ott/math/unbalanced_functions.py b/src/ott/math/unbalanced_functions.py similarity index 99% rename from ott/math/unbalanced_functions.py rename to src/ott/math/unbalanced_functions.py index 296e25005..4d40e50f0 100644 --- a/ott/math/unbalanced_functions.py +++ b/src/ott/math/unbalanced_functions.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Functions useful to define unbalanced OT problems.""" from typing import Callable diff --git a/ott/math/utils.py b/src/ott/math/utils.py similarity index 100% rename from ott/math/utils.py rename to src/ott/math/utils.py diff --git a/ott/problems/__init__.py b/src/ott/problems/__init__.py similarity index 100% rename from ott/problems/__init__.py rename to src/ott/problems/__init__.py diff --git a/ott/problems/linear/__init__.py b/src/ott/problems/linear/__init__.py similarity index 100% rename from ott/problems/linear/__init__.py rename to src/ott/problems/linear/__init__.py diff --git a/ott/problems/linear/barycenter_problem.py b/src/ott/problems/linear/barycenter_problem.py similarity index 100% rename from ott/problems/linear/barycenter_problem.py rename to src/ott/problems/linear/barycenter_problem.py diff --git a/ott/problems/linear/linear_problem.py b/src/ott/problems/linear/linear_problem.py similarity index 100% rename from ott/problems/linear/linear_problem.py rename to src/ott/problems/linear/linear_problem.py diff --git a/ott/problems/linear/potentials.py b/src/ott/problems/linear/potentials.py similarity index 100% rename from ott/problems/linear/potentials.py rename to src/ott/problems/linear/potentials.py diff --git a/ott/problems/quadratic/__init__.py b/src/ott/problems/quadratic/__init__.py similarity index 100% rename from ott/problems/quadratic/__init__.py rename to src/ott/problems/quadratic/__init__.py diff --git a/ott/problems/quadratic/gw_barycenter.py b/src/ott/problems/quadratic/gw_barycenter.py similarity index 100% rename from ott/problems/quadratic/gw_barycenter.py rename to src/ott/problems/quadratic/gw_barycenter.py diff --git a/ott/problems/quadratic/quadratic_costs.py b/src/ott/problems/quadratic/quadratic_costs.py similarity index 100% rename from ott/problems/quadratic/quadratic_costs.py rename to src/ott/problems/quadratic/quadratic_costs.py diff --git a/ott/problems/quadratic/quadratic_problem.py b/src/ott/problems/quadratic/quadratic_problem.py similarity index 100% rename from ott/problems/quadratic/quadratic_problem.py rename to src/ott/problems/quadratic/quadratic_problem.py diff --git a/ott/solvers/__init__.py b/src/ott/solvers/__init__.py similarity index 100% rename from ott/solvers/__init__.py rename to src/ott/solvers/__init__.py diff --git a/ott/solvers/linear/__init__.py b/src/ott/solvers/linear/__init__.py similarity index 100% rename from ott/solvers/linear/__init__.py rename to src/ott/solvers/linear/__init__.py diff --git a/ott/solvers/linear/acceleration.py b/src/ott/solvers/linear/acceleration.py similarity index 100% rename from ott/solvers/linear/acceleration.py rename to src/ott/solvers/linear/acceleration.py diff --git a/ott/solvers/linear/continuous_barycenter.py b/src/ott/solvers/linear/continuous_barycenter.py similarity index 99% rename from ott/solvers/linear/continuous_barycenter.py rename to src/ott/solvers/linear/continuous_barycenter.py index 715a72384..c4d557bf9 100644 --- a/ott/solvers/linear/continuous_barycenter.py +++ b/src/ott/solvers/linear/continuous_barycenter.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """A Jax version of the W barycenter algorithm (Cuturi Doucet 2014).""" import functools from typing import Any, NamedTuple, Optional, Tuple diff --git a/ott/solvers/linear/discrete_barycenter.py b/src/ott/solvers/linear/discrete_barycenter.py similarity index 99% rename from ott/solvers/linear/discrete_barycenter.py rename to src/ott/solvers/linear/discrete_barycenter.py index 044295129..19dcbb506 100644 --- a/ott/solvers/linear/discrete_barycenter.py +++ b/src/ott/solvers/linear/discrete_barycenter.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Implementation of :cite:`janati:20` Wasserstein barycenter algorithm.""" import functools diff --git a/ott/solvers/linear/implicit_differentiation.py b/src/ott/solvers/linear/implicit_differentiation.py similarity index 100% rename from ott/solvers/linear/implicit_differentiation.py rename to src/ott/solvers/linear/implicit_differentiation.py diff --git a/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py similarity index 99% rename from ott/solvers/linear/sinkhorn.py rename to src/ott/solvers/linear/sinkhorn.py index aaa2259c3..cbec3044a 100644 --- a/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """A Jax implementation of the Sinkhorn algorithm.""" from typing import Any, Callable, Mapping, NamedTuple, Optional, Sequence, Tuple, Union diff --git a/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py similarity index 99% rename from ott/solvers/linear/sinkhorn_lr.py rename to src/ott/solvers/linear/sinkhorn_lr.py index 4b3530746..d402e1e80 100644 --- a/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """A Jax implementation of the Low-Rank Sinkhorn algorithm.""" from typing import Any, Mapping, NamedTuple, NoReturn, Optional, Tuple, Union diff --git a/ott/solvers/nn/__init__.py b/src/ott/solvers/nn/__init__.py similarity index 100% rename from ott/solvers/nn/__init__.py rename to src/ott/solvers/nn/__init__.py diff --git a/ott/solvers/nn/icnn.py b/src/ott/solvers/nn/icnn.py similarity index 99% rename from ott/solvers/nn/icnn.py rename to src/ott/solvers/nn/icnn.py index 4e93bc91c..dc121ca90 100644 --- a/ott/solvers/nn/icnn.py +++ b/src/ott/solvers/nn/icnn.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Implementation of :cite:`amos:17` input convex neural networks (ICNN).""" from typing import Any, Callable, Sequence, Tuple, Union diff --git a/ott/solvers/nn/layers.py b/src/ott/solvers/nn/layers.py similarity index 99% rename from ott/solvers/nn/layers.py rename to src/ott/solvers/nn/layers.py index 770d966dd..2aa72862a 100644 --- a/ott/solvers/nn/layers.py +++ b/src/ott/solvers/nn/layers.py @@ -9,8 +9,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Layers used in input convex neural networks :cite:`amos:17,bunne:22`.""" from typing import Any, Callable, Tuple diff --git a/ott/solvers/nn/neuraldual.py b/src/ott/solvers/nn/neuraldual.py similarity index 99% rename from ott/solvers/nn/neuraldual.py rename to src/ott/solvers/nn/neuraldual.py index 306afa5c2..aa8eb18a9 100644 --- a/ott/solvers/nn/neuraldual.py +++ b/src/ott/solvers/nn/neuraldual.py @@ -9,8 +9,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """A Jax implementation of the ICNN based Kantorovich dual.""" import warnings diff --git a/ott/solvers/quadratic/__init__.py b/src/ott/solvers/quadratic/__init__.py similarity index 100% rename from ott/solvers/quadratic/__init__.py rename to src/ott/solvers/quadratic/__init__.py diff --git a/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py similarity index 99% rename from ott/solvers/quadratic/gromov_wasserstein.py rename to src/ott/solvers/quadratic/gromov_wasserstein.py index 7318ec8f3..59b1994c6 100644 --- a/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """A Jax version of the regularised GW Solver (Peyre et al. 2016).""" from typing import Any, Dict, Mapping, NamedTuple, Optional, Sequence, Tuple, Union diff --git a/ott/solvers/quadratic/gw_barycenter.py b/src/ott/solvers/quadratic/gw_barycenter.py similarity index 100% rename from ott/solvers/quadratic/gw_barycenter.py rename to src/ott/solvers/quadratic/gw_barycenter.py diff --git a/ott/solvers/was_solver.py b/src/ott/solvers/was_solver.py similarity index 99% rename from ott/solvers/was_solver.py rename to src/ott/solvers/was_solver.py index c823ef1d4..5d9fd4913 100644 --- a/ott/solvers/was_solver.py +++ b/src/ott/solvers/was_solver.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """A Jax version of the regularised GW Solver (Peyre et al. 2016).""" from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union diff --git a/ott/tools/__init__.py b/src/ott/tools/__init__.py similarity index 100% rename from ott/tools/__init__.py rename to src/ott/tools/__init__.py diff --git a/ott/tools/gaussian_mixture/__init__.py b/src/ott/tools/gaussian_mixture/__init__.py similarity index 100% rename from ott/tools/gaussian_mixture/__init__.py rename to src/ott/tools/gaussian_mixture/__init__.py diff --git a/ott/tools/gaussian_mixture/fit_gmm.py b/src/ott/tools/gaussian_mixture/fit_gmm.py similarity index 100% rename from ott/tools/gaussian_mixture/fit_gmm.py rename to src/ott/tools/gaussian_mixture/fit_gmm.py diff --git a/ott/tools/gaussian_mixture/fit_gmm_pair.py b/src/ott/tools/gaussian_mixture/fit_gmm_pair.py similarity index 100% rename from ott/tools/gaussian_mixture/fit_gmm_pair.py rename to src/ott/tools/gaussian_mixture/fit_gmm_pair.py diff --git a/ott/tools/gaussian_mixture/gaussian.py b/src/ott/tools/gaussian_mixture/gaussian.py similarity index 100% rename from ott/tools/gaussian_mixture/gaussian.py rename to src/ott/tools/gaussian_mixture/gaussian.py diff --git a/ott/tools/gaussian_mixture/gaussian_mixture.py b/src/ott/tools/gaussian_mixture/gaussian_mixture.py similarity index 99% rename from ott/tools/gaussian_mixture/gaussian_mixture.py rename to src/ott/tools/gaussian_mixture/gaussian_mixture.py index 25d4990ed..234b133f2 100644 --- a/ott/tools/gaussian_mixture/gaussian_mixture.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python 3 """Pytree for a Gaussian mixture model.""" from typing import List, Optional, Tuple, Union diff --git a/ott/tools/gaussian_mixture/gaussian_mixture_pair.py b/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py similarity index 100% rename from ott/tools/gaussian_mixture/gaussian_mixture_pair.py rename to src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py diff --git a/ott/tools/gaussian_mixture/linalg.py b/src/ott/tools/gaussian_mixture/linalg.py similarity index 100% rename from ott/tools/gaussian_mixture/linalg.py rename to src/ott/tools/gaussian_mixture/linalg.py diff --git a/ott/tools/gaussian_mixture/probabilities.py b/src/ott/tools/gaussian_mixture/probabilities.py similarity index 100% rename from ott/tools/gaussian_mixture/probabilities.py rename to src/ott/tools/gaussian_mixture/probabilities.py diff --git a/ott/tools/gaussian_mixture/scale_tril.py b/src/ott/tools/gaussian_mixture/scale_tril.py similarity index 100% rename from ott/tools/gaussian_mixture/scale_tril.py rename to src/ott/tools/gaussian_mixture/scale_tril.py diff --git a/ott/tools/k_means.py b/src/ott/tools/k_means.py similarity index 100% rename from ott/tools/k_means.py rename to src/ott/tools/k_means.py diff --git a/ott/tools/plot.py b/src/ott/tools/plot.py similarity index 100% rename from ott/tools/plot.py rename to src/ott/tools/plot.py diff --git a/ott/tools/segment_sinkhorn.py b/src/ott/tools/segment_sinkhorn.py similarity index 100% rename from ott/tools/segment_sinkhorn.py rename to src/ott/tools/segment_sinkhorn.py diff --git a/ott/tools/sinkhorn_divergence.py b/src/ott/tools/sinkhorn_divergence.py similarity index 100% rename from ott/tools/sinkhorn_divergence.py rename to src/ott/tools/sinkhorn_divergence.py diff --git a/ott/tools/soft_sort.py b/src/ott/tools/soft_sort.py similarity index 100% rename from ott/tools/soft_sort.py rename to src/ott/tools/soft_sort.py diff --git a/ott/tools/transport.py b/src/ott/tools/transport.py similarity index 100% rename from ott/tools/transport.py rename to src/ott/tools/transport.py diff --git a/ott/types.py b/src/ott/types.py similarity index 100% rename from ott/types.py rename to src/ott/types.py diff --git a/ott/utils.py b/src/ott/utils.py similarity index 100% rename from ott/utils.py rename to src/ott/utils.py diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index 924561eb3..76f58b7ab 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for the cost/norm functions.""" import jax diff --git a/tests/geometry/low_rank_test.py b/tests/geometry/low_rank_test.py index 1a5bbc1a4..49eb7647c 100644 --- a/tests/geometry/low_rank_test.py +++ b/tests/geometry/low_rank_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Test Low-Rank Geometry.""" from typing import Callable, Optional, Union diff --git a/tests/geometry/pointcloud_test.py b/tests/geometry/pointcloud_test.py index f3d2a501b..24ee95211 100644 --- a/tests/geometry/pointcloud_test.py +++ b/tests/geometry/pointcloud_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for apply_cost and apply_kernel.""" from typing import Union diff --git a/tests/initializers/linear/sinkhorn_init_test.py b/tests/initializers/linear/sinkhorn_init_test.py index 27e3f347d..9b073a7ff 100644 --- a/tests/initializers/linear/sinkhorn_init_test.py +++ b/tests/initializers/linear/sinkhorn_init_test.py @@ -9,8 +9,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for Sinkhorn initializers.""" import functools diff --git a/tests/initializers/linear/sinkhorn_lr_init_test.py b/tests/initializers/linear/sinkhorn_lr_init_test.py index bff23ec2e..0a8001f21 100644 --- a/tests/initializers/linear/sinkhorn_lr_init_test.py +++ b/tests/initializers/linear/sinkhorn_lr_init_test.py @@ -9,8 +9,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for Sinkhorn initializers.""" import jax diff --git a/tests/initializers/quadratic/gw_init_test.py b/tests/initializers/quadratic/gw_init_test.py index b900b60fc..255c27bc0 100644 --- a/tests/initializers/quadratic/gw_init_test.py +++ b/tests/initializers/quadratic/gw_init_test.py @@ -9,8 +9,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for Gromov-Wasserstein initializers.""" import jax diff --git a/tests/math/lse_test.py b/tests/math/lse_test.py index f0b076147..f3368613b 100644 --- a/tests/math/lse_test.py +++ b/tests/math/lse_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for the jvp of a custom implementation of lse.""" import jax diff --git a/tests/math/matrix_square_root_test.py b/tests/math/matrix_square_root_test.py index b5d3c08d8..6eb7319ba 100644 --- a/tests/math/matrix_square_root_test.py +++ b/tests/math/matrix_square_root_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for matrix square roots.""" from typing import Callable diff --git a/tests/solvers/linear/continuous_barycenter_test.py b/tests/solvers/linear/continuous_barycenter_test.py index 36fdf1d97..20f2ff4cd 100644 --- a/tests/solvers/linear/continuous_barycenter_test.py +++ b/tests/solvers/linear/continuous_barycenter_test.py @@ -10,8 +10,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for continuous barycenter.""" import functools from typing import Tuple diff --git a/tests/solvers/linear/discrete_barycenter_test.py b/tests/solvers/linear/discrete_barycenter_test.py index 8f2bcfa0b..cd0c99f10 100644 --- a/tests/solvers/linear/discrete_barycenter_test.py +++ b/tests/solvers/linear/discrete_barycenter_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Lint as: python3 import jax.numpy as jnp import pytest diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index 80d9e62c8..c4b1dd2d1 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for the differentiability of reg_ot_cost w.r.t weights/locations.""" import functools from typing import Tuple diff --git a/tests/solvers/linear/sinkhorn_grid_test.py b/tests/solvers/linear/sinkhorn_grid_test.py index 7937ce717..0f3a8446c 100644 --- a/tests/solvers/linear/sinkhorn_grid_test.py +++ b/tests/solvers/linear/sinkhorn_grid_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for Sinkhorn when applied on a grid.""" import jax diff --git a/tests/solvers/linear/sinkhorn_lr_test.py b/tests/solvers/linear/sinkhorn_lr_test.py index 55d95a533..8b70fd469 100644 --- a/tests/solvers/linear/sinkhorn_lr_test.py +++ b/tests/solvers/linear/sinkhorn_lr_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests Sinkhorn Low-Rank solver with various initializations.""" import jax import jax.numpy as jnp diff --git a/tests/solvers/linear/sinkhorn_misc_test.py b/tests/solvers/linear/sinkhorn_misc_test.py index 96da5c6bc..d93e985ce 100644 --- a/tests/solvers/linear/sinkhorn_misc_test.py +++ b/tests/solvers/linear/sinkhorn_misc_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests Anderson acceleration for Sinkhorn.""" import functools from typing import Callable, Tuple diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index d3372f1be..4500a5d1a 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for Sinkhorn.""" import jax diff --git a/tests/solvers/nn/icnn_test.py b/tests/solvers/nn/icnn_test.py index b45d2c521..f7f64127a 100644 --- a/tests/solvers/nn/icnn_test.py +++ b/tests/solvers/nn/icnn_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for ICNN network architecture.""" import jax diff --git a/tests/solvers/nn/neuraldual_test.py b/tests/solvers/nn/neuraldual_test.py index 4fa9c009b..526d8903f 100644 --- a/tests/solvers/nn/neuraldual_test.py +++ b/tests/solvers/nn/neuraldual_test.py @@ -10,8 +10,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for implementation of ICNN-based Kantorovich dual by Makkuva+(2020).""" from typing import Iterator, Sequence, Tuple diff --git a/tests/solvers/quadratic/fgw_test.py b/tests/solvers/quadratic/fgw_test.py index 338ffafb1..5256613c9 100644 --- a/tests/solvers/quadratic/fgw_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for the Fused Gromov Wasserstein.""" from typing import Tuple, Union diff --git a/tests/solvers/quadratic/gw_barycenter_test.py b/tests/solvers/quadratic/gw_barycenter_test.py index 94cd5759b..f824955c8 100644 --- a/tests/solvers/quadratic/gw_barycenter_test.py +++ b/tests/solvers/quadratic/gw_barycenter_test.py @@ -10,8 +10,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for Gromov-Wasserstein barycenter.""" from typing import Any, Optional, Sequence, Tuple diff --git a/tests/solvers/quadratic/gw_test.py b/tests/solvers/quadratic/gw_test.py index b3ffd3051..5ec47d2e5 100644 --- a/tests/solvers/quadratic/gw_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for the Gromov Wasserstein.""" from typing import Tuple, Union diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py index db1c0338d..8790a6d48 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python 3 """Tests for gaussian_mixture_pair.""" import jax diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_test.py index f7ff09d9d..0c9857921 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python 3 """Tests for gaussian_mixture.""" import jax diff --git a/tests/tools/segment_sinkhorn_test.py b/tests/tools/segment_sinkhorn_test.py index 5a1e81c1e..44f2cbe5e 100644 --- a/tests/tools/segment_sinkhorn_test.py +++ b/tests/tools/segment_sinkhorn_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for Segmented Sinkhorn.""" import jax diff --git a/tests/tools/sinkhorn_divergence_test.py b/tests/tools/sinkhorn_divergence_test.py index 87b82b3d8..458147bf2 100644 --- a/tests/tools/sinkhorn_divergence_test.py +++ b/tests/tools/sinkhorn_divergence_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for the Sinkhorn divergence.""" from typing import Any, Dict, Optional @@ -43,10 +41,10 @@ def setUp(self, rng: jnp.ndarray): cost_fn=[costs.Euclidean(), costs.SqEuclidean(), costs.SqPNorm(p=2.1)], - epsilon=[.01, .001], + epsilon=[1e-2, 1e-3], only_fast={ - "costs_fn": costs.SqEuclidean(), - "epsilon": .01 + "cost_fn": costs.SqEuclidean(), + "epsilon": 1e-2 }, ) def test_euclidean_point_cloud(self, cost_fn, epsilon): diff --git a/tests/tools/soft_sort_test.py b/tests/tools/soft_sort_test.py index fcf91e617..a6fe6bb37 100644 --- a/tests/tools/soft_sort_test.py +++ b/tests/tools/soft_sort_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Lint as: python3 """Tests for the soft sort tools.""" import functools from typing import Tuple