Skip to main content

One post tagged with "machine-learning"

View All Tags

JAX Learning 1

· 3 min read
Gan
AI Engineer

1. Auto-vectorization (vmap)

JAX's vmap function can automatically vectorize functions to handle batched data.

import jax
import jax.numpy as jnp

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
output = []
for i in range(1, len(x)-1):
output.append(jnp.dot(x[i-1:i+2], w))
return jnp.array(output)

convolve(x, w)

# Batch processing
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])

auto_batch_convolve = jax.vmap(convolve)
auto_batch_convolve(xs, ws)

Output:

Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)

2. JIT Compilation

JAX's Just-In-Time (JIT) compilation can significantly improve function execution speed.

Basic Usage

import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)

# Non-compiled version
%timeit selu(x).block_until_ready()
# 1.05 ms ± 7.85 μs per loop

# JIT compiled version
selu_jit = jax.jit(selu)
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()
# 228 μs ± 1.69 μs per loop - approximately 4.6x speedup

Inspecting Compiled Code

Use make_jaxpr to see how JAX processes functions:

from jax import make_jaxpr

def f(x, y):
return jnp.dot(x + 1, y + 1)

y = jnp.ones(x.shape)
make_jaxpr(f)(x, y)

Output:

{ lambda ; a:i32[1000000] b:f32[1000000]. let
c:i32[1000000] = add a 1:i32[]
d:f32[1000000] = add b 1.0:f32[]
e:f32[] = dot_general[
dimension_numbers=(([0], [0]), ([], []))
preferred_element_type=float32
] c d
in (e,) }

Static Arguments

Mark variables as static that don't want JIT to trace (no recompilation needed when they change):

from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
return -x if neg else x

f(1, True)

3. Automatic Differentiation (grad)

JAX's automatic differentiation can automatically compute function gradients.

import jax
import jax.numpy as jnp
from jax import grad

key = jax.random.key(0)

def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3,))
b = jax.random.normal(b_key, ())

# Differentiate with respect to the first positional argument
W_grad = grad(loss, argnums=0)(W, b)
print(f'{W_grad=}')

# Since argnums=0 is the default, this does the same thing
W_grad = grad(loss)(W, b)
print(f'{W_grad=}')

# Differentiate with respect to the second argument
b_grad = grad(loss, 1)(W, b)
print(f'{b_grad=}')

# Differentiate with respect to multiple arguments
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print(f'{W_grad=}')
print(f'{b_grad=}')

Output:

W_grad=Array([-0.433146 , -0.7354605, -1.2598922], dtype=float32)
W_grad=Array([-0.433146 , -0.7354605, -1.2598922], dtype=float32)
b_grad=Array(-0.6900178, dtype=float32)
W_grad=Array([-0.433146 , -0.7354605, -1.2598922], dtype=float32)
b_grad=Array(-0.6900178, dtype=float32)