JAX - AshokBhat/ml GitHub Wiki

About

Positioning

  • NumPy on the CPU, GPU, and TPU
  • with great automatic differentiation
  • for high-performance machine learning research

Key functions

  • jit() - for speeding up your code
  • grad()- for taking derivatives
  • vmap()- for automatic vectorization or batching

JAX and XLA

  • Uses XLA to compile and run your NumPy code on accelerators, like GPUs and TPUs

Arm support

PyPi Package arm64 (Apple M1) aarch64 Notes
jaxlib ✔️ ✔️
jax NA NA Not architecture-specific

Example - Create an array and square its elements

import jax.numpy as jnp
array = jnp.array([1, 2, 3])
squared_array = jnp.square(array)
print("Squared array:", squared_array)

Output will be Squared array: [1 4 9]

See also

⚠️ **GitHub.com Fallback** ⚠️