JAX - AshokBhat/ml GitHub Wiki
- Python library for machine learning research
- API similar to NumPy and SciPy
- Uses XLA to compile and run your programs on GPUs and TPUs
NumPy on the CPU, GPU, and TPU
with great automatic differentiation
for high-performance machine learning research
-
jit()
- for speeding up your code -
grad()
- for taking derivatives -
vmap()
- for automatic vectorization or batching
PyPi Package | arm64 (Apple M1) | aarch64 | Notes |
---|---|---|---|
jaxlib | ✔️ | ✔️ | |
jax | NA | NA | Not architecture-specific |
import jax.numpy as jnp array = jnp.array([1, 2, 3]) squared_array = jnp.square(array) print("Squared array:", squared_array)
Output will be Squared array: [1 4 9]