JAX - AshokBhat/ml GitHub Wiki
- Python library for machine learning research
- API similar to NumPy and SciPy
- Uses XLA to compile and run your programs on GPUs and TPUs
NumPy on the CPU, GPU, and TPUwith great automatic differentiationfor high-performance machine learning research
-
jit()- for speeding up your code -
grad()- for taking derivatives -
vmap()- for automatic vectorization or batching
| PyPi Package | arm64 (Apple M1) | aarch64 | Notes |
|---|---|---|---|
| jaxlib | ✔️ | ✔️ | |
| jax | NA | NA | Not architecture-specific |
import jax.numpy as jnp
array = jnp.array([1, 2, 3])
squared_array = jnp.square(array)
print("Squared array:", squared_array)
Output will be Squared array: [1 4 9]