Gradient Flows II: Convexity and Connections to Machine Learning

Connections to Machine Learning

$$ \nonumber \newcommand{\xb}{\boldsymbol{x}} \newcommand{\yb}{\boldsymbol{y}} \newcommand{\zb}{\boldsymbol{z}} \newcommand{\thetab}{\boldsymbol{\theta}} \newcommand{\grad}{\nabla} \newcommand{\RR}{\mathbb{R}} $$

In my previous post, I introduced the notion of proximal gradient descent and explained the way in which the “geometry” or the metric used in the proximal scheme allows us to define gradient flows on arbitrary metric spaces. This concept is important in the context of statistical mechanics because analysis of the Fokker-Planck equation naturally yields a gradient flow in the Wasserstein metric. One of the most important consequences of this change of perspective is the statement that the relaxation of the Fokker-Planck equation is displacement convex; this fact can be leveraged to prove that (aside from some technical caveats necessary to formally state this results), from any initial condition a system of particles governed by the Fokker-Planck equation relaxes to its equilibrium exponentially fast. In this post, I want to go through the argument that leads to that proof of exponential convergence. Furthermore, want to illustrate a connection with the basic problem of supervised machine learning: minimizing the “loss” function with stochastic gradient descent.

Let’s begin by recalling the Fokker-Planck equation. We take $V:\RR^d\to \RR$ to be the potential energy for a particle system and consider the stochastic differential equation $$ dX_t = -\nabla V (X_t)dt + \sqrt{2\beta^{-1}} dW_t \qquad X_0= x_0 $$ with $W_t$ a standard Wiener process. We do need a technical assumption on the potential, which is that it’s lower semi-continuous and uniformly $K$-Lipschitz. Introducing the particle empirical measure $$ \begin{equation} \mu^{(n)}(\xb) = \frac1n \sum_{i=1}^n \delta_{\xb_i}(d\xb) \end{equation} $$ we see, with an application of Itô’s lemma, that $$ \begin{equation} \label{eq:fpe} \partial_t\mu^{(n)} = \grad\cdot\left( \mu^{(n)} \grad V\right) + \beta^{-1}\Delta \mu^{(n)}. \end{equation} $$ We won’t linger on it here, but in order to perform this calculation, one needs to use bounded continuous test functions and the convergence with respect to the number of particles $n$ is actually “weak” convergence, meaning that the test functions integrated against the particle measure converge.

Essentially, what we showed last time is that a proximal algorithm for gradient descent on the free energy (in which the proximity metric is the Wasserstein distance) recovers the Fokker-Planck equation as its limit. We discussed one very useful property of the gradient flow corresponding to the evolution of the Fokker-Planck equation, namely “displacement convexity”. This is a generalization of the classical notion of convexity, due to McCann, to the case of a dynamics on a metric space which asserts that there is convexity along geodesics. This property is strong enough to ensure that, given an initial condition, $$ \mu_0\geq 0\qquad \int \mu_0 =1 $$ then the solution of $\eqref{eq:fpe}$ will converge to $\mu_{\rm eq} \propto e^{-\beta V}$ exponentially fast.

In the case of a uniformly convex potential, a straightforward argument for this exponential convergence is based on Grönwall’s inequality and does not rely on the theory of gradient flows. If we define the process $$ dY_t = -\nabla V (Y_t)dt + \sqrt{2\beta^{-1}} dW_t \qquad Y_0= y_0 $$ where $y_0$ and $x_0$ are sampled independently but the noise driving the stochastic differential equation $dW_t$ is the same between the two processes, then the process $$ a_t := X_t-Y_t $$ is solves the deterministic ODE $$ \frac{da_t}{dt} = -\grad (V(X_t)-\grad V(Y_t)). $$ The magnitude of the discrepancy between $X_t$ and $Y_t$ will diminish as a function of time as a consequence of Grönwall’s lemma: $$ \frac{d}{dt}|a_t|^2 \leq - K |a_0|^2 \implies |a_t|^2 \leq e^{-2Kt} |a_0|^2. $$ So, taking expectations, we see that $$ \mathbb{E} |a_t|^2 \leq e^{-2Kt} \mathbb{E} |a_0|^2 \leq 2e^{-2Kt} \left[\mathbb{E} |X_0|^2 + \mathbb{E}|Y_0|^2 \right] $$ which goes to zero almost surely. However, displacement convexity is enough to ensure exponential convergence to the equilibrium measure. This type of result was developed by Otto and Villani for a variety of systems and is related to Logarithmic Sobalev inequalities, which I will detail in a later post.

Neural networks and the lack of exponential convergence

Connections between the dynamics that describes interacting particle systems (Wasserstein gradient flows) and the optimization dynamics of neural networks were initially discovered essentially simultaneously by several groups: Eric Vanden-Eijnden and myself; Lénaïc Chizat and Francais Bach; Song Mei, Pham Nguyen, and Andrea Montanari; Justin Sirignano and Konstantinos Spiliopolous. Remarkably, these all appeared on the arXiv within days of one and other. The basic framework, which I’ll outline below, has come to be called the “mean-field” approach. This nomenclature comes from the fact that each parameter is treated as a “charged particle” and the particles interact through a term that averages over all the particles in the system (i.e., a mean field). To see how this point of view arises, it’s easiest to contextualize it in the case of function approximation.

Suppose we have access to some information about a function $f:\RR^k\to\RR$. We might, for example, know its value on a data set. We want to approximate this function with a neural network. Let’s take the simplest case (and the only case for which we have explicit “mean-field” results and global convergence results)—a single-hidden layer network, which we write as $$ \begin{equation} \label{eq:nn} f^{(n)}(\xb; \thetab) = \frac1n \sum_{i=1}^n c_i \varphi(\xb, \zb_i) \end{equation} $$ where $\thetab_i = (c_i,\zb_i).$ As has been known for decades, if $\varphi$ is a continuous, non-polynomial function, then $f^{(n)}$ is a universal approximator. Loosely, that means that for $n$ large enough there exists a set of parameters such that $f^{(n)}$ is arbitrarily close to the target function $f$. Of couse, one of the major problems in machine learning is finding that set of parameters through some optimization algorithm.

In the mean-field perspective, we consider what happens when let the number of parameters (or width of the network) tend towards infinity. We can always rewrite $\eqref{eq:nn}$ as $$ \begin{equation} \label{eq:nnmu} f^{(n)}(\xb; \thetab) = \int_D c \varphi(\xb, \zb) d\mu(c,z) \end{equation} $$ where $\mu$ is the empirical measure for the parameters and $D$ is the parameter space. There’s nothing sophisticated going on here, just replacing the values with an bunch of $\delta$ masses at those values. In general, we have one class of methods for determining the optimal parameters, doing some variant of gradient descent on a loss function. We want to minimize $$ \begin{equation} \mathcal{L}(\mu) = \frac12 \int_{\Omega} |f(\xb) - f^{(n)}(\xb, \mu)|^2 d\nu(\xb) \end{equation} $$ which is sometimes called the population loss function because computing it would require integrating over the entire data distribution $\nu$. In almost all machine learning problems, we don’t know $\nu$ (or even $f$) so instead we approximate the integral defining $\mathcal{L}$ with an empirical sum over some dataset. We use that data to estimate the gradients and then we update the parameters accordingly.

The question we now ask is a bit fundamental so we’re going to stick with the population loss.

Question: Does gradient descent on $\mathcal{L}$ converge asymptotically (in $n$ and $t$) to the optimum?

To address this, we first take the limit $n\to\infty$ and ask how the measure evolves as a function of time. A calculation much like the one we described for deriving the Fokker-Planck equation leads to the following equation for the evolution of the parameter measure $$ \partial_t \mu = \grad \cdot (\mu \grad V(\mu)). $$ This equation differs from Fokker-Planck in a couple very important ways. First, it is nonlinear because $V$ is a functional of $\mu$. Secondly, there is no Laplacian term in this equation, and that means we don’t have an explicit description of the stationary distribution nor can we guarantee displacement convexity. These two factors make the analysis of global convergence markedly more involved. Chizat and Bach proved global convergence using a rigorous argument that is quite clearly presented in their paper. Currently, we do not have convergence rates that are competitive with the exponential rates established for Fokker-Planck; this is a significant open problem. Though it should be noted that, empirically, the convergence properties are typically quite good in the overparameterized regime.

Assistant Professor of Chemistry

Grant M. Rotskoff is an assistant professor at Stanford. He studies statistical mechanics with a focus on nonequilibrium phenomena.