Jekyll2020-11-27T00:10:50+00:00https://anguswilliams91.github.io/feed.xmlMy blog discussing various topics in statistics and data science.Angus WilliamsHow should we compare models?2020-10-04T00:00:00+00:002020-10-04T00:00:00+00:00https://anguswilliams91.github.io/statistics/model-comparison<p>At various points in the past few years I have had discussions or debates with friends and colleagues about model comparison in the context of Bayesian inference.
What is the most “principled” way to do it?
What are the relative merits of different approaches?
My opinion has evolved alongside my understanding of the subject, and I recently read a paper that conceptually explained some of the intuition that I had developed over the years.
Consequently, I wanted to write a note on this subject in case it is useful for others.</p>
<h2 id="introduction-to-model-comparison">Introduction to model comparison</h2>
<p>Before diving into different approaches for model comparison, let me first define it.
Suppose you are building a statistical model for a particular process, and have some data \(y\) to use in order to fit and test your approaches.
In the process of analysing the data, you come up with two distinct approaches for modelling: \(M_1\) and \(M_2\).
For example, if you are solving a regression problem, \(M_1\) might be a GLM of some sort, whereas \(M_2\) could be a GAM.
Both models seem to fit the data, but you want to quantify which is doing a better job.
Any approach to answering this question falls into the category of <em>model comparison</em>.</p>
<p>What should we care about when comparing models?
Fundamentally, we are interested in knowing which model provides a better approximation to the underlying data generating process.
In that sense, we want to know which model <em>generalises</em> better beyond the immediate data \(y\) that we have available.
In other words, if were were to receive some new data \(y_\mathrm{new}\), which model would describe it better?</p>
<p>The field of model comparison seeks to answer this question, particularly when we must try to make our best guess without access to an unlimited supply of data.</p>
<h2 id="recap-of-the-bayesian-approach">Recap of the Bayesian approach</h2>
<p>No piece of writing about Bayesian methods would be complete without stating Bayes’ theorem.
In the context of a single model, we can write this as:</p>
\[p(\theta | y) = \dfrac{p(y | \theta)p(\theta)}{p(y)}, \label{bayes}\tag{1}\]
<p>where \(\theta\) are the set of model parameters (e.g., the coefficients in a linear regression) and \(y\) are the data we are analysing.
\(p(\theta | y)\) is called the <em>posterior distribution</em>, because it is the distribution of the model parameters <em>conditional</em> on the data \(y\) (i.e., <em>after</em> we receive the data).
\(p(\theta)\) is the <em>prior</em> because it is the <em>unconditional</em> distribution of the model parameters (i.e., prior to seeing the data \(y\)).
\(p(y | \theta)\) is called the <em>likelihood</em>, and is the probability of the data given a particular set of model parameters.
\(p(y)\) has a few names: the <em>marginal likelihood</em> or the <em>evidence</em> are probably the two most common.
I’ll discuss this part in more detail shortly.</p>
<p>The above discussion is in the context of a single model, so everything is implicitly conditional on the choice of model \(M\):</p>
\[p(\theta | y, M) = \dfrac{p(y | \theta, M)p(\theta | M)}{p(y | M)}. \label{model_bayes}\tag{2}\]
<p>We don’t normally write it like this because it is clear that we are implicitly conditioning on the particular model we are considering, but it’s helpful to write Bayes’ theorem like this when thinking about model comparison.</p>
<h2 id="bayes-factors">Bayes factors</h2>
<p>Since we can use conditional probability to quantify how we should update our beliefs about the parameters \(\theta\) of an individual model given some data \(y\), why not use it to update our beliefs about which model might be the best choice from an available set?
This feels like a very natural way to approach the problem of model comparison.
Concretely, we can write down Bayes’ theorem as</p>
\[p(M | y) = \dfrac{p(y | M) p(M)}{p(y)}.\]
<p>Now we have the probability that model \(M\) is the “true” model given the data at hand: \(p(M | y)\).
\(p(M)\) is the prior probability that model \(M\) is correct.</p>
<p>On the face of it, this seems to be exactly what we need!
Suppose we are comparing two models,\(M_1\) and \(M_2\)
We can make a very intuitive rule for choosing between them: if \(p(M_1 | y) > p(M_2 | y)\), then choose \(M_1\), and choose \(M_2\) if the converse is true.</p>
<p>If we have no prior preference for \(M_1\) or \(M_2\), so that \(p(M_1) = p(M_2) = \frac{1}{2}\), it’s easy to show that</p>
\[\dfrac{p(M_1 | y)}{p(M_2 | y)} = \dfrac{p(y | M_1)}{p(y | M_2)} = K.\]
<p>We call \(K\) the <em>Bayes factor</em>.
In terms of \(K\), we should choose \(M_1\) if \(K > 1\), and choose \(M_2\) if \(K < 1\).</p>
<p>So, if we can find a way to calculate \(p(y | M_i)\), then we can use the Bayes factor to choose between models.
How can we calculate it?
You might have noticed that \(p(y | M)\) appeared earlier in (\ref{model_bayes}) - it appears as the denominator on the RHS - we called it the <em>marginal likelihood</em> or the <em>evidence</em>.
But what is it?
A clue comes from one of its names: the <em>marginal</em> likelihood.
We can write it as follows:</p>
\[p(y | M) = \int \mathrm{d}\theta\, p(y | \theta, M)\,p(\theta | M) \label{marg_lik}\tag{3}\]
<p>i.e., we <em>marginalise</em> out the parameters of the model \(\theta\) to obtain the probability of the data \(y\) given the model choice \(M\).
We marginalise out \(\theta\) using the prior distribution, \(p(\theta | M)\).
This adds some intuition about the Bayes’ factor: we choose the model for which the data \(y\) are <em>most probable, given all likely configurations of the model parameters</em>.</p>
<p>This argument for using Bayes’ factors for model comparison is quite persuasive, and at first it can seem almost irrefutable because it is so intuitive.</p>
<h2 id="all-that-glitters-is-not-gold-bayes-factors-are-overly-sensitive-to-the-prior">All that glitters is not gold: Bayes factors are overly sensitive to the prior</h2>
<p>Despite seeming apparently watertight at first, the Bayes factor has some undesirable traits.
These are well documented, and I’ll focus on just one of them in this note, that they are overly sensitive to the prior.</p>
<p>Practitioners of Bayesian inference quickly learn the rule of thumb that the more data you have, the less influence the prior has on your final inference.
This makes sense - the more evidence you accumulate by collecting more data, the less weight you will place on your prior beliefs.</p>
<p>Suppose we have a model \(M\) with a single parameter \(\theta\) and some data \(y\).
Let’s further assume that the <em>likelihood</em> \(p(y | \theta, M)\) is a normal distribution with mean \(\mu_\ell\) and variance \(\sigma_\ell ^2\):</p>
\[p(y | \theta, M) = \dfrac{C}{\sqrt{2\pi \sigma_\ell ^ 2}} \exp [ -\dfrac{(\theta - \mu_\ell)^2}{2 \sigma_\ell ^2}].\]
<p>Note that because I’m considering the likelihood as a function of \(\theta\), it need not be normalised, which is why the constant \(C\) appears on the right hand side.
Further suppose that the prior on \(\theta\) is a normal distribution with mean zero and variance \(\sigma_\mathrm{prior} ^2\)</p>
\[p(\theta | M) = \dfrac{1}{\sqrt{2\pi \sigma_\mathrm{prior} ^ 2}} \exp [ -\dfrac{\theta ^2}{2 \sigma_\mathrm{prior} ^2}].\]
<p>Given these assumptions, we can exactly work out some of the quantities of interest for model comparison.
Let’s first work out the numerator of the RHS of (\ref{model_bayes}). It turns out to be another normal distribution:</p>
\[p(y | \theta, M)\,p(\theta) = A \times \dfrac{1}{\sqrt{2\pi \sigma_\mathrm{post} ^ 2}}\exp [ -\dfrac{(\theta - \mu_\mathrm{post})^2}{2 \sigma_\mathrm{post} ^2}],\label{integral}\tag{4}\]
<p>where</p>
\[\mu_\mathrm{post} = \mu_\ell \dfrac{\sigma_\mathrm{prior} ^2}{\sigma_\mathrm{prior} ^2 + \sigma_\ell ^2} \quad;\quad \sigma_\mathrm{post}^2 = \dfrac{\sigma_\mathrm{prior}^2 \sigma_\ell^2}{\sigma_\mathrm{prior}^2 + \sigma_\ell^2}.\label{norm_post}\tag{5}\]
<p>The constant \(A\) is equal to yet another normal distribution:</p>
\[A = \dfrac{C}{\sqrt{2\pi (\sigma_\ell ^ 2 + \sigma_\mathrm{prior} ^ 2)}} \exp [ -\dfrac{\mu_\ell^2}{2 (\sigma_\ell ^2 + \sigma_\mathrm{prior}^2)}].\label{normal_marginal}\tag{6}\]
<p>Now, looking at Bayes’ theorem (\ref{model_bayes}), we can see that the left hand side is a probability distribution, which means that it integrates to one:</p>
\[\int p(\theta | y, M) \, \mathrm{d}\theta = 1.\]
<p>Using right hand side of Bayes theorem, we can rearrange this to:</p>
\[\int p(y | \theta, M) p(\theta | M) \, \mathrm{d}\theta = p(y | M).\]
<p>Thus we can see that the marginal likelihood can be regarded as a <em>normalisation constant</em> - it guarantees that the posterior distribution integrates to one.
Looking at (\ref{integral}), we can deduce that the constant \(A\) is in fact the marginal likelihood:</p>
<p>\begin{equation}
\begin{aligned}
\int p(y | \theta, M)\,p(\theta)\,\mathrm{d} \theta &= A \times \int \dfrac{1}{\sqrt{2\pi \sigma_\mathrm{post} ^ 2}}\exp [ -\dfrac{(\theta - \mu_\mathrm{post})^2}{2 \sigma_\mathrm{post} ^2}] \,\mathrm{d}\theta \newline
&= A \newline
\implies p(y | M) &= A
\end{aligned}
\end{equation}</p>
<p>This result follows from the fact that the normal distribution integrates to one.
This means that the posterior distribution is</p>
\[p(\theta | M, y) = \dfrac{1}{\sqrt{2\pi \sigma_\mathrm{post} ^ 2}}\exp [ -\dfrac{(\theta - \mu_\mathrm{post})^2}{2 \sigma_\mathrm{post} ^2}]\]
<p>So, that seems like a lot of work, but now we have expressions for the posterior \(p(\theta | M, y)\) and the marginal likelihood \(p(y | M)\).
Now that we have them, lets’s think about the limit when we have a lot of data.
In this case, the likelihood will be more informative than the prior - it will be more strongly peaked around its mean.
We can express this quantitatively by saying that \(\sigma_\mathrm{prior}^2 \gg \sigma_\ell^2\), since the larger the variance of a normal distribution, the less strongly peaked it is.</p>
<p>What does the posterior look like in this limit?
We know it is a normal distribution with mean and variance given by (\ref{norm_post}).
In the limit we’re interested in, these become:</p>
\[\mu_\mathrm{post} \approx \mu_\ell \quad;\quad \sigma_\mathrm{post}^2 \approx \sigma_\ell^2.\]
<p>This makes intuitive sense - when the likelihood dominates the prior, the posterior will approach the likelihood as the prior loses influence.
What about the marginal likelihood?
Looking at (\ref{normal_marginal}), we can see that this becomes:</p>
\[p(y | M) \approx \dfrac{C}{\sqrt{2\pi\sigma_\mathrm{prior} ^ 2}} \exp [ -\dfrac{\mu_\ell^2}{2 \sigma_\mathrm{prior}^2}].\label{approx_marginal}\tag{7}\]
<p>One thing to immediately notice is that, in the limit where the likelihood dominates the prior, the posterior distribution is approximately <em>independent</em> of the prior: we can see that because it is approximately equal to the likelihood.
But quite the opposite is true of the marginal likelihood!
It explicitly depends on the prior variance \(\sigma_\mathrm{prior}^2\).</p>
<p>Let’s consider two versions of \(M\).
Both use the same likelihood and data, so they have the same likelihood function, but have different priors.
For concreteness, let’s say that \(\sigma_\ell = 1\), and \(\mu_\ell = 1\).
In one case, call it \(M_1\), \(\sigma_\mathrm{prior} = 10\), and in the other, \(M_2\), we have \(\sigma_\mathrm{prior} = 100\).
In both cases, \(\sigma_\ell \ll \sigma_\mathrm{prior}\), so we’re in the limit considered above.
But let’s compute the Bayes factor for these two models using (\ref{approx_marginal}):</p>
\[K = \dfrac{p(y | M_1)}{p(y | M_2)} \approx 10.\]
<p>So, even though these two models produce almost <em>identical</em> inference (they have almost the same posterior), the Bayes factor tells us that one of them is <em>ten times</em> more likely than the other!
Thus, our model selection policy tells us that we should <em>strongly</em> favour \(M_1\) over \(M_2\), even though these two models produce essentially identical predictions for new data.
At this point, a quote from <a href="http://www.stat.columbia.edu/~gelman/research/published/philosophy.pdf">Gelman & Shalizi (2012)</a> is appropriate:</p>
<p>“The main point where we disagree with many Bayesians is that we do not see Bayesian methods as generally useful for giving the posterior probability that a model is true, or the probability for preferring model A over model B, or whatever. Beyond the philosophical difficulties, there are technical problems with methods that purport to determine the posterior probability of models, most notably that in models with continuous parameters, aspects of the model that have essentially no effect on posterior inferences within a model can have huge effects on the comparison of posterior probability among models.”</p>
<h2 id="an-alternative-approach-cross-validation">An alternative approach: cross validation</h2>
<p>To me, cross validation often used to feel like a concept more commonly used in the field of machine learning, and less elegant than ideas like Bayes factors.
That said, cross validation makes great sense intuitively as a method for model comparison.
Following <a href="https://arxiv.org/abs/1507.04544">Vehtari et al. (2016)</a>, let’s first define a useful metric for comparing probabilistic models, the expected log pointwise predictive density for a new dataset:</p>
\[\mathrm{elpd} = \sum_{i=1}^{N} \int \mathrm{d}\tilde y \, p_t(\tilde y_i)\, \log p(\tilde y_i | y, M).\label{elpd}\tag{8}\]
<p>In this expression, we have</p>
<ul>
<li>\(\tilde y\) is a new, unseen dataset that (i.e., is not part of the training data \(y\)),assumed to contain \(N\) data points.</li>
<li>\(p_t(\tilde y)\) is the true data generating distribution of \(\tilde y\).</li>
<li>\(p(\tilde y \| y, M)\) is the predictive distribution of the model being assessed.</li>
</ul>
<p>For Bayesian models, the predictive distribution is</p>
\[p(\tilde y | y, M) = \int \mathrm{d}\theta \, p(\tilde y | \theta, M) p(\theta | y, M).\]
<p>It is so called because this is the pdf we would use for predicting new data, because we include the new information we got from the training data \(y\).
Notice how similar this looks to the marginal likelihood (\ref{marg_lik}).
Both are integrals of the product of the likelihood with a distribution over \(\theta\).
For the marginal likelihood, this distribution is the prior \(p(\theta | M)\), whereas for the predictive distribution it is the posterior \(p(\theta | y, M)\).</p>
<p>The elpd looks like a very useful alternative metric for comparing models.
The larger the elpd for a given model, the better we expect it to generalise to new datasets.
Concretely, for two models \(M_1\) and \(M_2\), we would choose \(M_1\) if \(\mathrm{elpd}_\mathrm{M_1} > \mathrm{elpd}_\mathrm{M_2}\), and choose \(M_2\) if the converse is true.
Take the example above where the Bayes factor strongly favoured \(M_1\), but where \(M_1\) and \(M_2\) have the same likelihood function and posterior distribution.
What would the elpd look like for these models?
Since they have the same likelihood and posterior distributions, we can say:</p>
\[\begin{aligned}
p(\tilde y | y, M_1) &\simeq p(\tilde y | y, M_2) \\
\implies \mathrm{elpd}_{M_1} &\simeq \mathrm{elpd}_{M_2}.
\end{aligned}\]
<p>Thus, the elpd would say that these two models are the same in terms of predictive performance, and we could choose either one of them.
This makes a lot more sense that the result using Bayes factors!</p>
<p>However, there is a problem - in order to evaluate the elpd, we require the <em>true data generating distribution</em>, \(p_t(\tilde y_i)\).
Obviously we don’t know what that is, otherwise we would not be bothering to build a model to approximate it!
The way we proceed is through the use of cross validation.
We can’t get at \(p_t(\tilde y_i)\) directly, but we can note that the training data \(y\) should be a representative draw from this distribution.
Consequently, we can proceed using Monte Carlo integration to approximate the elpd:</p>
\[\mathrm{elpd} \approx \sum_{i=1}^N \log p(y_i | y_{-i}, M),\]
<p>where \(y_{-i}\) means “the training set with data point \(i\) removed”.
Using this approximation, we now have a recipe for performing leave-one-out cross validation (LOO CV) for Bayesian models.
For all data points \(y_i\) in the training set \(y\), do:</p>
<ol>
<li>Compute the posterior distribution of the model \(M\) using the dataset with \(y_i\) removed (\(y_{-i}\)).</li>
<li>Calculate the log predictive density of \(y_i\) using this posterior.</li>
<li>Add the result to the approximate elpd being calculated.</li>
</ol>
<p>This procedure will result in an approximate elpd that can be used to compare two models.
This approach does not have the same problematic sensitivity to the prior as the Bayes factor approach, and is favoured by many (indeed, it’s implemented in <a href="https://cran.r-project.org/web/packages/loo/index.html">R</a> and <a href="https://arviz-devs.github.io/arviz/">python</a>).
Note that step (1) in the above recipe could be very expensive computationally (imagine carrying out MCMC once per data point for a dataset with 1000 observations.)
Fortunately, <a href="https://arxiv.org/abs/1507.04544">Vehtari et al. (2016)</a> provide an efficient approximation to this procedure that can alleviate the problem by letting us carry out MCMC just once (or at most a few times).</p>
<h2 id="bridging-the-gap-the-marginal-likelihood-and-cross-validation">Bridging the gap: the marginal likelihood and cross validation</h2>
<p>I was recently directed to <a href="https://academic.oup.com/biomet/article/107/2/489/5715611">Fong & Holmes (2020)</a> which bridges the gap between these two concepts, and provides intuition as to why the marginal likelihood is overly sensitive to the prior.
The paper proves that the marginal likelihood is equivalent to <em>“leave-\(p\)-out cross validation averaged over all values of \(p\) and all held-out test sets”</em>.
Concretely, this means:</p>
\[\log p(y | M) = \sum_{p=1}^N S_\mathrm{CV}(y; p) \label{result}\tag{9}\]
<p>The cross validation score is</p>
\[S_\mathrm{CV}(y; p) = \frac{1}{N \choose p}\times \frac{1}{p} \sum_{t=1}^{N \choose p} \log p(y_t | y_{-t}, M).\]
<p>We can pick this apart a bit:</p>
<ul>
<li>There are \(N \choose p\) possible holdout sets of size \(p\) when there are \(N\) datapoints, so we average over all of these possible holdout sets.</li>
<li>For each holdout set, indexed by \(t\), we evaluate the log predictive density of the holdout set \(y_t\), conditioned on the full dataset with this holdout set removed, \(y_{-t}\).</li>
</ul>
<p>This looks very similar to (\ref{elpd}), as you’d expect.
The result (\ref{result}) draws a concrete connection between cross validation and the marginal likelihood.
In the previous section, I only considered <em>leave-one-out</em> cross validation.
But the marginal likelihood considers <em>all possible holdout set sizes</em>.
This explains why the two quantities behave differently (in the toy example I gave, the marginal likelihood favoured \(M_1\) strongly, whereas LOO CV could not choose between \(M_1\) and \(M_2\)).
In particular, it explains why the marginal likelihood can be sensitive to the prior.
Fong & Holmes put it well (I have modified the equation numbers and the notation slightly to match this post):</p>
<p>“The last term on the right-hand side of (\ref{result}), <br />
\(S_\mathrm{CV}(y;N) = \sum\limits_{i=1}^N \log \int \,p(y_i | \theta, M)\,p(\theta)\,\mathrm{d}\theta\),<br />
involves no training data and scores the model entirely on how well the analyst is able to specify the prior. In many situations, the analyst may not want this term to contribute to model evaluation. Moreover, there is conflict between the desire to specify vague priors to safeguard their influence and the fact that diffuse priors can lead to an arbitrarily large and negative model score for real-valued parameters from (\ref{result}). It may seem inappropriate to penalize a model based on the subjective ability to specify the prior, or to compare models using a score that includes contributions from predictions made using only a handful of training points even with informative priors.”</p>
<p>Fong & Holmes go on to recommend an approach where, instead of considering all sizes of holdout set between one and \(N\), analysts can consider a maximum size holdout set that is \(< N\) in order to avoid the problem above.
LOO CV is such an approach, where the maximum size holdout set is one.</p>
<h2 id="conclusion">Conclusion</h2>
<p>I really enjoyed seeing the result from <a href="https://academic.oup.com/biomet/article/107/2/489/5715611">Fong & Holmes (2020)</a>, because it joined the dots between different concepts in model comparison that I have learned about.
It reinforced the point that the marginal likelihood may not be appropriate for comparing models, and gave an elegant explanation as to why that is.
For me, a simple conclusion is: when comparing models, try to replicate as closely as possible the way that the model will be used in the future.
I’ll continue to use various flavours of cross validation to compare my models!</p>Angus WilliamsAt various points in the past few years I have had discussions or debates with friends and colleagues about model comparison in the context of Bayesian inference. What is the most “principled” way to do it? What are the relative merits of different approaches? My opinion has evolved alongside my understanding of the subject, and I recently read a paper that conceptually explained some of the intuition that I had developed over the years. Consequently, I wanted to write a note on this subject in case it is useful for others. Introduction to model comparison Before diving into different approaches for model comparison, let me first define it. Suppose you are building a statistical model for a particular process, and have some data \(y\) to use in order to fit and test your approaches. In the process of analysing the data, you come up with two distinct approaches for modelling: \(M_1\) and \(M_2\). For example, if you are solving a regression problem, \(M_1\) might be a GLM of some sort, whereas \(M_2\) could be a GAM. Both models seem to fit the data, but you want to quantify which is doing a better job. Any approach to answering this question falls into the category of model comparison. What should we care about when comparing models? Fundamentally, we are interested in knowing which model provides a better approximation to the underlying data generating process. In that sense, we want to know which model generalises better beyond the immediate data \(y\) that we have available. In other words, if were were to receive some new data \(y_\mathrm{new}\), which model would describe it better? The field of model comparison seeks to answer this question, particularly when we must try to make our best guess without access to an unlimited supply of data. Recap of the Bayesian approach No piece of writing about Bayesian methods would be complete without stating Bayes’ theorem. In the context of a single model, we can write this as: \[p(\theta | y) = \dfrac{p(y | \theta)p(\theta)}{p(y)}, \label{bayes}\tag{1}\] where \(\theta\) are the set of model parameters (e.g., the coefficients in a linear regression) and \(y\) are the data we are analysing. \(p(\theta | y)\) is called the posterior distribution, because it is the distribution of the model parameters conditional on the data \(y\) (i.e., after we receive the data). \(p(\theta)\) is the prior because it is the unconditional distribution of the model parameters (i.e., prior to seeing the data \(y\)). \(p(y | \theta)\) is called the likelihood, and is the probability of the data given a particular set of model parameters. \(p(y)\) has a few names: the marginal likelihood or the evidence are probably the two most common. I’ll discuss this part in more detail shortly. The above discussion is in the context of a single model, so everything is implicitly conditional on the choice of model \(M\): \[p(\theta | y, M) = \dfrac{p(y | \theta, M)p(\theta | M)}{p(y | M)}. \label{model_bayes}\tag{2}\] We don’t normally write it like this because it is clear that we are implicitly conditioning on the particular model we are considering, but it’s helpful to write Bayes’ theorem like this when thinking about model comparison. Bayes factors Since we can use conditional probability to quantify how we should update our beliefs about the parameters \(\theta\) of an individual model given some data \(y\), why not use it to update our beliefs about which model might be the best choice from an available set? This feels like a very natural way to approach the problem of model comparison. Concretely, we can write down Bayes’ theorem as \[p(M | y) = \dfrac{p(y | M) p(M)}{p(y)}.\] Now we have the probability that model \(M\) is the “true” model given the data at hand: \(p(M | y)\). \(p(M)\) is the prior probability that model \(M\) is correct. On the face of it, this seems to be exactly what we need! Suppose we are comparing two models,\(M_1\) and \(M_2\) We can make a very intuitive rule for choosing between them: if \(p(M_1 | y) > p(M_2 | y)\), then choose \(M_1\), and choose \(M_2\) if the converse is true. If we have no prior preference for \(M_1\) or \(M_2\), so that \(p(M_1) = p(M_2) = \frac{1}{2}\), it’s easy to show that \[\dfrac{p(M_1 | y)}{p(M_2 | y)} = \dfrac{p(y | M_1)}{p(y | M_2)} = K.\] We call \(K\) the Bayes factor. In terms of \(K\), we should choose \(M_1\) if \(K > 1\), and choose \(M_2\) if \(K < 1\). So, if we can find a way to calculate \(p(y | M_i)\), then we can use the Bayes factor to choose between models. How can we calculate it? You might have noticed that \(p(y | M)\) appeared earlier in (\ref{model_bayes}) - it appears as the denominator on the RHS - we called it the marginal likelihood or the evidence. But what is it? A clue comes from one of its names: the marginal likelihood. We can write it as follows: \[p(y | M) = \int \mathrm{d}\theta\, p(y | \theta, M)\,p(\theta | M) \label{marg_lik}\tag{3}\] i.e., we marginalise out the parameters of the model \(\theta\) to obtain the probability of the data \(y\) given the model choice \(M\). We marginalise out \(\theta\) using the prior distribution, \(p(\theta | M)\). This adds some intuition about the Bayes’ factor: we choose the model for which the data \(y\) are most probable, given all likely configurations of the model parameters. This argument for using Bayes’ factors for model comparison is quite persuasive, and at first it can seem almost irrefutable because it is so intuitive. All that glitters is not gold: Bayes factors are overly sensitive to the prior Despite seeming apparently watertight at first, the Bayes factor has some undesirable traits. These are well documented, and I’ll focus on just one of them in this note, that they are overly sensitive to the prior. Practitioners of Bayesian inference quickly learn the rule of thumb that the more data you have, the less influence the prior has on your final inference. This makes sense - the more evidence you accumulate by collecting more data, the less weight you will place on your prior beliefs. Suppose we have a model \(M\) with a single parameter \(\theta\) and some data \(y\). Let’s further assume that the likelihood \(p(y | \theta, M)\) is a normal distribution with mean \(\mu_\ell\) and variance \(\sigma_\ell ^2\): \[p(y | \theta, M) = \dfrac{C}{\sqrt{2\pi \sigma_\ell ^ 2}} \exp [ -\dfrac{(\theta - \mu_\ell)^2}{2 \sigma_\ell ^2}].\] Note that because I’m considering the likelihood as a function of \(\theta\), it need not be normalised, which is why the constant \(C\) appears on the right hand side. Further suppose that the prior on \(\theta\) is a normal distribution with mean zero and variance \(\sigma_\mathrm{prior} ^2\) \[p(\theta | M) = \dfrac{1}{\sqrt{2\pi \sigma_\mathrm{prior} ^ 2}} \exp [ -\dfrac{\theta ^2}{2 \sigma_\mathrm{prior} ^2}].\] Given these assumptions, we can exactly work out some of the quantities of interest for model comparison. Let’s first work out the numerator of the RHS of (\ref{model_bayes}). It turns out to be another normal distribution: \[p(y | \theta, M)\,p(\theta) = A \times \dfrac{1}{\sqrt{2\pi \sigma_\mathrm{post} ^ 2}}\exp [ -\dfrac{(\theta - \mu_\mathrm{post})^2}{2 \sigma_\mathrm{post} ^2}],\label{integral}\tag{4}\] where \[\mu_\mathrm{post} = \mu_\ell \dfrac{\sigma_\mathrm{prior} ^2}{\sigma_\mathrm{prior} ^2 + \sigma_\ell ^2} \quad;\quad \sigma_\mathrm{post}^2 = \dfrac{\sigma_\mathrm{prior}^2 \sigma_\ell^2}{\sigma_\mathrm{prior}^2 + \sigma_\ell^2}.\label{norm_post}\tag{5}\] The constant \(A\) is equal to yet another normal distribution: \[A = \dfrac{C}{\sqrt{2\pi (\sigma_\ell ^ 2 + \sigma_\mathrm{prior} ^ 2)}} \exp [ -\dfrac{\mu_\ell^2}{2 (\sigma_\ell ^2 + \sigma_\mathrm{prior}^2)}].\label{normal_marginal}\tag{6}\] Now, looking at Bayes’ theorem (\ref{model_bayes}), we can see that the left hand side is a probability distribution, which means that it integrates to one: \[\int p(\theta | y, M) \, \mathrm{d}\theta = 1.\] Using right hand side of Bayes theorem, we can rearrange this to: \[\int p(y | \theta, M) p(\theta | M) \, \mathrm{d}\theta = p(y | M).\] Thus we can see that the marginal likelihood can be regarded as a normalisation constant - it guarantees that the posterior distribution integrates to one. Looking at (\ref{integral}), we can deduce that the constant \(A\) is in fact the marginal likelihood: \begin{equation} \begin{aligned} \int p(y | \theta, M)\,p(\theta)\,\mathrm{d} \theta &= A \times \int \dfrac{1}{\sqrt{2\pi \sigma_\mathrm{post} ^ 2}}\exp [ -\dfrac{(\theta - \mu_\mathrm{post})^2}{2 \sigma_\mathrm{post} ^2}] \,\mathrm{d}\theta \newline &= A \newline \implies p(y | M) &= A \end{aligned} \end{equation} This result follows from the fact that the normal distribution integrates to one. This means that the posterior distribution is \[p(\theta | M, y) = \dfrac{1}{\sqrt{2\pi \sigma_\mathrm{post} ^ 2}}\exp [ -\dfrac{(\theta - \mu_\mathrm{post})^2}{2 \sigma_\mathrm{post} ^2}]\] So, that seems like a lot of work, but now we have expressions for the posterior \(p(\theta | M, y)\) and the marginal likelihood \(p(y | M)\). Now that we have them, lets’s think about the limit when we have a lot of data. In this case, the likelihood will be more informative than the prior - it will be more strongly peaked around its mean. We can express this quantitatively by saying that \(\sigma_\mathrm{prior}^2 \gg \sigma_\ell^2\), since the larger the variance of a normal distribution, the less strongly peaked it is. What does the posterior look like in this limit? We know it is a normal distribution with mean and variance given by (\ref{norm_post}). In the limit we’re interested in, these become: \[\mu_\mathrm{post} \approx \mu_\ell \quad;\quad \sigma_\mathrm{post}^2 \approx \sigma_\ell^2.\] This makes intuitive sense - when the likelihood dominates the prior, the posterior will approach the likelihood as the prior loses influence. What about the marginal likelihood? Looking at (\ref{normal_marginal}), we can see that this becomes: \[p(y | M) \approx \dfrac{C}{\sqrt{2\pi\sigma_\mathrm{prior} ^ 2}} \exp [ -\dfrac{\mu_\ell^2}{2 \sigma_\mathrm{prior}^2}].\label{approx_marginal}\tag{7}\] One thing to immediately notice is that, in the limit where the likelihood dominates the prior, the posterior distribution is approximately independent of the prior: we can see that because it is approximately equal to the likelihood. But quite the opposite is true of the marginal likelihood! It explicitly depends on the prior variance \(\sigma_\mathrm{prior}^2\). Let’s consider two versions of \(M\). Both use the same likelihood and data, so they have the same likelihood function, but have different priors. For concreteness, let’s say that \(\sigma_\ell = 1\), and \(\mu_\ell = 1\). In one case, call it \(M_1\), \(\sigma_\mathrm{prior} = 10\), and in the other, \(M_2\), we have \(\sigma_\mathrm{prior} = 100\). In both cases, \(\sigma_\ell \ll \sigma_\mathrm{prior}\), so we’re in the limit considered above. But let’s compute the Bayes factor for these two models using (\ref{approx_marginal}): \[K = \dfrac{p(y | M_1)}{p(y | M_2)} \approx 10.\] So, even though these two models produce almost identical inference (they have almost the same posterior), the Bayes factor tells us that one of them is ten times more likely than the other! Thus, our model selection policy tells us that we should strongly favour \(M_1\) over \(M_2\), even though these two models produce essentially identical predictions for new data. At this point, a quote from Gelman & Shalizi (2012) is appropriate: “The main point where we disagree with many Bayesians is that we do not see Bayesian methods as generally useful for giving the posterior probability that a model is true, or the probability for preferring model A over model B, or whatever. Beyond the philosophical difficulties, there are technical problems with methods that purport to determine the posterior probability of models, most notably that in models with continuous parameters, aspects of the model that have essentially no effect on posterior inferences within a model can have huge effects on the comparison of posterior probability among models.” An alternative approach: cross validation To me, cross validation often used to feel like a concept more commonly used in the field of machine learning, and less elegant than ideas like Bayes factors. That said, cross validation makes great sense intuitively as a method for model comparison. Following Vehtari et al. (2016), let’s first define a useful metric for comparing probabilistic models, the expected log pointwise predictive density for a new dataset: \[\mathrm{elpd} = \sum_{i=1}^{N} \int \mathrm{d}\tilde y \, p_t(\tilde y_i)\, \log p(\tilde y_i | y, M).\label{elpd}\tag{8}\] In this expression, we have \(\tilde y\) is a new, unseen dataset that (i.e., is not part of the training data \(y\)),assumed to contain \(N\) data points. \(p_t(\tilde y)\) is the true data generating distribution of \(\tilde y\). \(p(\tilde y \| y, M)\) is the predictive distribution of the model being assessed. For Bayesian models, the predictive distribution is \[p(\tilde y | y, M) = \int \mathrm{d}\theta \, p(\tilde y | \theta, M) p(\theta | y, M).\] It is so called because this is the pdf we would use for predicting new data, because we include the new information we got from the training data \(y\). Notice how similar this looks to the marginal likelihood (\ref{marg_lik}). Both are integrals of the product of the likelihood with a distribution over \(\theta\). For the marginal likelihood, this distribution is the prior \(p(\theta | M)\), whereas for the predictive distribution it is the posterior \(p(\theta | y, M)\). The elpd looks like a very useful alternative metric for comparing models. The larger the elpd for a given model, the better we expect it to generalise to new datasets. Concretely, for two models \(M_1\) and \(M_2\), we would choose \(M_1\) if \(\mathrm{elpd}_\mathrm{M_1} > \mathrm{elpd}_\mathrm{M_2}\), and choose \(M_2\) if the converse is true. Take the example above where the Bayes factor strongly favoured \(M_1\), but where \(M_1\) and \(M_2\) have the same likelihood function and posterior distribution. What would the elpd look like for these models? Since they have the same likelihood and posterior distributions, we can say: \[\begin{aligned} p(\tilde y | y, M_1) &\simeq p(\tilde y | y, M_2) \\ \implies \mathrm{elpd}_{M_1} &\simeq \mathrm{elpd}_{M_2}. \end{aligned}\] Thus, the elpd would say that these two models are the same in terms of predictive performance, and we could choose either one of them. This makes a lot more sense that the result using Bayes factors! However, there is a problem - in order to evaluate the elpd, we require the true data generating distribution, \(p_t(\tilde y_i)\). Obviously we don’t know what that is, otherwise we would not be bothering to build a model to approximate it! The way we proceed is through the use of cross validation. We can’t get at \(p_t(\tilde y_i)\) directly, but we can note that the training data \(y\) should be a representative draw from this distribution. Consequently, we can proceed using Monte Carlo integration to approximate the elpd: \[\mathrm{elpd} \approx \sum_{i=1}^N \log p(y_i | y_{-i}, M),\] where \(y_{-i}\) means “the training set with data point \(i\) removed”. Using this approximation, we now have a recipe for performing leave-one-out cross validation (LOO CV) for Bayesian models. For all data points \(y_i\) in the training set \(y\), do: Compute the posterior distribution of the model \(M\) using the dataset with \(y_i\) removed (\(y_{-i}\)). Calculate the log predictive density of \(y_i\) using this posterior. Add the result to the approximate elpd being calculated. This procedure will result in an approximate elpd that can be used to compare two models. This approach does not have the same problematic sensitivity to the prior as the Bayes factor approach, and is favoured by many (indeed, it’s implemented in R and python). Note that step (1) in the above recipe could be very expensive computationally (imagine carrying out MCMC once per data point for a dataset with 1000 observations.) Fortunately, Vehtari et al. (2016) provide an efficient approximation to this procedure that can alleviate the problem by letting us carry out MCMC just once (or at most a few times). Bridging the gap: the marginal likelihood and cross validation I was recently directed to Fong & Holmes (2020) which bridges the gap between these two concepts, and provides intuition as to why the marginal likelihood is overly sensitive to the prior. The paper proves that the marginal likelihood is equivalent to “leave-\(p\)-out cross validation averaged over all values of \(p\) and all held-out test sets”. Concretely, this means: \[\log p(y | M) = \sum_{p=1}^N S_\mathrm{CV}(y; p) \label{result}\tag{9}\] The cross validation score is \[S_\mathrm{CV}(y; p) = \frac{1}{N \choose p}\times \frac{1}{p} \sum_{t=1}^{N \choose p} \log p(y_t | y_{-t}, M).\] We can pick this apart a bit: There are \(N \choose p\) possible holdout sets of size \(p\) when there are \(N\) datapoints, so we average over all of these possible holdout sets. For each holdout set, indexed by \(t\), we evaluate the log predictive density of the holdout set \(y_t\), conditioned on the full dataset with this holdout set removed, \(y_{-t}\). This looks very similar to (\ref{elpd}), as you’d expect. The result (\ref{result}) draws a concrete connection between cross validation and the marginal likelihood. In the previous section, I only considered leave-one-out cross validation. But the marginal likelihood considers all possible holdout set sizes. This explains why the two quantities behave differently (in the toy example I gave, the marginal likelihood favoured \(M_1\) strongly, whereas LOO CV could not choose between \(M_1\) and \(M_2\)). In particular, it explains why the marginal likelihood can be sensitive to the prior. Fong & Holmes put it well (I have modified the equation numbers and the notation slightly to match this post): “The last term on the right-hand side of (\ref{result}), \(S_\mathrm{CV}(y;N) = \sum\limits_{i=1}^N \log \int \,p(y_i | \theta, M)\,p(\theta)\,\mathrm{d}\theta\), involves no training data and scores the model entirely on how well the analyst is able to specify the prior. In many situations, the analyst may not want this term to contribute to model evaluation. Moreover, there is conflict between the desire to specify vague priors to safeguard their influence and the fact that diffuse priors can lead to an arbitrarily large and negative model score for real-valued parameters from (\ref{result}). It may seem inappropriate to penalize a model based on the subjective ability to specify the prior, or to compare models using a score that includes contributions from predictions made using only a handful of training points even with informative priors.” Fong & Holmes go on to recommend an approach where, instead of considering all sizes of holdout set between one and \(N\), analysts can consider a maximum size holdout set that is \(< N\) in order to avoid the problem above. LOO CV is such an approach, where the maximum size holdout set is one. Conclusion I really enjoyed seeing the result from Fong & Holmes (2020), because it joined the dots between different concepts in model comparison that I have learned about. It reinforced the point that the marginal likelihood may not be appropriate for comparing models, and gave an elegant explanation as to why that is. For me, a simple conclusion is: when comparing models, try to replicate as closely as possible the way that the model will be used in the future. I’ll continue to use various flavours of cross validation to compare my models!A brief introduction to JAX and Laplace’s method2019-11-23T00:00:00+00:002019-11-23T00:00:00+00:00https://anguswilliams91.github.io/statistics/computing/jax<h2 id="introduction">Introduction</h2>
<p>Google recently released an interesting new library called <a href="https://github.com/google/jax">JAX</a>.
It looks like it could be very useful for computational statistics, so I thought I’d take a look.
In this post, I’ll describe some of the basic features of JAX, and then use them to implement Laplace’s approximation, a method used in Bayesian statistics.</p>
<p>JAX is described as “Autograd and XLA, brought together”.
Autograd refers to <em>automatic differentiation</em>, and XLA stands for <em>accelerated linear algebra</em>.
Both of these are very useful in computational statistics, where we often have to differentiate things and perform lots of matrix (or tensor) manipulations.</p>
<p>JAX effectively extends the numpy library to include these extra components.
In doing so, it preserves an API that many scientists are familiar with, and introduces powerful new functionality.</p>
<h2 id="automatic-differentiation">Automatic differentiation</h2>
<p>Automatic differentiation (or autodiff, for short) is a set of methods for evaluating the derivative of functions using a computer.
If you’ve taken a calculus course before, you will have taken the derivative of functions by hand, e.g.</p>
\[f(x) = x ^ 2 \implies f'(x) = 2 x\]
<p>For the simple function above, differentiating is quite straightforward.
But, when the function is complicated and has many arguments (for example, the objective function of a neural network), differentiating by hand quickly becomes unfeasible.
Autodiff frameworks save us the trouble: we simply pass a function to the framework, and it returns another function that computes the gradient.
This is really useful in computational statistics and machine learning, since we often want to <em>optimise</em> functions with respect to their arguments.
Generally, we can carry out optimisation much more efficiently if we know gradients.
It’s also really useful in Bayesian statistics, where the state-of-the-art Markov Chain Monte Carlo method, Hamiltonian Monte Carlo, relies on the calculation of the gradient of the posterior distribution.</p>
<p>So, how do we do this in JAX?
Here’s a snippet that defines the logistic function in python, and then uses JAX to compute its derivative:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">grad</span>
<span class="k">def</span> <span class="nf">logistic</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="c1"># logistic function
</span> <span class="k">return</span> <span class="p">(</span><span class="mf">1.</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span><span class="p">))</span> <span class="o">**</span> <span class="o">-</span><span class="mf">1.0</span>
<span class="c1"># differentiate with jax!
</span><span class="n">grad_logistic</span> <span class="o">=</span> <span class="n">grad</span><span class="p">(</span><span class="n">logistic</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">logistic</span><span class="p">(</span><span class="mf">0.0</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="n">grad_logistic</span><span class="p">(</span><span class="mf">0.0</span><span class="p">))</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>0.5
0.25
</code></pre></div></div>
<p>That was easy! There are two lines to focus on in this snippet:</p>
<ol>
<li><code class="language-plaintext highlighter-rouge">import jax.numpy as np</code>: instead of importing regular numpy, I imported <code class="language-plaintext highlighter-rouge">jax.numpy</code> which is JAX’s implementation of numpy functionality. After this line, we can pretty much forget about it and pretend that we’re using regular numpy most of the time.</li>
<li><code class="language-plaintext highlighter-rouge">grad_logistic = grad(logistic)</code>: this is where the magic happens. We passed the <code class="language-plaintext highlighter-rouge">logistic</code> function to JAX’s <code class="language-plaintext highlighter-rouge">grad</code> function, and it returned another function, which I called <code class="language-plaintext highlighter-rouge">grad_logistic</code>. This function takes the same inputs as <code class="language-plaintext highlighter-rouge">logistic</code>, but returns the gradient with respect to these inputs.</li>
</ol>
<p>To convince ourselves that this all worked, let’s plot the logistic function and its derivative:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">vmap</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="n">plt</span><span class="p">.</span><span class="n">rcParams</span><span class="p">[</span><span class="s">"figure.figsize"</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">rcParams</span><span class="p">[</span><span class="s">"font.size"</span><span class="p">]</span> <span class="o">=</span> <span class="mi">20</span>
<span class="n">vectorised_grad_logistic</span> <span class="o">=</span> <span class="n">vmap</span><span class="p">(</span><span class="n">grad_logistic</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mf">10.0</span><span class="p">,</span> <span class="mf">10.0</span><span class="p">,</span> <span class="mi">1000</span><span class="p">)</span>
<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">()</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">logistic</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="s">"$f(x)$"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">vectorised_grad_logistic</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="s">"$f'(x)$"</span><span class="p">)</span>
<span class="n">_</span> <span class="o">=</span> <span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">_</span> <span class="o">=</span> <span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"$x$"</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="/assets/images/jax_images/post_3_0.png" alt="png" /></p>
<p>You’ll notice that I had to define another function: <code class="language-plaintext highlighter-rouge">vectorised_grad_logistic</code> in order to make the plot. The reason is that functions produced by <code class="language-plaintext highlighter-rouge">grad</code> are not vectorised (cannot accept multiple inputs and return the gradient at each of the inputs). To facilitate this, we can wrap our <code class="language-plaintext highlighter-rouge">grad_logistic</code> function with <code class="language-plaintext highlighter-rouge">vmap</code>, which automatically vectorises it for us.</p>
<p>This is already pretty neat. We can also obtain higher-order derivatives by the repeated application of <code class="language-plaintext highlighter-rouge">grad</code>:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">second_order_grad_logistic</span> <span class="o">=</span> <span class="n">vmap</span><span class="p">(</span><span class="n">grad</span><span class="p">(</span><span class="n">grad</span><span class="p">(</span><span class="n">logistic</span><span class="p">)))</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mf">10.0</span><span class="p">,</span> <span class="mf">10.0</span><span class="p">,</span> <span class="mi">1000</span><span class="p">)</span>
<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">()</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">logistic</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="s">"$f(x)$"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">vectorised_grad_logistic</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="s">"$f'(x)$"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">second_order_grad_logistic</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="s">"$f''(x)$"</span><span class="p">)</span>
<span class="n">_</span> <span class="o">=</span> <span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">_</span> <span class="o">=</span> <span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"$x$"</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="/assets/images/jax_images/post_5_0.png" alt="png" /></p>
<h2 id="jit-complilation">JIT complilation</h2>
<p>Before I demonstrate an application of JAX, I want to mention another useful feature: JIT compilation.
As you may know, python is an interpreted language, rather than being compiled.
This is one of the reasons that python code can run slower than the same logic in a compiled language (like C).
I won’t go into detail about why this is, but <a href="https://jakevdp.github.io/blog/2014/05/09/why-python-is-slow/">here’s a great blog post</a> by Jake VanderPlas on the subject.</p>
<p>One of the reasons why numpy is so useful is that it is calling C code under the hood, which is compiled.
This means that it can be much faster than code that uses native python arrays.
JAX adds an additional feature on top of this: Just In Time (JIT) compilation.
The “JIT” part means that the code is compiled at runtime the first time that it is needed.
Using this feature can speed up our code.</p>
<p>To do this, we just have to apply the <code class="language-plaintext highlighter-rouge">jit</code> decorator to our function:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">jit</span>
<span class="o">@</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">jit_logistic</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="c1"># logistic function
</span> <span class="k">return</span> <span class="p">(</span><span class="mf">1.</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span><span class="p">))</span> <span class="o">**</span> <span class="o">-</span><span class="mf">1.0</span>
<span class="o">@</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">jit_grad_logistic</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="c1"># compile the gradient as well
</span> <span class="k">return</span> <span class="n">grad</span><span class="p">(</span><span class="n">logistic</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span>
</code></pre></div></div>
<p>Now we can compare our JIT compiled functions to the ones we made earlier, and see if there’s any difference in execution time:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">timeit</span>
<span class="n">logistic</span><span class="p">(</span><span class="mf">0.0</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>376 µs ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">timeit</span>
<span class="n">jit_logistic</span><span class="p">(</span><span class="mf">0.0</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>90.4 µs ± 1.19 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">timeit</span>
<span class="n">grad_logistic</span><span class="p">(</span><span class="mf">0.0</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>1.76 ms ± 7.45 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">timeit</span>
<span class="n">jit_grad_logistic</span><span class="p">(</span><span class="mf">0.0</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>86.5 µs ± 1.48 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
</code></pre></div></div>
<p>For the plain logistic function, JIT led to a 4x speedup. For the gradient, there was a 20x speedup. These are non-trivial gains! I could have obtained even bigger speedups if I was using JAX in combination with a GPU.</p>
<h2 id="putting-it-all-together-the-laplace-approximation">Putting it all together: the Laplace approximation</h2>
<p>I’ve barely scratched the surface of what JAX can do, but already there are interesting and useful applications with the functionality I have described here.
As an example, I’ll describe how to implement an important method in Bayesian statistics: <em>Laplace’s approximation</em>.</p>
<h3 id="the-laplace-approximation">The Laplace approximation</h3>
<p>Imagine we have some probability model with some parameters \(\theta\), and we’ve constrained the model using some data \(D\).
In Bayesian inference, our goal is always to calculate integrals like this:</p>
\[\mathbb{E}\left[h(\theta)\right] = \int \mathrm{d}\theta \, h(\theta) \, p(\theta | D)\]
<p>we are interested in the expectation of some function \(h(\theta)\) with respect to the <em>posterior distribution</em> \(p(\theta | D)\).
For interesting models, the posterior is complex, and so we have no hope of calculating these integrals analytically.
Because of this, Bayesians have devised many methods for approximating them.
If you’ve got time, the best thing to do is use <a href="https://en.wikipedia.org/wiki/Markov_chain_Monte_Carlo">Markov Chain Monte Carlo</a>.
But, if your dataset is quite large relative to your time and computational budget, you may need to try something else. A typical choice is <a href="https://en.wikipedia.org/wiki/Variational_Bayesian_methods">Variational Inference</a>.</p>
<p>Another, possibly less talked-about, approach is called Laplace’s approximation.
It works really well when you have quite a lot of data because of the <em>Bayesian central limit theorem</em>.
In this approach, we approximate the posterior distribution by a Normal distribution.
This is a common approximation (it’s often used in Variational Inference too), but Laplace’s method has a specific way of finding the Normal distribution that best matches the posterior.</p>
<p>Suppose we know the location \(\theta^*\) of the maximum point of the posterior<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote">1</a></sup>.
Now let’s Taylor expand the log posterior around this point.
To reduce clutter, I’ll use the notation \(\log p(\theta | D) \equiv f(\theta)\).
For simplicity, let’s consider the case when \(\theta\) is scalar:</p>
\[f(\theta)
\approx
f(\theta^*)
+ \frac{\partial f}{\partial \theta}\bigg|_{\theta^*}\,(\theta - \theta^*)
+ \dfrac{1}{2}\frac{\partial^2 f}{\partial \theta^2}\bigg|_{\theta^*}\,(\theta - \theta^*)^2
\\
= f(\theta^*)
+ \dfrac{1}{2}\frac{\partial^2 f}{\partial \theta^2}\bigg|_{\theta^*}\,(\theta - \theta^*)^2\]
<p>The first derivative disappears because \(\theta^*\) is a maximum point, so the gradient there is zero.
Let’s compare this to the logarithm of a normal distribution with mean \(\mu\) and standard deviation \(\sigma\), which I’ll call \(g(\theta)\):</p>
\[g(\theta) = -\frac{1}{2}\log (2\pi\sigma^2) - \dfrac{1}{2}\dfrac{1}{\sigma^2}(\theta - \mu)^2\]
<p>We can match up the terms in the expressions for \(g(\theta)\) and the Taylor expansion of \(f(\theta)\) (ignoring the constant additive terms) to see that</p>
\[\mu = \theta^* \\
\sigma^2 = \left(-\dfrac{\partial^2 f}{\partial \theta^2}\bigg|_{\theta^*}\right)^{-1}\]
<p>Consequently, we might try approximating the posterior distribution with a Normal distribution, and set the mean and variance to these values.
In multiple dimensions, the covariance matrix of the resulting multivariate normal is the inverse of the Hessian matrix of the negative log posterior at \(\theta^*\):</p>
\[\Sigma_{ij} = \dfrac{\partial ^2 (-f)}{\partial \theta_i \partial \theta_j}^{-1}\bigg|_{\theta^*}\]
<p>Already, we can see that Laplace’s approximation requires us to be able to twice differentiate the posterior distribution in order to obtain \(\sigma\). In addition, we have to find the location \(\theta^*\) of the maximum of the posterior. We probably have to do this numerically, which means using some kind of optimisation routine. The most efficient of these optimisation routines require the gradient of the objective function. So, using Laplace’s approximation means we want to evaluate the <em>first and second derivatives of the posterior</em>. Sounds like a job for JAX!</p>
<h2 id="example-a-student-t-posterior-distribution">Example: a Student-t posterior distribution</h2>
<p>Suppose our true posterior is a 2D Student-t:</p>
\[p(\theta | D)
\propto
\left(1+\frac{1}{\nu}(\theta - \mu)^T \Sigma^{-1}(\theta - \mu)\right)^{-(\nu + \mathrm{dim}(\theta))/2}\]
<p>This is a simple example, and we can actually sample from a Student-t rather easily.
Nevertheless, let’s go ahead and use it to implement Laplace’s method in JAX.
Let’s set the values of the constants in the Student-t:</p>
\[\mu = \begin{pmatrix}
0.5 \\
2
\end{pmatrix}
\\
\Sigma = \begin{pmatrix}
1 & 0.5 \\
0.5 & 1
\end{pmatrix}
\\
\nu = 7\]
<p>First, let’s plot the log posterior:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># choose some values for the Student-t
</span><span class="n">sigma</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([(</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">),</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)])</span>
<span class="n">mu</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">])</span>
<span class="n">nu</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="mi">7</span><span class="p">])</span>
<span class="n">sigma_inv</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">inv</span><span class="p">(</span><span class="n">sigma</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">log_posterior</span><span class="p">(</span><span class="n">theta</span><span class="p">):</span>
<span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span>
<span class="mf">1.0</span> <span class="o">+</span> <span class="n">nu</span> <span class="o">**</span> <span class="o">-</span><span class="mf">1.0</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">dot</span><span class="p">((</span><span class="n">theta</span> <span class="o">-</span> <span class="n">mu</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">dot</span><span class="p">(</span><span class="n">sigma_inv</span><span class="p">,</span> <span class="p">(</span><span class="n">theta</span> <span class="o">-</span> <span class="n">mu</span><span class="p">).</span><span class="n">T</span><span class="p">).</span><span class="n">T</span><span class="p">)</span>
<span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="o">-</span><span class="p">(</span><span class="n">nu</span> <span class="o">+</span> <span class="n">theta</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span>
<span class="c1"># plot the distribution
</span><span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
<span class="n">X</span><span class="p">,</span> <span class="n">Y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">meshgrid</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="n">XY</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">stack</span><span class="p">((</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">)).</span><span class="n">reshape</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">10000</span><span class="p">).</span><span class="n">T</span>
<span class="n">Z</span> <span class="o">=</span> <span class="n">vmap</span><span class="p">(</span><span class="n">log_posterior</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="mi">0</span><span class="p">)(</span><span class="n">XY</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">()</span>
<span class="n">ax</span><span class="p">.</span><span class="n">contourf</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">Z</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">r"$\theta_0$"</span><span class="p">)</span>
<span class="n">_</span> <span class="o">=</span> <span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">r"$\theta_1$"</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="/assets/images/jax_images/post_13_0.png" alt="png" /></p>
<p>Now let’s implement Laplace’s method in JAX:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">hessian</span>
<span class="kn">from</span> <span class="nn">scipy.optimize</span> <span class="kn">import</span> <span class="n">minimize</span>
<span class="kn">from</span> <span class="nn">scipy.stats</span> <span class="kn">import</span> <span class="n">multivariate_normal</span>
<span class="o">@</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">negative_log_posterior</span><span class="p">(</span><span class="n">theta</span><span class="p">):</span>
<span class="c1"># negative log posterior to minimise
</span> <span class="k">return</span> <span class="p">(</span><span class="o">-</span><span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span>
<span class="mf">1.0</span> <span class="o">+</span> <span class="n">nu</span> <span class="o">**</span> <span class="o">-</span><span class="mf">1.0</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">dot</span><span class="p">((</span><span class="n">theta</span> <span class="o">-</span> <span class="n">mu</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">dot</span><span class="p">(</span><span class="n">sigma_inv</span><span class="p">,</span> <span class="p">(</span><span class="n">theta</span> <span class="o">-</span> <span class="n">mu</span><span class="p">).</span><span class="n">T</span><span class="p">).</span><span class="n">T</span><span class="p">)</span>
<span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="o">-</span><span class="p">(</span><span class="n">nu</span> <span class="o">+</span> <span class="n">theta</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])))[</span><span class="mi">0</span><span class="p">]</span>
<span class="o">@</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">grad_negative_log_posterior</span><span class="p">(</span><span class="n">theta</span><span class="p">):</span>
<span class="c1"># gradient of the negative log posterior
</span> <span class="k">return</span> <span class="n">grad</span><span class="p">(</span><span class="n">negative_log_posterior</span><span class="p">)(</span><span class="n">theta</span><span class="p">)</span>
<span class="o">@</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">approx_covariance_matrix</span><span class="p">(</span><span class="n">theta</span><span class="p">):</span>
<span class="c1"># evaluate the covariance matrix of the approximate normal
</span> <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">inv</span><span class="p">(</span><span class="n">hessian</span><span class="p">(</span><span class="n">negative_log_posterior</span><span class="p">)(</span><span class="n">theta</span><span class="p">))</span>
<span class="c1"># go!
</span><span class="n">theta_star</span> <span class="o">=</span> <span class="n">minimize</span><span class="p">(</span>
<span class="n">negative_log_posterior</span><span class="p">,</span>
<span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">]),</span>
<span class="n">jac</span><span class="o">=</span><span class="n">grad_negative_log_posterior</span><span class="p">,</span>
<span class="n">method</span><span class="o">=</span><span class="s">"BFGS"</span>
<span class="p">).</span><span class="n">x</span>
<span class="n">sigma_approx</span> <span class="o">=</span> <span class="n">approx_covariance_matrix</span><span class="p">(</span><span class="n">theta_star</span><span class="p">)</span>
</code></pre></div></div>
<p>This is a <em>very</em> short piece of code! I had to define the negative log posterior (and JIT compiled it for speed), since we will minimise this to find \(\theta^*\). Then, I used JAX’s <code class="language-plaintext highlighter-rouge">grad</code> function to differentiate this once, so that we can used a gradient-based optimiser. Next, I used JAX’s <code class="language-plaintext highlighter-rouge">hessian</code> function to find the covariance matrix for our approximating normal.
Finally, I used scipy’s <code class="language-plaintext highlighter-rouge">minimize</code> function to find the optimal point \(\theta^*\).</p>
<p>Note that this code is actually rather general! As long as the function <code class="language-plaintext highlighter-rouge">negative_log_posterior</code> can be implemented in a way that JAX can differentiate (which it probably can), then the rest of the code stays exactly the same!
Let’s have a look at how good our normal approximation is:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">scipy.stats</span> <span class="kn">import</span> <span class="n">norm</span>
<span class="kn">from</span> <span class="nn">scipy.stats</span> <span class="kn">import</span> <span class="n">t</span>
<span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">constrained_layout</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">15</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
<span class="n">spec</span> <span class="o">=</span> <span class="n">fig</span><span class="p">.</span><span class="n">add_gridspec</span><span class="p">(</span><span class="n">ncols</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">nrows</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">fig</span><span class="p">.</span><span class="n">subplots_adjust</span><span class="p">(</span><span class="n">hspace</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">wspace</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">ax3</span> <span class="o">=</span> <span class="n">fig</span><span class="p">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="n">spec</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span>
<span class="n">ax2</span> <span class="o">=</span> <span class="n">fig</span><span class="p">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="n">spec</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
<span class="n">ax1</span> <span class="o">=</span> <span class="n">fig</span><span class="p">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="n">spec</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span>
<span class="n">contour</span> <span class="o">=</span> <span class="n">ax1</span><span class="p">.</span><span class="n">contour</span><span class="p">(</span>
<span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">Z</span> <span class="o">/</span> <span class="n">Z</span><span class="p">.</span><span class="nb">max</span><span class="p">(),</span> <span class="n">colors</span><span class="o">=</span><span class="s">"0.4"</span><span class="p">,</span> <span class="n">levels</span><span class="o">=</span><span class="mi">15</span><span class="p">,</span> <span class="n">linestyles</span><span class="o">=</span><span class="s">"-"</span><span class="p">,</span> <span class="n">linewidths</span><span class="o">=</span><span class="mi">3</span>
<span class="p">)</span>
<span class="c1"># calculate the density of the approximating Normal distribution
</span><span class="n">Z_0</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">multivariate_normal</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="n">theta_star</span><span class="p">,</span> <span class="n">cov</span><span class="o">=</span><span class="n">sigma_approx</span><span class="p">).</span><span class="n">logpdf</span><span class="p">(</span><span class="n">XY</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
<span class="p">)</span>
<span class="n">ax1</span><span class="p">.</span><span class="n">contour</span><span class="p">(</span>
<span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">Z_0</span> <span class="o">/</span> <span class="n">Z_0</span><span class="p">.</span><span class="nb">max</span><span class="p">(),</span> <span class="n">colors</span><span class="o">=</span><span class="s">"#2c7fb8"</span><span class="p">,</span> <span class="n">levels</span><span class="o">=</span><span class="mi">15</span><span class="p">,</span> <span class="n">linestyles</span><span class="o">=</span><span class="s">"--"</span><span class="p">,</span> <span class="n">linewidths</span><span class="o">=</span><span class="mi">3</span>
<span class="p">)</span>
<span class="n">ax1</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">r"$\theta_0$"</span><span class="p">)</span>
<span class="n">ax1</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">r"$\theta_1$"</span><span class="p">)</span>
<span class="n">ax2</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
<span class="n">norm</span><span class="p">.</span><span class="n">pdf</span><span class="p">(</span><span class="n">theta_grid</span><span class="p">,</span> <span class="n">theta_star</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">sigma_approx</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])),</span>
<span class="n">theta_grid</span><span class="p">,</span>
<span class="n">c</span><span class="o">=</span><span class="s">"#2c7fb8"</span><span class="p">,</span>
<span class="n">ls</span><span class="o">=</span><span class="s">"--"</span><span class="p">,</span>
<span class="n">lw</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ax2</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
<span class="n">t</span><span class="p">.</span><span class="n">pdf</span><span class="p">(</span><span class="n">theta_grid</span><span class="p">,</span> <span class="n">nu</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">mu</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">sigma</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])),</span> <span class="n">theta_grid</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s">"0.4"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">3</span>
<span class="p">)</span>
<span class="n">ax3</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
<span class="n">theta_grid</span><span class="p">,</span>
<span class="n">norm</span><span class="p">.</span><span class="n">pdf</span><span class="p">(</span><span class="n">theta_grid</span><span class="p">,</span> <span class="n">theta_star</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">sigma_approx</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">])),</span>
<span class="n">c</span><span class="o">=</span><span class="s">"#2c7fb8"</span><span class="p">,</span>
<span class="n">ls</span><span class="o">=</span><span class="s">"--"</span><span class="p">,</span>
<span class="n">lw</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span>
<span class="n">label</span><span class="o">=</span><span class="s">"Laplace"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ax3</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
<span class="n">theta_grid</span><span class="p">,</span>
<span class="n">t</span><span class="p">.</span><span class="n">pdf</span><span class="p">(</span><span class="n">theta_grid</span><span class="p">,</span> <span class="n">nu</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">mu</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">sigma</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">])),</span>
<span class="n">c</span><span class="o">=</span><span class="s">"0.4"</span><span class="p">,</span>
<span class="n">lw</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span>
<span class="n">label</span><span class="o">=</span><span class="s">"Exact"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ax3</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">ax2</span><span class="p">.</span><span class="n">xaxis</span><span class="p">.</span><span class="n">set_visible</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span>
<span class="n">ax3</span><span class="p">.</span><span class="n">yaxis</span><span class="p">.</span><span class="n">set_visible</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="/assets/images/jax_images/post_17_0.png" alt="png" /></p>
<p>At least by eye, the approximation seems reasonable. Of course, I have rather cheated here since a Student-t approaches a normal distribution as \(\nu \rightarrow \infty\). Nonetheless, it’s still pleasing to see that the numerical implementation with JAX and scipy works as expected.</p>
<h2 id="conclusion">Conclusion</h2>
<p>Hopefully this post has inspired you to go and play with JAX yourself. There are a ton of interesting applications that I can imagine for this library. Some already exist, such as the <a href="https://github.com/pyro-ppl/numpyro">numpyro</a> library from Uber, which uses JAX under the hood to perform fast Hamiltonian Monte Carlo. In addition, it’ll be interesting to see how this library is adopted as compared with other popular autodiff frameworks like Tensorflow and Pytorch.</p>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:1" role="doc-endnote">
<p>I’m assuming there’s only one maximum, but in reality there might be several if the posterior is multimodal. Multimodality is a pain, and Laplace’s approximation won’t do as well in this case (in fact most methods in Bayesian statistics share this weakness). <a href="#fnref:1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Angus WilliamsIntroduction Google recently released an interesting new library called JAX. It looks like it could be very useful for computational statistics, so I thought I’d take a look. In this post, I’ll describe some of the basic features of JAX, and then use them to implement Laplace’s approximation, a method used in Bayesian statistics. JAX is described as “Autograd and XLA, brought together”. Autograd refers to automatic differentiation, and XLA stands for accelerated linear algebra. Both of these are very useful in computational statistics, where we often have to differentiate things and perform lots of matrix (or tensor) manipulations. JAX effectively extends the numpy library to include these extra components. In doing so, it preserves an API that many scientists are familiar with, and introduces powerful new functionality. Automatic differentiation Automatic differentiation (or autodiff, for short) is a set of methods for evaluating the derivative of functions using a computer. If you’ve taken a calculus course before, you will have taken the derivative of functions by hand, e.g. \[f(x) = x ^ 2 \implies f'(x) = 2 x\] For the simple function above, differentiating is quite straightforward. But, when the function is complicated and has many arguments (for example, the objective function of a neural network), differentiating by hand quickly becomes unfeasible. Autodiff frameworks save us the trouble: we simply pass a function to the framework, and it returns another function that computes the gradient. This is really useful in computational statistics and machine learning, since we often want to optimise functions with respect to their arguments. Generally, we can carry out optimisation much more efficiently if we know gradients. It’s also really useful in Bayesian statistics, where the state-of-the-art Markov Chain Monte Carlo method, Hamiltonian Monte Carlo, relies on the calculation of the gradient of the posterior distribution. So, how do we do this in JAX? Here’s a snippet that defines the logistic function in python, and then uses JAX to compute its derivative: import jax.numpy as np from jax import grad def logistic(x): # logistic function return (1. + np.exp(-x)) ** -1.0 # differentiate with jax! grad_logistic = grad(logistic) print(logistic(0.0)) print(grad_logistic(0.0)) 0.5 0.25 That was easy! There are two lines to focus on in this snippet: import jax.numpy as np: instead of importing regular numpy, I imported jax.numpy which is JAX’s implementation of numpy functionality. After this line, we can pretty much forget about it and pretend that we’re using regular numpy most of the time. grad_logistic = grad(logistic): this is where the magic happens. We passed the logistic function to JAX’s grad function, and it returned another function, which I called grad_logistic. This function takes the same inputs as logistic, but returns the gradient with respect to these inputs. To convince ourselves that this all worked, let’s plot the logistic function and its derivative: from jax import vmap import matplotlib.pyplot as plt plt.rcParams["figure.figsize"] = (10, 5) plt.rcParams["font.size"] = 20 vectorised_grad_logistic = vmap(grad_logistic) x = np.linspace(-10.0, 10.0, 1000) fig, ax = plt.subplots() ax.plot(x, logistic(x), label="$f(x)$") ax.plot(x, vectorised_grad_logistic(x), label="$f'(x)$") _ = ax.legend() _ = ax.set_xlabel("$x$") You’ll notice that I had to define another function: vectorised_grad_logistic in order to make the plot. The reason is that functions produced by grad are not vectorised (cannot accept multiple inputs and return the gradient at each of the inputs). To facilitate this, we can wrap our grad_logistic function with vmap, which automatically vectorises it for us. This is already pretty neat. We can also obtain higher-order derivatives by the repeated application of grad: second_order_grad_logistic = vmap(grad(grad(logistic))) x = np.linspace(-10.0, 10.0, 1000) fig, ax = plt.subplots() ax.plot(x, logistic(x), label="$f(x)$") ax.plot(x, vectorised_grad_logistic(x), label="$f'(x)$") ax.plot(x, second_order_grad_logistic(x), label="$f''(x)$") _ = ax.legend() _ = ax.set_xlabel("$x$") JIT complilation Before I demonstrate an application of JAX, I want to mention another useful feature: JIT compilation. As you may know, python is an interpreted language, rather than being compiled. This is one of the reasons that python code can run slower than the same logic in a compiled language (like C). I won’t go into detail about why this is, but here’s a great blog post by Jake VanderPlas on the subject. One of the reasons why numpy is so useful is that it is calling C code under the hood, which is compiled. This means that it can be much faster than code that uses native python arrays. JAX adds an additional feature on top of this: Just In Time (JIT) compilation. The “JIT” part means that the code is compiled at runtime the first time that it is needed. Using this feature can speed up our code. To do this, we just have to apply the jit decorator to our function: from jax import jit @jit def jit_logistic(x): # logistic function return (1. + np.exp(-x)) ** -1.0 @jit def jit_grad_logistic(x): # compile the gradient as well return grad(logistic)(x) Now we can compare our JIT compiled functions to the ones we made earlier, and see if there’s any difference in execution time: %%timeit logistic(0.0) 376 µs ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) %%timeit jit_logistic(0.0) 90.4 µs ± 1.19 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) %%timeit grad_logistic(0.0) 1.76 ms ± 7.45 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) %%timeit jit_grad_logistic(0.0) 86.5 µs ± 1.48 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) For the plain logistic function, JIT led to a 4x speedup. For the gradient, there was a 20x speedup. These are non-trivial gains! I could have obtained even bigger speedups if I was using JAX in combination with a GPU. Putting it all together: the Laplace approximation I’ve barely scratched the surface of what JAX can do, but already there are interesting and useful applications with the functionality I have described here. As an example, I’ll describe how to implement an important method in Bayesian statistics: Laplace’s approximation. The Laplace approximation Imagine we have some probability model with some parameters \(\theta\), and we’ve constrained the model using some data \(D\). In Bayesian inference, our goal is always to calculate integrals like this: \[\mathbb{E}\left[h(\theta)\right] = \int \mathrm{d}\theta \, h(\theta) \, p(\theta | D)\] we are interested in the expectation of some function \(h(\theta)\) with respect to the posterior distribution \(p(\theta | D)\). For interesting models, the posterior is complex, and so we have no hope of calculating these integrals analytically. Because of this, Bayesians have devised many methods for approximating them. If you’ve got time, the best thing to do is use Markov Chain Monte Carlo. But, if your dataset is quite large relative to your time and computational budget, you may need to try something else. A typical choice is Variational Inference. Another, possibly less talked-about, approach is called Laplace’s approximation. It works really well when you have quite a lot of data because of the Bayesian central limit theorem. In this approach, we approximate the posterior distribution by a Normal distribution. This is a common approximation (it’s often used in Variational Inference too), but Laplace’s method has a specific way of finding the Normal distribution that best matches the posterior. Suppose we know the location \(\theta^*\) of the maximum point of the posterior1. Now let’s Taylor expand the log posterior around this point. To reduce clutter, I’ll use the notation \(\log p(\theta | D) \equiv f(\theta)\). For simplicity, let’s consider the case when \(\theta\) is scalar: \[f(\theta) \approx f(\theta^*) + \frac{\partial f}{\partial \theta}\bigg|_{\theta^*}\,(\theta - \theta^*) + \dfrac{1}{2}\frac{\partial^2 f}{\partial \theta^2}\bigg|_{\theta^*}\,(\theta - \theta^*)^2 \\ = f(\theta^*) + \dfrac{1}{2}\frac{\partial^2 f}{\partial \theta^2}\bigg|_{\theta^*}\,(\theta - \theta^*)^2\] The first derivative disappears because \(\theta^*\) is a maximum point, so the gradient there is zero. Let’s compare this to the logarithm of a normal distribution with mean \(\mu\) and standard deviation \(\sigma\), which I’ll call \(g(\theta)\): \[g(\theta) = -\frac{1}{2}\log (2\pi\sigma^2) - \dfrac{1}{2}\dfrac{1}{\sigma^2}(\theta - \mu)^2\] We can match up the terms in the expressions for \(g(\theta)\) and the Taylor expansion of \(f(\theta)\) (ignoring the constant additive terms) to see that \[\mu = \theta^* \\ \sigma^2 = \left(-\dfrac{\partial^2 f}{\partial \theta^2}\bigg|_{\theta^*}\right)^{-1}\] Consequently, we might try approximating the posterior distribution with a Normal distribution, and set the mean and variance to these values. In multiple dimensions, the covariance matrix of the resulting multivariate normal is the inverse of the Hessian matrix of the negative log posterior at \(\theta^*\): \[\Sigma_{ij} = \dfrac{\partial ^2 (-f)}{\partial \theta_i \partial \theta_j}^{-1}\bigg|_{\theta^*}\] Already, we can see that Laplace’s approximation requires us to be able to twice differentiate the posterior distribution in order to obtain \(\sigma\). In addition, we have to find the location \(\theta^*\) of the maximum of the posterior. We probably have to do this numerically, which means using some kind of optimisation routine. The most efficient of these optimisation routines require the gradient of the objective function. So, using Laplace’s approximation means we want to evaluate the first and second derivatives of the posterior. Sounds like a job for JAX! Example: a Student-t posterior distribution Suppose our true posterior is a 2D Student-t: \[p(\theta | D) \propto \left(1+\frac{1}{\nu}(\theta - \mu)^T \Sigma^{-1}(\theta - \mu)\right)^{-(\nu + \mathrm{dim}(\theta))/2}\] This is a simple example, and we can actually sample from a Student-t rather easily. Nevertheless, let’s go ahead and use it to implement Laplace’s method in JAX. Let’s set the values of the constants in the Student-t: \[\mu = \begin{pmatrix} 0.5 \\ 2 \end{pmatrix} \\ \Sigma = \begin{pmatrix} 1 & 0.5 \\ 0.5 & 1 \end{pmatrix} \\ \nu = 7\] First, let’s plot the log posterior: # choose some values for the Student-t sigma = np.array([(1.0, 0.5), (0.5, 1.0)]) mu = np.array([0.5, 2.0]) nu = np.array([7]) sigma_inv = np.linalg.inv(sigma) def log_posterior(theta): return np.log( 1.0 + nu ** -1.0 * np.dot((theta - mu), np.dot(sigma_inv, (theta - mu).T).T) ) * (0.5 * -(nu + theta.shape[0])) # plot the distribution x = np.linspace(-10, 10, 100) y = np.linspace(-10, 10, 100) X, Y = np.meshgrid(x, y) XY = np.stack((X, Y)).reshape(2, 10000).T Z = vmap(log_posterior, in_axes=0)(XY).reshape(100, 100) fig, ax = plt.subplots() ax.contourf(X, Y, Z) ax.set_xlabel(r"$\theta_0$") _ = ax.set_ylabel(r"$\theta_1$") Now let’s implement Laplace’s method in JAX: from jax import hessian from scipy.optimize import minimize from scipy.stats import multivariate_normal @jit def negative_log_posterior(theta): # negative log posterior to minimise return (-np.log( 1.0 + nu ** -1.0 * np.dot((theta - mu), np.dot(sigma_inv, (theta - mu).T).T) ) * (0.5 * -(nu + theta.shape[0])))[0] @jit def grad_negative_log_posterior(theta): # gradient of the negative log posterior return grad(negative_log_posterior)(theta) @jit def approx_covariance_matrix(theta): # evaluate the covariance matrix of the approximate normal return np.linalg.inv(hessian(negative_log_posterior)(theta)) # go! theta_star = minimize( negative_log_posterior, np.array([0.0, 0.0]), jac=grad_negative_log_posterior, method="BFGS" ).x sigma_approx = approx_covariance_matrix(theta_star) This is a very short piece of code! I had to define the negative log posterior (and JIT compiled it for speed), since we will minimise this to find \(\theta^*\). Then, I used JAX’s grad function to differentiate this once, so that we can used a gradient-based optimiser. Next, I used JAX’s hessian function to find the covariance matrix for our approximating normal. Finally, I used scipy’s minimize function to find the optimal point \(\theta^*\). Note that this code is actually rather general! As long as the function negative_log_posterior can be implemented in a way that JAX can differentiate (which it probably can), then the rest of the code stays exactly the same! Let’s have a look at how good our normal approximation is: from scipy.stats import norm from scipy.stats import t fig = plt.figure(constrained_layout=True, figsize=(15, 10)) spec = fig.add_gridspec(ncols=2, nrows=2) fig.subplots_adjust(hspace=0, wspace=0) ax3 = fig.add_subplot(spec[0, 0]) ax2 = fig.add_subplot(spec[1, 1]) ax1 = fig.add_subplot(spec[1, 0]) contour = ax1.contour( X, Y, Z / Z.max(), colors="0.4", levels=15, linestyles="-", linewidths=3 ) # calculate the density of the approximating Normal distribution Z_0 = ( multivariate_normal(mean=theta_star, cov=sigma_approx).logpdf(XY).reshape(100, 100) ) ax1.contour( X, Y, Z_0 / Z_0.max(), colors="#2c7fb8", levels=15, linestyles="--", linewidths=3 ) ax1.set_xlabel(r"$\theta_0$") ax1.set_ylabel(r"$\theta_1$") ax2.plot( norm.pdf(theta_grid, theta_star[1], np.sqrt(sigma_approx[1, 1])), theta_grid, c="#2c7fb8", ls="--", lw=3, ) ax2.plot( t.pdf(theta_grid, nu[1], mu[1], np.sqrt(sigma[1, 1])), theta_grid, c="0.4", lw=3 ) ax3.plot( theta_grid, norm.pdf(theta_grid, theta_star[0], np.sqrt(sigma_approx[0, 0])), c="#2c7fb8", ls="--", lw=3, label="Laplace", ) ax3.plot( theta_grid, t.pdf(theta_grid, nu[0], mu[0], np.sqrt(sigma[0, 0])), c="0.4", lw=3, label="Exact", ) ax3.legend() ax2.xaxis.set_visible(False) ax3.yaxis.set_visible(False) At least by eye, the approximation seems reasonable. Of course, I have rather cheated here since a Student-t approaches a normal distribution as \(\nu \rightarrow \infty\). Nonetheless, it’s still pleasing to see that the numerical implementation with JAX and scipy works as expected. Conclusion Hopefully this post has inspired you to go and play with JAX yourself. There are a ton of interesting applications that I can imagine for this library. Some already exist, such as the numpyro library from Uber, which uses JAX under the hood to perform fast Hamiltonian Monte Carlo. In addition, it’ll be interesting to see how this library is adopted as compared with other popular autodiff frameworks like Tensorflow and Pytorch. I’m assuming there’s only one maximum, but in reality there might be several if the posterior is multimodal. Multimodality is a pain, and Laplace’s approximation won’t do as well in this case (in fact most methods in Bayesian statistics share this weakness). ↩World cup redux2019-10-19T00:00:00+00:002019-10-19T00:00:00+00:00https://anguswilliams91.github.io/statistics/sport/world-cup-redux<p>Now that the group stages are over, here are the probabilities that the remaining teams will win the world cup:</p>
<table>
<thead>
<tr>
<th>Team</th>
<th>Probability of winning the world cup</th>
</tr>
</thead>
<tbody>
<tr>
<td>New Zealand</td>
<td>0.428</td>
</tr>
<tr>
<td>South Africa</td>
<td>0.206</td>
</tr>
<tr>
<td>England</td>
<td>0.165</td>
</tr>
<tr>
<td>Wales</td>
<td>0.082</td>
</tr>
<tr>
<td>Ireland</td>
<td>0.057</td>
</tr>
<tr>
<td>Australia</td>
<td>0.038</td>
</tr>
<tr>
<td>France</td>
<td>0.022</td>
</tr>
<tr>
<td>Japan</td>
<td>0.001</td>
</tr>
</tbody>
</table>
<p>This uses the same model as in my <a href="/statistics/sport/rugby-world-cup/">previous post</a>.
I’m quite sure that I am underrating Japan by not allowing form to vary!
Otherwise, the results aren’t wildly different to my predictions before the group stages.</p>Angus WilliamsNow that the group stages are over, here are the probabilities that the remaining teams will win the world cup: Team Probability of winning the world cup New Zealand 0.428 South Africa 0.206 England 0.165 Wales 0.082 Ireland 0.057 Australia 0.038 France 0.022 Japan 0.001 This uses the same model as in my previous post. I’m quite sure that I am underrating Japan by not allowing form to vary! Otherwise, the results aren’t wildly different to my predictions before the group stages.Who will win the rugby world cup?2019-09-21T00:00:00+00:002019-09-21T00:00:00+00:00https://anguswilliams91.github.io/statistics/sport/rugby-world-cup<p>The men’s rugby union world cup is just starting<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote">1</a></sup>, so I thought it would be fun to make some predictions before the tournament gets going.
The plan is to build a statistical model using previous match results, and then use it to evaluate the probability that each of the teams will win.
(All of the code I used to obtain the data and produce the results in this post is <a href="https://github.com/anguswilliams91/ruwc_2019">here</a>.)</p>
<h2 id="historical-data">Historical data</h2>
<p>I have modelled football data before, and have found it very easy to obtain.
Rugby data, on the other hand, proved a bit trickier to get hold of.
I couldn’t find a website where I could simply download a file containing historical results, so I had to resort to scraping ESPN.
I hadn’t done this kind of thing in a while, so it took a little while to inspect the html and figure out how to extract the data.
In the end, I downloaded all of the men’s international rugby union results between the present day and 1st January 2013 to use as a training set for my model.
A combination of the python standard library and Beautiful Soup got me there in the end!</p>
<h2 id="model">Model</h2>
<p>I will model the number of points scored by each of the teams in a match using an independent negative binomial model<sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote">2</a></sup>:</p>
\[\mathrm{points_{ij}} \sim \mathrm{NegBinom2}(\alpha_i \beta_j, \phi).\]
<p>\(\mathrm{points}_{ij}\) refers to the number of points scored by team \(i\) against team \(j\).
Each team is assigned an <em>attacking aptitude</em> \(\alpha_i\) and a <em>defending aptitude</em> \(\beta_i\).
The expected number of points scored by team \(i\) against team \(j\) is then equal to the product of team \(i\)’s attacking aptitude with team \(j\)’s defending aptitude.
This is fairly intuitive: the better team \(i\) is at attacking (larger \(\alpha_i\)), the more points they’ll score.
The better team \(j\)’s defending aptitude (smaller \(\beta_j\)), the fewer points team \(i\) will score.</p>
<p>This kind of model is very commonly used in the context of football matches (check out the classic <a href="http://web.math.ku.dk/~rolf/teaching/thesis/DixonColes.pdf">Dixon & Coles</a> paper).<sup id="fnref:3" role="doc-noteref"><a href="#fn:3" class="footnote">3</a></sup>
In that case, a Poisson likelihood is typically used, but I found that a negative binomial better replicated the distribution of scores in rugby matches.</p>
<p>I also use a hierarchical prior on the attack and defense aptitudes, e.g.:</p>
\[\log \alpha_i \sim \mathcal{N}(\mu_\alpha, \sigma_\alpha).\]
<p>This should regularise the model better.
Here’s the Stan code for the model:</p>
<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>data {
int<lower=1> nteam;
int<lower=1> nmatch;
int home_team[nmatch];
int away_team[nmatch];
int home_points[nmatch];
int away_points[nmatch];
}
parameters {
vector[nteam] log_a_tilde;
vector[nteam] log_b_tilde;
real<lower=0> sigma_a;
real<lower=0> sigma_b;
real mu_b;
real<lower=0> phi;
}
transformed parameters {
vector[nteam] a = exp(sigma_a * log_a_tilde);
vector[nteam] b = exp(mu_b + sigma_b * log_b_tilde);
vector[nmatch] home_rate = a[home_team] .* b[away_team];
vector[nmatch] away_rate = a[away_team] .* b[home_team];
}
model {
phi ~ normal(0, 5);
sigma_a ~ normal(0, 1);
sigma_b ~ normal(0, 1);
mu_b ~ normal(0, 5);
log_a_tilde ~ normal(0, 1);
log_b_tilde ~ normal(0, 1);
home_points ~ neg_binomial_2(home_rate, phi);
away_points ~ neg_binomial_2(away_rate, phi);
}
generated quantities {
int home_points_rep[nmatch];
int away_points_rep[nmatch];
for (i in 1:nmatch) {
home_points_rep[i] = neg_binomial_2_rng(home_rate[i], phi);
away_points_rep[i] = neg_binomial_2_rng(away_rate[i], phi);
}
}
</code></pre></div></div>
<p>If the implementation looks a bit funny, that’s because I used a <a href="https://arxiv.org/abs/1312.0906">non-centered version</a> of the model so that Stan’s sampler would work better.</p>
<h2 id="model-checks">Model checks</h2>
<p>In the interest of brevity, I won’t spend long on this.
But, it would be pretty bad practice not to spend <em>some</em> time showing that the model produces reasonable simulated data!
In the above Stan code you can see that I generate some simulated data for this purpose.
I end up with the same number of simulated datasets as there are steps in my MCMC chain.</p>
<p>One nice way to do visual checks is to plot the distribution of the data on the same axes as the distribution of a single simulated dataset.
Since we have lots of simulated datasets, we can make this plot multiple times.
This gives us an idea of whether the real data are “typical” of the model.
Here’s a figure like that, where I plot the distribution of total points scored in a match:</p>
<p><img src="/assets/images/rugby_wc_post/total_points_distro.png" alt="points-ppc" class="img-responsive" /></p>
<p>and another where I plot the distribution of the difference in points between the two teams:</p>
<p><img src="/assets/images/rugby_wc_post/difference_distro.png" alt="diff-ppc" class="img-responsive" /></p>
<p>The model seems to consistently produce a few matches with <em>very</em> high points totals relative to the data, but otherwise seems to be doing a reasonable job.</p>
<h2 id="simulating-the-world-cup">Simulating the world cup</h2>
<p>Now that I have posterior samples from the model, I can simulate the world cup many times and use the results to evaluate the probability that each of the teams will win.
To do this, I need to know the rules of the world cup.
In the group stages, teams are allocated 4 points for a win, 2 points for a draw and 0 points for a loss.
Additionally, teams are awarded a bonus point if they score 4 or more tries, or if they lose by 7 or fewer points.
Since my model produces total points, but does not predict the number of tries a team will score explicitly, I just allocate a bonus point if they score more than 25 points.</p>
<p>My recipe will be as follows:</p>
<ol>
<li>Select a set of model parameters \(\theta = (\{\alpha_i\}, \{\beta_i\}, \phi)\) from a single iteration of MCMC.</li>
<li>Use the parameters to simulate a single realisation of each of the group matches, and use the rules of the tournament to figure out which teams will graduate into the knockout stages.</li>
<li>Simulate each of the knockout stage matches, eventually ending up with a winner.</li>
<li>Store the results, and repeat (1) to (3) for every iteration of MCMC.</li>
</ol>
<p>Once these calculations have been done, I have thousands of simulated world cups.
To calculate the posterior predictive probability of a given team winning, all I have to do is calculate the fraction of times that team won in my simulations – simple!
I really like this side of using MCMC, it becomes straightforward to calculate approximate posterior predictive distribution of non-trivial functions.</p>
<h2 id="results">Results</h2>
<p>Ok – so who is going to win?
Here are the probabilities assigned to each of the teams by the model (I only display probabilities for teams for whom the probability is 0.01 or larger):</p>
<table>
<thead>
<tr>
<th>Team</th>
<th>Probability of winning the world cup</th>
</tr>
</thead>
<tbody>
<tr>
<td>New Zealand</td>
<td>0.50</td>
</tr>
<tr>
<td>England</td>
<td>0.15</td>
</tr>
<tr>
<td>South Africa</td>
<td>0.13</td>
</tr>
<tr>
<td>Ireland</td>
<td>0.08</td>
</tr>
<tr>
<td>Wales</td>
<td>0.06</td>
</tr>
<tr>
<td>Australia</td>
<td>0.05</td>
</tr>
<tr>
<td>France</td>
<td>0.01</td>
</tr>
<tr>
<td>Scotland</td>
<td>0.01</td>
</tr>
</tbody>
</table>
<p>New Zealand are massive favourites, with England, South Africa and Ireland all hovering at around 0.1 chance of winning.
I am perhaps surprised by Wales being assigned a noticeably lower probability than England, but perhaps this is due to the likely path they would need to take to the final being more difficult than England’s.
Also, Wales had a dire spell a few years ago, and the simple model I used does not account for changing form, so it might underrate Wales somewhat.</p>
<p>As an England fan, I am also curious to know how likely England are to get to various stages of the tournament.
The model gives England a probability of 0.94 of getting out of the group – so there’s a very good chance they’ll do better than at the last world cup, and we should get to see them in a quarter final.
The probability of them getting to the semi finals is 0.65, and 0.30 for the final.
So, fans should be disappointed if they don’t see England win at least one knockout match!</p>
<p>Whilst I’m at it, here are the model outputs for each of the groups (again I leave out teams with probability < 0.01, and round to the nearest 0.01).</p>
<h3 id="group-a">Group A</h3>
<table>
<thead>
<tr>
<th>Team</th>
<th>Probability of winning group A</th>
</tr>
</thead>
<tbody>
<tr>
<td>Ireland</td>
<td>0.71</td>
</tr>
<tr>
<td>Scotland</td>
<td>0.25</td>
</tr>
<tr>
<td>Japan</td>
<td>0.03</td>
</tr>
<tr>
<td>Samoa</td>
<td>0.01</td>
</tr>
</tbody>
</table>
<table>
<thead>
<tr>
<th>Team</th>
<th>Probability of being runner up of group A</th>
</tr>
</thead>
<tbody>
<tr>
<td>Scotland</td>
<td>0.50</td>
</tr>
<tr>
<td>Ireland</td>
<td>0.24</td>
</tr>
<tr>
<td>Japan</td>
<td>0.18</td>
</tr>
<tr>
<td>Samoa</td>
<td>0.08</td>
</tr>
</tbody>
</table>
<h3 id="group-b">Group B</h3>
<table>
<thead>
<tr>
<th>Team</th>
<th>Probability of winning group B</th>
</tr>
</thead>
<tbody>
<tr>
<td>New Zealand</td>
<td>0.75</td>
</tr>
<tr>
<td>South Africa</td>
<td>0.25</td>
</tr>
</tbody>
</table>
<table>
<thead>
<tr>
<th>Team</th>
<th>Probability of being runner up of group B</th>
</tr>
</thead>
<tbody>
<tr>
<td>South Africa</td>
<td>0.72</td>
</tr>
<tr>
<td>New Zealand</td>
<td>0.25</td>
</tr>
<tr>
<td>Italy</td>
<td>0.03</td>
</tr>
</tbody>
</table>
<h3 id="group-c">Group C</h3>
<table>
<thead>
<tr>
<th>Team</th>
<th>Probability of winning group C</th>
</tr>
</thead>
<tbody>
<tr>
<td>England</td>
<td>0.74</td>
</tr>
<tr>
<td>France</td>
<td>0.14</td>
</tr>
<tr>
<td>Argentina</td>
<td>0.11</td>
</tr>
</tbody>
</table>
<table>
<thead>
<tr>
<th>Team</th>
<th>Probability of being runner up of group C</th>
</tr>
</thead>
<tbody>
<tr>
<td>France</td>
<td>0.41</td>
</tr>
<tr>
<td>Argentina</td>
<td>0.35</td>
</tr>
<tr>
<td>England</td>
<td>0.20</td>
</tr>
<tr>
<td>United States of America</td>
<td>0.03</td>
</tr>
<tr>
<td>Tonga</td>
<td>0.02</td>
</tr>
</tbody>
</table>
<h3 id="group-d">Group D</h3>
<table>
<thead>
<tr>
<th>Team</th>
<th>Probability of winning group D</th>
</tr>
</thead>
<tbody>
<tr>
<td>Wales</td>
<td>0.49</td>
</tr>
<tr>
<td>Australia</td>
<td>0.45</td>
</tr>
<tr>
<td>Fiji</td>
<td>0.06</td>
</tr>
</tbody>
</table>
<table>
<thead>
<tr>
<th>Team</th>
<th>Probability of being runner up of group D</th>
</tr>
</thead>
<tbody>
<tr>
<td>Australia</td>
<td>0.41</td>
</tr>
<tr>
<td>Wales</td>
<td>0.38</td>
</tr>
<tr>
<td>Fiji</td>
<td>0.15</td>
</tr>
<tr>
<td>Georgia</td>
<td>0.05</td>
</tr>
</tbody>
</table>
<p>Three of the four groups have a clear favourite.
Group D, unsurprisingly, is a toss-up between Wales and Australia to top the group.</p>
<h2 id="conclusions">Conclusions</h2>
<p>This was a nice end-to-end bit of analysis: scraping the data, building the model and then simulating the results.
The conclusions are slightly at odds with the articles I’ve seen about the competition, which claim that it is very open and a few teams have a relatively even chance of winning.
On the other hand, when I checked the bookies’ odds on New Zealand, they were quite consistent with my results.</p>
<p>Nonetheless, England’s chances aren’t too bad, so there’s still hope!</p>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:1" role="doc-endnote">
<p>I am a bit late in posting this, and a few group stage matches have already happened! <a href="#fnref:1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:2" role="doc-endnote">
<p>I am using the alternative parameterisation of the negative binomial, in terms of its expectation \(\mu\) and dispersion parameter \(\phi\). <a href="#fnref:2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:3" role="doc-endnote">
<p>I actually have a python package on <a href="https://github.com/anguswilliams91/bpl">GitHub</a> that implements this model for football. <a href="#fnref:3" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Angus WilliamsThe men’s rugby union world cup is just starting1, so I thought it would be fun to make some predictions before the tournament gets going. The plan is to build a statistical model using previous match results, and then use it to evaluate the probability that each of the teams will win. (All of the code I used to obtain the data and produce the results in this post is here.) Historical data I have modelled football data before, and have found it very easy to obtain. Rugby data, on the other hand, proved a bit trickier to get hold of. I couldn’t find a website where I could simply download a file containing historical results, so I had to resort to scraping ESPN. I hadn’t done this kind of thing in a while, so it took a little while to inspect the html and figure out how to extract the data. In the end, I downloaded all of the men’s international rugby union results between the present day and 1st January 2013 to use as a training set for my model. A combination of the python standard library and Beautiful Soup got me there in the end! Model I will model the number of points scored by each of the teams in a match using an independent negative binomial model2: \[\mathrm{points_{ij}} \sim \mathrm{NegBinom2}(\alpha_i \beta_j, \phi).\] \(\mathrm{points}_{ij}\) refers to the number of points scored by team \(i\) against team \(j\). Each team is assigned an attacking aptitude \(\alpha_i\) and a defending aptitude \(\beta_i\). The expected number of points scored by team \(i\) against team \(j\) is then equal to the product of team \(i\)’s attacking aptitude with team \(j\)’s defending aptitude. This is fairly intuitive: the better team \(i\) is at attacking (larger \(\alpha_i\)), the more points they’ll score. The better team \(j\)’s defending aptitude (smaller \(\beta_j\)), the fewer points team \(i\) will score. This kind of model is very commonly used in the context of football matches (check out the classic Dixon & Coles paper).3 In that case, a Poisson likelihood is typically used, but I found that a negative binomial better replicated the distribution of scores in rugby matches. I also use a hierarchical prior on the attack and defense aptitudes, e.g.: \[\log \alpha_i \sim \mathcal{N}(\mu_\alpha, \sigma_\alpha).\] This should regularise the model better. Here’s the Stan code for the model: data { int<lower=1> nteam; int<lower=1> nmatch; int home_team[nmatch]; int away_team[nmatch]; int home_points[nmatch]; int away_points[nmatch]; } parameters { vector[nteam] log_a_tilde; vector[nteam] log_b_tilde; real<lower=0> sigma_a; real<lower=0> sigma_b; real mu_b; real<lower=0> phi; } transformed parameters { vector[nteam] a = exp(sigma_a * log_a_tilde); vector[nteam] b = exp(mu_b + sigma_b * log_b_tilde); vector[nmatch] home_rate = a[home_team] .* b[away_team]; vector[nmatch] away_rate = a[away_team] .* b[home_team]; } model { phi ~ normal(0, 5); sigma_a ~ normal(0, 1); sigma_b ~ normal(0, 1); mu_b ~ normal(0, 5); log_a_tilde ~ normal(0, 1); log_b_tilde ~ normal(0, 1); home_points ~ neg_binomial_2(home_rate, phi); away_points ~ neg_binomial_2(away_rate, phi); } generated quantities { int home_points_rep[nmatch]; int away_points_rep[nmatch]; for (i in 1:nmatch) { home_points_rep[i] = neg_binomial_2_rng(home_rate[i], phi); away_points_rep[i] = neg_binomial_2_rng(away_rate[i], phi); } } If the implementation looks a bit funny, that’s because I used a non-centered version of the model so that Stan’s sampler would work better. Model checks In the interest of brevity, I won’t spend long on this. But, it would be pretty bad practice not to spend some time showing that the model produces reasonable simulated data! In the above Stan code you can see that I generate some simulated data for this purpose. I end up with the same number of simulated datasets as there are steps in my MCMC chain. One nice way to do visual checks is to plot the distribution of the data on the same axes as the distribution of a single simulated dataset. Since we have lots of simulated datasets, we can make this plot multiple times. This gives us an idea of whether the real data are “typical” of the model. Here’s a figure like that, where I plot the distribution of total points scored in a match: and another where I plot the distribution of the difference in points between the two teams: The model seems to consistently produce a few matches with very high points totals relative to the data, but otherwise seems to be doing a reasonable job. Simulating the world cup Now that I have posterior samples from the model, I can simulate the world cup many times and use the results to evaluate the probability that each of the teams will win. To do this, I need to know the rules of the world cup. In the group stages, teams are allocated 4 points for a win, 2 points for a draw and 0 points for a loss. Additionally, teams are awarded a bonus point if they score 4 or more tries, or if they lose by 7 or fewer points. Since my model produces total points, but does not predict the number of tries a team will score explicitly, I just allocate a bonus point if they score more than 25 points. My recipe will be as follows: Select a set of model parameters \(\theta = (\{\alpha_i\}, \{\beta_i\}, \phi)\) from a single iteration of MCMC. Use the parameters to simulate a single realisation of each of the group matches, and use the rules of the tournament to figure out which teams will graduate into the knockout stages. Simulate each of the knockout stage matches, eventually ending up with a winner. Store the results, and repeat (1) to (3) for every iteration of MCMC. Once these calculations have been done, I have thousands of simulated world cups. To calculate the posterior predictive probability of a given team winning, all I have to do is calculate the fraction of times that team won in my simulations – simple! I really like this side of using MCMC, it becomes straightforward to calculate approximate posterior predictive distribution of non-trivial functions. Results Ok – so who is going to win? Here are the probabilities assigned to each of the teams by the model (I only display probabilities for teams for whom the probability is 0.01 or larger): Team Probability of winning the world cup New Zealand 0.50 England 0.15 South Africa 0.13 Ireland 0.08 Wales 0.06 Australia 0.05 France 0.01 Scotland 0.01 New Zealand are massive favourites, with England, South Africa and Ireland all hovering at around 0.1 chance of winning. I am perhaps surprised by Wales being assigned a noticeably lower probability than England, but perhaps this is due to the likely path they would need to take to the final being more difficult than England’s. Also, Wales had a dire spell a few years ago, and the simple model I used does not account for changing form, so it might underrate Wales somewhat. As an England fan, I am also curious to know how likely England are to get to various stages of the tournament. The model gives England a probability of 0.94 of getting out of the group – so there’s a very good chance they’ll do better than at the last world cup, and we should get to see them in a quarter final. The probability of them getting to the semi finals is 0.65, and 0.30 for the final. So, fans should be disappointed if they don’t see England win at least one knockout match! Whilst I’m at it, here are the model outputs for each of the groups (again I leave out teams with probability < 0.01, and round to the nearest 0.01). Group A Team Probability of winning group A Ireland 0.71 Scotland 0.25 Japan 0.03 Samoa 0.01 Team Probability of being runner up of group A Scotland 0.50 Ireland 0.24 Japan 0.18 Samoa 0.08 Group B Team Probability of winning group B New Zealand 0.75 South Africa 0.25 Team Probability of being runner up of group B South Africa 0.72 New Zealand 0.25 Italy 0.03 Group C Team Probability of winning group C England 0.74 France 0.14 Argentina 0.11 Team Probability of being runner up of group C France 0.41 Argentina 0.35 England 0.20 United States of America 0.03 Tonga 0.02 Group D Team Probability of winning group D Wales 0.49 Australia 0.45 Fiji 0.06 Team Probability of being runner up of group D Australia 0.41 Wales 0.38 Fiji 0.15 Georgia 0.05 Three of the four groups have a clear favourite. Group D, unsurprisingly, is a toss-up between Wales and Australia to top the group. Conclusions This was a nice end-to-end bit of analysis: scraping the data, building the model and then simulating the results. The conclusions are slightly at odds with the articles I’ve seen about the competition, which claim that it is very open and a few teams have a relatively even chance of winning. On the other hand, when I checked the bookies’ odds on New Zealand, they were quite consistent with my results. Nonetheless, England’s chances aren’t too bad, so there’s still hope! I am a bit late in posting this, and a few group stage matches have already happened! ↩ I am using the alternative parameterisation of the negative binomial, in terms of its expectation \(\mu\) and dispersion parameter \(\phi\). ↩ I actually have a python package on GitHub that implements this model for football. ↩Tracking my fitness2019-09-07T00:00:00+00:002019-09-07T00:00:00+00:00https://anguswilliams91.github.io/statistics/sport/gradient-adjusted-pace<p>I like to go running, and I also like to see how I’m doing by using the Strava app.
It’s satisfying to see progress over time if I’m training for an event.</p>
<p>This is pretty easy to do if I am repeating the same routes.
However, I recently moved house, and consequently have a new set of typical routes.
A big difference in my new surroundings is that it’s <em>really hilly</em>!
This means that my typical pace now comes out slower than before the move.</p>
<p>Fortunately, Strava has a nifty feature called <em>Gradient Adjusted Pace</em> (GAP).
GAP estimates of your pace are corrected for the vertical gradient of the terrain you’re running on.
This facilitates comparison between runs that took place in different locations – great!</p>
<p>A lazy search didn’t instantly tell me how GAP is calculated, so I thought it would be fun to try and come up with my own recipe instead.</p>
<h2 id="what-is-gap">What is GAP?</h2>
<p>To get going, I need a definition for GAP.
Let’s go with this:</p>
<p><em>“If I had done the same run in a parallel universe where all of the hills were removed, but everything else stayed the same, how fast would I have been?”</em></p>
<p>This achieves the desired goal by putting all my runs on a (literally) level playing field for comparison.
This definition also means that what I’m trying to calculate is a <em>counterfactual</em>, so I’ll need to interpret whatever model I build as causal.
To produce estimates of the answer to this question, I will need to use data from my previous runs to build a statistical model, which I can then query.</p>
<h2 id="data">Data</h2>
<p>To come up with my own GAP estimates, I’ll use data from Strava (there’s an API that you can use to download your data).
To keep things simple, I’ll model the <strong>average speed</strong> of each run as a function of the <strong>elevation gain</strong> during the run and the <strong>total distance</strong> of the run.
All of my runs are loops, so using elevation gain takes into account that I lose as much elevation as I gain.</p>
<p>Here’s a plot of the data:</p>
<p><img src="/assets/images/strava_post/distance_vs_speed.png" alt="distance-vs-speed" class="img-responsive" /></p>
<p>The relationship between the log of distance and the log of speed looks like it would be well modelled as a straight line with some scatter.
As expected, my average speed is lower when I run further.
The points are coloured by the elevation gain, which lets us see that I get slower if there’s more elevation gain in a run.
I’ve got data from 128 runs in total.</p>
<h2 id="statistical-model">Statistical model</h2>
<p>Based on the visualisation above, I came up with the following super-simple linear model</p>
\[\log\,(\mathrm{speed}) =
\alpha
+ \beta_\mathrm{elevation}\log\left(\mathrm{elevation} + \delta\right)
+ \beta_\mathrm{distance}\log\left(\mathrm{distance}\right)
+ \epsilon\]
<p>where \(\epsilon\) is normally distributed noise with variance \(\sigma ^2\).
Note that \(\delta\) is not a parameter: it’s a constant that I add to elevation so that I don’t run into trouble with logs when elevation is zero.
I set \(\delta = 10\mathrm{m}\).
I don’t have much data, so point estimates of the model parameters aren’t going to cut it.
To properly quantify my uncertainty, I’ll take a Bayesian approach and sample the posterior using <a href="https://mc-stan.org">Stan</a>.</p>
<p>Here’s the stan code for this model:</p>
<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>data {
int n;
vector[n] log_speed;
vector[n] log_elevation;
vector[n] log_distance;
}
parameters {
real beta_elevation;
real beta_distance;
real<lower=0> sigma;
}
transformed parameters {
vector[n] z = beta_elevation * log_elevation
+ beta_distance * log_distance;
}
model {
beta_elevation ~ normal(0, 1);
beta_distance ~ normal(0, 1);
sigma ~ normal(0, 5);
log_speed ~ normal(z, sigma);
}
generated quantities {
vector[n] log_speed_rep;
for (i in 1:n) {
log_speed_rep[i] = normal_rng(z[i], sigma);
}
}
</code></pre></div></div>
<p>I centred log speed, so there’s no explicit intercept term.
I also generate some simulated data to be used for model checking in the <code class="language-plaintext highlighter-rouge">generated quantities</code> block.
The marginal posterior distributions for the slopes look like this:</p>
<p><img src="/assets/images/strava_post/model_params.png" alt="slopes-posterior" class="img-responsive" /></p>
<p>Both elevation gain and total distance cause my average speed to go down, as expected.
Let’s also do a quick check of the model by plotting the distribution of the residual</p>
\[\log(\mathrm{speed}) - \log(\mathrm{speed})_\mathrm{rep}\]
<p>for each of the runs.
\(\log(\mathrm{speed})_\mathrm{rep}\) are the simulated speeds from the model posterior predictive distribution.
For each run (in temporal order), I display a box and whisker plot of these residuals.
The speeds and simulated speeds are scaled according to the mean and variance of the observed data, so we should expect to see the residuals distributed something like a unit normal, with some variation between runs.</p>
<p><img src="/assets/images/strava_post/residuals_simple.png" alt="posterior-checks" class="img-responsive" /></p>
<p>The figure shows that the model is doing a reasonable job at replicating the data, but there is obviously some autocorrelation visible (remember that the runs are plotted in temporal order).
This is probably because my fitness changed over time, which this model does not account for (we’ll get to that later).
Nonetheless, the model is a decent enough representation of the data and I’ll use it to calculate GAP estimates.</p>
<h2 id="calculating-gap">Calculating GAP</h2>
<p>Armed with this model, I can now calculate my simple GAP estimates by re-arranging my linear model and setting elevation to zero:</p>
\[\log (\mathrm{GAP}) = \log(\mathrm{speed}) + \beta_\mathrm{elevation}\left[\log(\delta) - \log(\mathrm{elevation} + \delta) \right].\]
<p>The counterfactual GAP speed is related to the actual speed through a correction proportional to \(\beta_\mathrm{elevation}\).
Because I have uncertainty about the value of \(\beta_\mathrm{elevation}\), I’ll also have uncertainty about the GAP value that I infer for each run.
The more elevation there is in a run, the more uncertainty there will be in the estimate of GAP.
To compute this uncertainty, I can just plug my MCMC samples for \(\beta_\mathrm{elevation}\) into the above formula.
Here’s the result of estimating GAP for all of my runs:</p>
<p><img src="/assets/images/strava_post/gap_vs_true.png" alt="gap-estimates" class="img-responsive" /></p>
<p>In this figure, GAP is plotted against my actual pace for the run.
The points are coloured by the elevation gain of the run.
The error-bars are the 95% credible interval for GAP.</p>
<p>The plot broadly makes sense: the GAP estimates are always lower than the true pace (i.e., I would have run faster on the flat), except for the one instance where I did a run where there was zero elevation gain.
Furthermore, runs with more elevation have a bigger difference between GAP and my actual pace.</p>
<p>It’s interesting to note that I only did a single run with zero elevation!
This means that I’m leaning on the model assumptions, and hoping that they are plausible when extrapolating to zero elevation gain.
To really test if this is true, I’d need to go out and do some more runs at zero elevation in a variety of conditions and distances (i.e., try to observe something close to the counterfactual).</p>
<p>I couldn’t see how to get Strava’s own GAP estimate out of the API, so I didn’t do a full comparison between the average GAP produced by Strava and my simple model.
I manually grabbed Strava’s GAP for the run with the largest elevation gain and did a comparison:</p>
<p><img src="/assets/images/strava_post/gap_vs_strava.png" alt="me-vs-strava" class="img-responsive" /></p>
<p>At least in this case, my approach and the Strava data are consistent with one another.
I probably won’t start using this instead of Strava’s estimates, but it was good fun to build a simple model myself.</p>
<h2 id="modelling-my-fitness">Modelling my fitness</h2>
<p>The simple approach to modelling my running pace produced reasonable results, but we saw in the model checks that there was some autocorrelation in the model errors.
The simple model above does not allow for variation in fitness – it just says that my pace is a function of how far I’ll go and how hilly the run is.
To include some notion of fitness, I expanded the model so that it looks like this:</p>
\[\log\,(\mathrm{pace}_i) =
\alpha_i
+ \beta_\mathrm{elevation}\log\left(\mathrm{elevation}_i + \delta\right)
+ \beta_\mathrm{distance}\log\left(\mathrm{distance}_i\right)
+ \epsilon\]
<p>where the index \(i\) encodes an ordering to my runs (e.g. run 2 comes after run 1 and before run 3).
The key difference between this model and the original one is that now the intercept \(\alpha\) is a function of time instead of a constant.
It can be thought of as my “base pace”: i.e., fitness.
I should probably index using time explicitly, but I was too lazy for that.
Since I know that my fitness varies smoothly with time, I then put a random walk prior on the intercepts:</p>
<p>\(\alpha_i = \alpha_{i - 1} + \zeta\),</p>
<p>where \(\zeta\) is normally distributed noise with variance \(\sigma_\mathrm{rw}^2\).
The size of the variance controls how rapidly my fitness can change between consecutive runs.
Now, I can infer the set \(\{\alpha_i\}\) and interpret them as my “fitness” on each of my runs.
Here’s the stan code for this model:</p>
<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>data {
int n;
vector[n] log_speed;
vector[n] log_elevation_gain;
vector[n] log_distance;
}
parameters {
real beta_elevation;
real beta_distance;
real<lower=0> sigma;
vector[n] fitness_std;
real<lower=0> sigma_rw;
}
transformed parameters {
vector[n] fitness;
vector[n] z;
fitness[1] = fitness_std[1];
for (i in 2:n) {
fitness[i] = fitness_std[i] * sigma_rw
+ fitness[i - 1];
}
z = fitness
+ beta_elevation * log_elevation_gain
+ beta_distance * log_distance;
}
model {
beta_elevation ~ normal(0, 1);
beta_distance ~ normal(0, 1);
sigma ~ normal(0, 1);
log_speed ~ normal(z, sigma);
fitness_std ~ normal(0, 1);
sigma_rw ~ normal(0, 1);
}
generated quantities {
vector[n] log_speed_rep;
for (i in 1:n) {
log_speed_rep[i] = normal_rng(z[i], sigma);
}
}
</code></pre></div></div>
<p>After running MCMC, I can plot my inferred fitness over time:</p>
<p><img src="/assets/images/strava_post/fitness_trend.png" alt="fitness-trend" class="img-responsive" /></p>
<p>The grey band is the 68% credible interval, and the black dots mark when a run took place.
The results look broadly as I would expect them to.
I know I was pretty fit last summer, but then had an injury which bothered me until early January.
I then started training again for a couple of half-marathons in May / June.
A definite issue with this approach is that the amount of effort I put into runs is variable, but the model assumes that I am trying my best in every run.
One way to get around this would be to include heart-rate data, which provide a measure of how strained I am during the run.
But as a simple first approach, the results are reasonable.</p>
<p>Let’s see if that autocorrelation we saw in the previous model check has been reduced:</p>
<p><img src="/assets/images/strava_post/residuals.png" alt="posterior-checks-2" class="img-responsive" /></p>
<p>It definitely has, although there’s still some present in the middle of the plot (the large negative residuals clustered together).
I think that this is highlighting an incorrect assumption that I made: fitness varies smoothly over time <em>unless</em> you get injured.
Then it changes very abruptly.
I actually got injured last year, and so my runs became notably slower for a while as I recovered.
The injury is effectively a <em>change-point</em> in my fitness.
I reckon what’s going on is this: because of the random walk prior, the model is forced to smoothly approach a low fitness, which means it underestimates the speed of the runs just before the injury.</p>
<h2 id="final-thoughts">Final thoughts</h2>
<p>This was a fun exercise, and it was quite satisfying to do analysis of my own running data!
Very simple statistical models produced relatively interesting insights – especially the fitness metric in the final section.
I might try adding heart rate data into the model at some point, and see if this improves the results.
If you’re interested in trying this for yourself, I put the notebook I used to generate the results from this post in a <a href="https://github.com/anguswilliams91/negsplit/">github repo</a> (be warned: it’s a bit scrappy).</p>Angus WilliamsI like to go running, and I also like to see how I’m doing by using the Strava app. It’s satisfying to see progress over time if I’m training for an event. This is pretty easy to do if I am repeating the same routes. However, I recently moved house, and consequently have a new set of typical routes. A big difference in my new surroundings is that it’s really hilly! This means that my typical pace now comes out slower than before the move. Fortunately, Strava has a nifty feature called Gradient Adjusted Pace (GAP). GAP estimates of your pace are corrected for the vertical gradient of the terrain you’re running on. This facilitates comparison between runs that took place in different locations – great! A lazy search didn’t instantly tell me how GAP is calculated, so I thought it would be fun to try and come up with my own recipe instead. What is GAP? To get going, I need a definition for GAP. Let’s go with this: “If I had done the same run in a parallel universe where all of the hills were removed, but everything else stayed the same, how fast would I have been?” This achieves the desired goal by putting all my runs on a (literally) level playing field for comparison. This definition also means that what I’m trying to calculate is a counterfactual, so I’ll need to interpret whatever model I build as causal. To produce estimates of the answer to this question, I will need to use data from my previous runs to build a statistical model, which I can then query. Data To come up with my own GAP estimates, I’ll use data from Strava (there’s an API that you can use to download your data). To keep things simple, I’ll model the average speed of each run as a function of the elevation gain during the run and the total distance of the run. All of my runs are loops, so using elevation gain takes into account that I lose as much elevation as I gain. Here’s a plot of the data: The relationship between the log of distance and the log of speed looks like it would be well modelled as a straight line with some scatter. As expected, my average speed is lower when I run further. The points are coloured by the elevation gain, which lets us see that I get slower if there’s more elevation gain in a run. I’ve got data from 128 runs in total. Statistical model Based on the visualisation above, I came up with the following super-simple linear model \[\log\,(\mathrm{speed}) = \alpha + \beta_\mathrm{elevation}\log\left(\mathrm{elevation} + \delta\right) + \beta_\mathrm{distance}\log\left(\mathrm{distance}\right) + \epsilon\] where \(\epsilon\) is normally distributed noise with variance \(\sigma ^2\). Note that \(\delta\) is not a parameter: it’s a constant that I add to elevation so that I don’t run into trouble with logs when elevation is zero. I set \(\delta = 10\mathrm{m}\). I don’t have much data, so point estimates of the model parameters aren’t going to cut it. To properly quantify my uncertainty, I’ll take a Bayesian approach and sample the posterior using Stan. Here’s the stan code for this model: data { int n; vector[n] log_speed; vector[n] log_elevation; vector[n] log_distance; } parameters { real beta_elevation; real beta_distance; real<lower=0> sigma; } transformed parameters { vector[n] z = beta_elevation * log_elevation + beta_distance * log_distance; } model { beta_elevation ~ normal(0, 1); beta_distance ~ normal(0, 1); sigma ~ normal(0, 5); log_speed ~ normal(z, sigma); } generated quantities { vector[n] log_speed_rep; for (i in 1:n) { log_speed_rep[i] = normal_rng(z[i], sigma); } } I centred log speed, so there’s no explicit intercept term. I also generate some simulated data to be used for model checking in the generated quantities block. The marginal posterior distributions for the slopes look like this: Both elevation gain and total distance cause my average speed to go down, as expected. Let’s also do a quick check of the model by plotting the distribution of the residual \[\log(\mathrm{speed}) - \log(\mathrm{speed})_\mathrm{rep}\] for each of the runs. \(\log(\mathrm{speed})_\mathrm{rep}\) are the simulated speeds from the model posterior predictive distribution. For each run (in temporal order), I display a box and whisker plot of these residuals. The speeds and simulated speeds are scaled according to the mean and variance of the observed data, so we should expect to see the residuals distributed something like a unit normal, with some variation between runs. The figure shows that the model is doing a reasonable job at replicating the data, but there is obviously some autocorrelation visible (remember that the runs are plotted in temporal order). This is probably because my fitness changed over time, which this model does not account for (we’ll get to that later). Nonetheless, the model is a decent enough representation of the data and I’ll use it to calculate GAP estimates. Calculating GAP Armed with this model, I can now calculate my simple GAP estimates by re-arranging my linear model and setting elevation to zero: \[\log (\mathrm{GAP}) = \log(\mathrm{speed}) + \beta_\mathrm{elevation}\left[\log(\delta) - \log(\mathrm{elevation} + \delta) \right].\] The counterfactual GAP speed is related to the actual speed through a correction proportional to \(\beta_\mathrm{elevation}\). Because I have uncertainty about the value of \(\beta_\mathrm{elevation}\), I’ll also have uncertainty about the GAP value that I infer for each run. The more elevation there is in a run, the more uncertainty there will be in the estimate of GAP. To compute this uncertainty, I can just plug my MCMC samples for \(\beta_\mathrm{elevation}\) into the above formula. Here’s the result of estimating GAP for all of my runs: In this figure, GAP is plotted against my actual pace for the run. The points are coloured by the elevation gain of the run. The error-bars are the 95% credible interval for GAP. The plot broadly makes sense: the GAP estimates are always lower than the true pace (i.e., I would have run faster on the flat), except for the one instance where I did a run where there was zero elevation gain. Furthermore, runs with more elevation have a bigger difference between GAP and my actual pace. It’s interesting to note that I only did a single run with zero elevation! This means that I’m leaning on the model assumptions, and hoping that they are plausible when extrapolating to zero elevation gain. To really test if this is true, I’d need to go out and do some more runs at zero elevation in a variety of conditions and distances (i.e., try to observe something close to the counterfactual). I couldn’t see how to get Strava’s own GAP estimate out of the API, so I didn’t do a full comparison between the average GAP produced by Strava and my simple model. I manually grabbed Strava’s GAP for the run with the largest elevation gain and did a comparison: At least in this case, my approach and the Strava data are consistent with one another. I probably won’t start using this instead of Strava’s estimates, but it was good fun to build a simple model myself. Modelling my fitness The simple approach to modelling my running pace produced reasonable results, but we saw in the model checks that there was some autocorrelation in the model errors. The simple model above does not allow for variation in fitness – it just says that my pace is a function of how far I’ll go and how hilly the run is. To include some notion of fitness, I expanded the model so that it looks like this: \[\log\,(\mathrm{pace}_i) = \alpha_i + \beta_\mathrm{elevation}\log\left(\mathrm{elevation}_i + \delta\right) + \beta_\mathrm{distance}\log\left(\mathrm{distance}_i\right) + \epsilon\] where the index \(i\) encodes an ordering to my runs (e.g. run 2 comes after run 1 and before run 3). The key difference between this model and the original one is that now the intercept \(\alpha\) is a function of time instead of a constant. It can be thought of as my “base pace”: i.e., fitness. I should probably index using time explicitly, but I was too lazy for that. Since I know that my fitness varies smoothly with time, I then put a random walk prior on the intercepts: \(\alpha_i = \alpha_{i - 1} + \zeta\), where \(\zeta\) is normally distributed noise with variance \(\sigma_\mathrm{rw}^2\). The size of the variance controls how rapidly my fitness can change between consecutive runs. Now, I can infer the set \(\{\alpha_i\}\) and interpret them as my “fitness” on each of my runs. Here’s the stan code for this model: data { int n; vector[n] log_speed; vector[n] log_elevation_gain; vector[n] log_distance; } parameters { real beta_elevation; real beta_distance; real<lower=0> sigma; vector[n] fitness_std; real<lower=0> sigma_rw; } transformed parameters { vector[n] fitness; vector[n] z; fitness[1] = fitness_std[1]; for (i in 2:n) { fitness[i] = fitness_std[i] * sigma_rw + fitness[i - 1]; } z = fitness + beta_elevation * log_elevation_gain + beta_distance * log_distance; } model { beta_elevation ~ normal(0, 1); beta_distance ~ normal(0, 1); sigma ~ normal(0, 1); log_speed ~ normal(z, sigma); fitness_std ~ normal(0, 1); sigma_rw ~ normal(0, 1); } generated quantities { vector[n] log_speed_rep; for (i in 1:n) { log_speed_rep[i] = normal_rng(z[i], sigma); } } After running MCMC, I can plot my inferred fitness over time: The grey band is the 68% credible interval, and the black dots mark when a run took place. The results look broadly as I would expect them to. I know I was pretty fit last summer, but then had an injury which bothered me until early January. I then started training again for a couple of half-marathons in May / June. A definite issue with this approach is that the amount of effort I put into runs is variable, but the model assumes that I am trying my best in every run. One way to get around this would be to include heart-rate data, which provide a measure of how strained I am during the run. But as a simple first approach, the results are reasonable. Let’s see if that autocorrelation we saw in the previous model check has been reduced: It definitely has, although there’s still some present in the middle of the plot (the large negative residuals clustered together). I think that this is highlighting an incorrect assumption that I made: fitness varies smoothly over time unless you get injured. Then it changes very abruptly. I actually got injured last year, and so my runs became notably slower for a while as I recovered. The injury is effectively a change-point in my fitness. I reckon what’s going on is this: because of the random walk prior, the model is forced to smoothly approach a low fitness, which means it underestimates the speed of the runs just before the injury. Final thoughts This was a fun exercise, and it was quite satisfying to do analysis of my own running data! Very simple statistical models produced relatively interesting insights – especially the fitness metric in the final section. I might try adding heart rate data into the model at some point, and see if this improves the results. If you’re interested in trying this for yourself, I put the notebook I used to generate the results from this post in a github repo (be warned: it’s a bit scrappy).