| name | python-jax |
| description | Expert guidance for JAX (Just After eXecution) - high-performance numerical computing with automatic differentiation, JIT compilation, vectorization, and GPU/TPU acceleration; includes transformations (grad, jit, vmap, pmap), sharp bits, gotchas, and differences from NumPy |
| allowed-tools | ["*"] |
JAX - High-Performance Numerical Computing
Overview
JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. It combines a familiar NumPy-style API with powerful function transformations for automatic differentiation, compilation, vectorization, and parallelization.
Core value: Write NumPy-like Python code and automatically get gradients, GPU/TPU acceleration, vectorization, and parallelization through composable function transformations—without changing your mathematical notation.
When to Use JAX
Use JAX when:
- Need automatic differentiation for optimization or machine learning
- Want GPU/TPU acceleration with minimal code changes
- Require high-performance numerical computing
- Building custom gradient-based algorithms
- Need to vectorize or parallelize functions automatically
- Working on research requiring flexible differentiation
- Want functional programming approach to numerical code
Don't use when:
- Simple NumPy operations without performance needs (overhead not justified)
- Heavy reliance on in-place mutations (JAX arrays are immutable)
- Imperative, stateful code with side effects
- Need control flow depending on runtime data values (limited support)
- Working with libraries that require NumPy arrays (compatibility issues)
Core Transformations
JAX provides four fundamental, composable transformations:
1. jax.grad - Automatic Differentiation
Compute gradients automatically using reverse-mode autodiff.
import jax
import jax.numpy as jnp
def loss(x):
return jnp.sum(x**2)
grad_loss = jax.grad(loss)
x = jnp.array([1.0, 2.0, 3.0])
gradient = grad_loss(x)
print(gradient)
Key features:
- Returns a function that computes gradients
- Composes for higher-order derivatives
- Works with complex nested structures (pytrees)
2. jax.jit - Just-In-Time Compilation
Compile functions to XLA for dramatic speedups.
import jax
@jax.jit
def fast_function(x):
return jnp.sum(x**2 + 3*x + 1)
result = fast_function(jnp.array([1.0, 2.0, 3.0]))
result = fast_function(jnp.array([4.0, 5.0, 6.0]))
Performance:
- 10-100x speedup typical for numerical functions
- First call slower (compilation overhead)
- Cached for subsequent calls with same shapes/dtypes
3. jax.vmap - Automatic Vectorization
Vectorize functions across batch dimensions automatically.
import jax
def process_single(x):
"""Process single example"""
return jnp.sum(x**2)
process_batch = jax.vmap(process_single)
batch = jnp.array([[1, 2], [3, 4], [5, 6]])
results = process_batch(batch)
print(results)
Benefits:
- Eliminates manual loop writing
- Often faster than explicit loops
- Cleaner, more declarative code
4. jax.pmap - Parallel Map
Parallelize across multiple devices (GPUs/TPUs).
import jax
@jax.pmap
def parallel_fn(x):
return x**2
devices = jax.devices()
x = jnp.arange(len(devices))
results = parallel_fn(x)
Use for:
- Multi-GPU/TPU computation
- Data parallelism in training
- Large-scale simulations
Transformation Composition
JAX transformations compose seamlessly:
@jax.jit
def fast_loss(x):
return jnp.sum(x**2)
grad_fast_loss = jax.grad(fast_loss)
fast_grad_loss = jax.jit(jax.grad(fast_loss))
batch_grad_loss = jax.vmap(jax.grad(fast_loss))
fast_batch_grad = jax.jit(jax.vmap(jax.grad(fast_loss)))
Order matters for performance but not correctness.
JAX vs NumPy - Critical Differences
1. Immutability
NumPy (mutable):
import numpy as np
x = np.array([1, 2, 3])
x[0] = 10
print(x)
JAX (immutable):
import jax.numpy as jnp
x = jnp.array([1, 2, 3])
x = x.at[0].set(10)
print(x)
Functional updates:
x = x.at[0].set(10)
x = x.at[0].add(5)
x = x.at[0].mul(2)
x = x.at[0].min(5)
x = x.at[0].max(5)
x = x.at[0, 1].set(10)
x = x.at[[0, 2]].set(10)
x = x.at[0:3].set(10)
2. Random Number Generation
NumPy (global state):
import numpy as np
np.random.seed(42)
x = np.random.normal(size=3)
y = np.random.normal(size=3)
JAX (explicit keys):
import jax
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, shape=(3,))
key, subkey = jax.random.split(key)
y = jax.random.normal(subkey, shape=(3,))
keys = jax.random.split(key, num=10)
samples = jax.vmap(lambda k: jax.random.normal(k, shape=(3,)))(keys)
Key management pattern:
key = jax.random.PRNGKey(0)
key, subkey1, subkey2 = jax.random.split(key, 3)
x = jax.random.normal(subkey1, shape=(10,))
y = jax.random.uniform(subkey2, shape=(10,))
3. 64-bit Precision
JAX defaults to 32-bit for performance.
import jax.numpy as jnp
x = jnp.array([1.0])
print(x.dtype)
from jax import config
config.update("jax_enable_x64", True)
x = jnp.array([1.0])
print(x.dtype)
4. Out-of-Bounds Indexing
NumPy (raises error):
import numpy as np
x = np.array([1, 2, 3])
JAX (clamps silently):
import jax.numpy as jnp
x = jnp.array([1, 2, 3])
print(x[10])
x = x.at[10].set(99)
⚠️ Warning: This is undefined behavior - avoid relying on it!
5. Non-Array Inputs
NumPy (accepts lists):
import numpy as np
result = np.sum([1, 2, 3])
JAX (requires arrays):
import jax.numpy as jnp
result = jnp.sum(jnp.array([1, 2, 3]))
Reason: Prevents performance degradation during tracing.
Sharp Bits & Gotchas
1. Pure Functions Required
❌ Bad: Side effects
counter = 0
@jax.jit
def impure_fn(x):
global counter
counter += 1
return x**2
result = impure_fn(2.0)
result = impure_fn(3.0)
✅ Good: Pure function
@jax.jit
def pure_fn(x, counter):
return x**2, counter + 1
result, counter = pure_fn(2.0, 0)
result, counter = pure_fn(3.0, counter)
2. Control Flow Limitations
❌ Bad: Value-dependent control flow
@jax.jit
def conditional(x):
if x > 0:
return x**2
else:
return x**3
✅ Good: Use jax.lax.cond
@jax.jit
def conditional(x):
return jax.lax.cond(
x > 0,
lambda x: x**2,
lambda x: x**3
)
For loops:
@jax.jit
def loop_bad(x, n):
for i in range(n):
x = x + 1
return x
@jax.jit
def loop_good(x, n):
def body(i, val):
return val + 1
return jax.lax.fori_loop(0, n, body, x)
While loops:
@jax.jit
def while_loop(x):
def cond_fun(val):
return val < 10
def body_fun(val):
return val + 1
return jax.lax.while_loop(cond_fun, body_fun, x)
3. Dynamic Shapes
❌ Bad: Shape depends on runtime values
@jax.jit
def dynamic_shape(x, mask):
return x[mask]
✅ Good: Use jnp.where for masking
@jax.jit
def static_shape(x, mask):
return jnp.where(mask, x, 0)
4. In-Place Update Semantics
❌ Bad: Relying on shared references
x = jnp.array([1, 2, 3])
y = x
x = x.at[0].set(10)
print(y[0])
Note: .at returns a NEW array; doesn't modify original.
5. Print Debugging in JIT
❌ Bad: print() in JIT context
@jax.jit
def debug(x):
print(f"x = {x}")
return x**2
✅ Good: Use jax.debug.print
@jax.jit
def debug(x):
jax.debug.print("x = {}", x)
return x**2
6. Gradient Through Discrete Operations
❌ Bad: Gradient through argmax
def loss(x):
idx = jnp.argmax(x)
return x[idx]
✅ Good: Use differentiable approximations
def loss(x):
weights = jax.nn.softmax(x * temperature)
return jnp.sum(x * weights)
grad_loss = jax.grad(loss)
Automatic Differentiation
Basic Gradient
import jax
import jax.numpy as jnp
def f(x):
return jnp.sum(x**2)
grad_f = jax.grad(f)
x = jnp.array([1.0, 2.0, 3.0])
print(grad_f(x))
Value and Gradient
value_and_grad_f = jax.value_and_grad(f)
value, gradient = value_and_grad_f(x)
print(f"Value: {value}, Gradient: {gradient}")
Multiple Arguments
def f(x, y):
return jnp.sum(x**2 + y**3)
grad_f = jax.grad(f)
print(grad_f(x, y))
grad_f_wrt_y = jax.grad(f, argnums=1)
print(grad_f_wrt_y(x, y))
grad_f_both = jax.grad(f, argnums=(0, 1))
grad_x, grad_y = grad_f_both(x, y)
Auxiliary Data
def loss_with_aux(x):
loss = jnp.sum(x**2)
aux_data = {'norm': jnp.linalg.norm(x), 'mean': jnp.mean(x)}
return loss, aux_data
grad_fn = jax.grad(loss_with_aux, has_aux=True)
gradient, aux = grad_fn(x)
print(f"Gradient: {gradient}")
print(f"Auxiliary: {aux}")
Higher-Order Derivatives
def f(x):
return jnp.sum(x**3)
grad_f = jax.grad(f)
hess_diag_f = jax.grad(lambda x: jnp.sum(grad_f(x) * x))
hessian_f = jax.hessian(f)
x = jnp.array([1.0, 2.0, 3.0])
print(hessian_f(x))
Jacobian
def vector_fn(x):
"""Vector to vector function"""
return jnp.array([x[0]**2, x[1]**3, x[0]*x[1]])
jacfwd = jax.jacfwd(vector_fn)
jacrev = jax.jacrev(vector_fn)
x = jnp.array([2.0, 3.0])
print(jacfwd(x))
print(jacrev(x))
Custom Gradients
Use @jax.custom_vjp (reverse-mode) or @jax.custom_jvp (forward-mode):
@jax.custom_vjp
def f(x):
return jnp.exp(x)
def f_fwd(x):
result = jnp.exp(x)
return result, result
def f_bwd(result, g):
return (g * 2 * result,)
f.defvjp(f_fwd, f_bwd)
grad_f = jax.grad(f)
JIT Compilation
Basic Usage
import jax
import jax.numpy as jnp
def slow_fn(x):
return jnp.sum(x**2 + 3*x + 1)
fast_fn = jax.jit(slow_fn)
@jax.jit
def fast_fn2(x):
return jnp.sum(x**2 + 3*x + 1)
Static Arguments
@jax.jit
def fn(x, n):
for i in range(n):
x = x + 1
return x
result = fn(x, 5)
result = fn(x, 10)
@jax.jit(static_argnums=(1,))
def fn_static(x, n):
for i in range(n):
x = x + 1
return x
result = fn_static(x, 5)
result = fn_static(x, 10)
Avoiding Recompilation
@jax.jit
def process(x):
return jnp.sum(x**2)
x1 = jnp.ones(10)
x2 = jnp.ones(20)
x3 = jnp.ones(10)
batch_size = 32
x1 = jnp.ones((batch_size, 10))
x2 = jnp.ones((batch_size, 10))
Vectorization (vmap)
Basic Batching
def process_single(x):
"""Process single example: scalar input"""
return x**2 + 3*x
def process_batch_manual(xs):
return jnp.array([process_single(x) for x in xs])
process_batch = jax.vmap(process_single)
batch = jnp.array([1.0, 2.0, 3.0, 4.0])
print(process_batch(batch))
Batching Matrix Operations
def matrix_vector_product(matrix, vector):
"""Single matrix-vector product"""
return matrix @ vector
batch_mvp = jax.vmap(matrix_vector_product, in_axes=(None, 0))
A = jnp.ones((3, 3))
vectors = jnp.ones((10, 3))
results = batch_mvp(A, vectors)
batch_both = jax.vmap(matrix_vector_product, in_axes=(0, 0))
matrices = jnp.ones((10, 3, 3))
results = batch_both(matrices, vectors)
Nested vmap
def fn(x, y):
return x * y
fn_batch = jax.vmap(jax.vmap(fn, in_axes=(None, 0)), in_axes=(0, None))
x = jnp.array([1, 2, 3])
y = jnp.array([10, 20])
result = fn_batch(x, y)
PyTrees - Nested Structures
JAX works with nested Python containers (pytrees):
import jax
import jax.numpy as jnp
params = {
'w1': jnp.ones((10, 5)),
'b1': jnp.zeros(5),
'w2': jnp.ones((5, 1)),
'b2': jnp.zeros(1)
}
def loss(params, x):
h = x @ params['w1'] + params['b1']
h = jax.nn.relu(h)
out = h @ params['w2'] + params['b2']
return jnp.mean(out**2)
grad_fn = jax.grad(loss)
x = jnp.ones((32, 10))
grads = grad_fn(params, x)
print(grads.keys())
PyTree Operations
scaled_params = jax.tree.map(lambda x: x * 0.9, params)
total_params = jax.tree.reduce(
lambda total, x: total + x.size,
params,
initializer=0
)
flat, treedef = jax.tree_flatten(params)
reconstructed = jax.tree_unflatten(treedef, flat)
Common Patterns
Optimization Loop
import jax
import jax.numpy as jnp
params = jnp.array([1.0, 2.0, 3.0])
def loss(params, x, y):
pred = jnp.dot(params, x)
return jnp.mean((pred - y)**2)
grad_fn = jax.jit(jax.grad(loss))
x_train = jnp.ones((100, 3))
y_train = jnp.ones(100)
learning_rate = 0.01
for step in range(1000):
grads = grad_fn(params, x_train, y_train)
params = params - learning_rate * grads
if step % 100 == 0:
l = loss(params, x_train, y_train)
print(f"Step {step}, Loss: {l:.4f}")
Mini-batch Training
def train_step(params, batch):
"""Single training step on one batch"""
x, y = batch
def batch_loss(params):
pred = jnp.dot(x, params)
return jnp.mean((pred - y)**2)
loss_value, grads = jax.value_and_grad(batch_loss)(params)
params = params - 0.01 * grads
return params, loss_value
train_step = jax.jit(train_step)
for epoch in range(10):
for batch in data_loader:
params, loss = train_step(params, batch)
Scan for Efficient Loops
def cumulative_sum(xs):
"""Efficient cumulative sum using scan"""
def step(carry, x):
new_carry = carry + x
output = new_carry
return new_carry, output
final_carry, outputs = jax.lax.scan(step, 0, xs)
return outputs
xs = jnp.array([1, 2, 3, 4, 5])
print(cumulative_sum(xs))
RNN with Scan
def rnn_step(carry, x):
"""Single RNN step"""
h = carry
h_new = jnp.tanh(jnp.dot(W_h, h) + jnp.dot(W_x, x))
return h_new, h_new
def rnn(xs, h0):
"""Run RNN over sequence"""
final_h, all_h = jax.lax.scan(rnn_step, h0, xs)
return all_h
W_h = jnp.ones((5, 5))
W_x = jnp.ones((5, 3))
xs = jnp.ones((10, 3))
h0 = jnp.zeros(5)
outputs = rnn(xs, h0)
Performance Best Practices
1. JIT Your Critical Paths
@jax.jit
def expensive_fn(x):
for _ in range(100):
x = jnp.dot(x, x.T)
return x
def trivial(x):
return x + 1
2. Use vmap Instead of Loops
def slow_batch(xs):
return jnp.array([process(x) for x in xs])
fast_batch = jax.vmap(process)
3. Minimize Host-Device Transfers
for i in range(1000):
x = compute_on_gpu(x)
print(float(x))
for i in range(1000):
x = compute_on_gpu(x)
print(float(x))
4. Use Appropriate Precision
x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)
from jax import config
config.update("jax_enable_x64", True)
5. Preallocate When Possible
result = jnp.array([])
for i in range(1000):
result = jnp.append(result, compute(i))
result = jnp.zeros(1000)
for i in range(1000):
result = result.at[i].set(compute(i))
def compute_all(i):
return compute(i)
result = jax.vmap(compute_all)(jnp.arange(1000))
Debugging
Checking Array Values
@jax.jit
def debug_fn(x):
jax.debug.print("x = {}", x)
jax.debug.print("x shape = {}, dtype = {}", x.shape, x.dtype)
return x**2
Gradient Checking
from jax.test_util import check_grads
def f(x):
return jnp.sum(x**3)
x = jnp.array([1.0, 2.0, 3.0])
check_grads(f, (x,), order=2)
Inspecting Compiled Code
def f(x):
return x**2 + 3*x
jaxpr = jax.make_jaxpr(f)(1.0)
print(jaxpr)
compiled = jax.jit(f).lower(1.0).compile()
print(compiled.as_text())
Common Gotchas - Quick Reference
| Gotcha | NumPy | JAX | Solution |
|---|
| In-place update | x[0] = 1 | ❌ Error | x = x.at[0].set(1) |
| Random state | np.random.seed() | ❌ Not reliable | key = jax.random.PRNGKey() |
| List inputs | np.sum([1,2,3]) | ❌ Error | jnp.sum(jnp.array([1,2,3])) |
| Out-of-bounds | IndexError | ⚠️ Silent clamp | Avoid, validate indices |
| Value-dependent if | Works | ❌ In JIT | jax.lax.cond() |
| Dynamic shapes | Works | ❌ In JIT | Keep shapes static |
| Default precision | float64 | float32 | Set jax_enable_x64 |
| Print in JIT | Works | Only once | jax.debug.print() |
Installation
pip install jax
pip install jax[cuda12]
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Ecosystem Libraries
JAX has a rich ecosystem built on its transformation primitives:
Neural Networks:
- Flax - Official neural network library
- Haiku - DeepMind's neural network library (DEPRECATED: migrate to Flax/NNX)
- Equinox - Elegant PyTorch-like library
Optimization:
- Optax - Gradient processing and optimization
- JAXopt - Non-linear optimization
Scientific Computing:
- JAX-MD - Molecular dynamics
- Diffrax - Differential equation solvers
- BlackJAX - MCMC sampling
Utilities:
- jaxtyping - Type annotations for arrays
- chex - Testing utilities
Additional Resources
Related Skills
python-optimization - Numerical optimization with scipy, pyomo
python-ase - Atomic simulation environment (can use JAX for forces)
pycse - Scientific computing utilities
python-best-practices - Code quality for JAX projects