PKU Summer School 2025 | 4 从理论到算法:扩散过程的离散化

Abstract.

扩散过程的稳态分布与优化问题

扩散过程的稳态分布究竟给我们带来什么?考虑光滑非凸函数 $f$ 无约束优化问题:

$$
\mathrm{minimize} \quad f(\boldsymbol{x})
$$

考虑跑一个 $f$ 给出的(带温度的)Langevin 扩散过程来进行采样:

$$
\mathrm{d}\boldsymbol{X}_t = -\nabla f(\boldsymbol{X}_t)\mathrm{d}t + \sqrt{\frac{2}{\beta}}\mathrm{d}\boldsymbol{B}_t
$$

可以算出该扩散过程的稳态是

$$
\pi(\boldsymbol{x}) = \frac{\exp(-\beta f(\boldsymbol{x}))}{\int \exp(-\beta f(\boldsymbol{y}))\mathrm{d}\boldsymbol{y}}
$$

方便起降将归一化因子记作 $z_\beta$。换言之,$f(\boldsymbol{x})$ 越小,从中采出 $\boldsymbol{x}$ 的概率越大。现在来考察采样的效果:希望分析 $\mathbb{E}_\pi[f(\boldsymbol{x})] - f_\min$。

$$
\begin{aligned}
\int \pi(\boldsymbol{x})f(\boldsymbol{x})\mathrm{d}\boldsymbol{x} &= \int \pi(\boldsymbol{x})\left(-\frac{1}{\beta}\log \pi(\boldsymbol{x}) - \frac{1}{\beta}z_{\beta}\right)\mathrm{d}\boldsymbol{x} \\
&= \frac{1}{\beta}\mathrm{Ent}[\pi(\boldsymbol{x})] - \frac{1}{\beta} z_{\beta}
\end{aligned}
$$

第一项可以用最大熵估计(熟知正态分布熵最大)证明其不超过 $\frac{d}{2}\log \mathbb{E}_\pi[\boldsymbol{x}^2]$。据说这个很小,但我算不来。

第二项可以用光滑性硬算一下

$$
\begin{aligned}
z_\beta &= f_\min - \frac{1}{\beta}\log \int e^{-\beta (f(\boldsymbol{z}) - f_\min)}\mathrm{d}\boldsymbol{z} \\
&= f_\min - \frac{1}{\beta}\log \int e^{-\frac{\beta}{2}L(\boldsymbol{z} - \boldsymbol{z}^*)^2}\mathrm{d}\boldsymbol{z} \\
&= f_\min + \frac{d}{\beta}\log(\cdots)
\end{aligned}
$$

因此误差基本上是 $\mathbb{E}_\pi[f(\boldsymbol{x})] - f_\min \leq \frac{d}{\beta}\log(\cdots)$ 级别,这里 $\cdots$ 里面是某些高斯积分的结果,总之为常数。

前向欧拉方法

在上述推导中我们知道扩散过程确实可以用来做优化问题。但是为了用计算机模拟扩散过程,(除 Ornstein-Uhlembeck 过程有美观的闭式解外)需将其离散化。离散化之后,随机过程发生了变化,但是仍然希望它(在 Wasserstein 距离等意义下)能收敛于稳态分布。得到的算法的开销可基于如下两个方面:

  • 总时间;
  • 维度。

Langevin 扩散

先考虑

$$
\mathrm{d}\boldsymbol{X}_t = -\nabla f(\boldsymbol{X}_t)\mathrm{d}t + \sqrt 2 \mathrm{d}\boldsymbol{B}_t
$$

其中 $f(\cdot)$ 是一个强凸光滑函数(参数为 $\lambda, L$)。以时间步长 $\eta$ 做离散化:

$$
\widetilde{\boldsymbol{X}}_{(k + 1)\eta} = \widetilde{\boldsymbol{X}}_{k\eta} - \eta\nabla f(\widetilde{\boldsymbol{X}}_{k\eta}) + \sqrt{2\eta} \xi_k
$$

其中 $\xi_k$ 服从标准正态分布 $\mathcal{N}(0, \mathbf{I})$。现在,随机过程发生了变化。为了方便进行耦合,我们写出等价的连续时间过程

$$
\widetilde{\boldsymbol{X}}_{t} = -\nabla f(\widetilde{\boldsymbol{X}}_t)\mathrm{d}t + \sqrt 2 \mathrm{d}\boldsymbol{B}_t \qquad \text{for $t\in [k\eta, (k + 1)\eta)$}
$$

做同步耦合,计算

$$
\begin{aligned}
\frac{\mathrm{d}}{\mathrm{d}t} \mathbb{E}\left[\left\lVert \widetilde{\boldsymbol{X}}_t - \boldsymbol{X}_t \right\rVert^2\right] &=
-2\mathbb{E}\left[\langle\widetilde{\boldsymbol{X}}_t - \boldsymbol{X}_t, \nabla f(\widetilde{\boldsymbol{X}}_{k\eta}) - \nabla f(\boldsymbol{X}_t)\rangle\right] \\
&=
-2\mathbb{E}\left[\langle\widetilde{\boldsymbol{X}}_t - \boldsymbol{X}_t, \nabla f(\widetilde{\boldsymbol{X}}_{t}) - \nabla f(\boldsymbol{X}_t)\rangle\right] -
2\mathbb{E}\left[\langle\widetilde{\boldsymbol{X}}_t - \boldsymbol{X}_t, \nabla f(\widetilde{\boldsymbol{X}}_{k\eta}) - \nabla f(\widetilde{\boldsymbol{X}}_t)\rangle\right]
\end{aligned}
$$

第一项,依强凸性不会超过 $-\lambda \mathbb{E}\left[\left\lVert\boldsymbol{X}_t - \widetilde{\boldsymbol{X}}_t\right\rVert^2\right]$,而第二项

$$
\begin{aligned}
-2\mathbb{E}\left[\langle\widetilde{\boldsymbol{X}}_t - \boldsymbol{X}_t, \nabla f(\widetilde{\boldsymbol{X}}_{k\eta}) - \nabla f(\widetilde{\boldsymbol{X}}_t)\rangle\right]
&\leq 2\sqrt{\mathbb{E}\left[\left\lVert\widetilde{\boldsymbol{X}}_t - \boldsymbol{X}_t\right\rVert^2\right]} \cdot \sqrt{\mathbb{E}\left[\left\lVert\nabla f(\widetilde{\boldsymbol{X}}_{k\eta}) - \nabla f(\widetilde{\boldsymbol{X}}_t)\right\rVert^2\right]} & \color{blue}{\text{(Cauchy-Schwartz)}} \\
&\leq 2L^2 \sqrt{\mathbb{E}\left[\left\lVert\widetilde{\boldsymbol{X}}_t - \boldsymbol{X}_t\right\rVert^2\right]} \cdot {\color{red}\sqrt{\mathbb{E}\left[\left\lVert\widetilde{\boldsymbol{X}}_{k\eta} - \widetilde{\boldsymbol{X}}_t\right\rVert^2\right]}} & \color{blue}{\text{(Smoothness)}}
\end{aligned}
$$

红色部分里面的差实际上来自于一个黎曼积分和一个随机积分

$$
\widetilde{\boldsymbol{X}}_{k\eta} - \widetilde{\boldsymbol{X}}_t = \underbrace{-(k\eta - t) \nabla f(\widetilde{\boldsymbol{X}}_{k\eta})}_{\text{$O(\eta \sqrt d)$ by smoothness}} + \underbrace{\sqrt 2 \int_t^{k\eta} \mathrm{d}\boldsymbol{B}_t}_{\text{$O(\sqrt{\eta d})$ in expectation}}
$$

定义 $\Phi_t = \mathbb{E}\left[\left\lVert\widetilde{\boldsymbol{X}}_t - \boldsymbol{X}_t\right\rVert^2\right]$。综合以上结果有

$$
\begin{aligned}
\frac{\mathrm{d}\Phi_t}{\mathrm{d}t} &\leq -\lambda \Phi_t + 2\sqrt{\Phi_t}\cdot O(\sqrt{\eta d}) \\
&\leq -\lambda \Phi_t + \frac{\lambda}{2}\Phi_t + \frac{2}{\lambda}O(d\eta) & \color{blue}{\text{(AM-GM)}}
\end{aligned}
$$

解上述微分不等式得到

$$
\Phi_t \leq \exp\left(-\frac{\lambda t}{2}\right)\Phi_0 + O\left(\frac{\eta}{\lambda}\right)
$$

因此我们得到了

定理 2.1. 取 $\eta = O\left(\frac{\varepsilon^2}{d}\right)$,$t = \Omega\left(\frac{1}{\lambda}\log \frac{1}{\varepsilon}\right)$,有

$$
\mathcal{W}_2\left(\widetilde{\pi}_t, \pi\right) \leq \varepsilon
$$

因此我们得到了所谓的“朴素 Langevin 算法(Unadjusted Langevin Algorithm,ULA)”。其查询一个正态分布采样器的 oracle 的次数为

$$
O\left(\frac{d}{\varepsilon^2}\log\frac{1}{\varepsilon^2}\right)
$$

补充 2.1. 严格意义上来说,我们还需要算复杂度对 $\kappa = L / \lambda$ 的依赖性。但上述分析中,简便起见将其当作常数。实际上你加进去算也是容易的。

补充 2.1. 回忆随机游走里面有一个经典的操作叫做 Metropolis 过程:在转移之后以概率

$$
\min\left(1, \frac{p(x | y)\pi(x)}{p(y | x)\pi(y)}\right)
$$

接受转移。在 Langevin 过程中,$p(x | y) / p(y | x)$ 是容易求的(两个正态分布的商)。因此可以离散化 Langevin 扩散作为转移概率,运行一个 Metropolis 过程(Metropolis Adjusted Langevin Algorithm,MALA)。文献 [1] 证明了热启动 MALA 在 $\Delta_{TV}$ 意义下达到误差不超过 $\varepsilon$ 只需要 $O(d\log (1 / \varepsilon))$ 次迭代,显著快于 ULA。(似乎对关于 $\kappa$ 的优化也有所改进)

欠阻尼 Langevin 扩散

考虑(其中 $f$ 仍然光滑强凸)

$$
\begin{cases}
\mathrm{d}\boldsymbol{v}_t = -\gamma \boldsymbol{v}_t\mathrm{d}t - u\nabla f(\boldsymbol{x}_t)\mathrm{d}t + \sqrt{2\gamma u}\mathrm{d}\boldsymbol{B}_t \\
\mathrm{d}\boldsymbol{x}_t = \boldsymbol{v}_t\mathrm{d}t
\end{cases}
$$

警告. 下面的推导是课上讲的,skip 了莫名其妙的一步($(\ref{waht})$ 中动能的估计)。但是得到的结果确实是一个和文献 [2] 一样的结果。该文献的方法又不一样,没有时间看。

还是考虑之前的那个函数,并采纳记号 $\boldsymbol{z}_t = \widetilde{\boldsymbol{x}}_t - \boldsymbol{x}_t, \boldsymbol{y}_t = \widetilde{\boldsymbol{v}}_t - \boldsymbol{v}_t$:

$$
\Phi_t = \mathbb{E}\left[\left\lVert\begin{pmatrix}
1 & 1 \\ 1 & 0
\end{pmatrix}\begin{pmatrix}
\widetilde{\boldsymbol{x}}_t - \boldsymbol{x}_t \\ \widetilde{\boldsymbol{v}}_t - \boldsymbol{v}_t
\end{pmatrix}\right\rVert^2\right]
$$

计算

$$
\frac{\mathrm{d}}{\mathrm{d}t} \Phi_t = \mathbb{E}\left[\begin{pmatrix}
\boldsymbol{z}_t + \boldsymbol{y}_t \\ \boldsymbol{z}_t
\end{pmatrix}^\top \begin{pmatrix}
(1 - \gamma) \boldsymbol{y}_t - u(\nabla f(\widetilde{\boldsymbol{x}}_{k\eta}) - \nabla f(\boldsymbol{x}_t)) \\
- \boldsymbol{y}_t
\end{pmatrix}\right]
$$

这里如果是 $\nabla f(\tilde{\boldsymbol{x}}_t) - \nabla f(\boldsymbol{x}_t)$,就和我们在第二节课拿到的东西一模一样(一个 $\boldsymbol{v}^\top \mathbf{Q}_t \boldsymbol{v}$,其中 $\mathbf{Q}$ 是正定矩阵,最小特征值为 $\frac{\lambda}{2L}$ 级别,称为“something nice”),于是

$$
\begin{aligned}
\frac{\mathrm{d}}{\mathrm{d}t} \Phi_t &= - {\color{blue} \text{(something nice)}} + u\mathbb{E}\left[
\begin{pmatrix}
\boldsymbol{z}_t + \boldsymbol{y}_t \\ \boldsymbol{z}_t
\end{pmatrix}^\top
\begin{pmatrix}
-(\nabla f(\widetilde{\boldsymbol{x}}_{k\eta}) - \nabla f(\widetilde{\boldsymbol{x}}_t)) \\ 0
\end{pmatrix}\right] \\
&\leq -\frac{\lambda}{2L}\Phi_t + uL^2\sqrt{\Phi_t \cdot \mathbb{E}\left[\left\lVert \widetilde{\boldsymbol{x}}_{k\eta} - \widetilde{\boldsymbol{x}}_t \right\rVert^2\right]}
\end{aligned}
$$

然后是喜闻乐见的积分环节

$$
\begin{aligned}
\widetilde{\boldsymbol{x}}_t - \widetilde{\boldsymbol{x}}_{k\eta} &= \int_t^{k\eta} \widetilde{\boldsymbol{v}}_t\mathrm{d}t
\end{aligned}
$$

因此

$$
\begin{aligned}
\mathbb{E}\left[\left\lVert\widetilde{\boldsymbol{x}}_t - \widetilde{\boldsymbol{x}}_{k\eta}\right\rVert^2\right]
&\leq \eta^2 \max \mathbb{E}\left[\widetilde{\boldsymbol{v}}^2\right]
\end{aligned} \label{waht}
$$

可以证明,$\mathbb{E}\left[\widetilde{\boldsymbol{v}}^2\right] = O(d)$。这倒是很符合直觉。但我想了半天没想明白具体怎么证。到这里和 Langevin 扩散的区别是常数项从 $O(d\eta)$ 变成了 $O(d\eta^2)$,现在我们拿到了和前面相似的微分不等式,解出

$$
\Phi_t \leq \exp\left(-\frac{\lambda}{2L}t\right)\Phi_0 + O(\eta^2 d)
$$

现在有

定理 2.2. 取 $\eta = O\left(\frac{\varepsilon}{\sqrt{d}}\right)$,$t = \Omega\left(\frac{1}{\lambda}\log(1 / \varepsilon)\right)$,有

$$
\mathcal{W}_2\left(\widetilde{\pi}_t, \pi\right) \leq \varepsilon
$$

得到的算法查询一个正态分布采样器的 oracle 的次数为如下值,比 ULA 收敛更快

$$
O\left(\frac{\sqrt d}{\varepsilon}\log\frac{1}{\varepsilon}\right)
$$

参考资料

  1. Raaz Dwivedi, Yuansi Chen, Martin J. Wainwright, Bin Yu. Log-concave sampling: Metropolis-Hastings algorithms are fast, JMLR (2019)
  2. Xiang Cheng, Niladri S. Chatterji, Peter L. Bartlett, Michael I. Jordan. Underdamped Langevin MCMC: A non-asymptotic analysis, arxiv: 1707.03663