A brief introduction to JAX and Laplace’s method

10 minute read

Introduction

Google recently released an interesting new library called JAX. It looks like it could be very useful for computational statistics, so I thought I’d take a look. In this post, I’ll describe some of the basic features of JAX, and then use them to implement Laplace’s approximation, a method used in Bayesian statistics.

JAX is described as “Autograd and XLA, brought together”. Autograd refers to automatic differentiation, and XLA stands for accelerated linear algebra. Both of these are very useful in computational statistics, where we often have to differentiate things and perform lots of matrix (or tensor) manipulations.

JAX effectively extends the numpy library to include these extra components. In doing so, it preserves an API that many scientists are familiar with, and introduces powerful new functionality.

Automatic differentiation

Automatic differentiation (or autodiff, for short) is a set of methods for evaluating the derivative of functions using a computer. If you’ve taken a calculus course before, you will have taken the derivative of functions by hand, e.g.

\[f(x) = x ^ 2 \implies f'(x) = 2 x\]

For the simple function above, differentiating is quite straightforward. But, when the function is complicated and has many arguments (for example, the objective function of a neural network), differentiating by hand quickly becomes unfeasible. Autodiff frameworks save us the trouble: we simply pass a function to the framework, and it returns another function that computes the gradient. This is really useful in computational statistics and machine learning, since we often want to optimise functions with respect to their arguments. Generally, we can carry out optimisation much more efficiently if we know gradients. It’s also really useful in Bayesian statistics, where the state-of-the-art Markov Chain Monte Carlo method, Hamiltonian Monte Carlo, relies on the calculation of the gradient of the posterior distribution.

So, how do we do this in JAX? Here’s a snippet that defines the logistic function in python, and then uses JAX to compute its derivative:

import jax.numpy as np
from jax import grad

def logistic(x):
    # logistic function
    return (1. + np.exp(-x)) ** -1.0

# differentiate with jax!
grad_logistic = grad(logistic)

print(logistic(0.0))
print(grad_logistic(0.0))
0.5
0.25

That was easy! There are two lines to focus on in this snippet:

  1. import jax.numpy as np: instead of importing regular numpy, I imported jax.numpy which is JAX’s implementation of numpy functionality. After this line, we can pretty much forget about it and pretend that we’re using regular numpy most of the time.
  2. grad_logistic = grad(logistic): this is where the magic happens. We passed the logistic function to JAX’s grad function, and it returned another function, which I called grad_logistic. This function takes the same inputs as logistic, but returns the gradient with respect to these inputs.

To convince ourselves that this all worked, let’s plot the logistic function and its derivative:

from jax import vmap
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = (10, 5)
plt.rcParams["font.size"] = 20

vectorised_grad_logistic = vmap(grad_logistic)

x = np.linspace(-10.0, 10.0, 1000)
fig, ax = plt.subplots()
ax.plot(x, logistic(x), label="$f(x)$")
ax.plot(x, vectorised_grad_logistic(x), label="$f'(x)$")
_ = ax.legend()
_ = ax.set_xlabel("$x$")

png

You’ll notice that I had to define another function: vectorised_grad_logistic in order to make the plot. The reason is that functions produced by grad are not vectorised (cannot accept multiple inputs and return the gradient at each of the inputs). To facilitate this, we can wrap our grad_logistic function with vmap, which automatically vectorises it for us.

This is already pretty neat. We can also obtain higher-order derivatives by the repeated application of grad:

second_order_grad_logistic = vmap(grad(grad(logistic)))

x = np.linspace(-10.0, 10.0, 1000)
fig, ax = plt.subplots()
ax.plot(x, logistic(x), label="$f(x)$")
ax.plot(x, vectorised_grad_logistic(x), label="$f'(x)$")
ax.plot(x, second_order_grad_logistic(x), label="$f''(x)$")
_ = ax.legend()
_ = ax.set_xlabel("$x$")

png

JIT complilation

Before I demonstrate an application of JAX, I want to mention another useful feature: JIT compilation. As you may know, python is an interpreted language, rather than being compiled. This is one of the reasons that python code can run slower than the same logic in a compiled language (like C). I won’t go into detail about why this is, but here’s a great blog post by Jake VanderPlas on the subject.

One of the reasons why numpy is so useful is that it is calling C code under the hood, which is compiled. This means that it can be much faster than code that uses native python arrays. JAX adds an additional feature on top of this: Just In Time (JIT) compilation. The “JIT” part means that the code is compiled at runtime the first time that it is needed. Using this feature can speed up our code.

To do this, we just have to apply the jit decorator to our function:

from jax import jit

@jit
def jit_logistic(x):
    # logistic function
    return (1. + np.exp(-x)) ** -1.0

@jit
def jit_grad_logistic(x):
    # compile the gradient as well
    return grad(logistic)(x)

Now we can compare our JIT compiled functions to the ones we made earlier, and see if there’s any difference in execution time:

%%timeit
logistic(0.0)
376 µs ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%%timeit
jit_logistic(0.0)
90.4 µs ± 1.19 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%%timeit
grad_logistic(0.0)
1.76 ms ± 7.45 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%%timeit
jit_grad_logistic(0.0)
86.5 µs ± 1.48 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

For the plain logistic function, JIT led to a 4x speedup. For the gradient, there was a 20x speedup. These are non-trivial gains! I could have obtained even bigger speedups if I was using JAX in combination with a GPU.

Putting it all together: the Laplace approximation

I’ve barely scratched the surface of what JAX can do, but already there are interesting and useful applications with the functionality I have described here. As an example, I’ll describe how to implement an important method in Bayesian statistics: Laplace’s approximation.

The Laplace approximation

Imagine we have some probability model with some parameters \(\theta\), and we’ve constrained the model using some data \(D\). In Bayesian inference, our goal is always to calculate integrals like this:

\[\mathbb{E}\left[h(\theta)\right] = \int \mathrm{d}\theta \, h(\theta) \, p(\theta | D)\]

we are interested in the expectation of some function \(h(\theta)\) with respect to the posterior distribution \(p(\theta | D)\). For interesting models, the posterior is complex, and so we have no hope of calculating these integrals analytically. Because of this, Bayesians have devised many methods for approximating them. If you’ve got time, the best thing to do is use Markov Chain Monte Carlo. But, if your dataset is quite large relative to your time and computational budget, you may need to try something else. A typical choice is Variational Inference.

Another, possibly less talked-about, approach is called Laplace’s approximation. It works really well when you have quite a lot of data because of the Bayesian central limit theorem. In this approach, we approximate the posterior distribution by a Normal distribution. This is a common approximation (it’s often used in Variational Inference too), but Laplace’s method has a specific way of finding the Normal distribution that best matches the posterior.

Suppose we know the location \(\theta^*\) of the maximum point of the posterior1. Now let’s Taylor expand the log posterior around this point. To reduce clutter, I’ll use the notation \(\log p(\theta | D) \equiv f(\theta)\). For simplicity, let’s consider the case when \(\theta\) is scalar:

\[f(\theta) \approx f(\theta^*) + \frac{\partial f}{\partial \theta}\bigg|_{\theta^*}\,(\theta - \theta^*) + \dfrac{1}{2}\frac{\partial^2 f}{\partial \theta^2}\bigg|_{\theta^*}\,(\theta - \theta^*)^2 \\ = f(\theta^*) + \dfrac{1}{2}\frac{\partial^2 f}{\partial \theta^2}\bigg|_{\theta^*}\,(\theta - \theta^*)^2\]

The first derivative disappears because \(\theta^*\) is a maximum point, so the gradient there is zero. Let’s compare this to the logarithm of a normal distribution with mean \(\mu\) and standard deviation \(\sigma\), which I’ll call \(g(\theta)\):

\[g(\theta) = -\frac{1}{2}\log (2\pi\sigma^2) - \dfrac{1}{2}\dfrac{1}{\sigma^2}(\theta - \mu)^2\]

We can match up the terms in the expressions for \(g(\theta)\) and the Taylor expansion of \(f(\theta)\) (ignoring the constant additive terms) to see that

\[\mu = \theta^* \\ \sigma^2 = \left(-\dfrac{\partial^2 f}{\partial \theta^2}\bigg|_{\theta^*}\right)^{-1}\]

Consequently, we might try approximating the posterior distribution with a Normal distribution, and set the mean and variance to these values. In multiple dimensions, the covariance matrix of the resulting multivariate normal is the inverse of the Hessian matrix of the negative log posterior at \(\theta^*\):

\[\Sigma_{ij} = \dfrac{\partial ^2 (-f)}{\partial \theta_i \partial \theta_j}^{-1}\bigg|_{\theta^*}\]

Already, we can see that Laplace’s approximation requires us to be able to twice differentiate the posterior distribution in order to obtain \(\sigma\). In addition, we have to find the location \(\theta^*\) of the maximum of the posterior. We probably have to do this numerically, which means using some kind of optimisation routine. The most efficient of these optimisation routines require the gradient of the objective function. So, using Laplace’s approximation means we want to evaluate the first and second derivatives of the posterior. Sounds like a job for JAX!

Example: a Student-t posterior distribution

Suppose our true posterior is a 2D Student-t:

\[p(\theta | D) \propto \left(1+\frac{1}{\nu}(\theta - \mu)^T \Sigma^{-1}(\theta - \mu)\right)^{-(\nu + \mathrm{dim}(\theta))/2}\]

This is a simple example, and we can actually sample from a Student-t rather easily. Nevertheless, let’s go ahead and use it to implement Laplace’s method in JAX. Let’s set the values of the constants in the Student-t:

\[\mu = \begin{pmatrix} 0.5 \\ 2 \end{pmatrix} \\ \Sigma = \begin{pmatrix} 1 & 0.5 \\ 0.5 & 1 \end{pmatrix} \\ \nu = 7\]

First, let’s plot the log posterior:

# choose some values for the Student-t
sigma = np.array([(1.0, 0.5), (0.5, 1.0)])
mu = np.array([0.5, 2.0])
nu = np.array([7])

sigma_inv = np.linalg.inv(sigma)

def log_posterior(theta):
    return np.log(
            1.0 + nu ** -1.0 * np.dot((theta - mu), np.dot(sigma_inv, (theta - mu).T).T)
        ) * (0.5  * -(nu + theta.shape[0]))

# plot the distribution
x = np.linspace(-10, 10, 100)
y = np.linspace(-10, 10, 100)

X, Y = np.meshgrid(x, y)
XY = np.stack((X, Y)).reshape(2, 10000).T

Z = vmap(log_posterior, in_axes=0)(XY).reshape(100, 100)

fig, ax = plt.subplots()
ax.contourf(X, Y, Z)
ax.set_xlabel(r"$\theta_0$")
_ = ax.set_ylabel(r"$\theta_1$")

png

Now let’s implement Laplace’s method in JAX:

from jax import hessian
from scipy.optimize import minimize
from scipy.stats import multivariate_normal

@jit
def negative_log_posterior(theta):
    # negative log posterior to minimise
    return (-np.log(
        1.0 + nu ** -1.0 * np.dot((theta - mu), np.dot(sigma_inv, (theta - mu).T).T)
    ) * (0.5  * -(nu + theta.shape[0])))[0]

@jit
def grad_negative_log_posterior(theta):
    # gradient of the negative log posterior
    return grad(negative_log_posterior)(theta)

@jit
def approx_covariance_matrix(theta):
    # evaluate the covariance matrix of the approximate normal
    return np.linalg.inv(hessian(negative_log_posterior)(theta))

# go!
theta_star = minimize(
    negative_log_posterior, 
    np.array([0.0, 0.0]), 
    jac=grad_negative_log_posterior, 
    method="BFGS"
).x

sigma_approx = approx_covariance_matrix(theta_star)

This is a very short piece of code! I had to define the negative log posterior (and JIT compiled it for speed), since we will minimise this to find \(\theta^*\). Then, I used JAX’s grad function to differentiate this once, so that we can used a gradient-based optimiser. Next, I used JAX’s hessian function to find the covariance matrix for our approximating normal. Finally, I used scipy’s minimize function to find the optimal point \(\theta^*\).

Note that this code is actually rather general! As long as the function negative_log_posterior can be implemented in a way that JAX can differentiate (which it probably can), then the rest of the code stays exactly the same! Let’s have a look at how good our normal approximation is:

from scipy.stats import norm
from scipy.stats import t

fig = plt.figure(constrained_layout=True, figsize=(15, 10))
spec = fig.add_gridspec(ncols=2, nrows=2)
fig.subplots_adjust(hspace=0, wspace=0)

ax3 = fig.add_subplot(spec[0, 0])
ax2 = fig.add_subplot(spec[1, 1])
ax1 = fig.add_subplot(spec[1, 0])

contour = ax1.contour(
    X, Y, Z / Z.max(), colors="0.4", levels=15, linestyles="-", linewidths=3
)

# calculate the density of the approximating Normal distribution
Z_0 = (
    multivariate_normal(mean=theta_star, cov=sigma_approx).logpdf(XY).reshape(100, 100)
)

ax1.contour(
    X, Y, Z_0 / Z_0.max(), colors="#2c7fb8", levels=15, linestyles="--", linewidths=3
)

ax1.set_xlabel(r"$\theta_0$")
ax1.set_ylabel(r"$\theta_1$")


ax2.plot(
    norm.pdf(theta_grid, theta_star[1], np.sqrt(sigma_approx[1, 1])),
    theta_grid,
    c="#2c7fb8",
    ls="--",
    lw=3,
)
ax2.plot(
    t.pdf(theta_grid, nu[1], mu[1], np.sqrt(sigma[1, 1])), theta_grid, c="0.4", lw=3
)


ax3.plot(
    theta_grid,
    norm.pdf(theta_grid, theta_star[0], np.sqrt(sigma_approx[0, 0])),
    c="#2c7fb8",
    ls="--",
    lw=3,
    label="Laplace",
)
ax3.plot(
    theta_grid,
    t.pdf(theta_grid, nu[0], mu[0], np.sqrt(sigma[0, 0])),
    c="0.4",
    lw=3,
    label="Exact",
)
ax3.legend()

ax2.xaxis.set_visible(False)
ax3.yaxis.set_visible(False)

png

At least by eye, the approximation seems reasonable. Of course, I have rather cheated here since a Student-t approaches a normal distribution as \(\nu \rightarrow \infty\). Nonetheless, it’s still pleasing to see that the numerical implementation with JAX and scipy works as expected.

Conclusion

Hopefully this post has inspired you to go and play with JAX yourself. There are a ton of interesting applications that I can imagine for this library. Some already exist, such as the numpyro library from Uber, which uses JAX under the hood to perform fast Hamiltonian Monte Carlo. In addition, it’ll be interesting to see how this library is adopted as compared with other popular autodiff frameworks like Tensorflow and Pytorch.

  1. I’m assuming there’s only one maximum, but in reality there might be several if the posterior is multimodal. Multimodality is a pain, and Laplace’s approximation won’t do as well in this case (in fact most methods in Bayesian statistics share this weakness).