Wasserstein Gradient Flows and the Fokker Planck Equation (Part I)
Wasserstein Gradient Flows and the Fokker Planck Equation (Part I)
$$ \nonumber \newcommand{\xb}{\boldsymbol{x}} \newcommand{\yb}{\boldsymbol{y}} \newcommand{\grad}{\nabla} \newcommand{\RR}{\mathbb{R}} $$
The connection between partial differential equations arising in chemical physics, like the FokkerPlanck equation discussed below, and the notions of distance in the space of probability measures is a relatively young set of mathematical ideas. While the theory of gradient flows of arbitrary metric spaces can get exceedingly intricate, the fundamental ideas are not unapproachable. In this note, my aim is to illustrate some of the main ideas of the abstract theory of Wasserstein gradient flows and highlight the connection first to chemistry via the FokkerPlanck equations, and then to machine learning, in the context of training neural networks.
Let’s begin with an intuitive picture of a gradient flow. Let $\Theta \in \RR^k$ and consider the standard discrete time gradient descent update $$ \begin{equation} \Theta^{(k+1)}=\Theta^{(k)}  \tau \nabla F(\Theta^{(k)}). \end{equation} $$ In the limit $\tau \to 0$, the interpolating sequence ${ \Theta^{(k)}}_{k\in \mathbb{N}}$ defines a flow of the parameters through Euclidean space. We can make the Euclidean metric explicit by reformulating the discrete time update with the proximal formulation $$ \begin{equation} \Theta ^{(k+1)} = \textrm{argmin}_{\Theta} \tau F(\Theta) + \frac12  \Theta  \Theta^{(k)} ^2 \label{eq:eucprox} \end{equation} $$ which leads to the implicit update scheme $$ \begin{equation} \label{eq:gd_implicit} \Theta^{(k+1)}=\Theta^{(k)}  \tau \nabla F(\Theta^{(k+1)}). \end{equation} $$ These two formulations define exactly the same flow $\Theta_t$ through Euclidean space in the continuous time limit. This basic observation will motivate a similar calculation for gradient flows in the Wasserstein metric.
Wasserstein metric
The name “Wasserstein” gradient flows originates from a connection to the Wasserstein metric. This metric is sometimes called the “earthmover’s distance” because of its historical connection to the Monge problem, which asks, colloquially, given a pile of dirt, how should I move that dirt to fill a given hole in the ground in such a way that the amount of work I do is minimal. While I wouldn’t slight the impressive mathematical tradition of France with this joke myself, I heard Gabriel Peyr'{e}, both a local and an expert on the topic, point out that minimizing the work to accomplish a given task is a very French endeavor. The history of this type of problem is quite rich: while the Monge problem does not always have a solution[footnote], a somewhat relaxed formulation was given by Kantorovich. His work was motivated by Soviet central planning, in particular, the movement of ore for processing into steel. The last little historical tidbit I’ll mention concerns Wasserstein himself: he’s not actually responsible for the metric, ironically.
So, let us proceed to the definition.
Let $\mu_1, \mu_2$ be probability measures in $\mathcal{P}_2(\mathcal{M})$, the set of probability measures with finite second moment on $\mathcal{M}$, $$ \mathcal{P}_2(\mathcal{M}) = \{ \rho: \mathcal{M}\to [0,\infty) \bigg \int_{\mathcal{M}} d\rho = 1, \quad \int_{\mathcal{M}} \xb^2 \rho(\xb) d\xb<+\infty \} $$
Let $\Pi(\mu, \nu)$ define the set of probability measures with marginals $\mu$ and $\nu$. The Wasserstein metric with $p=2$ is then
$$
W_2^2 ( \mu, \nu) = \inf_{\pi \in \Pi(\mu,\nu)}\int xy^2 d\pi(x,y).
$$
FokkerPlanck Equation
Let $V:\RR^d\to \RR$ be the potential energy for a particle system and consider the stochastic differential equation $$ dX_t = \nabla V (X_t)dt + \sqrt{2\beta^{1}} dW_t \qquad X_0= x_0 $$ with $W_t$ a standard Wiener process. This equation is often interpreted as (and is empirically consistent with) a physical system of particles in which the noise comes from fluctuations due to a solvent. It is well known and easy to verify (simply using the chain rule in the Ito calculus) that the particle density satisfies the FokkerPlanck equation, $$ \begin{equation} \partial_t \rho_t = \grad \cdot \left( \rho_t \grad V\right) + \beta^{1} \Delta \rho_t, \qquad \rho_0 = \rho^0. \end{equation} $$
It is easy to verify by a direct calculation that the unique stationary distribution,
$$
\rho_{\textrm s}(\xb) = Z^{1} \exp\left(  \beta V(\xb) \right) \qquad Z = \int \exp\left(  \beta V(\xb) \right) d\xb.
$$
In order that the integral defining $Z$ converges, one generally needs coerciveness assumptions on the potential $V$. Furthermore, we know (using ideas from statistical mechanics) that the density $\rho$ satisfies a variational principle with a corresponding free energy
$$
\begin{aligned}
\mathcal{F}[\rho] &= \int V(\xb) \rho(d\xb) + \beta^{1} \int \rho(\xb) \log \rho(\xb) d\xb\
&\equiv E[\rho] + S[\rho]
\end{aligned}
$$
Indeed, a formal calculation shows
$$
\frac{\delta \mathcal{F}}{\delta \rho} = 0 \implies \rho(\xb) \propto e^{\beta V(\xb)}
$$
and since $\mathcal{F}$ is convex in $\rho$, this minimizer of the functional is the unique global minimizer and any other density has higher free energy. We set $\beta=1$ in what follows to simplify notation.
Proximal Wasserstein Gradient Descent
We now formalize this notion of a variational principle. We assume that we have a smooth, positive potential function with a gradient that’s not too big:
$$ %\newtheorem{assumption}[theorem]{Assumption} %\begin{assumption} V \in C^{\infty} (\RR^n) %\end{assumption} $$
$$ %\newtheorem{assumption}[theorem]{Assumption} %\begin{assumption} V(\xb) \geq 0,\quad \forall \xb \in \RR^n %\end{assumption} $$
$$  \grad V(\xb)  \leq C(V(\xb))+1),\quad \forall\xb \in \mathbb{R}^n $$
Just as in the proximal gradient steps in the Euclidean metric$~\eqref{eq:eucprox}$, we can construct an iterative scheme
$$
\begin{equation}
\begin{cases}
, \
\rho^{(k)} &\leftarrow \textrm{argmin}_{\rho\in \mathcal{P}_2(\mathcal{M})} W_2^2(\rho, \rho^{(k1)}) + \epsilon \mathcal{F}[\rho]. \
,
\end{cases}
\label{eq:wprox}
\end{equation}
$$
We will prove (i) that a solution to the minimization problem above exists for all $k$ and (ii) the iterative scheme converges to the FokkerPlanck equation. This discretization is also known as a minimizing movement scheme.
Existence of a minimizer
Let $\rho^0\in \mathcal{P}_2(\mathcal{M})$, then we show:

$S[\rho]$ is welldefined and has range $(\infty,\infty]$ and $\exists \alpha < 1$, $C<\infty$ such that $S[\rho]\geq C (m_2[\rho] + 1)^\alpha$ where $m_2$ is the second moment.
Proof:
Essentially Holder’s inequality.

Let $m_p[\rho] = \int \xb^p \rho(\xb) d\xb$ denote the $p$th moment of $\rho$ (assuming it exists). The iterates are bounded below:
An elementary calculation using the inequality
$$ y^2 \leq 2x^2+2xy^2 $$
to show that
$$
2 W_2^2(\rho_0, \rho_1) \geq m_2[\rho_1]  2 m_2[\rho_0]
$$
leads to
$$
\begin{aligned}
\frac12 W_2^2(\rho, \rho^{(k1)}) + \epsilon E[\rho] + \epsilon S[\rho] &\geq \frac14 m_2[\rho]  \frac\epsilon2 m_2[\rho^{(k1)}] + \epsilon S[\rho] \
&\geq \frac14 m_2[\rho]  \frac\epsilon2 m_2[\rho^{(k1)}]  C (m_2[\rho]+1)^\alpha
\end{aligned}
$$
and so we know that the iterates are bounded below (since we know the moments are finite). Let $\rho_\nu$ be a sequence of minimizers as defined in the proximal scheme. The remaining steps is showing that the minimizers converge weakly to the iterates in $L^1$.
$$
S[\rho^{(k)}] \leq \lim_{\nu \to\infty} \inf S[\rho_\nu] \
E[\rho^{(k)}] \leq \lim_{\nu \to\infty} \inf E[\rho_\nu] \
W_2^2(\rho^{(k1)},\rho^{(k)}) \leq \lim_{\nu \to\infty} \inf W_2^2(\rho^{(k1)},\rho_\nu) \
$$
Let’s carry out one such verification: Let $p_\nu$ such that $\int \xb\yb^2 p_\nu (d\xb, d\yb) < W_2^2(\rho^{(k1)}, \rho_\nu) + \frac{1}{\nu}$ and note that $p_\nu$ converges weakly to a measure $p\in \mathcal{P}(\rho^{(k1)}, \rho^{(k)})$. Using compact indicator sets, we can use the monotone convergence theorem to show that as the size of the sets grows, we recover the inequality above.
 The convexity of the entropy, the convexity of $W_2^2(\rho, \cdot)$ and the linearity of the potential, ensures that there is a unique minimizer.
Interpolation to a weak solution
Let us state precisely the theorem in JKO:
Theorem [JKO] Consider an initial condition $\rho_{\rm init} \in \mathcal{P}_2(\mathcal{M})$ which is welldefined at time $t=0$. Let $\epsilon>0$ and consider the sequence of minimizers ${ \rho^{(k), \epsilon}}_{k\in \mathbb{N}}$, each of which we take to be a solution of $ \eqref{eq:wprox}$. Then we define the interpolation $\rho^{\epsilon}$ as $$ \rho^\epsilon_t = \rho^{(k), \epsilon}\quad \textrm{for}\quad t\in[k\epsilon, (k+1)\epsilon]. $$ Then, as $\epsilon\downarrow 0$, $\rho^\epsilon_t \rightharpoonup \rho_t $ weakly in $L^1$ for all $t\geq 0$, where $\rho$ solves $$ \partial_t \rho_t = \grad\cdot \left( \rho \grad V\right) + \Delta \rho, \quad \rho_0 = \rho_{\rm init}. $$ The main idea of the proof is to show that the first variation of the free energy leads to a time discrete scheme that converges to the PDE. In essence, this proof requires ensuring that $$ \rho^{\epsilon} \grad \left( \frac{\delta \mathcal{F}}{\delta \rho}\rho^{\epsilon}\right) $$ makes sense as we take the limit $\epsilon\to 0$ with $k\epsilon$ large. Let $\rho_j$ be a sequence converging weakly to a $\rho$ solving the PDE. For the potential, $\rho_j \rightharpoonup \rho$ clearly implies $$ \rho^{\epsilon} \grad \left( \frac{\delta E}{\delta \rho}\rho^{\epsilon}\right) \rightharpoonup \rho \grad V $$ because the potential is linear in $\rho$. The entropy term is not much worse because $$ \rho \grad \left( \frac{\delta S}{\delta \rho}\rho \right) = \grad \rho $$ which again yields straightforward convergence in the weak sense.