advanced bayesian-ml 60 min read

Bayesian Neural Networks

Weight-space posteriors over neural networks: Laplace approximation, MC-dropout as approximate VI, deep ensembles, stochastic-gradient MCMC (SGLD and SGHMC), calibration diagnostics, and the function-space view via NNGP and NTK

Overview

A trained neural network gives us a function, but it does not tell us how confident to be in that function. A point-estimate classifier confidently extrapolates its decision rule into regions where no training data lives, with no language for “I don’t know.” Bayesian neural networks are the response: instead of a single weight vector ww^* that minimizes a training loss, we work with a distribution p(wD)p(w \mid \mathcal{D}) over weights and integrate predictions against it, so predictive variance grows where the data has nothing to say. The catch is that p(wD)p(w \mid \mathcal{D}) is a distribution on Rp\mathbb{R}^p for pp on the order of 10410^4 to 10910^9, and the posterior is intractable in four distinct ways for any non-trivial network.

This topic develops the practical recipes that make BNN inference work at deep-learning scale. §1 frames the problem on a 2D Two Moons toy classifier so the picture is visible. §2 lifts the geometric picture into a formal weight-space posterior, derives the weight-decay-as-MAP identity, and explains the three structural ways the Bernstein–von Mises asymptotic fails for neural networks. §§3–7 develop five recipes for fitting the posterior — Laplace approximation, MC-dropout, deep ensembles, stochastic-gradient Langevin dynamics (SGLD), and stochastic-gradient HMC (SGHMC) — each addressing a different subset of the §2.4 obstacles. §8 evaluates them head-to-head on calibration metrics: expected calibration error, the Brier score, and negative log-likelihood as a proper scoring rule, plus the epistemic-aleatoric decomposition and the cold-posterior effect. §9 closes with the function-space view from the neural network Gaussian process (NNGP, Neal 1996) and the neural tangent kernel (NTK, Jacot et al. 2018), which ties weight-space methods to the Gaussian Processes machinery.

This is the sixth topic of T5 Bayesian & Probabilistic ML, after Variational Inference, Gaussian Processes, Probabilistic Programming, Mixed-Effects Models, and Stacking & Predictive Ensembles. It is the T5 flagship — the topic where the substrate of the prior five topics (variational families, GP machinery, model declaration, hierarchical structure, predictive averaging) gets combined into the catalog of methods practitioners actually reach for when the model is a deep network and uncertainty matters.

1. Why Bayesian over the weights

A trained neural network gives us a function, but it does not tell us how confident to be in that function. A point-estimate classifier on the Two Moons distribution — two interleaving crescents in R2\mathbb{R}^2, separable by a smooth nonlinear boundary, contaminated with isotropic noise — produces a single decision rule and confidently extrapolates that rule into regions where no training data lives. The rule is wrong somewhere, but the model has no way to say where. For an ML practitioner, this is a structural problem: medical-imaging classifiers, autonomous-vehicle perception models, and clinical-decision-support systems all need to know when they don’t know. A model that returns “99% confident” everywhere — including on inputs unlike anything it has ever seen — is a liability.

Bayesian neural networks are the response. Instead of a single weight vector wRpw^* \in \mathbb{R}^p that minimizes a training loss, we work with a distribution p(wD)p(w \mid \mathcal{D}) over weights, conditioned on the training data D={(xi,yi)}i=1n\mathcal{D} = \{(x_i, y_i)\}_{i=1}^n. Predictions integrate over that distribution. This topic is the catalogue of practical recipes for building, fitting, and reading that distribution when the model has thousands or millions of weights and the posterior is intractable in every standard sense. §§3–7 develop five recipes — Laplace approximation, MC-dropout, deep ensembles, SGLD, SGHMC. §8 evaluates them head-to-head on calibration. §9 closes with the function-space view. This section frames why the recipes are needed and what they share.

1.1 What a point-estimate model can’t tell us

Fix notation. A neural network with weights wRpw \in \mathbb{R}^p is a parametric function fw:XYf_w : \mathcal{X} \to \mathcal{Y} assembled from compositions of affine maps and elementwise nonlinearities. For binary classification, Y={0,1}\mathcal{Y} = \{0, 1\} and the model produces a real-valued logit fw(x)Rf_w(x) \in \mathbb{R} which we squash through the logistic sigmoid σ(z)=(1+ez)1\sigma(z) = (1 + e^{-z})^{-1} to get a class-1 probability σ(fw(x))\sigma(f_w(x)). The likelihood under the Bernoulli observation model is p(yx,w)  =  σ(fw(x))y(1σ(fw(x)))1y.p(y \mid x, w) \;=\; \sigma(f_w(x))^y \bigl(1 - \sigma(f_w(x))\bigr)^{1 - y}.

A point-estimate network minimizes a regularized cross-entropy loss w  =  argminwRp{i=1nlogp(yixi,w)+λ2w2}w^* \;=\; \arg\min_{w \in \mathbb{R}^p}\,\Bigl\{ -\sum_{i=1}^n \log p(y_i \mid x_i, w) + \tfrac{\lambda}{2}\|w\|^2 \Bigr\} using stochastic gradient descent or one of its momentum variants, and returns the single learned weight vector ww^*. At test time the model produces a single predicted probability σ(fw(x))\sigma(f_{w^*}(x^*)). Two things are missing.

The first is uncertainty about the prediction. The model returns a number near 00 or near 11, with no companion estimate of how much that number would change under a different reasonable choice of ww. The second is uncertainty about ww itself. The training loss is non-convex in ww; the SGD trajectory ends in some local minimum determined by initialization, batch order, and learning-rate schedule; a different run with a different seed lands in a different minimum, with potentially different predictions far from the training data. The point estimate ww^* is one of many plausible weight vectors, and the model has no language for that fact.

1.2 The predictive distribution and the four obstacles

The Bayesian fix is to replace ww^* with the full posterior. Place a prior p(w)p(w) on the weights — typically the isotropic Gaussian N(0,τ2Ip)\mathcal{N}(0, \tau^2 I_p) that pairs with the 2\ell_2 regularizer above (§2.2 makes this correspondence explicit). Bayes’ rule gives p(wD)  =  p(Dw)p(w)p(D),p(D)  =  p(Dw)p(w)dw,p(w \mid \mathcal{D}) \;=\; \frac{p(\mathcal{D} \mid w)\,p(w)}{p(\mathcal{D})}, \qquad p(\mathcal{D}) \;=\; \int p(\mathcal{D} \mid w)\,p(w)\,dw, where p(Dw)=i=1np(yixi,w)p(\mathcal{D} \mid w) = \prod_{i=1}^n p(y_i \mid x_i, w) is the likelihood of the training data under weights ww. The predictive distribution at a new input xx^* marginalizes the weights out: p(yx,D)  =  p(yx,w)p(wD)dw.p(y^* \mid x^*, \mathcal{D}) \;=\; \int p(y^* \mid x^*, w)\,p(w \mid \mathcal{D})\,dw.

Where the data is dense, plausible weight settings agree on what yy^* should be, and the integral concentrates — predictive variance is small. Where the data is sparse, plausible weight settings disagree, the integral spreads, and predictive variance grows. The model knows when it doesn’t know, because “I don’t know” is encoded as breadth in the weight posterior.

The catch is that p(wD)p(w \mid \mathcal{D}) is a distribution on Rp\mathbb{R}^p for pp on the order of 10410^4 to 10910^9, and the posterior is intractable in four distinct ways. The marginal likelihood p(D)p(\mathcal{D}) has no closed form for any non-trivial network, so we cannot evaluate p(wD)p(w \mid \mathcal{D}) pointwise without an approximation. The negative log-likelihood is deeply non-convex in ww: by symmetry, every loss-landscape mode has copies under permutations of the hidden units, and the modes are separated by sharp ridges where Gaussian approximations break down. The dimension pp is high enough that vanilla MCMC mixes too slowly to be useful at deep-learning scale. And the likelihood gradient wlogp(Dw)=i=1nwlogp(yixi,w)\nabla_w \log p(\mathcal{D} \mid w) = \sum_{i=1}^n \nabla_w \log p(y_i \mid x_i, w) requires a full pass over D\mathcal{D} — feasible for nn in the hundreds, infeasible for the millions that motivate using a neural network in the first place.

Every method in this topic is an answer to one of these obstacles. Laplace approximation (§3) gives up on the multimodal structure and fits a single local Gaussian centered at the MAP estimate — cheap and asymptotically justified by the Bernstein–von Mises theorem from formalStatistics: Central Limit Theorem that §2 imports from formalstatistics. MC-dropout (§4) reinterprets the dropout regularizer used at training time as a Bernoulli variational posterior and turns the deterministic predictor’s stochastic forward passes into Monte Carlo posterior samples — almost free, but the variational family is rigid. Deep ensembles (§5) abandon weight-space inference altogether and treat KK independently-trained networks as approximate samples in function space — empirically the strongest of the cheap methods, but the theoretical justification is delicate. Stochastic-gradient Langevin dynamics (§6) and stochastic-gradient HMC (§7) bring the asymptotic exactness of MCMC into the deep-learning regime by injecting calibrated Gaussian noise into mini-batch gradient updates — the noise compensates for the bias mini-batching introduces, and the resulting stochastic process has the posterior as its stationary distribution.

1.3 The geometric picture and the road ahead

For the rest of this section we work on a 2D toy classifier so the picture is visible. Data are drawn from Two Moons with noise level 0.200.20 at n=300n = 300 points. We fit a small MLP — three hidden layers, 32 ReLU units each — by minimizing the regularized binary cross-entropy loss with λ=104\lambda = 10^{-4} for 200 Adam epochs. The point-estimate model gives the decision surface in panel (a) of Figure 1: a clean boundary separating the two classes, with predicted probabilities near 00 or near 11 across most of the input space — high confidence even at distances from the data where any honest model should hesitate.

To see what an honest predictive distribution looks like we fit five copies of the same architecture from independent random initializations. Panel (b) overlays the five 0.50.5-probability decision boundaries. Where the data is dense, the boundaries agree and the overlay is sharp. Far from the data — in the corners of the input frame — the boundaries fan out: five plausible models trained on the same data give five different answers in the regions where the data has nothing to say. Panel (c) renders the disagreement as a heatmap of predictive variance over the input space, the variance of the predicted class-1 probabilities across the five models. Bright regions are where the model knows it doesn’t know. The five-model ensemble previewed here is the embryo of every method in this topic, and §5 formalizes it as a deep ensemble; §§3, 4, 6, 7 will produce comparable heatmaps from very different mechanisms.

A second axis of uncertainty matters too. If two data points sit in the same input neighborhood but carry different class labels — which Two Moons can produce in the strip between the crescents at high noise — no model on any architecture can confidently predict that neighborhood. That is aleatoric uncertainty: irreducible noise inherent in the data-generating process. The variance from disagreement among trained models is epistemic uncertainty: it reflects what the model would learn from more data, not what no model can ever learn. §8 disentangles the two formally; for now we note that BNNs primarily address epistemic uncertainty, and that a calibrated BNN should report large epistemic variance off-distribution and large aleatoric variance in noisy regions of the support.

The §2 derivation lifts this geometric picture into a formal weight-space posterior — the object the rest of the topic approximates. With that machinery in hand, §3 fits its first Bayesian neural network.

Three panels on Two Moons data: panel (a) point-estimate predicted-probability heatmap with a sharp confident decision surface that is correct near the data but arbitrary far from it; panel (b) five independently-trained MLPs’ 0.5-probability contours overlaid as red lines, agreeing tightly near the data and fanning out far from any training point; panel (c) a viridis heatmap of predictive variance computed across the five MLP predictions, dark near the data and bright in the off-distribution corners.
Loading…
Figure 1. The ensemble preview. (a) A single trained MLP confidently predicts everywhere — including regions far from the training data. (b) K MLPs trained from different seeds agree on the data and disagree off it — disagreement among independently-trained models is itself a kind of uncertainty quantification. (c) The variance over the input space recovers the desideratum: the model is uncertain where it lacks data. Drag the noise slider to retrain on a noisier dataset; vary K to see how few members are needed before the variance pattern stabilizes.

2. The weight-space posterior

The §1 picture — a model that should report breadth of plausible weight settings as predictive variance — is geometric. §2 lifts it into formal mathematics. The deliverable of this section is four facts: (i) the posterior under a Gaussian weight prior has a closed-form negative-log expression; (ii) maximizing that posterior is equivalent to L2-penalized empirical-risk minimization, so the well-known weight decay regularizer is a maximum-a-posteriori estimator under a specific prior; (iii) the Bayesian central-limit theorem (Bernstein–von Mises) guarantees that, in well-specified parametric models, the posterior asymptotically concentrates as a Gaussian centered at the MLE; and (iv) the regularity conditions for that theorem fail for typical neural networks, in three specific ways that motivate the four approximation strategies of §§3–7.

2.1 The Gaussian-prior MLP

Fix a parametric MLP fw:XRf_w: \mathcal{X} \to \mathbb{R} with weights wRpw \in \mathbb{R}^p. For binary classification the conditional distribution of the label given the input is Bernoulli with logit fw(x)f_w(x): p(yx,w)  =  σ(fw(x))y(1σ(fw(x)))1y,p(y \mid x, w) \;=\; \sigma(f_w(x))^y \bigl(1 - \sigma(f_w(x))\bigr)^{1 - y}, where σ(z)=(1+ez)1\sigma(z) = (1 + e^{-z})^{-1} is the logistic sigmoid. We complete the model with a Gaussian prior on the weights:

Definition 2.1 (Gaussian-prior MLP).

Let fw:XRf_w: \mathcal{X} \to \mathbb{R} be an MLP with weights wRpw \in \mathbb{R}^p and Bernoulli observation likelihood as above. Place the isotropic Gaussian prior wN(0,τ2Ip)w \sim \mathcal{N}(0, \tau^2 I_p) for some scale τ>0\tau > 0. The weight-space posterior given training data D={(xi,yi)}i=1n\mathcal{D} = \{(x_i, y_i)\}_{i=1}^n is the conditional distribution p(wD)  =  p(Dw)p(w)p(D),p(Dw)  =  i=1np(yixi,w).p(w \mid \mathcal{D}) \;=\; \frac{p(\mathcal{D} \mid w)\,p(w)}{p(\mathcal{D})}, \qquad p(\mathcal{D} \mid w) \;=\; \prod_{i=1}^n p(y_i \mid x_i, w).

The choice of τ2Ip\tau^2 I_p as the prior covariance is the default in the BNN literature and the one we’ll work with throughout. It assumes weights are a priori independent and identically distributed — a strong assumption that ignores all structure in the weight tensor (layer, channel, position) but keeps the math tractable. Heavy-tailed and hierarchical priors are the subject of Sparse Bayesian Priors; per-layer prior scales are revisited in §8.5’s cold-posterior remark. For now, τ\tau is a single scalar hyperparameter — large τ\tau is a weak prior (close to MLE), small τ\tau is a strong prior (heavy regularization).

2.2 The negative log-posterior, and weight decay as a Bayesian prior

Take the negative log of the posterior: logp(wD)  =  logp(Dw)    logp(w)  +  logp(D).-\log p(w \mid \mathcal{D}) \;=\; -\log p(\mathcal{D} \mid w) \;-\; \log p(w) \;+\; \log p(\mathcal{D}). The first term is the negative log-likelihood — the cross-entropy training loss (modulo signs). The second is a quadratic in ww, since the prior is Gaussian: logp(w)=12τ2w22+constτ,p-\log p(w) = \frac{1}{2\tau^2}\|w\|_2^2 + \mathrm{const}_{\tau, p}, where the constant absorbs log((2πτ2)p/2)\log\bigl((2\pi\tau^2)^{p/2}\bigr). The third does not depend on ww. Stripping ww-independent constants we get logp(wD)  =  i=1nlogp(yixi,w)  +  12τ2w22  +  C,-\log p(w \mid \mathcal{D}) \;=\; -\sum_{i=1}^n \log p(y_i \mid x_i, w) \;+\; \frac{1}{2\tau^2}\|w\|_2^2 \;+\; C, where CC is a constant in ww. This is structurally the L2-penalized cross-entropy training loss with regularization strength λ=1/τ2\lambda = 1/\tau^2.

Proposition 2.2 (MAP equals weight decay).

Let LWD(w)  :=  i=1nlogp(yixi,w)  +  λ2w22\mathcal{L}_{\mathrm{WD}}(w) \;:=\; -\sum_{i=1}^n \log p(y_i \mid x_i, w) \;+\; \frac{\lambda}{2}\|w\|_2^2 be the L2-penalized cross-entropy loss with weight-decay strength λ>0\lambda > 0, and let w^WD=argminwLWD(w)\hat{w}_{\mathrm{WD}} = \arg\min_w \mathcal{L}_{\mathrm{WD}}(w). Under Definition 2.1’s Gaussian prior with τ2=1/λ\tau^2 = 1/\lambda, the maximum-a-posteriori estimator w^MAP  :=  argmaxwp(wD)\hat{w}_{\mathrm{MAP}} \;:=\; \arg\max_w p(w \mid \mathcal{D}) coincides with w^WD\hat{w}_{\mathrm{WD}}.

Proof.

Maximizing a function and minimizing its negative are equivalent, so w^MAP=argminw[logp(wD)]\hat{w}_{\mathrm{MAP}} = \arg\min_w \bigl[-\log p(w \mid \mathcal{D})\bigr]. From the calculation above, with τ2=1/λ\tau^2 = 1/\lambda, logp(wD)  =  i=1nlogp(yixi,w)  +  λ2w22  +  C.-\log p(w \mid \mathcal{D}) \;=\; -\sum_{i=1}^n \log p(y_i \mid x_i, w) \;+\; \frac{\lambda}{2}\|w\|_2^2 \;+\; C. The first two terms are exactly LWD(w)\mathcal{L}_{\mathrm{WD}}(w), and CC does not depend on ww. So argminw[logp(wD)]=argminwLWD(w)=w^WD\arg\min_w \bigl[-\log p(w \mid \mathcal{D})\bigr] = \arg\min_w \mathcal{L}_{\mathrm{WD}}(w) = \hat{w}_{\mathrm{WD}}.

Three corollaries are worth pulling out, each with downstream consequences for §§3–8.

Weight decay is a Bayesian regularizer. The standard “weight decay” hyperparameter λ\lambda in deep-learning libraries has a Bayesian interpretation: λ\lambda is the inverse variance of the implicit Gaussian prior on the weights, λ=1/τ2\lambda = 1/\tau^2. A weight-decay-trained model with λ=104\lambda = 10^{-4} is the MAP estimator under a N(0,104Ip)\mathcal{N}(0, 10^4\, I_p) prior on each weight. The choice of λ\lambda is therefore a choice of prior, and “tuning λ\lambda on a validation set” is empirical-Bayes hyperparameter selection. This is the canonical home for the cross-cutting concept of weight decay — the slug never gets its own page; the concept lives here.

The MAP is the starting point for §3. The Laplace approximation builds a Gaussian posterior approximation centered at w^MAP\hat{w}_{\mathrm{MAP}}. Because w^MAP=w^WD\hat{w}_{\mathrm{MAP}} = \hat{w}_{\mathrm{WD}}, any standard PyTorch model trained with weight decay is already the center of a Laplace approximation; we just haven’t computed the surrounding curvature yet. §3 fills in that missing piece.

The cold-posterior caveat. Many BNN practitioners empirically find that tempered posteriors p(wD)1/Tp(w \mid \mathcal{D})^{1/T} with T<1T < 1 — equivalently, larger λ\lambda than the prior naturally specifies — give better-calibrated predictions than the strict Bayesian T=1T = 1. This is the “cold-posterior effect” (Wenzel et al. 2020) and indicates that the Gaussian prior is misspecified for typical training datasets: the data effectively want a stronger regularizer than the principled λ=1/τ2\lambda = 1/\tau^2 delivers. The phenomenon is one of the central open problems in BNNs and is treated formally in §8.5.

2.3 The Bernstein–von Mises asymptotic

What does the weight-space posterior look like for large nn? The classical result is the Bayesian central-limit theorem, also known as Bernstein–von Mises:

Theorem 2.3 (Bernstein–von Mises).

Let {p(w):wRp}\{p(\cdot \mid w) : w \in \mathbb{R}^p\} be a regular parametric family with pp fixed, and suppose the data Dn={(xi,yi)}i=1n\mathcal{D}_n = \{(x_i, y_i)\}_{i=1}^n are iid from p(w0)p(\cdot \mid w_0) for some interior w0w_0. Let wnw^*_n denote the MLE based on Dn\mathcal{D}_n, and let I(w0)I(w_0) denote the Fisher information matrix at the true parameter. Under standard regularity conditions (see formalstatistics’s formalStatistics: Central Limit Theorem topic), as nn \to \infty: dTV ⁣(p(n(wwn)Dn),  N ⁣(0,I(w0)1))    0d_{\mathrm{TV}}\!\Bigl(\,p\bigl(\sqrt{n}(w - w^*_n) \,\big|\, \mathcal{D}_n\bigr) \,,\; \mathcal{N}\!\bigl(0,\, I(w_0)^{-1}\bigr)\,\Bigr) \;\to\; 0 in posterior probability, where dTVd_{\mathrm{TV}} denotes total-variation distance.

The proof is the subject of formalstatistics’s central-limit-theorem topic. In words: the posterior, recentered at the MLE and rescaled by n\sqrt{n}, converges in total variation to a Normal distribution with covariance equal to the inverse Fisher information. Equivalently, for large nn the un-rescaled posterior is approximately p(wDn)    N ⁣(wn,  1nI(w0)1),p(w \mid \mathcal{D}_n) \;\approx\; \mathcal{N}\!\Bigl(w^*_n,\; \tfrac{1}{n}\, I(w_0)^{-1}\Bigr), a Gaussian centered at the MLE with covariance the inverse total Fisher information Hn1(nI(w0))1H_n^{-1} \approx (n I(w_0))^{-1}.

This is the asymptotic license for §3’s Laplace approximation. Under BvM, fitting a Gaussian centered at the MAP with the right curvature gives a posterior approximation that is exact in the nn \to \infty limit. The Laplace construction in §3 will use the observed Fisher information Hn=2logp(wDn)w=w^MAPH_n = -\nabla^2 \log p(w \mid \mathcal{D}_n) \big|_{w = \hat{w}_{\mathrm{MAP}}}, which differs from nI(w0)n I(w_0) by a fluctuation that is itself O(n)O(\sqrt{n}). The point is: there is a principled reason to expect the posterior to be approximately Gaussian centered at the MAP for large nn, provided the regularity conditions hold.

2.4 Why Bernstein–von Mises fails for neural networks

The regularity conditions for BvM are: pp fixed, model well-specified, MLE consistent and asymptotically Normal, Fisher information matrix positive-definite at w0w_0, posterior absolutely continuous with respect to the prior. These conditions fail for typical neural networks in three structural ways.

Hidden-unit permutation symmetry. Consider an MLP layer with hh hidden units, WRd×hW \in \mathbb{R}^{d \times h} input weights and bRhb \in \mathbb{R}^h biases. Permuting the columns of WW (and the corresponding entries of bb and the next layer’s row weights) gives an exactly equivalent function. So the function fwf_w is invariant under ShS_h, the symmetric group on hh elements. The likelihood p(Dw)p(\mathcal{D} \mid w) inherits this symmetry, and so does the posterior. There are at least h!h! identical modes per hidden layer. For a 32-unit layer, 32!103532! \approx 10^{35}. The posterior has more than 103510^{35} identical modes, separated by ridges of low likelihood — interpolating linearly between two modes via wt=(1t)wa+twbw_t = (1 - t) w_a + t w_b does not generally produce another mode, since the loss along the interpolation rises through a barrier. Vanilla MCMC can mix over the connected component of the posterior support that it starts in, but cannot in finite wall-clock time mix between disconnected modes. BvM’s “single Gaussian centered at the MLE” is wrong globally because there is no single MLE.

ReLU positive scaling. For ReLU activations, ReLU(cz)=cReLU(z)\mathrm{ReLU}(c\, z) = c \cdot \mathrm{ReLU}(z) for any c>0c > 0. So if we multiply the input weights to a hidden unit by c>0c > 0 and divide the outgoing weights by cc, the function is unchanged: fwf_w is invariant under the action of the multiplicative group (R>0)h(\mathbb{R}_{>0})^h on weights. This is a continuous symmetry, not a discrete one, and it means that even within a single permutation class the MLE is not isolated — it sits on a hh-dimensional submanifold of equally-good weights. The Fisher information matrix is singular along this submanifold (the gradient of the log-likelihood is zero in those directions because the function does not change), so BvM’s invocation of I(w0)1I(w_0)^{-1} as a positive-definite covariance fails. Weight decay breaks the symmetry by penalizing rescalings — cwin2+wout/c2\|c\, w_{\mathrm{in}}\|^2 + \|w_{\mathrm{out}} / c\|^2 is minimized at a specific cc — so the MAP is locally unique even when the MLE is not. But the underlying likelihood is not strongly identified.

Over-parametrization. Modern deep networks have pnp \gg n. ResNet-50 has  ⁣25×106\sim\!25 \times 10^6 parameters; ImageNet has  ⁣106\sim\!10^6 training images. BvM operates in the regime pp fixed, nn \to \infty. The opposite regime — pp \to \infty with nn fixed or growing slowly — is the regime of high-dimensional statistics, and the asymptotics there are different. The infinite-width limit of §9 makes one specific version of this regime tractable (function-space rather than weight-space), but for finite-width over-parametrized networks, BvM gives no asymptotic guidance.

These three failure modes organize the rest of the topic. Laplace (§3) gives up on the global picture and approximates the posterior locally at the MAP — accepting the failure of BvM globally and pursuing a local Gaussian that captures local curvature even if it misses the h!h! permutation copies and the ReLU rescaling submanifold. MC-dropout (§4) uses a tractable but rigid variational family that sidesteps the exact-posterior question entirely. Deep ensembles (§5) sample KK approximate samples from KK different initializations, getting some coverage of distinct modes. SG-MCMC (§§6–7) accepts the multimodality and tries to mix between modes via Langevin dynamics on minibatches.

Remark (Why naïve sampling fails).

Three obstacles in compact form, as a checklist for §§3–7. (i) Multimodality: h!h! exact replicas per layer; vanilla MCMC cannot mix in any practical wall-clock time. (ii) Identifiability: ReLU positive-scaling submanifold; Fisher information singular within each mode under the un-regularized likelihood. (iii) Computational cost: full-data gradients infeasible at deep-learning scale; vanilla HMC needs the gradient at every leapfrog step, an O(n)O(n) pass. Each method in §§3–7 sidesteps a subset of these obstacles via a specific technical trick.

Two panels: panel (a) PCA scatter of trained MLP weight vectors projected to their first two principal components, color-coded by final training loss; panel (b) loss along a linear interpolation between two trained models, rising from the trained-model floor to a peak barrier and back down — a non-convex ridge separating two modes.
Training pool of 30 MLPs (~15 s)…
Figure 2. The loss landscape is genuinely multi-modal. (a) N MLPs from independent seeds project into discrete clusters in PCA space — each cluster a permutation/scaling class of §2.4’s hidden-unit symmetry. (b) Linearly interpolating between two trained models passes through a barrier of strictly higher loss. Click any two points in (a) to set the endpoints; a single Gaussian (Laplace, mean-field VI) cannot bridge that ridge — which is why §5 (deep ensembles) and §§6–7 (SG-MCMC) exist.

3. The Laplace approximation

§2 ends with a problem and a permission. The problem is that the true weight-space posterior p(wD)p(w \mid \mathcal{D}) is intractable: no closed-form normalizer, more than h!h! identical modes, and continuous symmetries that make the Fisher information singular within each mode. The permission is Bernstein–von Mises: under the regularity conditions the posterior is asymptotically a Gaussian centered at the MLE. The Laplace approximation accepts that the BvM regularity conditions don’t strictly hold globally but observes that they hold locally at any sufficiently regular minimum of the negative log-posterior — and builds a Gaussian approximation by Taylor-expanding around such a minimum. The construction goes back to Laplace’s 1774 calculation of the asymptotic value of enh(θ)dθ\int e^{-n h(\theta)}\,d\theta; MacKay’s 1992 PhD thesis brought it into Bayesian neural networks; the modern revival is Daxberger et al. 2021 (laplace-torch), which scales the construction to ImageNet-class networks via Kronecker-factored curvature.

3.1 The construction

Let U(w):=logp(wD)U(w) := -\log p(w \mid \mathcal{D}) be the negative log-posterior, modulo the unknown constant logp(D)\log p(\mathcal{D}). From §2.2 we have U(w)  =  i=1nlogp(yixi,w)  +  12τ2w22  +  const.U(w) \;=\; -\sum_{i=1}^n \log p(y_i \mid x_i, w) \;+\; \frac{1}{2\tau^2}\|w\|_2^2 \;+\; \mathrm{const}. Suppose w^:=w^MAP=argminwU(w)\hat{w} := \hat{w}_{\mathrm{MAP}} = \arg\min_w U(w) is a local minimum at which UU is twice continuously differentiable and the Hessian H:=2U(w)w=w^H := \nabla^2 U(w)|_{w = \hat{w}} is positive-definite. Taylor-expand UU to second order around w^\hat{w}: U(w)  =  U(w^)  +  U(w^)(ww^)  +  12(ww^)H(ww^)  +  O(ww^3).U(w) \;=\; U(\hat{w}) \;+\; \nabla U(\hat{w})^\top (w - \hat{w}) \;+\; \tfrac{1}{2}(w - \hat{w})^\top H (w - \hat{w}) \;+\; O(\|w - \hat{w}\|^3). At a local minimum U(w^)=0\nabla U(\hat{w}) = 0, so the linear term vanishes. Dropping the higher-order terms — the Laplace approximation itself — we get U(w)    U(w^)  +  12(ww^)H(ww^).U(w) \;\approx\; U(\hat{w}) \;+\; \tfrac{1}{2}(w - \hat{w})^\top H (w - \hat{w}). Exponentiating and renormalizing, p(wD)    q(w)  :=  N(ww^,H1).p(w \mid \mathcal{D}) \;\approx\; q(w) \;:=\; \mathcal{N}\bigl(w \,\big|\, \hat{w},\, H^{-1}\bigr). The construction takes a Gaussian whose mean is the MAP and whose precision is the local curvature of the negative log-posterior. Two further facts come for free.

The Hessian decomposes additively. From the form of UU, H  =  w2logp(Dw)w=w^  +  1τ2Ip  =:  Hdata+Hprior,H \;=\; -\nabla^2_w \log p(\mathcal{D} \mid w)\big|_{w = \hat{w}} \;+\; \frac{1}{\tau^2} I_p \;=:\; H_{\mathrm{data}} + H_{\mathrm{prior}}, the data Hessian plus the prior precision. The prior contribution is constant in ww, isotropic, and explicitly positive-definite; it stabilizes the data Hessian when the latter has near-zero or negative eigenvalues from saddle directions or the §2.4 ReLU rescaling submanifold. In practice we either compute HH directly or compute HdataH_{\mathrm{data}} and add τ2Ip\tau^{-2} I_p explicitly.

The marginal likelihood comes for free as a side effect. The Laplace approximation also gives a closed-form estimate of p(D)p(\mathcal{D}): logp(D)    U(w^)  +  p2log(2π)    12logdetH,\log p(\mathcal{D}) \;\approx\; -U(\hat{w}) \;+\; \frac{p}{2}\log(2\pi) \;-\; \frac{1}{2}\log\det H, known as the Laplace marginal-likelihood approximation. This is the basis of Variational Bayes for Model Selection and the BIC’s asymptotic form (cross-link to formalstatistics’s formalStatistics: Model Selection and Information Criteria ).

Definition 3.1 (Laplace approximation).

Given a twice-differentiable negative log-posterior U(w)U(w) with positive-definite Hessian H=2U(w^)H = \nabla^2 U(\hat{w}) at a local minimum w^\hat{w}, the Laplace approximation to the posterior is the Gaussian qLap(w)  :=  N(ww^,H1).q_{\mathrm{Lap}}(w) \;:=\; \mathcal{N}\bigl(w \,\big|\, \hat{w},\, H^{-1}\bigr).

3.2 Asymptotic exactness — and where it stops

Theorem 3.2 (Asymptotic exactness of Laplace under BvM).

Suppose the conditions of Theorem 2.3 hold. Let w^n\hat{w}_n be the MAP based on nn iid observations and HnH_n the negative-log-posterior Hessian at w^n\hat{w}_n. Let qLap(n)(w)=N(ww^n,Hn1)q_{\mathrm{Lap}}^{(n)}(w) = \mathcal{N}(w \,|\, \hat{w}_n, H_n^{-1}) be the Laplace approximation. Then dTV ⁣(p(wDn),  qLap(n)(w))    0as nd_{\mathrm{TV}}\!\bigl(p(w \mid \mathcal{D}_n),\; q_{\mathrm{Lap}}^{(n)}(w)\bigr) \;\to\; 0 \quad \text{as } n \to \infty in posterior probability.

Proof.

Two ingredients combine. First, w^n=wn+O(1/n)\hat{w}_n = w^*_n + O(1/n), where wnw^*_n is the MLE: differentiating Un(w)=logp(Dnw)logp(w)U_n(w) = -\log p(\mathcal{D}_n \mid w) - \log p(w) and setting to zero gives [logp(Dnw)]w^n=logp(w)w^n\nabla[-\log p(\mathcal{D}_n \mid w)]\big|_{\hat{w}_n} = -\nabla \log p(w)\big|_{\hat{w}_n}. The right-hand side is w^n/τ2\hat{w}_n / \tau^2 for the Gaussian prior, so [NLL]w^n=w^n/τ2\nabla[\text{NLL}]\big|_{\hat{w}_n} = \hat{w}_n/\tau^2. The MLE satisfies [NLL]wn=0\nabla[\text{NLL}]\big|_{w^*_n} = 0, so w^nwn\hat{w}_n - w^*_n solves the first-order condition 2[NLL]wn(w^nwn)=wn/τ2+O(w^nwn2)\nabla^2[\text{NLL}]\big|_{w^*_n}(\hat{w}_n - w^*_n) = -w^*_n/\tau^2 + O(\|\hat{w}_n - w^*_n\|^2), which gives w^nwn=O(Hn1)=O(1/n)\hat{w}_n - w^*_n = O(\|H_n\|^{-1}) = O(1/n) since the data Hessian scales as nn.

Second, Hn/nI(w0)H_n / n \to I(w_0) in probability by the law of large numbers applied to the per-observation Fisher information contributions. So Hn1(nI(w0))1H_n^{-1} \to (n I(w_0))^{-1} at first order.

Putting these together, N(w^n,Hn1)N(wn,(nI(w0))1)\mathcal{N}(\hat{w}_n, H_n^{-1}) \to \mathcal{N}(w^*_n, (n I(w_0))^{-1}) in total variation. Theorem 2.3 says the true posterior also converges to this limit. By the triangle inequality for dTVd_{\mathrm{TV}}, the Laplace approximation converges to the true posterior.

The §2.4 caveats apply with full force here. The proof above assumes (i) pp fixed, (ii) HnH_n positive-definite at w^n\hat{w}_n, and (iii) the BvM regularity conditions. For neural networks: pp is not fixed in any meaningful sense (over-parametrization), HdataH_{\mathrm{data}} has near-zero eigenvalues from the ReLU rescaling submanifold (we rescue this with Hprior=τ2IpH_{\mathrm{prior}} = \tau^{-2} I_p, but the result is a Laplace approximation around an artificially regularized minimum), and the global posterior has many modes that the Laplace approximation around any one of them cannot represent. So Theorem 3.2 should be read as a local guarantee that the Gaussian fit is the best second-order approximation around w^\hat{w}, plus an asymptotic guarantee that this best approximation converges to the truth in well-specified parametric settings. For BNNs we get the local part and lose most of the asymptotic part.

This is enough for a useful method. The Laplace BNN’s predictive mean recovers the point estimate, its predictive variance grows away from the data exactly because the local quadratic in UU flattens away from w^\hat{w}, and the construction needs no sampling at all once w^\hat{w} and HH are in hand. What it loses — multi-modality, ReLU-rescaling spread — is exactly what §§5–7 will recover.

3.3 The predictive distribution under linearization

A Gaussian posterior over weights does not directly give a Gaussian posterior over outputs, because fw(x)f_w(x) is a nonlinear function of ww. The integral p(yx,D)  =  p(yx,w)qLap(w)dwp(y^* \mid x^*, \mathcal{D}) \;=\; \int p(y^* \mid x^*, w)\,q_{\mathrm{Lap}}(w)\,dw has no closed form. The standard reduction is first-order linearization of the network around the MAP: fw(x)    fw^(x)  +  J(x)(ww^),J(x):=wfw(x)w=w^.f_w(x) \;\approx\; f_{\hat{w}}(x) \;+\; J(x)^\top (w - \hat{w}), \qquad J(x) := \nabla_w f_w(x)\big|_{w = \hat{w}}. Under this linearization, fw(x)f_w(x) is an affine function of ww. If wN(w^,H1)w \sim \mathcal{N}(\hat{w}, H^{-1}), then the logit fw(x)f_w(x) is Gaussian with mean fw^(x)f_{\hat{w}}(x) and variance J(x)H1J(x)J(x)^\top H^{-1} J(x).

Proposition 3.3 (Linearized Laplace predictive).

Define m~(x):=fw^(x)\tilde{m}(x) := f_{\hat{w}}(x) and v~(x):=J(x)H1J(x)\tilde{v}(x) := J(x)^\top H^{-1} J(x). Under the linearized Laplace approximation, the logit at xx is approximately Gaussian: fw(x)D  lin  N(m~(x),v~(x)).f_w(x) \mid \mathcal{D} \;\stackrel{\mathrm{lin}}{\sim}\; \mathcal{N}\bigl(\tilde{m}(x),\, \tilde{v}(x)\bigr). The class-1 probability under the moment-matched probit approximation is p^(y=1x,D)    σ ⁣(m~(x)1+πv~(x)/8).\hat{p}(y = 1 \mid x, \mathcal{D}) \;\approx\; \sigma\!\left(\frac{\tilde{m}(x)}{\sqrt{1 + \pi \tilde{v}(x) / 8}}\right).

The factor π/8\pi/8 comes from MacKay’s 1992 Gaussian-CDF approximation to the sigmoid: σ(z)N(z;m,v)dzσ(m/1+πv/8)\int \sigma(z) \mathcal{N}(z; m, v)\,dz \approx \sigma(m / \sqrt{1 + \pi v / 8}). The reader sees the predictive variance v~(x)\tilde{v}(x) enter explicitly: where v~(x)\tilde{v}(x) is small (data-dense regions, where the network is well-determined), the predictive collapses to σ(m~(x))\sigma(\tilde{m}(x)) — the point estimate from §1. Where v~(x)\tilde{v}(x) is large (off-distribution regions, where small changes in ww produce large changes in fw(x)f_w(x)), the predictive is squashed toward 0.50.5 — the model’s confession of ignorance.

In practice we have two implementation choices for the predictive: (a) closed-form via Prop 3.3, computing J(x)J(x) at each test point and the matrix product J(x)H1J(x)J(x)^\top H^{-1} J(x); (b) Monte Carlo, sampling w(s)N(w^,H1)w^{(s)} \sim \mathcal{N}(\hat{w}, H^{-1}) and averaging σ(fw(s)(x))\sigma(f_{w^{(s)}}(x)) over s=1,,Ss = 1, \ldots, S samples. For the §3.5 Two Moons figure we use (b) because it lets us also visualize sample decision boundaries; in production BNN libraries the closed-form (a) is faster.

3.4 Practical curvature

Computing and storing HRp×pH \in \mathbb{R}^{p \times p} for pp in the millions is impossible — p2p^2 memory, O(p3)O(p^3) inversion. Three standard reductions trade fidelity for tractability.

Last-layer Laplace (Daxberger et al. 2021). Fix all but the final linear layer’s weights at their MAP values; do Laplace only over the last layer. The last layer typically has pLpp_L \ll p parameters — for a 32-unit penultimate layer × CC output classes, pL32Cp_L \approx 32 C. Hlast-layerRpL×pLH_{\mathrm{last\text{-}layer}} \in \mathbb{R}^{p_L \times p_L} is cheap. Empirically competitive with full Laplace on a wide range of tasks; the underlying intuition is that in over-parametrized networks, most of the weight-space uncertainty that affects predictions is concentrated in the final layer.

Kronecker-factored approximate curvature (KFAC; Martens & Grosse 2015). For each layer, approximate the Fisher information block as a Kronecker product HAGH_\ell \approx A_\ell \otimes G_\ell, where AA_\ell is the input-activation second moment and GG_\ell is the output-gradient second moment. Storage cost reduces from p2p_\ell^2 to dim(A)2+dim(G)2\dim(A_\ell)^2 + \dim(G_\ell)^2 per layer. The Kronecker assumption is structural (it ignores cross-layer covariance entirely), but matches the block-diagonal-plus-Kronecker structure of the Fisher information of natural-gradient methods.

Diagonal Fisher. Take only the diagonal of HH. Cheapest possible — O(p)O(p) memory — but discards all weight correlations. Often too aggressive: the diagonal Hessian eigenvalues are unreliable estimates of the full spectrum, and diagonal-Fisher Laplace BNNs frequently underestimate predictive variance in correlated directions.

For the §3.5 Two Moons example with p=2241p = 2241 parameters, full Laplace is feasible. For real-world BNNs, last-layer or KFAC is the practical choice; the laplace-torch library (Daxberger et al. 2021) implements all three reductions and is the default tool.

3.5 Algorithm and worked example

Algorithm 3.4 (Laplace-fit BNN).

Input: training data D\mathcal{D}, network fwf_w, prior scale τ2\tau^2, sample count SS. Output: predictive distribution p^(yx)\hat{p}(y \mid x) at test inputs.

  1. Compute w^:=argminwU(w)\hat{w} := \arg\min_w U(w) via Adam or SGD with weight decay λ=1/τ2\lambda = 1/\tau^2.
  2. Compute the data Hessian Hdata:=2logp(Dw)w=w^H_{\mathrm{data}} := -\nabla^2 \log p(\mathcal{D} \mid w)|_{w = \hat{w}} via autodiff (or one of the §3.4 reductions).
  3. Form H:=Hdata+τ2IpH := H_{\mathrm{data}} + \tau^{-2} I_p and stabilize: HH+δIpH \leftarrow H + \delta I_p for small δ>0\delta > 0 if any eigenvalue is non-positive.
  4. Cholesky-factor H=LLH = L L^\top.
  5. Sample z(s)N(0,Ip)z^{(s)} \sim \mathcal{N}(0, I_p) for s=1,,Ss = 1, \ldots, S; set w(s):=w^+Lz(s)w^{(s)} := \hat{w} + L^{-\top} z^{(s)}.
  6. Predict via p^(yx):=S1sp(yx,w(s))\hat{p}(y \mid x) := S^{-1} \sum_s p(y \mid x, w^{(s)}), or via Prop 3.3’s closed-form.

The §1 first-seed model serves as w^\hat{w}. Step 2 uses torch.autograd.functional.hessian on the negative-log-posterior of the full training set; for p=2241p = 2241 this is  ⁣2\sim\!2 s on CPU. The τ2Ip\tau^{-2} I_p contribution is automatic (we trained with weight decay λ=104\lambda = 10^{-4}, so τ2=104\tau^2 = 10^4 and the prior precision contribution is 104Ip10^{-4} I_p). Stabilization δ=103\delta = 10^{-3} handles the residual ReLU-rescaling near-singularity. The Cholesky factorization on a 2241×22412241 \times 2241 matrix is  ⁣1\sim\!1 s; sampling S=100S = 100 posteriors and predicting on the 40,00040{,}000-point grid is  ⁣3\sim\!3 s.

Three panels on Two Moons data: panel (a) Laplace-BNN predictive mean RdBu_r heatmap visually indistinguishable from the §1 point-estimate predictive; panel (b) Laplace predictive standard deviation viridis heatmap with dark regions hugging the data and bright regions far from any training point; panel (c) sampled Laplace 0.5-probability decision boundaries overlaid in translucent blue.
Training MAP for Laplace approximation (~5 s)…
Figure 3. The Laplace BNN. (a) The predictive mean is the point estimate — Laplace doesn’t move the MAP. (b) The predictive standard deviation grows away from the data, recovering the §1 desideratum from a single trained model’s local Gaussian. (c) Sampled decision boundaries fan out within one mode of the loss landscape, missing the multi-mode structure §5 will recover. Slide τ² (prior scale, log axis) to widen or sharpen the posterior; slide S to trade compute for sampling smoothness; reseed for a fresh sample set.

4. MC-dropout as approximate variational inference

The Laplace approximation of §3 builds a Gaussian posterior over weights from a single point estimate plus its local curvature. The construction is principled but expensive once the network is large — full Laplace needs p2p^2 Hessian storage, and even the §3.4 reductions (last-layer, KFAC) require non-trivial implementation effort. MC-dropout (Gal & Ghahramani 2016) is the cheap end of the BNN spectrum: it observes that a dropout-regularized network, trained the standard way, is already doing variational inference under a specific variational family, and that the only change needed at test time to extract a predictive distribution is to leave dropout on and average over TT stochastic forward passes. No new training, no Hessian, no extra hyperparameters beyond what dropout already requires. The trade is that the variational family is rigid in a way that systematically underestimates epistemic uncertainty in some regimes; §8’s calibration analysis quantifies the trade.

4.1 Recap: dropout as a regularizer

Dropout (Srivastava et al. 2014) trains a network by, at each minibatch, sampling a Bernoulli mask b{0,1}h1b_\ell \in \{0, 1\}^{h_{\ell-1}} for each layer’s input activations with b,jBernoulli(ρ)b_{\ell, j} \sim \mathrm{Bernoulli}(\rho_\ell), multiplying activations elementwise by bb_\ell and rescaling by 1/ρ1/\rho_\ell so the post-mask expected activation is unchanged. The forward pass uses the masked activations; the backward pass differentiates only through the unmasked units. At test time, dropout is conventionally turned off — the masks are replaced by the mean activation ρ\rho_\ell — and the resulting deterministic prediction is interpreted as an approximate average over the implicit ensemble of “thinned” subnetworks.

The standard story stops here: dropout is a regularizer, the deterministic test-time prediction is the answer, full stop. Gal & Ghahramani 2016 disrupt this by proving that the standard dropout training loss is, up to constants and a fixed prior choice, the negative ELBO of a specific variational posterior — and that to extract that posterior’s predictive distribution at test time, you don’t turn dropout off, you leave it on and Monte-Carlo over TT stochastic forward passes.

4.2 The Bernoulli variational family

Let the network have LL weight matrices WRh1×hW_\ell \in \mathbb{R}^{h_{\ell-1} \times h_\ell}, =1,,L\ell = 1, \ldots, L. For each layer, define a Bernoulli mask matrix ZZ_\ell as a diagonal matrix with Bern(ρ)\mathrm{Bern}(\rho_\ell) entries on the diagonal: Z=diag(b),b{0,1}h1,b,jiidBernoulli(ρ).Z_\ell = \mathrm{diag}(b_\ell), \qquad b_\ell \in \{0, 1\}^{h_{\ell-1}}, \qquad b_{\ell, j} \stackrel{\mathrm{iid}}{\sim} \mathrm{Bernoulli}(\rho_\ell). The variational weight matrix at layer \ell is the random matrix Wvar:=MZW_\ell^{\mathrm{var}} := M_\ell Z_\ell, where MRh1×hM_\ell \in \mathbb{R}^{h_{\ell-1} \times h_\ell} is a learnable mean weight matrix. Equivalently, the jj-th column of WvarW_\ell^{\mathrm{var}} is either the jj-th column of MM_\ell (with probability ρ\rho_\ell) or the zero column (with probability 1ρ1 - \rho_\ell).

Definition 4.1 (Bernoulli variational family).

Let ρ=(ρ1,,ρL)(0,1]L\boldsymbol{\rho} = (\rho_1, \ldots, \rho_L) \in (0, 1]^L be a vector of dropout retain-probabilities, fixed in advance. The Bernoulli variational family over network weights is QBern(ρ)  =  {qM(W1,,WL)  :  W=Mdiag(b),    b,jiidBern(ρ)}.\mathcal{Q}_{\mathrm{Bern}}(\boldsymbol{\rho}) \;=\; \Bigl\{\, q_{\boldsymbol{M}}(W_1, \ldots, W_L) \;:\; W_\ell = M_\ell \,\mathrm{diag}(b_\ell),\;\; b_{\ell, j} \stackrel{\mathrm{iid}}{\sim} \mathrm{Bern}(\rho_\ell) \,\Bigr\}. The variational parameters are the mean weight matrices M=(M1,,ML)\boldsymbol{M} = (M_1, \ldots, M_L).

This is a discrete variational family. At each evaluation, WW_\ell is one of 2h12^{h_{\ell-1}} possible matrices, parametrized by which subset of MM_\ell‘s columns are zeroed. The family is rigid: ρ\rho_\ell is a fixed hyperparameter, not learned (Concrete dropout in §4.5 lifts this), and the structure of the variational posterior is determined entirely by the network’s architecture and the choice of which activations to drop.

4.3 The Gal–Ghahramani equivalence

Theorem 4.2 (Dropout training is variational inference).

Let Ldrop(M)\mathcal{L}_{\mathrm{drop}}(\boldsymbol{M}) be the standard dropout training objective: minibatch-averaged cross-entropy of the network with Bernoulli-masked weights W=Mdiag(b)W_\ell = M_\ell \,\mathrm{diag}(b_\ell), with L2L_2 weight decay λ\lambda on the mean weights M\boldsymbol{M}. Suppose the prior over weights is the layer-wise isotropic Gaussian p(M)=N(0,τ2I)p(M_\ell) = \mathcal{N}(0, \tau^2 I) with τ2=(1ρ)/(2Nλ)\tau^2 = (1 - \rho_\ell) / (2 N \lambda), where NN is the size of the training set. Then Ldrop(M)  =  ELBO(qMp(D))  +  const,\mathcal{L}_{\mathrm{drop}}(\boldsymbol{M}) \;=\; -\mathrm{ELBO}\bigl(q_{\boldsymbol{M}} \,\big\|\, p(\cdot \mid \mathcal{D})\bigr) \;+\; \mathrm{const}, where the ELBO is the variational evidence lower bound of qMQBern(ρ)q_{\boldsymbol{M}} \in \mathcal{Q}_{\mathrm{Bern}}(\boldsymbol{\rho}) against the posterior p(MD)p(\boldsymbol{M} \mid \mathcal{D}) of Definition 2.1’s BNN with the prior above. Minimizing Ldrop(M)\mathcal{L}_{\mathrm{drop}}(\boldsymbol{M}) over M\boldsymbol{M} is equivalent to maximizing the ELBO, hence to reverse-KL minimization within QBern(ρ)\mathcal{Q}_{\mathrm{Bern}}(\boldsymbol{\rho}).

Proof.

The ELBO decomposes (cf. Variational Inference) as ELBO(q)  =  Eq[logp(Dw)]    KL(qp).\mathrm{ELBO}(q) \;=\; \mathbb{E}_q[\log p(\mathcal{D} \mid w)] \;-\; \mathrm{KL}\bigl(q \,\big\|\, p\bigr). For q=qMQBern(ρ)q = q_{\boldsymbol{M}} \in \mathcal{Q}_{\mathrm{Bern}}(\boldsymbol{\rho}), the expected log-likelihood is the expectation over Bernoulli masks of the network’s log-likelihood at the masked weights: EqM[logp(Dw)]  =  Eb1,,bL ⁣[i=1Nlogp(yixi,M1diag(b1),,MLdiag(bL))].\mathbb{E}_{q_{\boldsymbol{M}}}[\log p(\mathcal{D} \mid w)] \;=\; \mathbb{E}_{b_1, \ldots, b_L}\!\left[\sum_{i=1}^N \log p(y_i \mid x_i, M_1 \mathrm{diag}(b_1), \ldots, M_L \mathrm{diag}(b_L))\right]. Standard dropout training Monte-Carlo-estimates this expectation by drawing one mask per minibatch, which is the unbiased single-sample estimator of the inner expectation.

For the KL term, qMq_{\boldsymbol{M}} has all of its mass on points of the form W=Mdiag(b)W_\ell = M_\ell \mathrm{diag}(b_\ell). The KL of qMq_{\boldsymbol{M}} to a continuous Gaussian prior p(W)=N(0,τ2I)p(W_\ell) = \mathcal{N}(0, \tau^2 I) is technically infinite (a discrete distribution against a continuous one), so Gal & Ghahramani treat the variational family as a limit of Gaussian approximations whose covariance shrinks to zero around the discrete support points. After the algebra (Gal & Ghahramani 2016 Appendix), the KL term reduces, up to additive constants in M\boldsymbol{M}, to KL(qMp)  =  =1L1ρ2τ2MF2  +  const.\mathrm{KL}\bigl(q_{\boldsymbol{M}} \,\big\|\, p\bigr) \;=\; \sum_{\ell = 1}^L \frac{1 - \rho_\ell}{2\tau^2}\,\|M_\ell\|_F^2 \;+\; \mathrm{const}. With τ2=(1ρ)/(2Nλ)\tau^2 = (1 - \rho_\ell)/(2 N \lambda) this becomes NλMF2N \lambda \sum_\ell \|M_\ell\|_F^2, which is the L2L_2-weight-decay penalty in the standard dropout loss. Putting the two pieces together, ELBO(qM)-\mathrm{ELBO}(q_{\boldsymbol{M}}) equals Ldrop(M)\mathcal{L}_{\mathrm{drop}}(\boldsymbol{M}) up to a constant, as claimed.

The theorem’s content is that the same algorithm — minibatch SGD on cross-entropy plus L2L_2, with Bernoulli masks at each forward pass — can be read in two ways. The frequentist reads it as “regularized empirical-risk minimization with dropout noise”; the Bayesian reads it as “variational inference under the Bernoulli family.” Both readings yield the same trained network. The Bayesian reading does, however, instruct us to do something different at test time.

4.4 The MC-dropout predictive

The standard test-time recipe — turn dropout off and use the deterministic mean network — corresponds, in the variational reading, to evaluating fM(x)f_{\boldsymbol{M}}(x) at the mean weights, which is the MAP-style point estimate under qMq_{\boldsymbol{M}}. To get a Bayesian predictive distribution we leave dropout on and Monte-Carlo:

Proposition 4.3 (MC-dropout predictive).

Let M\boldsymbol{M} be the mean weights of a trained dropout network and let f^t(x)\hat{f}_t(x) denote the output at xx produced by a stochastic forward pass with a fresh set of Bernoulli masks drawn at every layer, t=1,,Tt = 1, \ldots, T. Then under the Bernoulli variational posterior qMQBern(ρ)q_{\boldsymbol{M}} \in \mathcal{Q}_{\mathrm{Bern}}(\boldsymbol{\rho}): p^(y=1x,D)    1Tt=1Tσ(f^t(x)),\hat{p}(y = 1 \mid x, \mathcal{D}) \;\approx\; \frac{1}{T} \sum_{t=1}^T \sigma\bigl(\hat{f}_t(x)\bigr), Var^qM[σ(fw(x))]    1Tt=1Tσ(f^t(x))2(1Tt=1Tσ(f^t(x)))2.\widehat{\mathrm{Var}}_{q_{\boldsymbol{M}}}[\sigma(f_w(x))] \;\approx\; \frac{1}{T}\sum_{t=1}^T \sigma\bigl(\hat{f}_t(x)\bigr)^2 - \left(\frac{1}{T}\sum_{t=1}^T \sigma\bigl(\hat{f}_t(x)\bigr)\right)^2. Both estimators converge in TT at the standard O(T1/2)O(T^{-1/2}) Monte Carlo rate.

Proof.

Direct application of the law of large numbers to iid samples {f^t(x)}t=1T\{\hat{f}_t(x)\}_{t=1}^T from the variational predictive. The first identity is the sample mean of the bounded random variable σ(fw(x))[0,1]\sigma(f_w(x)) \in [0, 1] under qMq_{\boldsymbol{M}}, which has variance at most 1/41/4, so by the CLT the error is O(T1/2)O(T^{-1/2}). The second identity is the corresponding sample variance, which converges at the same rate.

In code, this is two lines: keep the model in train() mode (which leaves dropout active) at test time and run TT forward passes. The single hyperparameter is TT. Practical choice: T[10,100]T \in [10, 100] for visual-quality predictive heatmaps, T50T \geq 50 for stable calibration-metric estimates (§8).

4.5 Limits and extensions

The Bernoulli family is rigid in a specific way: each weight column is either kept exactly (with probability ρ\rho) or zeroed exactly (with probability 1ρ1 - \rho). The family cannot represent continuous deviations from the mean weights — it has zero probability mass on weight matrices that are slightly perturbed from MM_\ell but not exactly some Bernoulli-masked version. Compared to the §3 Laplace family (a full Gaussian with covariance H1H^{-1}), MC-dropout has dramatically less expressive power. This shows up in three places.

Remark (Limits of MC-dropout).

Underestimated epistemic variance in over-parametrized regimes. When the network has many redundant units, dropout’s mask-induced variance is largely cancelled by redundancy — the predictive variance saturates at a value determined by the dropout rate, not by how far the test point is from the data. Empirical studies (Foong et al. 2020, “On the Expressiveness of Approximate Inference in Bayesian Neural Networks”) show MC-dropout predictive variance converges to a constant in the network width even in regions where the true posterior predictive variance grows.

No multi-modality. The Bernoulli family, like the §3 Laplace family, parametrizes a single mode of the loss landscape. The §2.4 hidden-unit permutation symmetry is invisible to MC-dropout — the variational posterior captures noise around the trained mean weights, not separated modes. §5 deep ensembles will recover multi-modality; MC-dropout cannot.

Concrete dropout (Gal et al. 2017) lifts the rigid-ρ\rho limitation by making the dropout rates ρ\rho_\ell learnable parameters optimized jointly with M\boldsymbol{M} via a continuous relaxation of the Bernoulli (the Concrete distribution, Maddison et al. 2017). Variational dropout (Kingma et al. 2015) replaces Bernoulli masks with continuous Gaussian multiplicative noise, enabling weight-correlation modeling at the cost of more careful KL accounting. Structured dropout (DropConnect, DropBlock for CNNs, attention-head dropout for transformers) adapts the Bernoulli family to architecture-specific symmetries.

Three panels on Two Moons data: panel (a) MC-dropout predictive mean RdBu_r heatmap visually similar to a deterministic predictor; panel (b) MC-dropout predictive standard deviation viridis heatmap with flatter off-distribution std than a full-Hessian Laplace; panel (c) twenty MC-dropout sampled 0.5-probability contours overlaid in translucent green.
Loading…
Figure 4. MC-dropout. (a) The predictive mean is essentially the point-estimate predictor — turning dropout off at test confirms it. (b) The predictive standard deviation captures epistemic uncertainty qualitatively but is flatter off-distribution than Laplace. (c) Sampled decision boundaries fan out narrowly. Drag the dropout-rate slider to retrain (1−ρ from 0 retains a deterministic predictor; large 1−ρ over-regularizes); slide T to trade compute for variance smoothness; toggle dropout-on-test to compare deterministic vs MC predictions side by side.

5. Deep ensembles as a function-space posterior proxy

The §3 Laplace approximation and §4 MC-dropout share a structural limitation: each parametrizes a single mode of the loss landscape. Laplace’s Gaussian is centered at one MAP; MC-dropout’s Bernoulli posterior puts its mass around one mean weight matrix M\boldsymbol{M}. Neither captures the §2.4 hidden-unit permutation symmetry, the §2.4 ReLU rescaling submanifold, or any of the genuinely separated modes that §2’s Figure 2 made visible. Deep ensembles (Lakshminarayanan, Pritzel & Blundell 2017) take the opposite tack: don’t fit a parametric posterior at all. Instead, train KK networks from independent random initializations, treat the KK trained weight vectors as approximate samples from KK different modes of the posterior, and report the ensemble’s predictive as an equal-weighted mixture. The construction is brutally simple — the entire method is “train more times” — but it consistently outperforms more sophisticated single-mode methods on calibration and out-of-distribution detection benchmarks. Wilson & Izmailov 2020 articulate the why: deep ensembles are a coarse but legitimate function-space approximation to the Bayesian posterior predictive, in a regime where the function-space view dominates the weight-space view.

5.1 The construction

Definition 5.1 (Deep ensemble).

Let KNK \in \mathbb{N} be an ensemble size. A deep ensemble of size KK is a collection {w^(k)}k=1K\{\hat{w}^{(k)}\}_{k=1}^K of MAP estimates obtained by training the same model architecture from KK independent random initializations on the same training data D\mathcal{D}. The ensemble predictive distribution is the equal-weighted mixture p^ens(yx,D)  :=  1Kk=1Kp(yx,w^(k)).\hat{p}_{\mathrm{ens}}(y \mid x, \mathcal{D}) \;:=\; \frac{1}{K}\sum_{k=1}^K p(y \mid x, \hat{w}^{(k)}).

Two technical points are worth pulling out before the theory.

First, random initialization is essential — without it, the KK networks are identical. The standard recipe in PyTorch is to draw w^0(k)\hat{w}^{(k)}_{0} \sim Kaiming-normal (or Xavier-normal, for older networks), with the random seed varied across kk. The data and the optimizer are otherwise identical: same minibatches, same learning rate, same weight decay, same number of epochs.

Second, the ensemble is heterogeneous in function space even when each member is well-trained. Two networks trained from different initializations on the same data converge to different points in weight space (§2.4’s argument that there is no unique MLE), and although they typically agree on the training data, they extrapolate differently — exactly the §1 observation that motivated the topic.

The construction has no hyperparameters beyond KK. No Hessians, no variational parameters, no Bernoulli rates, no learning-rate schedules for sampling. It is the simplest possible BNN method and, as §8 will show, often the most accurate.

5.2 The function-space-posterior interpretation

Why does training KK models with different initializations approximate the Bayesian posterior predictive? The argument has two layers.

Theorem 5.2 (Mode-collapse limit of the deep ensemble).

Suppose the weight-space posterior p(wD)p(w \mid \mathcal{D}) is supported on KK disjoint regions R1,,RK\mathcal{R}_1, \ldots, \mathcal{R}_K of equal posterior mass 1/K1/K, and that within each region Rk\mathcal{R}_k the posterior collapses to a Dirac at a point w(k)w^{(k)}_* — i.e., p(wD)=K1k=1Kδ(ww(k))p(w \mid \mathcal{D}) = K^{-1} \sum_{k=1}^K \delta(w - w^{(k)}_*). Suppose further that each independent training run produces w^(k)=w(k)\hat{w}^{(k)} = w^{(k)}_* exactly, with each region hit by exactly one run. Then the ensemble predictive equals the posterior predictive: p^ens(yx,D)  =  p(yx,D)  :=  p(yx,w)p(wD)dw.\hat{p}_{\mathrm{ens}}(y \mid x, \mathcal{D}) \;=\; p(y \mid x, \mathcal{D}) \;:=\; \int p(y \mid x, w)\,p(w \mid \mathcal{D})\,dw.

Proof.

By assumption, p(wD)=K1kδ(ww(k))p(w \mid \mathcal{D}) = K^{-1} \sum_k \delta(w - w^{(k)}_*). So p(yx,D)  =  p(yx,w)K1k=1Kδ(ww(k))dw  =  1Kk=1Kp(yx,w(k)).p(y \mid x, \mathcal{D}) \;=\; \int p(y \mid x, w)\,K^{-1}\sum_{k=1}^K \delta(w - w^{(k)}_*)\,dw \;=\; \frac{1}{K}\sum_{k=1}^K p(y \mid x, w^{(k)}_*). By the second assumption w^(k)=w(k)\hat{w}^{(k)} = w^{(k)}_* for each kk, so the right-hand side equals p^ens(yx,D)\hat{p}_{\mathrm{ens}}(y \mid x, \mathcal{D}).

The theorem’s preconditions are gross idealizations: real posteriors are not collapsed Diracs, real ensembles do not perfectly cover all KK modes, and the assumption of equal posterior mass across modes is a strong identifiability condition. But the spirit of the result is the right intuition. A deep ensemble is the right approximation to the Bayesian predictive in a function-space regime where (i) the posterior has multiple modes, (ii) the within-mode predictive variance is small relative to the between-mode predictive variance, and (iii) the modes have comparable posterior mass. For BNNs trained with weight decay, conditions (i) and (iii) hold approximately by §2.4’s symmetry arguments — every mode has h!h! permutation copies, each with the same posterior mass — and (ii) holds whenever the network has sufficient capacity to fit the training data (the within-mode predictive variance shrinks as the network overfits).

The function-space view of Wilson & Izmailov 2020 makes this explicit. The Bayesian posterior over functions is p(fD)  =  p(fw)p(wD)dw,p(f \mid \mathcal{D}) \;=\; \int p(f \mid w)\,p(w \mid \mathcal{D})\,dw, where p(fw)=δ(ffw)p(f \mid w) = \delta(f - f_w) is the deterministic mapping from weights to functions. Distinct weight modes that produce the same function on the training data are redundant in function space — they all map to the same point ff. Distinct weight modes that produce different functions on the training data correspond to different points in function space, and a deep ensemble’s diversity in function space is what matters for predictive uncertainty.

5.3 The mixture predictive form

For regression with Gaussian observation noise, each ensemble member’s predictive is Gaussian, and the ensemble predictive is a mixture of Gaussians with closed-form moments. This is the cleanest setting in which to read the epistemic-vs-aleatoric decomposition that §8 formalizes.

Proposition 5.3 (Mixture-of-Gaussians ensemble predictive).

Suppose the observation model is yx,wN(fw(x),σnoise2)y \mid x, w \sim \mathcal{N}(f_w(x), \sigma^2_{\mathrm{noise}}) with known noise variance σnoise2\sigma^2_{\mathrm{noise}}. Let fˉ(x):=K1k=1Kfw^(k)(x)\bar{f}(x) := K^{-1}\sum_{k=1}^K f_{\hat{w}^{(k)}}(x). The deep-ensemble predictive distribution is the Gaussian mixture p^ens(yx)  =  1Kk=1KN(y;fw^(k)(x),σnoise2),\hat{p}_{\mathrm{ens}}(y \mid x) \;=\; \frac{1}{K}\sum_{k=1}^K \mathcal{N}\bigl(y;\, f_{\hat{w}^{(k)}}(x),\, \sigma^2_{\mathrm{noise}}\bigr), with predictive mean fˉ(x)\bar{f}(x) and predictive variance Varp^ens[yx]  =  σnoise2aleatoric  +  1Kk=1K(fw^(k)(x)fˉ(x))2epistemic.\mathrm{Var}_{\hat{p}_{\mathrm{ens}}}[y \mid x] \;=\; \underbrace{\sigma^2_{\mathrm{noise}}}_{\text{aleatoric}} \;+\; \underbrace{\frac{1}{K}\sum_{k=1}^K \bigl(f_{\hat{w}^{(k)}}(x) - \bar{f}(x)\bigr)^2}_{\text{epistemic}}.

Proof.

The mixture form is immediate from Definition 5.1 with p(yx,w)=N(y;fw(x),σnoise2)p(y \mid x, w) = \mathcal{N}(y; f_w(x), \sigma^2_{\mathrm{noise}}). For the variance, apply the law of total variance to the random variable yy under the mixture: Var[yx]  =  Ek[Var[yx,k]]+Vark[E[yx,k]],\mathrm{Var}[y \mid x] \;=\; \mathbb{E}_k\bigl[\mathrm{Var}[y \mid x, k]\bigr] + \mathrm{Var}_k\bigl[\mathbb{E}[y \mid x, k]\bigr], where kk is the ensemble-member index drawn uniformly from {1,,K}\{1, \ldots, K\}. The conditional variance Var[yx,k]=σnoise2\mathrm{Var}[y \mid x, k] = \sigma^2_{\mathrm{noise}} is constant in kk, so its expectation equals itself. The conditional mean is fw^(k)(x)f_{\hat{w}^{(k)}}(x), whose variance over uniform kk is the sample variance K1k(fw^(k)(x)fˉ(x))2K^{-1}\sum_k (f_{\hat{w}^{(k)}}(x) - \bar{f}(x))^2. Adding gives the claim.

The decomposition is structural: aleatoric uncertainty is the noise level the model assumes (irreducible by adding more data), and epistemic uncertainty is the variance across ensemble members (reducible by adding more data, which would shrink each mode toward zero width and pull the modes themselves toward the truth). For binary classification with Bernoulli observation likelihood the analogous identity uses Var[yx,k]=σ(fw^(k)(x))(1σ(fw^(k)(x)))\mathrm{Var}[y \mid x, k] = \sigma(f_{\hat{w}^{(k)}}(x))(1 - \sigma(f_{\hat{w}^{(k)}}(x))), and the epistemic term is the sample variance of the σ(fw^(k)(x))\sigma(f_{\hat{w}^{(k)}}(x))‘s — exactly the predictive-variance heatmaps we have been plotting.

5.4 Connection to stacking

Deep ensembles use uniform weights πk=1/K\pi_k = 1/K on their members. The Stacking & Predictive Ensembles topic develops a more general framework: given KK candidate predictive distributions, learn weights πΔK1\boldsymbol{\pi} \in \Delta^{K-1} that maximize the leave-one-out posterior-predictive log-density (Yao, Vehtari, Simpson & Gelman 2018). Stacking dominates uniform weighting when the candidate predictives are heterogeneous — different model classes, different priors, different architectures — because uniform weighting can be far from the optimum on the simplex. For the homogeneous deep-ensemble case (same architecture, different random seeds), the candidates are exchangeable by construction and uniform weighting is approximately optimal. The two methods are complementary: stacking is the right tool when you have a heterogeneous catalog of models; deep ensembles are the right tool when you have one architecture and want quick, well-calibrated Bayesian uncertainty.

Remark (Stacking weights vs. uniform weights).

A reader who has shipped the Stacking & Predictive Ensembles topic should think of deep ensembles as the special case “stacking on KK same-architecture, same-prior, same-data candidates with πk\pi_k fixed at 1/K1/K.” Lifting the uniform-weight constraint and learning π\boldsymbol{\pi} from PSIS-LOO is a strict improvement when the candidates have any genuine heterogeneity. For a homogeneous deep ensemble, the PSIS-LOO-optimal weights are within Monte Carlo noise of uniform, so the stacking machinery typically returns to the uniform answer.

5.5 Algorithm and worked example

Algorithm 5.5 (Deep ensemble).

Input: training data D\mathcal{D}, network architecture fwf_w, ensemble size KK, training recipe (optimizer, learning rate, epochs, weight decay). Output: ensemble predictive p^ens(yx)\hat{p}_{\mathrm{ens}}(y \mid x) at test inputs.

  1. For k=1,,Kk = 1, \ldots, K:
    • (a) Sample initialization w^0(k)\hat{w}^{(k)}_0 \sim Kaiming-normal with seed kk.
    • (b) Run the standard training loop to convergence: w^(k)argminwLWD(w;D)\hat{w}^{(k)} \leftarrow \arg\min_w \mathcal{L}_{\mathrm{WD}}(w; \mathcal{D}) via Adam.
  2. Predict at test input xx: p^ens(yx)=K1kp(yx,w^(k))\hat{p}_{\mathrm{ens}}(y \mid x) = K^{-1}\sum_k p(y \mid x, \hat{w}^{(k)}).

For the Two Moons running example we use K=10K = 10 — large enough to make the function-space mode coverage visible, small enough to fit comfortably in the runtime budget. Each member trains in ~0.5 s on CPU, so the §5 cell wall-clock is dominated by training: ~5 s. (This is also why deep ensembles are often called “expensive” in production settings — for ImageNet-scale models, training K=10K=10 networks costs K×K \times a single training run. The Two Moons cost is negligible because each network is tiny.)

Three panels on Two Moons data: panel (a) deep-ensemble predictive mean RdBu_r heatmap; panel (b) deep-ensemble predictive standard deviation viridis heatmap, noticeably brighter off-distribution than single-mode methods; panel (c) all K = 10 ensemble members 0.5-probability contours overlaid in translucent orange, fanning out wider off-distribution than single-mode samples.
Loading…
Figure 5. Deep ensembles. (a) Ensemble mean — robust across methods, similar to §§1/3/4. (b) Predictive standard deviation — noticeably brighter off-distribution, reflecting genuine multi-mode coverage. (c) K decision boundaries fan out wider than single-mode samples — visual confirmation that ensembles cover function-space modes the §3/§4 single-mode methods miss. Drag K to see the diminishing-returns curve; click any boundary in (c) to inspect that member’s full predictive in (a); resample seeds to confirm diversity is robust.

6. Stochastic-gradient Langevin dynamics (SGLD)

§§3–5 each give up something to gain tractability. Laplace gives up multi-modality; MC-dropout gives up the Gaussian variational family in favor of a more rigid Bernoulli; deep ensembles give up the posterior interpretation in favor of a function-space mixture. Stochastic-gradient MCMC (SG-MCMC) takes a different bargain: keep the asymptotic exactness of MCMC and pay for it in wall-clock per posterior sample, but make the per-iteration cost cheap enough that the trade is favorable in practice. The first method in the family — stochastic-gradient Langevin dynamics (SGLD; Welling & Teh 2011) — is a discretization of the Langevin SDE on the negative log-posterior, with mini-batch gradients standing in for full-data gradients and calibrated Gaussian noise injected at every step. Under a square-summable step-size schedule, SGLD samples converge in distribution to the true posterior. §7 will lift this to second-order Langevin (SGHMC) for faster mixing; §8 will compare both to the §§3–5 methods on calibration.

6.1 The Langevin SDE and its stationary distribution

The starting point is the (overdamped) Langevin SDE on the negative log-posterior U(w)=logp(wD)U(w) = -\log p(w \mid \mathcal{D}): dwt  =  U(wt)dt  +  2dBt,dw_t \;=\; -\nabla U(w_t)\,dt \;+\; \sqrt{2}\,dB_t, where BtB_t is standard Brownian motion in Rp\mathbb{R}^p. The fundamental fact about this SDE — proved via the Fokker–Planck equation — is that under mild regularity conditions on UU (smoothness, growth at infinity), its unique stationary distribution is π(w)    exp(U(w))  =  p(wD).\pi(w) \;\propto\; \exp(-U(w)) \;=\; p(w \mid \mathcal{D}). That is, simulating the SDE forward in time and sampling wtw_t at large tt produces samples from the posterior. The proof is a verification: under the Fokker–Planck equation tρ=(ρU)+2ρ\partial_t \rho = \nabla \cdot (\rho \nabla U) + \nabla^2 \rho, the candidate ρ(w)eU(w)\rho_\infty(w) \propto e^{-U(w)} satisfies ρ=ρU\nabla \rho_\infty = -\rho_\infty \nabla U, so ρU+ρ=0\rho_\infty \nabla U + \nabla \rho_\infty = 0 and tρ=0\partial_t \rho_\infty = 0 — the proposed stationary distribution is genuinely stationary. Uniqueness follows from standard ergodicity arguments on the Langevin diffusion.

So the Langevin SDE solves the posterior-sampling problem in continuous time. To use it on a computer we need to discretize, and to scale it to deep learning we need to replace the full-data gradient with a mini-batch estimate. SGLD is exactly that.

6.2 The SGLD update

Definition 6.1 (SGLD update).

Let U(w)=logp(Dw)logp(w)U(w) = -\log p(\mathcal{D} \mid w) - \log p(w) be the negative log-posterior under Definition 2.1’s Gaussian-prior BNN. Let {ηt}t0\{\eta_t\}_{t \geq 0} be a positive step-size schedule, and let bb be the mini-batch size. At each iteration tt, draw a mini-batch BtD\mathcal{B}_t \subset \mathcal{D} uniformly at random with replacement, compute the stochastic gradient g^t  :=  nbiBtwlogp(yixi,wt)  +  1τ2wt,\hat g_t \;:=\; -\frac{n}{b}\sum_{i \in \mathcal{B}_t} \nabla_w \log p(y_i \mid x_i, w_t) \;+\; \frac{1}{\tau^2}\,w_t, draw a fresh isotropic Gaussian ξtN(0,Ip)\xi_t \sim \mathcal{N}(0, I_p), and update wt+1  =  wt    ηt2g^t  +  ηtξt.w_{t+1} \;=\; w_t \;-\; \frac{\eta_t}{2}\,\hat g_t \;+\; \sqrt{\eta_t}\,\xi_t.

Three structural notes. First, g^t\hat g_t is an unbiased estimator of U(wt)\nabla U(w_t): E[g^twt]=U(wt)\mathbb{E}[\hat g_t \mid w_t] = \nabla U(w_t), because the mini-batch term averaged over the uniform random index has expectation equal to the full-data gradient and the prior term is exact. Second, the noise scale ηt\sqrt{\eta_t} matches the Euler–Maruyama discretization of 2dBt\sqrt{2}\,dB_t when the convention ηt/2\eta_t / 2 is used on the gradient — these factors of 2 are pure choice of parametrization, but they have to be consistent. Third, the step-size schedule {ηt}\{\eta_t\} is the central design choice. Welling & Teh prove that, under the schedule ηt=a(b+t)γ\eta_t = a (b + t)^{-\gamma} with γ(1/2,1]\gamma \in (1/2, 1], the chain converges to the posterior; this schedule satisfies the Robbins–Monro conditions tηt=\sum_t \eta_t = \infty (the chain explores all of weight space) and tηt2<\sum_t \eta_t^2 < \infty (the discretization error vanishes asymptotically).

6.3 Asymptotic exactness

Theorem 6.2 (Welling and Teh 2011).

Suppose (i) UU is twice continuously differentiable with bounded Hessian; (ii) the per-example gradient variance Var[logp(yixi,w)]\mathrm{Var}[\nabla \log p(y_i \mid x_i, w)] is bounded uniformly in ww on compacts; (iii) the step-size schedule satisfies tηt=\sum_t \eta_t = \infty and tηt2<\sum_t \eta_t^2 < \infty. Then the SGLD chain {wt}t0\{w_t\}_{t \geq 0} in Definition 6.1 has p(wD)p(w \mid \mathcal{D}) as its asymptotic distribution: for any bounded measurable test function φ\varphi, t=0Tηtφ(wt)t=0Tηt  Ta.s.  φ(w)p(wD)dw.\frac{\sum_{t=0}^{T} \eta_t\,\varphi(w_t)}{\sum_{t=0}^{T} \eta_t} \;\xrightarrow[T \to \infty]{\text{a.s.}}\; \int \varphi(w)\,p(w \mid \mathcal{D})\,dw.

Proof.

Sketch. Decompose the SGLD step into three contributions: wt+1wt  =  ηt2U(wt)  +  ηtξt  +  ηt2(g^tU(wt))=:ζt.w_{t+1} - w_t \;=\; -\frac{\eta_t}{2} \nabla U(w_t) \;+\; \sqrt{\eta_t}\,\xi_t \;+\; \underbrace{-\frac{\eta_t}{2}\bigl(\hat g_t - \nabla U(w_t)\bigr)}_{=: \zeta_t}. The first two terms are exactly an Euler–Maruyama step of the Langevin SDE with step ηt/2\eta_t / 2 (and the corresponding ηt\sqrt{\eta_t} noise). The third term ζt\zeta_t is mini-batch gradient noise: zero-mean, variance (ηt2/4)Var[g^t](\eta_t^2 / 4) \mathrm{Var}[\hat g_t]. As ηt0\eta_t \to 0 on the schedule, the variance of ζt\zeta_t scales as ηt2\eta_t^2, while the variance of the Brownian noise scales as ηt\eta_t. So the ratio Var[ζt]/Var[ηtξt]=O(ηt)0\mathrm{Var}[\zeta_t] / \mathrm{Var}[\sqrt{\eta_t}\,\xi_t] = O(\eta_t) \to 0. The mini-batch noise is asymptotically dominated by the Brownian noise, and in the limit the chain mimics the exact Langevin SDE — whose stationary distribution is the posterior. The square-summability ηt2<\sum \eta_t^2 < \infty controls the cumulative discretization error; the divergence ηt=\sum \eta_t = \infty ensures the chain has time to explore. The full proof (Welling & Teh 2011, Vollmer, Zygalakis & Teh 2016) makes these ratio-and-cumulative arguments rigorous via martingale convergence and ergodic-theorem machinery.

This is the asymptotic-exactness guarantee that motivated SG-MCMC. Note what it does not say: nothing about a constant-stepsize regime, nothing about the §2.4 BvM-failure caveats for over-parametrized neural networks, nothing about practical mixing time. The theorem says SGLD’s invariant measure is the posterior; it does not say SGLD mixes to that measure quickly.

6.4 Mini-batch noise budget

For practical step sizes the mini-batch noise is not negligible relative to the Brownian noise, and understanding how the two interact is the difference between a working SGLD chain and a chain that biases the wrong way.

Proposition 6.3 (Mini-batch noise as a fraction of the noise budget).

The conditional variance of the SGLD step at iteration tt decomposes as Var[wt+1wtwt]  =  ηtIpBrownian noise  +  ηt24Σg^(wt)minibatch noise,\mathrm{Var}[w_{t+1} - w_t \mid w_t] \;=\; \underbrace{\eta_t\,I_p}_{\text{Brownian noise}} \;+\; \underbrace{\frac{\eta_t^2}{4}\,\Sigma_{\hat g}(w_t)}_{\text{minibatch noise}}, where Σg^(wt)=Cov[g^twt]\Sigma_{\hat g}(w_t) = \mathrm{Cov}[\hat g_t \mid w_t]. The two contributions are independent (the mini-batch index is drawn independently of the noise injection). The mini-batch noise contribution scales as ηt2\eta_t^2 and the Brownian as ηt\eta_t, so their ratio is ηttrΣg^/4\eta_t \cdot \mathrm{tr}\,\Sigma_{\hat g} / 4.

Proof.

By definition wt+1wt=ηt2g^t+ηtξtw_{t+1} - w_t = -\tfrac{\eta_t}{2}\,\hat g_t + \sqrt{\eta_t}\,\xi_t. The two random variables are independent (mini-batch and noise injection draws are independent), so the variance is the sum of variances: Var[wt+1wtwt]  =  Var[ηt2g^twt]  +  Var[ηtξt]  =  ηt24Σg^(wt)  +  ηtIp,\mathrm{Var}[w_{t+1} - w_t \mid w_t] \;=\; \mathrm{Var}\Bigl[-\tfrac{\eta_t}{2}\,\hat g_t \,\Big|\, w_t\Bigr] \;+\; \mathrm{Var}\bigl[\sqrt{\eta_t}\,\xi_t\bigr] \;=\; \frac{\eta_t^2}{4}\,\Sigma_{\hat g}(w_t) \;+\; \eta_t\,I_p, which is the claim.

The practical takeaway: when ηt\eta_t is small, the mini-batch noise is negligible relative to the Brownian noise, and the chain is well-approximated by the exact Langevin SDE. When ηt\eta_t is large (constant-stepsize or aggressive schedule), the mini-batch noise becomes a significant — and anisotropic — perturbation, and the chain’s stationary distribution is no longer exactly the posterior.

6.5 The constant-stepsize bias-variance tradeoff

In production we often run SGLD with a constant step-size η\eta for a fixed wall-clock budget rather than the asymptotically-correct decaying schedule. This trades asymptotic exactness for faster mixing and a fixed iteration cost, and the bias structure has been quantified.

Proposition 6.4 (Bias-variance tradeoff in constant-stepsize SGLD).

With constant step-size η>0\eta > 0, the SGLD chain {wt}\{w_t\} has stationary distribution πη(w)p(wD)\pi_\eta(w) \neq p(w \mid \mathcal{D}) in general. Under regularity conditions (Vollmer, Zygalakis & Teh 2016), for any test function φ\varphi in a suitable function space, the asymptotic bias of the time-averaged Monte Carlo estimator decomposes as Eπη[φ]Ep(D)[φ]  =  O(η)  +  O ⁣(ηtrΣg^n),\Bigl|\,\mathbb{E}_{\pi_\eta}[\varphi] - \mathbb{E}_{p(\cdot \mid \mathcal{D})}[\varphi]\,\Bigr| \;=\; O(\eta) \;+\; O\!\left(\frac{\eta\,\mathrm{tr}\,\Sigma_{\hat g}}{n}\right), the first term from Euler–Maruyama discretization error and the second from minibatch gradient noise. The Monte Carlo variance of the time-averaged estimator over TT post-burn-in iterations is Var[φ^]=O(τauto/T)\mathrm{Var}[\hat\varphi] = O(\tau_{\mathrm{auto}} / T), where τauto\tau_{\mathrm{auto}} is the chain’s autocorrelation time.

Proof.

Sketch. Both bias terms come from the stationarity equation of the SGLD chain. Setting tπη=0\partial_t \pi_\eta = 0 in the modified Fokker–Planck equation that includes mini-batch noise, expanding πη=eU(1+O(η)+O(ηtrΣg^/n))\pi_\eta = e^{-U} (1 + O(\eta) + O(\eta\,\mathrm{tr}\,\Sigma_{\hat g}/n)), and integrating against φ\varphi produces the stated rates. The variance term is the standard MCMC variance scaling.

The reader should leave §6.5 with a concrete heuristic. Larger η\eta mixes faster but biases more; more samples reduces variance but not bias. The bias-variance tradeoff in constant-stepsize SGLD is qualitatively different from the bias-variance tradeoff in standard MCMC (which is bias-free at any stepsize, so only variance trades against burn-in). For practical BNN inference, the tradeoff is usually navigated by picking η\eta small enough that the bias is below the noise floor of the downstream calibration metrics — a few times 10410^{-4} for typical small-to-medium networks, smaller for ImageNet-scale.

6.6 Algorithm and worked example

Algorithm 6.5 (SGLD posterior sampling).

Input: training data D\mathcal{D}, network fwf_w, prior scale τ2\tau^2, step-size schedule {ηt}\{\eta_t\} or constant η\eta, mini-batch size bb, burn-in TburnT_{\mathrm{burn}}, sample count TT, thinning interval Δ\Delta. Output: posterior samples w(1),,w(T)w^{(1)}, \ldots, w^{(T)}.

  1. Initialize w0w_0 — random init or warm-start from a MAP w^\hat{w}.
  2. For t=0,1,,Tburn+TΔt = 0, 1, \ldots, T_{\mathrm{burn}} + T \cdot \Delta:
    • (a) Sample mini-batch BtD\mathcal{B}_t \subset \mathcal{D}.
    • (b) Compute stochastic gradient g^t\hat g_t per Definition 6.1.
    • (c) Draw ξtN(0,Ip)\xi_t \sim \mathcal{N}(0, I_p).
    • (d) Update wt+1wtηt2g^t+ηtξtw_{t+1} \leftarrow w_t - \tfrac{\eta_t}{2}\,\hat g_t + \sqrt{\eta_t}\,\xi_t.
  3. Discard the first TburnT_{\mathrm{burn}} iterations as burn-in.
  4. Return w(s):=wTburn+sΔw^{(s)} := w_{T_{\mathrm{burn}} + s\Delta} for s=1,,Ts = 1, \ldots, T (thinning to reduce autocorrelation).

For the §6.7 Two Moons example we use constant η=103\eta = 10^{-3}, mini-batch size b=32b = 32, burn-in Tburn=200T_{\mathrm{burn}} = 200, sample count T=100T = 100, thinning Δ=10\Delta = 10 — total 200+10010=1200200 + 100 \cdot 10 = 1200 iterations, each iteration a single forward+backward pass on a b=32b = 32 mini-batch. Total cell runtime ~7 s.

Three spatial panels and one diagnostic panel: (a) SGLD predictive mean on Two Moons; (b) SGLD predictive standard deviation, comparable to the deep ensemble; (c) sampled 0.5-probability decision boundaries from SGLD chain in purple; (d) single-component weight trace across the chain plus the autocorrelation function.
Loading…
Figure 6. SGLD samples the posterior. (a, b) Mean and standard deviation across post-burn-in samples — comparable to the deep-ensemble §5 with a single optimizer instead of K independent retrains. (c) Sampled decision boundaries. (d) The single-component trace and ACF are the practical mixing diagnostic — large noise early signals exploration, smaller post-burn-in oscillations signal the chain has settled into the typical set. Slide η to trade noise for mixing speed; reseed for a new chain.

7. Stochastic-gradient HMC (SGHMC)

SGLD’s mixing is rate-limited by the first-order Langevin diffusion: each step is a small drift along the gradient plus an isotropic Brownian kick, so traversing a long, low-curvature ridge in the loss landscape requires many small steps. Stochastic-gradient Hamiltonian Monte Carlo (SGHMC; Chen, Fox & Guestrin 2014) lifts the dynamics from first-order to second-order Langevin: introduce a momentum variable, let the chain accumulate velocity along the gradient, and damp the velocity with a friction term that simultaneously enforces stationarity and absorbs the variance contributed by stochastic gradients. Empirically the result mixes considerably faster than SGLD at the same wall-clock cost.

7.1 The second-order Langevin SDE

Augment the state with a momentum variable vRpv \in \mathbb{R}^p and let MRp×pM \in \mathbb{R}^{p \times p} be a positive-definite mass matrix (typically M=IpM = I_p). The underdamped Langevin SDE is dwt=M1vtdt,dvt=U(wt)dtCM1vtdt+2CdBt,dw_t = M^{-1} v_t\,dt, \qquad dv_t = -\nabla U(w_t)\,dt - C\, M^{-1} v_t\,dt + \sqrt{2 C}\,dB_t, where CRp×pC \in \mathbb{R}^{p \times p} is a positive-semidefinite friction matrix. The Fokker–Planck calculation (analogous to §6.1) gives the stationary distribution π(w,v)    exp ⁣(U(w)12vM1v),\pi(w, v) \;\propto\; \exp\!\Bigl(-U(w) - \tfrac{1}{2}\, v^\top M^{-1} v\Bigr), the joint posterior over (w,v)(w, v) in which the marginal π(w)eU(w)=p(wD)\pi(w) \propto e^{-U(w)} = p(w \mid \mathcal{D}) is exactly the BNN posterior we want and the velocity is independent Gaussian noise we can discard. So the second-order Langevin SDE samples the posterior at higher mixing rate than SGLD because momentum carries the chain through low-curvature regions in many fewer steps.

7.2 The complication: stochastic gradients add variance to velocity

If we discretize the SDE via Euler–Maruyama with full-data gradients, we get a ww-and-vv analogue of SGLD that converges to π\pi as the step size decays. But replacing U\nabla U with a stochastic mini-batch estimate g^\hat g introduces an extra noise term in the velocity dynamics: dvstoch=g^dtCM1vdt+2CdB  =  UdtζdtCM1vdt+2CdB,dv_{\mathrm{stoch}} = -\hat g\,dt - C M^{-1} v\,dt + \sqrt{2 C}\,dB \;=\; -\nabla U\,dt - \zeta\,dt - C M^{-1} v\,dt + \sqrt{2 C}\,dB, where ζ:=g^U\zeta := \hat g - \nabla U is zero-mean stochastic-gradient noise with covariance B(w):=Cov[g^w]B(w) := \mathrm{Cov}[\hat g \mid w]. The extra ζ\zeta term changes the effective noise covariance from 2C2C to 2C+B2C + B, and the stationary distribution shifts away from π\pi. Without correction, stochastic-gradient HMC samples a perturbed posterior. The Chen–Fox–Guestrin fix is the friction-compensation: choose the friction matrix CC large enough that the injected Brownian noise can absorb the gradient noise.

7.3 The SGHMC update

Definition 7.1 (SGHMC update).

Let η>0\eta > 0 be the step size, α:=ηC\alpha := \eta C the friction parameter (with CC a positive-semidefinite matrix; typically C=cIpC = c I_p for scalar c>0c > 0), MM the mass matrix (typically IpI_p), and B^(w)\hat B(w) a non-negative-definite estimate of the stochastic-gradient noise covariance B(w)B(w) satisfying CB^C \succeq \hat B. At each iteration, draw mini-batch Bt\mathcal{B}_t, compute g^t\hat g_t as in Definition 6.1, draw ξtN(0,Ip)\xi_t \sim \mathcal{N}(0, I_p), and update: vt+1  =  (1α)vt    ηg^t  +  2η(CB^(wt))ξt,v_{t+1} \;=\; (1 - \alpha) v_t \;-\; \eta\, \hat g_t \;+\; \sqrt{2\eta\,(C - \hat B(w_t))}\,\xi_t, wt+1  =  wt  +  ηM1vt+1.w_{t+1} \;=\; w_t \;+\; \eta\, M^{-1} v_{t+1}.

The structural reading: (1α)(1 - \alpha) scales the previous velocity (friction damping), ηg^t-\eta \hat g_t accelerates along the negative gradient (Hamiltonian drift), and the Brownian-noise injection has scale 2η(CB^)\sqrt{2\eta(C - \hat B)}less than the 2ηC\sqrt{2\eta C} that the un-compensated SDE would inject, by exactly the right amount to absorb the additional variance the stochastic gradient contributes. When B^=0\hat B = 0 (the simplest practical choice — no covariance estimation), we use 2ηC\sqrt{2\eta C} noise injection and accept an O(η)O(\eta) bias of the same order as constant-stepsize SGLD; when B^\hat B is exact and CB(w)C \succeq B(w), the chain has the correct stationary distribution at any step size.

7.4 Stationary distribution

Theorem 7.2 (Chen, Fox and Guestrin 2014).

Suppose the friction CB(w)C \succeq B(w) uniformly in ww and B^(w)=B(w)\hat B(w) = B(w) exactly. Then the continuous-time analogue of the SGHMC update has stationary distribution π(w,v)    exp ⁣(U(w)12vM1v),\pi(w, v) \;\propto\; \exp\!\Bigl(-U(w) - \tfrac{1}{2}\, v^\top M^{-1} v\Bigr), whose ww-marginal is the BNN posterior. Under appropriate step-size schedules, the discretized SGHMC chain converges in distribution to π\pi.

Proof.

Sketch. The Fokker–Planck equation for the joint (w,v)(w, v) density under the continuous-time SGHMC dynamics — including the stochastic-gradient noise term ζ\zeta — has the form tρ  =  w(ρM1v)  +  v(ρ[U+CM1v])  +  v2[(CB^+B(w))ρ].\partial_t \rho \;=\; -\nabla_w \cdot (\rho\, M^{-1} v) \;+\; \nabla_v \cdot \bigl(\rho\, [\nabla U + C M^{-1} v]\bigr) \;+\; \nabla_v^2 \cdot \bigl[(C - \hat B + B(w))\,\rho\bigr]. Setting tρ=0\partial_t \rho = 0 and inserting ρ=π(w,v)\rho_\infty = \pi(w, v) above: the gradient terms vanish (as for the noiseless second-order Langevin), and the diffusion-Laplacian terms vanish if and only if the effective noise covariance equals CC — i.e., CB^+B=CC - \hat B + B = C, which holds when B^=B\hat B = B. So with exact noise compensation, π\pi is stationary. The discrete-time convergence argument follows the standard SG-MCMC machinery (Vollmer, Zygalakis & Teh 2016, applied to the joint (w,v)(w, v) Markov chain): the per-iteration discretization error vanishes under square-summable step sizes, and the time-averaged Monte Carlo estimator converges almost surely to the posterior expectation.

In practice B^(w)=0\hat B(w) = 0 is the standard simplifying choice — the stochastic-gradient covariance is hard to estimate cheaply and depends on ww. The resulting chain has O(η)O(\eta) bias from the un-compensated gradient noise (Prop 7.3 below), of the same order as constant-stepsize SGLD’s bias, and the friction CC is tuned to be large enough that the bias is below the noise floor of downstream metrics.

7.5 Friction-vs-noise tradeoff

Proposition 7.3 (Friction compensates stochastic-gradient variance).

With B^=0\hat B = 0, the SGHMC chain’s stationary distribution πη,C(w,v)\pi_{\eta, C}(w, v) satisfies, for any test function φ\varphi in a suitable class, Eπη,C[φ(w)]Ep(D)[φ]  =  O ⁣(ηB(w)C)  +  O(η),\Bigl|\,\mathbb{E}_{\pi_{\eta, C}}[\varphi(w)] - \mathbb{E}_{p(\cdot \mid \mathcal{D})}[\varphi]\,\Bigr| \;=\; O\!\left(\frac{\eta\, \|B(w)\|}{C}\right) \;+\; O(\eta), where the first term is the noise-compensation bias and the second is the discretization bias. Increasing CC shrinks the first term but slows mixing (the friction damps velocity, eliminating the momentum advantage); decreasing η\eta shrinks both terms but slows wall-clock progress. Practical tuning: pick CC so that the stationary velocity variance C1C^{-1} is comparable to the desired exploration scale, then pick η\eta small enough that the bias is below the calibration noise floor.

Proof.

Sketch. The first bias term comes from solving the Fokker–Planck equation for πη,C\pi_{\eta, C} with the un-compensated stochastic-gradient covariance: setting tρ=0\partial_t \rho = 0 in the equation from §7.4 with B^=0\hat B = 0 gives ρexp(U)exp(12vM1v)\rho_\infty \propto \exp(-U^*) \exp(-\tfrac{1}{2} v^\top M^{*-1} v) for shifted potential U(w)=U(w)+O(B(w)/C)U^*(w) = U(w) + O(B(w)/C) and shifted mass matrix; the integrated bias on ww-marginal expectations is the integrated shift in UU, which scales as B(w)/CB(w)/C. The second bias term is standard Euler-Maruyama-on-second-order-Langevin discretization error, O(η)O(\eta).

The Chen–Fox–Guestrin contribution, in summary: SGHMC at fixed η\eta has the same O(η)O(\eta) bias as SGLD, but with the momentum-variable advantage that mixing is much faster — the chain explores the posterior in fewer iterations. For the same wall-clock budget, SGHMC typically produces a more accurate estimate of posterior expectations than SGLD, even at the cost of carrying the velocity variable through every iteration.

7.6 Practical preconditioning

Remark (Adaptive preconditioning).

Vanilla SGHMC uses a constant friction C=cIpC = c I_p across all dimensions, but neural-network loss landscapes have wildly varying curvature across weights — sharp directions in the early layers, flat directions in the later layers, almost-degenerate directions along the §2.4 ReLU rescaling submanifold. Preconditioned SGHMC (Li, Chen, Carlson & Carin 2016) makes CC a per-parameter or per-layer quantity, often using running estimates of squared gradients in the spirit of RMSProp or Adam. The result is a faster-mixing chain on ill-conditioned posteriors at the cost of some additional bookkeeping per iteration. The pyro and numpyro libraries implement preconditioned SGHMC; laplace-torch implements vanilla SGHMC. For the §7.7 Two Moons example we use vanilla SGHMC with a single scalar friction c=0.1c = 0.1 — the network is small enough that preconditioning gains are small.

7.7 Algorithm and worked example

Algorithm 7.5 (SGHMC posterior sampling).

Input: training data D\mathcal{D}, network fwf_w, prior scale τ2\tau^2, step size η\eta, friction C=cIpC = c I_p, mass M=IpM = I_p, mini-batch size bb, burn-in TburnT_{\mathrm{burn}}, sample count TT, thinning interval Δ\Delta. Output: posterior samples w(1),,w(T)w^{(1)}, \ldots, w^{(T)}.

  1. Initialize w0w_0 (random init or MAP warm-start), v0=0v_0 = 0.
  2. For t=0,1,,Tburn+TΔt = 0, 1, \ldots, T_{\mathrm{burn}} + T \cdot \Delta:
    • (a) Sample mini-batch BtD\mathcal{B}_t \subset \mathcal{D}.
    • (b) Compute stochastic gradient g^t\hat g_t per Definition 6.1.
    • (c) Draw ξtN(0,Ip)\xi_t \sim \mathcal{N}(0, I_p).
    • (d) Update: vt+1(1ηc)vtηg^t+2ηcξtv_{t+1} \leftarrow (1 - \eta c) v_t - \eta\, \hat g_t + \sqrt{2 \eta c}\, \xi_t;   wt+1wt+ηvt+1\;w_{t+1} \leftarrow w_t + \eta\, v_{t+1}.
  3. Discard the first TburnT_{\mathrm{burn}} iterations as burn-in.
  4. Return w(s):=wTburn+sΔw^{(s)} := w_{T_{\mathrm{burn}} + s\Delta} for s=1,,Ts = 1, \ldots, T.

For the Two Moons example we use η=103\eta = 10^{-3}, c=0.1c = 0.1, b=32b = 32, Tburn=200T_{\mathrm{burn}} = 200, T=100T = 100, Δ=10\Delta = 10 — same total iteration count as §6, exposing the head-to-head mixing comparison. Total cell runtime ~7 s.

Three spatial panels and one diagnostic panel: (a) SGHMC predictive mean on Two Moons; (b) SGHMC predictive standard deviation; (c) sampled 0.5-probability decision boundaries from SGHMC chain in red; (d) single-component weight trace plus autocorrelation function with optional SGLD overlay for mixing comparison.
Loading…
Figure 7. SGHMC and the momentum-induced mixing speedup. (a, b) The predictive distribution is similar to SGLD’s at the same iteration budget. (c) Sampled decision boundaries. (d) The autocorrelation function decays faster for SGHMC than SGLD — the visual signature of why momentum helps: each effective sample takes fewer iterations of wall-clock to produce. Toggle the SGLD overlay on (d) to see the side-by-side mixing comparison.

The §§6–7 development of SGLD and SGHMC stops at the recipes that work in practice. The deeper theory — the Itô-SDE view, the Fokker–Planck stationary-distribution proofs, the Vollmer–Zygalakis–Teh O(η+1/B)O(\eta + 1/B) bias bound, Riemann-manifold preconditioning for hierarchical models, and a head-to-head with NUTS that pins down where SG-MCMC genuinely wins — is the subject of stochastic-gradient-mcmc.

8. Calibration and uncertainty quantification

§§3–7 each produce a predictive distribution. §1’s panel (c) and the per-section heatmaps make the qualitative claim that BNN methods produce sharper uncertainty in the right places, but qualitative is not enough — reading a predictive distribution responsibly requires knowing whether the reported probabilities match empirical frequencies. A model that confidently predicts “90% chance class 1” should be right on roughly 90%90\% of inputs that get that prediction; if it is right 70%70\% of the time, it is over-confident and downstream decisions made under those probabilities will be miscalibrated. This section develops three calibration metrics — expected calibration error, the Brier score, and negative log-likelihood — that quantify how well predictive probabilities match empirical reality, decomposes BNN predictive variance into epistemic and aleatoric components, and runs the four §§3–7 methods plus the §1 point estimate head-to-head on a held-out Two Moons test set. The cold-posterior effect (Wenzel et al. 2020) and post-hoc temperature scaling (Guo et al. 2017) get their own remarks at the end.

8.1 Expected calibration error

The simplest calibration metric. Bin the test predictions by predicted probability into BB bins of equal width on [0,1][0, 1], and compare in-bin accuracy to in-bin average confidence.

Definition 8.1 (Expected calibration error (ECE)).

Let {(xi,yi)}i=1N\{(x_i, y_i)\}_{i=1}^N be a held-out test set, p^i[0,1]\hat{p}_i \in [0, 1] the model’s predicted probability of the predicted class at xix_i, and y^i{0,1,,K1}\hat{y}_i \in \{0, 1, \ldots, K-1\} the predicted class. Partition [0,1][0, 1] into BB equal-width bins B1,,BB\mathcal{B}_1, \ldots, \mathcal{B}_B and let Ib:={i:p^iBb}I_b := \{i : \hat{p}_i \in \mathcal{B}_b\}. Define the bin accuracy and bin confidence acc(b):=1IbiIb1[y^i=yi],conf(b):=1IbiIbp^i.\mathrm{acc}(b) := \frac{1}{|I_b|}\sum_{i \in I_b} \mathbb{1}[\hat{y}_i = y_i], \qquad \mathrm{conf}(b) := \frac{1}{|I_b|}\sum_{i \in I_b} \hat{p}_i. The expected calibration error is ECE  :=  b=1BIbNacc(b)conf(b).\mathrm{ECE} \;:=\; \sum_{b=1}^B \frac{|I_b|}{N}\,\bigl|\mathrm{acc}(b) - \mathrm{conf}(b)\bigr|.

A perfectly-calibrated model has ECE=0\mathrm{ECE} = 0: in every bin, the empirical accuracy matches the average confidence. ECE is always non-negative. Standard practice uses B=10B = 10 or B=15B = 15 bins; results are not very sensitive to BB for N103N \geq 10^3. ECE has a known weakness — it does not reward sharp predictions when accuracy is bin-averaged — but its interpretability (it has units of “probability points off”) makes it the most-reported calibration metric in the BNN literature.

8.2 The Brier score

A strictly proper scoring rule for binary classification. Where ECE is bin-based, Brier is point-wise: it averages the squared error between predicted probability and binary outcome.

Definition 8.2 (Brier score and its decomposition).

Let {(xi,yi)}i=1N\{(x_i, y_i)\}_{i=1}^N be a held-out test set with yi{0,1}y_i \in \{0, 1\} and p^i[0,1]\hat{p}_i \in [0, 1] the model’s predicted probability of class 1. The Brier score is BS  :=  1Ni=1N(p^iyi)2.\mathrm{BS} \;:=\; \frac{1}{N}\sum_{i=1}^N (\hat{p}_i - y_i)^2. Murphy 1973 decomposes this score as BS  =  bIbN(conf(b)acc(b))2Reliability    bIbN(acc(b)yˉ)2Resolution  +  yˉ(1yˉ)Uncertainty,\mathrm{BS} \;=\; \underbrace{\sum_b \frac{|I_b|}{N}\,(\mathrm{conf}(b) - \mathrm{acc}(b))^2}_{\text{Reliability}} \;-\; \underbrace{\sum_b \frac{|I_b|}{N}\,(\mathrm{acc}(b) - \bar{y})^2}_{\text{Resolution}} \;+\; \underbrace{\bar{y}(1 - \bar{y})}_{\text{Uncertainty}}, where yˉ=N1iyi\bar{y} = N^{-1}\sum_i y_i is the marginal class rate and the bins Bb\mathcal{B}_b are as in Def 8.1.

Proof.

Direct expansion. Writing p^i\hat p_i as conf(b(i))\mathrm{conf}(b(i)) within its bin and yiy_i as acc(b(i))+ri\mathrm{acc}(b(i)) + r_i for the within-bin residual (so iIbri=0\sum_{i \in I_b} r_i = 0), the term (p^iyi)2(\hat p_i - y_i)^2 becomes (conf(b)acc(b))22(conf(b)acc(b))ri+ri2(\mathrm{conf}(b) - \mathrm{acc}(b))^2 - 2(\mathrm{conf}(b) - \mathrm{acc}(b)) r_i + r_i^2. Sum within each bin: the cross-term vanishes because ri=0\sum r_i = 0, the first squared-term gives the Reliability piece, and the residual variance term equals acc(b)(1acc(b))\mathrm{acc}(b)(1 - \mathrm{acc}(b)). Sum across bins and rearrange using the identity bIbNacc(b)(1acc(b))=yˉ(1yˉ)bIbN(acc(b)yˉ)2\sum_b \frac{|I_b|}{N} \mathrm{acc}(b)(1 - \mathrm{acc}(b)) = \bar y(1 - \bar y) - \sum_b \frac{|I_b|}{N}(\mathrm{acc}(b) - \bar y)^2, which gives the claimed decomposition.

The decomposition reads: BS = (how miscalibrated within bins) − (how varied accuracy is across bins) + (irreducible class-base-rate variance). Uncertainty is the “no-skill” baseline; Resolution rewards a model whose accuracy varies meaningfully across bins; Reliability is the calibration-error analog of ECE. A perfectly calibrated model has Reliability = 0; a fully-discriminating model has Resolution = Uncertainty.

8.3 Negative log-likelihood

Definition 8.3 (Negative log-likelihood as a proper scoring rule).

The negative log-likelihood of the model on the test set is NLL  :=  1Ni=1Nlogp(yixi,D),\mathrm{NLL} \;:=\; -\frac{1}{N}\sum_{i=1}^N \log p(y_i \mid x_i, \mathcal{D}), where p(yixi,D)p(y_i \mid x_i, \mathcal{D}) is the model’s predictive probability of the true label. NLL is a strictly proper scoring rule (Gneiting & Raftery 2007): it is uniquely minimized in expectation by the true conditional distribution.

NLL has two practical properties that ECE and Brier do not. It penalizes over-confident wrong predictions most aggressively: logp-\log p blows up as p0p \to 0 for the true class. And it directly compares to held-out log-likelihoods used elsewhere in Bayesian model selection — the BIC, the marginal likelihood, the WAIC, and the LOO predictive log-density all share NLL’s units (nats per observation). For BNN inference, NLL is usually the headline metric, with ECE and Brier as complementary diagnostics.

8.4 The epistemic-aleatoric decomposition

A BNN’s predictive variance decomposes into two pieces — what the model would learn from more data versus what no model can ever learn — and the decomposition is exactly the law of total variance applied to the BNN predictive.

Proposition 8.4 (Epistemic-aleatoric decomposition).

For a BNN with weight posterior p(wD)p(w \mid \mathcal{D}) and observation likelihood p(yx,w)p(y \mid x, w), the predictive variance at test point xx^* decomposes as Var[yx,D]  =  EwD[Var[yx,w]]aleatoric  +  VarwD[E[yx,w]]epistemic.\mathrm{Var}\bigl[y \mid x^*, \mathcal{D}\bigr] \;=\; \underbrace{\mathbb{E}_{w \mid \mathcal{D}}\bigl[\mathrm{Var}[y \mid x^*, w]\bigr]}_{\mathrm{aleatoric}} \;+\; \underbrace{\mathrm{Var}_{w \mid \mathcal{D}}\bigl[\mathbb{E}[y \mid x^*, w]\bigr]}_{\mathrm{epistemic}}. The aleatoric term is the average of within-model conditional variance (irreducible label noise); the epistemic term is the variance of the conditional mean across weight samples (model uncertainty, which would shrink to zero with infinite data and a correctly specified family).

Proof.

Direct application of the law of total variance to the random variable yy given xx^*, treating ww as an auxiliary random variable conditional on D\mathcal{D}: Var[yx,D]  =  Ew[Var[yx,w]]  +  Varw[E[yx,w]],\mathrm{Var}[y \mid x^*, \mathcal{D}] \;=\; \mathbb{E}_{w}\bigl[\mathrm{Var}[y \mid x^*, w]\bigr] \;+\; \mathrm{Var}_{w}\bigl[\mathbb{E}[y \mid x^*, w]\bigr], where ww is averaged over p(wD)p(w \mid \mathcal{D}). The first term is the average within-weight conditional variance — for a Bernoulli observation, Var[yx,w]=σ(fw(x))(1σ(fw(x)))\mathrm{Var}[y \mid x^*, w] = \sigma(f_w(x^*))(1 - \sigma(f_w(x^*))), which is large in the aleatoric regions where the conditional probability is near 0.50.5. The second term is the variance across weight samples of the predictive mean — large where weight uncertainty produces different mean predictions, exactly the §1 desideratum.

Practically, the epistemic term is what BNN methods compute via Monte Carlo: the §3, §4, §5, §6, §7 predictive standard-deviation heatmaps are epistemic variance\sqrt{\text{epistemic variance}}. The aleatoric term is the average over weight samples of the per-sample conditional variance — for a Bernoulli observation, this is the sigmoid-of-logit variance and is large near the decision boundary regardless of how confident the model is in its weights. The §8.7 head-to-head figure renders both components separately for one method (deep ensemble) so the reader can see the decomposition concretely.

8.5 The cold-posterior effect

Remark (Cold posteriors).

Wenzel et al. 2020 (“How Good is the Bayes Posterior in Deep Neural Networks Really?”) observed empirically that BNN predictive accuracy and calibration consistently improve when the posterior is tempered by raising it to a power 1/T1/T for T<1T < 1: pT(wD)    p(wD)1/T  =  p(Dw)1/Tp(w)1/T.p_T(w \mid \mathcal{D}) \;\propto\; p(w \mid \mathcal{D})^{1/T} \;=\; p(\mathcal{D} \mid w)^{1/T}\,p(w)^{1/T}. Equivalently, with the Gaussian prior of Definition 2.1, tempering by 1/T1/T scales the negative log-likelihood by 1/T1/T and the prior precision by 1/T1/T, so the effective weight-decay strength is λeff=λ/T>λ\lambda_{\mathrm{eff}} = \lambda / T > \lambda when T<1T < 1 — i.e., a stronger regularizer than the principled prior calls for. Across image classification, regression, and language tasks, the optimal TT tends to be in the range 0.010.010.10.1, an order of magnitude or more away from the strict-Bayesian T=1T = 1.

The phenomenon is one of the central open problems in BNNs. The two leading hypotheses are: (i) the Gaussian prior is misspecified — real-world weight distributions are heavier-tailed than N(0,τ2I)\mathcal{N}(0, \tau^2 I), so the principled prior over-regularizes and tempering compensates; (ii) the likelihood is misspecified in some other way (data augmentation, label noise) that interacts with the prior. Aitchison 2021 (“A statistical theory of cold posteriors”), Adlam, Snoek & Smith 2020, and Izmailov et al. 2021 contribute partial resolutions, but the question is not closed.

8.6 Temperature scaling

Remark (Post-hoc temperature scaling).

A pragmatic alternative to choosing TT in the prior is post-hoc temperature scaling (Guo et al. 2017): after training a network the standard way, learn a single scalar temperature T^\hat T that minimizes NLL on a held-out validation set, and rescale all test-time logits by 1/T^1/\hat T. The construction does not change the model’s accuracy (the argmax of the rescaled logits matches the argmax of the original logits), but it can dramatically reduce ECE by softening over-confident predictions. Temperature scaling is now the default post-hoc calibration step in production deep-learning pipelines and is included in the laplace-torch library as an automatic post-processing step. The §8.7 head-to-head comparison reports each method’s NLL both before and after temperature scaling.

8.7 Head-to-head comparison

Algorithm 8.7 (Head-to-head calibration evaluation).

Input: training data D\mathcal{D}, held-out test set Dtest={(xi,yi)}i=1Ntest\mathcal{D}_{\mathrm{test}} = \{(x^*_i, y^*_i)\}_{i=1}^{N_{\mathrm{test}}}, the six method outputs from §§1, 3, 4, 5, 6, 7 (point estimate, Laplace, MC-dropout, deep ensemble, SGLD, SGHMC). Output: table of {ECE,BS,NLL}\{\mathrm{ECE}, \mathrm{BS}, \mathrm{NLL}\} for each method, with reliability diagrams.

  1. For each method mm, compute predicted class-1 probabilities p^i(m)\hat{p}^{(m)}_i on the test set via Monte Carlo over the method’s posterior samples.
  2. Compute ECE with B=10B = 10 bins, BS via Def 8.2, NLL via Def 8.3.
  3. Plot reliability diagrams: predicted-confidence bin centers on xx-axis, acc(b)\mathrm{acc}(b) on yy-axis, with the diagonal y=xy = x marked as the perfectly-calibrated reference.

For Two Moons we use Ntest=500N_{\mathrm{test}} = 500 held-out points generated with make_moons at the same noise level as training but a different random_state, so the test set is iid from the same distribution as training and ECE is well-defined.

Two panels: panel (a) reliability diagram with predicted confidence on the x-axis and empirical accuracy on the y-axis, with the diagonal y=x as the reference line and one connected line per method (point estimate, Laplace, MC-dropout, deep ensemble, SGLD, SGHMC); the point estimate's curve sits below the diagonal in mid-confidence bins (over-confidence) and the BNN methods' curves sit closer to the diagonal; panel (b) grouped bar chart of ECE, Brier × 10, and NLL × 10 for the six methods, with deep ensemble and SGHMC having the lowest values.
Figure 8. Head-to-head calibration on Two Moons. (a) Reliability diagram — distance from the y=x diagonal measures miscalibration. (b) Bar chart of ECE / Brier×10 / NLL×10 across methods. Toggle methods to compare; slide bin count to see binning sensitivity. The optional manual-temperature slider rescales logits as a sanity check; full temperature scaling fits T on a held-out set rather than letting the reader pick — that step is left for a v3 enhancement.

9. Function-space view: NNGP, NTK, and open problems

§§3–7 work in weight space. Each method approximates the posterior p(wD)p(w \mid \mathcal{D}) over the network’s parameters and reads predictions off the resulting weight distribution. The §2.4 obstacles — multimodality, ReLU rescaling, over-parametrization — are all weight-space pathologies, and the methods of §§3–7 are weight-space workarounds. Function space offers a different vantage. What we ultimately care about is the predictive distribution over outputs, which lives in function space; weight space is the awkward intermediate representation. Two classical results — Neal’s 1996 neural network Gaussian process (NNGP) and Jacot, Gabriel & Hongler’s 2018 neural tangent kernel (NTK) — show that in the infinite-width limit, both Bayesian inference and gradient-descent training reduce to operations on a fixed kernel function. The function-space view connects BNNs to Gaussian Processes, explains why deep ensembles work, and provides the asymptotic reference against which the §§3–7 methods can be evaluated. This section gives the four key facts, with proofs deferred to references; the running example is one panel of NNGP-prior samples that visualizes the infinite-width convergence.

9.1 The neural network Gaussian process

Remark (NNGP — Neal 1996; Lee et al. 2017).

Consider an MLP with LL hidden layers of widths h1,,hLh_1, \ldots, h_L, weights Wij()iidN(0,σw2/h1)W^{(\ell)}_{ij} \stackrel{\mathrm{iid}}{\sim} \mathcal{N}(0, \sigma_w^2 / h_{\ell - 1}), biases bj()iidN(0,σb2)b^{(\ell)}_j \stackrel{\mathrm{iid}}{\sim} \mathcal{N}(0, \sigma_b^2), and elementwise nonlinearity ϕ\phi. As hh_\ell \to \infty for all \ell, the prior over the function fw(x)f_w(x) converges, in the sense of finite-dimensional distributions, to a Gaussian process GP(0,kNNGP)\mathcal{GP}(0, k_{\mathrm{NNGP}}) with covariance kernel computable by the recursion k(0)(x,x)=σb2+σw2xxd,k()(x,x)=σb2+σw2E(u,u)N(0,K(1)(x,x))[ϕ(u)ϕ(u)],k^{(0)}(x, x') = \sigma_b^2 + \sigma_w^2\,\frac{x \cdot x'}{d}, \qquad k^{(\ell)}(x, x') = \sigma_b^2 + \sigma_w^2\,\mathbb{E}_{(u, u') \sim \mathcal{N}(0, K^{(\ell-1)}(x, x'))}[\phi(u)\,\phi(u')], with K(1)(x,x)=(k(1)(x,x)k(1)(x,x)k(1)(x,x)k(1)(x,x))K^{(\ell-1)}(x, x') = \begin{pmatrix} k^{(\ell-1)}(x, x) & k^{(\ell-1)}(x, x') \\ k^{(\ell-1)}(x', x) & k^{(\ell-1)}(x', x') \end{pmatrix}. The output kernel is kNNGP=k(L)k_{\mathrm{NNGP}} = k^{(L)}. For ReLU, the per-layer expectation has the closed-form arc-cosine kernel (Cho & Saul 2009).

Neal’s 1996 PhD thesis proved the one-hidden-layer case via a direct CLT argument: each output fw(x)=jwj(2)ϕ(Wj(1)x+bj(1))f_w(x) = \sum_j w^{(2)}_j \phi(W^{(1)}_j x + b^{(1)}_j) is a sum of h1h_1 iid bounded random variables (in the sense that suitable moments are bounded), so the CLT applies and fw(x)f_w(x) converges to a Gaussian. Joint distributions across multiple inputs x1,,xnx_1, \ldots, x_n likewise converge to a multivariate Gaussian — i.e., a Gaussian process. Lee et al. 2017 extended the argument to deep networks via induction over layers, giving the recursive kernel formula above.

The implication is structural: at infinite width, BNN posterior inference reduces to GP regression (or GP classification, via Laplace or EP, per Gaussian Processes). The NNGP kernel encodes the entire architectural choice — depth, activation, prior scales — and once the kernel is in hand the O(n3)O(n^3) GP-inference machinery applies directly. There is no weight-space optimization, no MAP, no Hessian, no SGLD. The infinite-width BNN is a Gaussian process.

9.2 The neural tangent kernel

Remark (NTK — Jacot, Gabriel and Hongler 2018).

Consider an MLP trained by gradient descent on the squared-error loss starting from random initialization w0w_0. The neural tangent kernel at w0w_0 is Θ(x,x)  :=  wfw(x),wfw(x)w=w0.\Theta(x, x') \;:=\; \bigl\langle \nabla_w f_w(x),\, \nabla_w f_w(x') \bigr\rangle\bigm|_{w = w_0}. As widths \to \infty, Θ\Theta becomes deterministic (concentrated at its expectation Θ\Theta_*) and constant during training — the gradient is dominated by the linearization at w0w_0, and the network evolves as if it were the linearized model fw0(x)+wfw0(x)(ww0)f_{w_0}(x) + \nabla_w f_{w_0}(x)^\top (w - w_0). The training dynamics converge to a deterministic ODE in function space: dft(x)dt  =  i=1nΘ(x,xi)(ft(xi)yi),\frac{df_t(x)}{dt} \;=\; -\sum_{i=1}^n \Theta_*(x, x_i)\,(f_t(x_i) - y_i), whose solution at tt \to \infty is exactly the kernel-regression predictor under Θ\Theta_*.

Jacot et al.’s argument has two parts. First, at initialization, the gradient wfw0(x)\nabla_w f_{w_0}(x) is itself a random function whose pairwise inner products converge in the infinite-width limit to a deterministic kernel Θ\Theta_* — a CLT analogous to NNGP’s. Second, during training, the change in weights wtw0w_t - w_0 stays small in a width-dependent norm (the lazy regime of Chizat & Bach 2019), so the linearization fw(x)fw0(x)+wfw0(x)(ww0)f_w(x) \approx f_{w_0}(x) + \nabla_w f_{w_0}(x)^\top (w - w_0) remains valid and the network’s outputs evolve linearly. The combination gives kernel-regression dynamics in function space.

The NTK is not the same as the NNGP kernel. NNGP describes the prior over functions before any training; NTK describes the trained function under gradient descent. In general ΘNTK(x,x)kNNGP(x,x)\Theta_{\mathrm{NTK}}(x, x') \neq k_{\mathrm{NNGP}}(x, x'), and at infinite width they describe two different inference regimes (Bayesian posterior vs. gradient-descent training). Lee et al. 2019 (“Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent”) gave the precise statement of NTK convergence and the relationship between the two kernels.

9.3 What this means for the §§3–7 methods

Remark (Function-space asymptotic ordering).

At infinite width, the §§3–7 methods can be ordered by how well they recover the NNGP posterior.

SG-MCMC (§§6–7). Asymptotically exact: in the joint limit of infinite width and infinitesimal step size with infinite chain length, SGLD and SGHMC sample exactly the NNGP posterior. This is the strongest function-space guarantee in the catalog.

Deep ensembles (§5). Approximately correct: the modes of an infinitely-wide network with random initialization are approximately samples from the NNGP prior, so KK-ensemble averaging at large KK recovers the NNGP posterior up to mode-coverage error (Wilson & Izmailov 2020). The argument is delicate but the spirit is right.

Laplace (§3). Coarse approximation: a single local Gaussian centered at one MAP estimate captures the local curvature of the NNGP posterior at one point, but misses the full function-space distribution. At infinite width the NNGP posterior is itself a Gaussian, and Laplace’s approximation converges to a sub-Gaussian within it — the bias is O(1)O(1) in the function-space sense.

MC-dropout (§4). Structurally biased: the Bernoulli variational family does not converge to the NNGP posterior at any width, because no Bernoulli posterior over weights induces a Gaussian process over functions. Foong et al. 2020 (“On the Expressiveness of Approximate Inference in Bayesian Neural Networks”) formalize this gap.

The §8 head-to-head ordering on Two Moons is roughly consistent with this asymptotic ordering: deep ensembles and SG-MCMC are the most-calibrated, Laplace is in the middle, and MC-dropout is at the bottom. For finite-width networks the differences are smaller than the infinite-width view suggests, and for some narrow regimes the ordering inverts — but the function-space asymptotic gives the right intuition for which method to reach for first.

9.4 Open problems

Remark (Active research directions).

Four open problems organize ongoing work in BNNs.

The cold-posterior effect. Wenzel et al. 2020’s empirical observation (§8.5) that BNN performance improves under tempering with T<1T < 1 remains poorly understood. Aitchison 2021, Adlam, Snoek & Smith 2020, and Izmailov et al. 2021 offer partial explanations — prior misspecification, label noise, data augmentation interacting with the implicit posterior — but no consensus has emerged.

Finite-width corrections to NNGP and NTK. The infinite-width theory is clean but practical networks have finite width. The leading-order correction in 1/h1 / h is the subject of active work — Yang & Hu 2021 (“Feature Learning in Infinite-Width Neural Networks”) propose a width-rescaled regime where features are learned even at infinite width; Bordelon, Canatar & Pehlevan 2020 give explicit finite-width corrections to NTK regression. The gap between the infinite-width theory and practical BNN inference at finite width remains one of the most active questions in deep-learning theory.

Scalable asymptotically-exact MCMC. Beyond SGHMC, methods like symmetric splitting integrators (Leimkuhler & Matthews 2013), cyclic SG-MCMC (Zhang, Sun, Duvenaud & Grosse 2020), and full-batch HMC at scale (Izmailov et al. 2021) push the asymptotically-exact frontier toward ImageNet-class production. The planned stochastic-gradient-mcmc (coming soon) topic will develop these methods in detail.

Function-space variational inference. Inferring directly in function space rather than weight space — Wang, Shi & Cheng 2019 (“Function-space VI through Stein discrepancy”), Sun, Zhang, Shi & Grosse 2019, Burt, Ober, Garriga-Alonso & van der Wilk 2020 — sidesteps the weight-space identifiability problems §2.4 catalogues. The conceptual appeal: do inference where the inference target lives. The technical difficulty: function space is infinite-dimensional, so variational families and divergences have to be chosen carefully.

Two panels: panel (a) bar chart of empirical Var f(x_0) at six widths h ∈ {50, 100, 200, 500, 1000, 2000} with a horizontal reference line for the closed-form NNGP arc-cosine kernel value at x_0; panel (b) NNGP-kernel GP regression posterior on a small synthetic regression dataset showing posterior mean and ±2σ band, with uncertainty growing between training points.
Figure 9. The function-space view. (a) Empirical Var f(x_0) of finite-width MLPs converges to the closed-form arc-cosine NNGP kernel value as h grows. (b) NNGP-kernel GP regression — closed-form, no-training, with uncertainty growing between training points. Toggle 95% bands; click a bar to highlight that width's deviation from the limit.

Connections and further reading

The function-space view connects this topic to neighboring formalML topics. Gaussian Processes develops the GP machinery the NNGP relies on — Cholesky factorization of the kernel matrix, conditional-MVN posteriors, hyperparameter learning by marginal likelihood. Variational Inference is the substrate for §4’s MC-dropout derivation and for the planned function-space VI of Rem 9.4. Stacking & Predictive Ensembles generalizes §5’s uniform-weighted deep ensemble to learned weights. Stochastic-Gradient MCMC and the planned meta-learning build directly on §§6–7’s SG-MCMC machinery; Sparse Bayesian Priors revisits §2.1’s prior choice with heavy-tailed alternatives that resolve some of Rem 9.4’s cold-posterior questions. Cross-site, formalstatistics’s formalStatistics: Bayesian Foundations and Prior Selection provides the prior-and-likelihood machinery this topic takes as given, and formalStatistics: Central Limit Theorem is the rigorous source of the Bernstein–von Mises invocation in §2.3.

Connections

  • MC-dropout (§4) is a Bernoulli variational family on weights; the BNN topic uses VI's ELBO machinery as substrate, and the §4.3 Gal–Ghahramani equivalence reduces dropout training to reverse-KL minimization within the Bernoulli family that VI develops the general theory for. variational-inference
  • The infinite-width MLP prior converges to a Gaussian process (NNGP); the §9 sidebar derives this connection and uses GP machinery for the function-space view. The §9.1 NNGP regression posterior is a direct application of the GP topic's closed-form conditional Gaussian construction. gaussian-processes
  • Deep ensembles (§5) are the special case of stacking with K same-architecture, same-prior candidates and uniform weights; §5.4 pulls this connection out explicitly. Lifting the uniform-weight constraint and learning weights via PSIS-LOO is a strict improvement when candidates have any genuine heterogeneity. stacking-and-predictive-ensembles
  • §4's Gal–Ghahramani equivalence minimizes reverse KL between the Bernoulli variational family and the BNN posterior; the topic invokes the KL machinery developed there for the variational characterization of dropout training. kl-divergence
  • §§6–7's SG-MCMC methods (SGLD, SGHMC) are direct extensions of the gradient-descent machinery developed there: the same minibatch-gradient update plus calibrated Gaussian noise to turn optimization into posterior sampling. gradient-descent

References & Further Reading