linox
is a Python package that provides a collection of linear operators for JAX, enabling efficient and flexible linear algebra operations with lazy evaluation. This package is designed as an JAX alternative to probnum.linops
, but it is currently still under development having less and more instable features. It has no dependencies other than JAX and plum
for multiple dispatch.
- Lazy Evaluation: All operators support lazy evaluation, allowing for efficient computation of complex linear transformations
- JAX Integration: Built on top of JAX, providing automatic differentiation, parallelization, JIT compilation, and GPU/TPU support
- Composable Operators: Operators can be combined to form complex linear transformations
Matrix
: General matrix operatorIdentity
: Identity matrix operatorDiagonal
: Diagonal matrix operatorScalar
: Scalar multiple of identityZero
: Zero matrix operatorOnes
: Matrix of ones operator
BlockMatrix
: General block matrix operatorBlockMatrix2x2
: 2x2 block matrix operatorBlockDiagonal
: Block diagonal matrix operator
LowRank
: General low rank operatorSymmetricLowRank
: Symmetric low rank operatorIsotropicScalingPlusSymmetricLowRank
: Isotropic scaling plus symmetric low rankPositiveDiagonalPlusSymmetricLowRank
: Positive diagonal plus symmetric low rank
Kronecker
: Kronecker product operatorPermutation
: Permutation matrix operatorEigenD
: Eigenvalue decomposition operator
- Automatic Differentiation: Compute gradients automatically through operator compositions
- JIT Compilation: Speed up computations with just-in-time compilation
- Vectorization: Efficient batch processing of linear operations via e.g.
jax.vmap
- GPU/TPU Support: Run computations on accelerators without code changes
- Functional Programming: Pure functions enable better optimization and parallelization
import jax
import jax.numpy as jnp
from linox import Matrix, Diagonal, BlockMatrix
# Create operators
A = Matrix(jnp.array([[1, 2], [3, 4]], dtype=jnp.float32))
D = Diagonal(jnp.array([1, 2], dtype=jnp.float32))
# Compose operators
B = BlockMatrix([[A, D], [D, A]])
# Apply to vector
x = jnp.ones((4,), dtype=jnp.float32)
y = B @ x # Lazy evaluation
# Parallelize over batch of vectors
x_batched = jnp.ones((10, 4), dtype=jnp.float32)
y_batched = jax.vmap(B)(x_batched)