Google recently released an interesting new library called JAX. It looks like it could be very useful for computational statistics, so I thought I’d take a look. In this post, I’ll describe some of the basic features of JAX, and then use them to implement Laplace’s approximation, a method used in Bayesian statistics.
JAX is described as “Autograd and XLA, brought together”. Autograd refers to automatic differentiation, and XLA stands for accelerated linear algebra. Both of these are very useful in computational statistics, where we often have to differentiate things and perform lots of matrix (or tensor) manipulations.
JAX effectively extends the numpy library to include these extra components. In doing so, it preserves an API that many scientists are familiar with, and introduces powerful new functionality.
Automatic differentiation (or autodiff, for short) is a set of methods for evaluating the derivative of functions using a computer. If you’ve taken a calculus course before, you will have taken the derivative of functions by hand, e.g.
For the simple function above, differentiating is quite straightforward. But, when the function is complicated and has many arguments (for example, the objective function of a neural network), differentiating by hand quickly becomes unfeasible. Autodiff frameworks save us the trouble: we simply pass a function to the framework, and it returns another function that computes the gradient. This is really useful in computational statistics and machine learning, since we often want to optimise functions with respect to their arguments. Generally, we can carry out optimisation much more efficiently if we know gradients. It’s also really useful in Bayesian statistics, where the state-of-the-art Markov Chain Monte Carlo method, Hamiltonian Monte Carlo, relies on the calculation of the gradient of the posterior distribution.
So, how do we do this in JAX? Here’s a snippet that defines the logistic function in python, and then uses JAX to compute its derivative:
import jax.numpy as np from jax import grad def logistic(x): # logistic function return (1. + np.exp(-x)) ** -1.0 # differentiate with jax! grad_logistic = grad(logistic) print(logistic(0.0)) print(grad_logistic(0.0))
That was easy! There are two lines to focus on in this snippet:
import jax.numpy as np: instead of importing regular numpy, I imported
jax.numpywhich is JAX’s implementation of numpy functionality. After this line, we can pretty much forget about it and pretend that we’re using regular numpy most of the time.
grad_logistic = grad(logistic): this is where the magic happens. We passed the
logisticfunction to JAX’s
gradfunction, and it returned another function, which I called
grad_logistic. This function takes the same inputs as
logistic, but returns the gradient with respect to these inputs.
To convince ourselves that this all worked, let’s plot the logistic function and its derivative:
from jax import vmap import matplotlib.pyplot as plt plt.rcParams["figure.figsize"] = (10, 5) plt.rcParams["font.size"] = 20 vectorised_grad_logistic = vmap(grad_logistic) x = np.linspace(-10.0, 10.0, 1000) fig, ax = plt.subplots() ax.plot(x, logistic(x), label="$f(x)$") ax.plot(x, vectorised_grad_logistic(x), label="$f'(x)$") _ = ax.legend() _ = ax.set_xlabel("$x$")
You’ll notice that I had to define another function:
vectorised_grad_logistic in order to make the plot. The reason is that functions produced by
grad are not vectorised (cannot accept multiple inputs and return the gradient at each of the inputs). To facilitate this, we can wrap our
grad_logistic function with
vmap, which automatically vectorises it for us.
This is already pretty neat. We can also obtain higher-order derivatives by the repeated application of
second_order_grad_logistic = vmap(grad(grad(logistic))) x = np.linspace(-10.0, 10.0, 1000) fig, ax = plt.subplots() ax.plot(x, logistic(x), label="$f(x)$") ax.plot(x, vectorised_grad_logistic(x), label="$f'(x)$") ax.plot(x, second_order_grad_logistic(x), label="$f''(x)$") _ = ax.legend() _ = ax.set_xlabel("$x$")
Before I demonstrate an application of JAX, I want to mention another useful feature: JIT compilation. As you may know, python is an interpreted language, rather than being compiled. This is one of the reasons that python code can run slower than the same logic in a compiled language (like C). I won’t go into detail about why this is, but here’s a great blog post by Jake VanderPlas on the subject.
One of the reasons why numpy is so useful is that it is calling C code under the hood, which is compiled. This means that it can be much faster than code that uses native python arrays. JAX adds an additional feature on top of this: Just In Time (JIT) compilation. The “JIT” part means that the code is compiled at runtime the first time that it is needed. Using this feature can speed up our code.
To do this, we just have to apply the
jit decorator to our function:
from jax import jit @jit def jit_logistic(x): # logistic function return (1. + np.exp(-x)) ** -1.0 @jit def jit_grad_logistic(x): # compile the gradient as well return grad(logistic)(x)
Now we can compare our JIT compiled functions to the ones we made earlier, and see if there’s any difference in execution time:
376 µs ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
90.4 µs ± 1.19 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
1.76 ms ± 7.45 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
86.5 µs ± 1.48 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
For the plain logistic function, JIT led to a 4x speedup. For the gradient, there was a 20x speedup. These are non-trivial gains! I could have obtained even bigger speedups if I was using JAX in combination with a GPU.
Putting it all together: the Laplace approximation
I’ve barely scratched the surface of what JAX can do, but already there are interesting and useful applications with the functionality I have described here. As an example, I’ll describe how to implement an important method in Bayesian statistics: Laplace’s approximation.
The Laplace approximation
Imagine we have some probability model with some parameters , and we’ve constrained the model using some data . In Bayesian inference, our goal is always to calculate integrals like this:
we are interested in the expectation of some function with respect to the posterior distribution . For interesting models, the posterior is complex, and so we have no hope of calculating these integrals analytically. Because of this, Bayesians have devised many methods for approximating them. If you’ve got time, the best thing to do is use Markov Chain Monte Carlo. But, if your dataset is quite large relative to your time and computational budget, you may need to try something else. A typical choice is Variational Inference.
Another, possibly less talked-about, approach is called Laplace’s approximation. It works really well when you have quite a lot of data because of the Bayesian central limit theorem. In this approach, we approximate the posterior distribution by a Normal distribution. This is a common approximation (it’s often used in Variational Inference too), but Laplace’s method has a specific way of finding the Normal distribution that best matches the posterior.
Suppose we know the location of the maximum point of the posterior1. Now let’s Taylor expand the log posterior around this point. To reduce clutter, I’ll use the notation . For simplicity, let’s consider the case when is scalar:
The first derivative disappears because is a maximum point, so the gradient there is zero. Let’s compare this to the logarithm of a normal distribution with mean and standard deviation , which I’ll call :
We can match up the terms in the expressions for and the Taylor expansion of (ignoring the constant additive terms) to see that
Consequently, we might try approximating the posterior distribution with a Normal distribution, and set the mean and variance to these values. In multiple dimensions, the covariance matrix of the resulting multivariate normal is the inverse of the Hessian matrix of the negative log posterior at :
Already, we can see that Laplace’s approximation requires us to be able to twice differentiate the posterior distribution in order to obtain . In addition, we have to find the location of the maximum of the posterior. We probably have to do this numerically, which means using some kind of optimisation routine. The most efficient of these optimisation routines require the gradient of the objective function. So, using Laplace’s approximation means we want to evaluate the first and second derivatives of the posterior. Sounds like a job for JAX!
Example: a Student-t posterior distribution
Suppose our true posterior is a 2D Student-t:
This is a simple example, and we can actually sample from a Student-t rather easily. Nevertheless, let’s go ahead and use it to implement Laplace’s method in JAX. Let’s set the values of the constants in the Student-t:
First, let’s plot the log posterior:
# choose some values for the Student-t sigma = np.array([(1.0, 0.5), (0.5, 1.0)]) mu = np.array([0.5, 2.0]) nu = np.array() sigma_inv = np.linalg.inv(sigma) def log_posterior(theta): return np.log( 1.0 + nu ** -1.0 * np.dot((theta - mu), np.dot(sigma_inv, (theta - mu).T).T) ) * (0.5 * -(nu + theta.shape)) # plot the distribution x = np.linspace(-10, 10, 100) y = np.linspace(-10, 10, 100) X, Y = np.meshgrid(x, y) XY = np.stack((X, Y)).reshape(2, 10000).T Z = vmap(log_posterior, in_axes=0)(XY).reshape(100, 100) fig, ax = plt.subplots() ax.contourf(X, Y, Z) ax.set_xlabel(r"$\theta_0$") _ = ax.set_ylabel(r"$\theta_1$")
Now let’s implement Laplace’s method in JAX:
from jax import hessian from scipy.optimize import minimize from scipy.stats import multivariate_normal @jit def negative_log_posterior(theta): # negative log posterior to minimise return (-np.log( 1.0 + nu ** -1.0 * np.dot((theta - mu), np.dot(sigma_inv, (theta - mu).T).T) ) * (0.5 * -(nu + theta.shape))) @jit def grad_negative_log_posterior(theta): # gradient of the negative log posterior return grad(negative_log_posterior)(theta) @jit def approx_covariance_matrix(theta): # evaluate the covariance matrix of the approximate normal return np.linalg.inv(hessian(negative_log_posterior)(theta)) # go! theta_star = minimize( negative_log_posterior, np.array([0.0, 0.0]), jac=grad_negative_log_posterior, method="BFGS" ).x sigma_approx = approx_covariance_matrix(theta_star)
This is a very short piece of code! I had to define the negative log posterior (and JIT compiled it for speed), since we will minimise this to find . Then, I used JAX’s
grad function to differentiate this once, so that we can used a gradient-based optimiser. Next, I used JAX’s
hessian function to find the covariance matrix for our approximating normal.
Finally, I used scipy’s
minimize function to find the optimal point .
Note that this code is actually rather general! As long as the function
negative_log_posterior can be implemented in a way that JAX can differentiate (which it probably can), then the rest of the code stays exactly the same!
Let’s have a look at how good our normal approximation is:
from scipy.stats import norm from scipy.stats import t fig = plt.figure(constrained_layout=True, figsize=(15, 10)) spec = fig.add_gridspec(ncols=2, nrows=2) fig.subplots_adjust(hspace=0, wspace=0) ax3 = fig.add_subplot(spec[0, 0]) ax2 = fig.add_subplot(spec[1, 1]) ax1 = fig.add_subplot(spec[1, 0]) contour = ax1.contour( X, Y, Z / Z.max(), colors="0.4", levels=15, linestyles="-", linewidths=3 ) # calculate the density of the approximating Normal distribution Z_0 = ( multivariate_normal(mean=theta_star, cov=sigma_approx).logpdf(XY).reshape(100, 100) ) ax1.contour( X, Y, Z_0 / Z_0.max(), colors="#2c7fb8", levels=15, linestyles="--", linewidths=3 ) ax1.set_xlabel(r"$\theta_0$") ax1.set_ylabel(r"$\theta_1$") ax2.plot( norm.pdf(theta_grid, theta_star, np.sqrt(sigma_approx[1, 1])), theta_grid, c="#2c7fb8", ls="--", lw=3, ) ax2.plot( t.pdf(theta_grid, nu, mu, np.sqrt(sigma[1, 1])), theta_grid, c="0.4", lw=3 ) ax3.plot( theta_grid, norm.pdf(theta_grid, theta_star, np.sqrt(sigma_approx[0, 0])), c="#2c7fb8", ls="--", lw=3, label="Laplace", ) ax3.plot( theta_grid, t.pdf(theta_grid, nu, mu, np.sqrt(sigma[0, 0])), c="0.4", lw=3, label="Exact", ) ax3.legend() ax2.xaxis.set_visible(False) ax3.yaxis.set_visible(False)
At least by eye, the approximation seems reasonable. Of course, I have rather cheated here since a Student-t approaches a normal distribution as . Nonetheless, it’s still pleasing to see that the numerical implementation with JAX and scipy works as expected.
Hopefully this post has inspired you to go and play with JAX yourself. There are a ton of interesting applications that I can imagine for this library. Some already exist, such as the numpyro library from Uber, which uses JAX under the hood to perform fast Hamiltonian Monte Carlo. In addition, it’ll be interesting to see how this library is adopted as compared with other popular autodiff frameworks like Tensorflow and Pytorch.
I’m assuming there’s only one maximum, but in reality there might be several if the posterior is multimodal. Multimodality is a pain, and Laplace’s approximation won’t do as well in this case (in fact most methods in Bayesian statistics share this weakness). ↩