Learn JAX with Real Code Examples
Updated Nov 24, 2025
Code Sample Descriptions
1
JAX Simple Linear Regression Example
import jax.numpy as jnp
from jax import grad, jit
# Sample data
x = jnp.array([1,2,3,4])
y = jnp.array([2,4,6,8])
# Initialize parameters
a = 0.0
b = 0.0
# Define loss function
def loss(a, b):
y_pred = a * x + b
return jnp.mean((y - y_pred)**2)
# Compute gradients
grad_loss = grad(loss, argnums=(0,1))
# Simple gradient descent loop
for _ in range(1000):
da, db = grad_loss(a, b)
a -= 0.01 * da
b -= 0.01 * db
print('Learned parameters:', a, b)
A minimal JAX example performing linear regression using automatic differentiation.
2
JAX Logistic Regression Example
import jax.numpy as jnp
from jax import grad
# Sample data
X = jnp.array([[0,0],[0,1],[1,0],[1,1]])
y = jnp.array([0,1,1,0]) # XOR example
# Initialize parameters
w = jnp.zeros(2)
b = 0.0
# Sigmoid function
def sigmoid(z):
return 1 / (1 + jnp.exp(-z))
# Loss function
def loss(w, b):
y_pred = sigmoid(jnp.dot(X, w) + b)
return -jnp.mean(y * jnp.log(y_pred) + (1-y) * jnp.log(1-y_pred))
grad_loss = grad(loss, argnums=(0,1))
# Gradient descent
for _ in range(1000):
dw, db = grad_loss(w, b)
w -= 0.1 * dw
b -= 0.1 * db
print('Learned weights:', w, 'bias:', b)
A logistic regression implementation using JAX for binary classification.
3
JAX Neural Network Forward Pass Example
import jax.numpy as jnp
# Input
x = jnp.array([1.0, 2.0, 3.0])
# Network parameters
W1 = jnp.array([[0.1,0.2,0.3],[0.4,0.5,0.6]])
b1 = jnp.array([0.1,0.2])
W2 = jnp.array([[0.7,0.8]])
b2 = jnp.array([0.3])
# Forward pass
def relu(x):
return jnp.maximum(0, x)
h = relu(jnp.dot(W1, x) + b1)
y_pred = jnp.dot(W2, h) + b2
print('Output:', y_pred)
Forward pass of a simple 2-layer neural network in JAX.
4
JAX Mean Squared Error Example
import jax.numpy as jnp
# Sample predictions and targets
y_true = jnp.array([1.0,2.0,3.0])
y_pred = jnp.array([1.1,1.9,3.2])
# Mean squared error
def mse(y_true, y_pred):
return jnp.mean((y_true - y_pred)**2)
print('MSE:', mse(y_true, y_pred))
Computing mean squared error using JAX for vectorized operations.
5
JAX Gradient Computation Example
import jax.numpy as jnp
from jax import grad
# Function
def f(x):
return x**2 + 3*x + 2
# Compute gradient
grad_f = grad(f)
x = 5.0
print('Gradient at x=5:', grad_f(x))
Illustrating automatic differentiation using JAX.
6
JAX Vectorized Operations Example
import jax.numpy as jnp
# Arrays
x = jnp.array([1,2,3,4])
y = jnp.array([2,4,6,8])
# Element-wise operations
z = x + y
d = x * y
print('Sum:', z)
print('Product:', d)
Demonstrating JAX's vectorized operations on arrays.
7
JAX JIT Compilation Example
import jax.numpy as jnp
from jax import jit
@jit
def f(x):
return jnp.sin(x) ** 2 + jnp.cos(x) ** 2
x = jnp.linspace(0, 10, 1000)
print('JIT function output:', f(x))
Using JIT compilation in JAX to speed up function execution.
8
JAX Neural Network Training Example
import jax.numpy as jnp
from jax import grad
# Data
X = jnp.array([[1.0],[2.0],[3.0],[4.0]])
y = jnp.array([2.0,4.0,6.0,8.0])
# Parameters
w = 0.0
b = 0.0
# Prediction
def predict(w, b, X):
return w * X + b
# Loss
def loss(w, b):
y_pred = predict(w, b, X)
return jnp.mean((y - y_pred)**2)
grad_loss = grad(loss, argnums=(0,1))
# Gradient descent
for _ in range(1000):
dw, db = grad_loss(w, b)
w -= 0.01 * dw
b -= 0.01 * db
print('Learned w,b:', w, b)
Training a small neural network using JAX and gradient descent.
9
JAX Softmax Classification Example
import jax.numpy as jnp
from jax import grad
# Data
X = jnp.array([[1,2],[3,4],[5,6]])
y = jnp.array([0,1,2])
# Parameters
W = jnp.zeros((3,2))
b = jnp.zeros(3)
# Softmax
def softmax(z):
e_z = jnp.exp(z - jnp.max(z))
return e_z / e_z.sum(axis=0)
# Loss
def loss(W, b):
logits = jnp.dot(X, W.T) + b
y_pred = jnp.array([softmax(l) for l in logits])
return -jnp.mean(jnp.log(y_pred[jnp.arange(len(y)), y]))
grad_loss = grad(loss, argnums=(0,1))
Performing multi-class classification using softmax in JAX.
10
JAX Convolution Example
import jax.numpy as jnp
from jax import lax
# Input sequence
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
# Kernel
w = jnp.array([0.2, 0.5, 0.2])
# 1D convolution
conv = lax.conv_general_dilated(x[jnp.newaxis, :, jnp.newaxis],
w[jnp.newaxis, :, jnp.newaxis],
window_strides=(1,),
padding='VALID',
dimension_numbers=('NWC','WIO','NWC'))
print('Convolution output:', conv)
Performing 1D convolution using JAX for sequence data.