Skip to content

2bys/linox

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

44 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

linox: Linear Operators in JAX

linox is a Python package that provides a collection of linear operators for JAX, enabling efficient and flexible linear algebra operations with lazy evaluation. This package is designed as an JAX alternative to probnum.linops, but it is currently still under development having less and more instable features. It has no dependencies other than JAX and plum for multiple dispatch.

Features

  • Lazy Evaluation: All operators support lazy evaluation, allowing for efficient computation of complex linear transformations
  • JAX Integration: Built on top of JAX, providing automatic differentiation, parallelization, JIT compilation, and GPU/TPU support
  • Composable Operators: Operators can be combined to form complex linear transformations

Linear Operators

Basic Operators

  • Matrix: General matrix operator
  • Identity: Identity matrix operator
  • Diagonal: Diagonal matrix operator
  • Scalar: Scalar multiple of identity
  • Zero: Zero matrix operator
  • Ones: Matrix of ones operator

Block Operators

  • BlockMatrix: General block matrix operator
  • BlockMatrix2x2: 2x2 block matrix operator
  • BlockDiagonal: Block diagonal matrix operator

Low Rank Operators

  • LowRank: General low rank operator
  • SymmetricLowRank: Symmetric low rank operator
  • IsotropicScalingPlusSymmetricLowRank: Isotropic scaling plus symmetric low rank
  • PositiveDiagonalPlusSymmetricLowRank: Positive diagonal plus symmetric low rank

Special Operators

  • Kronecker: Kronecker product operator
  • Permutation: Permutation matrix operator
  • EigenD: Eigenvalue decomposition operator

Benefits of JAX Integration

  • Automatic Differentiation: Compute gradients automatically through operator compositions
  • JIT Compilation: Speed up computations with just-in-time compilation
  • Vectorization: Efficient batch processing of linear operations via e.g. jax.vmap
  • GPU/TPU Support: Run computations on accelerators without code changes
  • Functional Programming: Pure functions enable better optimization and parallelization

Quick Example

import jax
import jax.numpy as jnp
from linox import Matrix, Diagonal, BlockMatrix

# Create operators
A = Matrix(jnp.array([[1, 2], [3, 4]], dtype=jnp.float32))
D = Diagonal(jnp.array([1, 2], dtype=jnp.float32))

# Compose operators
B = BlockMatrix([[A, D], [D, A]])

# Apply to vector
x = jnp.ones((4,), dtype=jnp.float32)
y = B @ x  # Lazy evaluation

# Parallelize over batch of vectors
x_batched = jnp.ones((10, 4), dtype=jnp.float32)
y_batched = jax.vmap(B)(x_batched)

About

Light-weight linear operators in JAX with plum-dispatch and without additional dependencies

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  

Languages