Jax Ecosystem - jejjohnson/ml4eo GitHub Wiki
Tutorials
- Jax Tutorial 101
- PyTrees | Jax PyTrees Tutorial
- Differentiable Programming
- Autodidax: Jax From Scratch
Software
Math
Linear Algebra
- Cola, Lab
- autoray
- einx
- einops
- matfree - Matrix-Free linear algebra in JAX
- opt-einsum - optimized einsum (numpy, JAX, TF, PyTorch, Dask, CuPy, Sparse)
Symbolic Math
- sympy2jax
Convolutions
Integration
Interpolation
- Nyx, RBF, KernelBiome
- quax, jaxtyping
Special DataStructures
Neural Networks
- Equinox, Flax, Keras
- xarray_jax
Optimization
- Optimistix, LineaX, Optax, JaxOpt, ott
- varz - Simple, multi-backend constrained (L-BFGS) and unconstrained optimization (Adam).
Kernels
- mlkernels - Kernel Matrices (JAX, TF, PyTorch, Julia).
Probabilistic
- blackjack,
- numpyro
- numpyro-ext
- tfp.substrate.jax
- fenbux,
- bayeux
- jaxns
- efax
- SGMCMCJax - stochastic Gradient samplers in jax
Normalizing Flows
Gaussian Processes
State Space Models
Numerical Methods
Differentiation
- FiniteDiffX, FiniteVolX, SpectralDiffX
- jax-fem
- Probfindiff
- LapJax
- RBF-FDax
ODESolvers
- Diffrax,
- probdiffeq - probabilistic solvers for differential equations
ODE Implementations
PDE Implementations
Basis Functions
- orthojax
- jax-wavelet-toolbox
- cr-wavelets
- s2fft
- s2ball
- s2wav
- orthax
- SphericalHarmonics - spherical harmonics (numpy, JAX, PyTorch, TF)
- Jax Implementation
Parallel Programming
**