通向最优分布之路:概率空间的最小化
By 苏剑林 | 2024-08-06 | 19210位读者 |当要求函数的最小值时,我们通常会先求导函数然后寻找其零点,比较幸运的情况下,这些零点之一正好是原函数的最小值点。如果是向量函数,则将导数改为梯度并求其零点。当梯度零点不易求得时,我们可以使用梯度下降来逐渐逼近最小值点。
以上这些都是无约束优化的基础结果,相信不少读者都有所了解。然而,本文的主题是概率空间中的优化,即目标函数的输入是一个概率分布,这类目标的优化更为复杂,因为它的搜索空间不再是无约束的,如果我们依旧去求解梯度零点或者执行梯度下降,所得结果未必能保证是一个概率分布。因此,我们需要寻找一种新的分析和计算方法,以确保优化结果能够符合概率分布的特性。
对此,笔者一直以来也感到颇为头疼,所以近来决定”痛定思痛“,针对概率分布的优化问题系统学习了一番,最后将学习所得整理在此,供大家参考。
梯度下降 #
我们先来重温一下无约束优化的相关内容。首先,假设我们的目标是
\begin{equation}\boldsymbol{x}_* = \mathop{\text{argmin}}_{\boldsymbol{x}\in\mathbb{R}^n} F(\boldsymbol{x})\end{equation}
高中生都知道,要求函数的最值,往往先求导在让它等于零来找极值点,这对很多人来说已经成为“常识”。但这里不妨考考各位读者,有多少人能证明这个结论?换句话说,函数的最值为什么会跟“导数等于零”扯上关系呢?
搜索视角 #
我们可以从搜索的视角来探究这个问题。假设我们当前所知的$\boldsymbol{x}$为$\boldsymbol{x}_t$,我们怎么判断$\boldsymbol{x}_t$是不是最小值点呢?这个问题可以反过来思考:如果我们能找到$\boldsymbol{x}_{t+\eta}$,使得$F(\boldsymbol{x}_{t+\eta}) < F(\boldsymbol{x}_t)$,那么$\boldsymbol{x}_t$自然就不可能是最小值点了。为此,我们可以搜索如下格式的$\boldsymbol{x}_{t+\eta}$:
\begin{equation}\boldsymbol{x}_{t+\eta} = \boldsymbol{x}_t + \eta \boldsymbol{u}_t,\quad 0 < \eta \ll 1\end{equation}
当$F(\boldsymbol{x})$足够光滑、$\eta$足够小时,我们认为一阶近似的精度是够用的,于是可以利用一阶近似:
\begin{equation}F(\boldsymbol{x}_{t+\eta}) = F(\boldsymbol{x}_t + \eta \boldsymbol{u}_t) \approx F(\boldsymbol{x}_t) + \eta \boldsymbol{u}_t \cdot \nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t)\end{equation}
只要$\nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t)\neq 0$,我们我们就可以选取$\boldsymbol{u}_t = -\nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t)$,使得
\begin{equation}F(\boldsymbol{x}_{t+\eta}) \approx F(\boldsymbol{x}_t) - \eta \Vert\nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t)\Vert^2 < F(\boldsymbol{x}_t)\end{equation}
这意味着,对于足够光滑的函数,它的最小值只能在满足$\nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t) = 0$的点或者无穷远处取到,这也就是为什么求最值的第一步通常是“导数等于零”。如果$\nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t) \neq 0$,我们总可选择足够小的$\eta$,通过
\begin{equation}\boldsymbol{x}_{t+\eta} = \boldsymbol{x}_t-\eta\nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t)\label{eq:gd}\end{equation}
来得到让$f$更小的点,这便是梯度下降。如果让$\eta\to 0$,我们可以得到ODE:
\begin{equation}\frac{d\boldsymbol{x}_t}{dt} = -\nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t)\end{equation}
这就是《梯度流:探索通向最小值之路》介绍过的“梯度流”,它可以视为我们用梯度下降搜索最小值点的轨迹。
投影下降 #
刚才我们说的都是无约束优化,现在我们来简单讨论梯度下降在约束优化中的一个简单推广。假设我们面临的问题是:
\begin{equation}\boldsymbol{x}_* = \mathop{\text{argmin}}_{\boldsymbol{x}\in\mathbb{X}} F(\boldsymbol{x})\label{eq:c-loss}\end{equation}
其中$\mathbb{X}$是$\mathbb{R}^n$的一个子集,如果是理论分析,通常要给$\mathbb{X}$加上“有界凸集”的要求,但如果是简单了解的话,我们就可以先不管这些细节了。
如果此时我们仍用梯度下降$\eqref{eq:gd}$,那么最大的问题就是无法保证$\boldsymbol{x}_{t+\eta}\in\mathbb{X}$,但其实我们可以多加一步投影运算
\begin{equation}\Pi_{\mathbb{X}} (\boldsymbol{y}) = \mathop{\text{argmin}}_{\boldsymbol{x}\in\mathbb{X}}\Vert\boldsymbol{x}-\boldsymbol{y}\Vert\label{eq:project}\end{equation}
从而形成“投影梯度下降”:
\begin{equation}\boldsymbol{x}_{t+\eta} = \Pi_{\mathbb{X}}(\boldsymbol{x}_t-\eta\nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t))\label{eq:pgd}\end{equation}
说白了,投影梯度下降就是先梯度下降,然后在$\mathbb{X}$中找到跟梯度下降结果最相近的点作为输出,这样就保证了输出结果一定在$\mathbb{X}$内。在《让炼丹更科学一些(一):SGD的平均损失收敛》中我们证明了,在一定假设下,投影梯度下降可以找到约束优化问题$\eqref{eq:c-loss}$的最优解。
从结果来看,投影梯度下降将约束优化$\eqref{eq:c-loss}$转化为了“梯度下降+投影”两步,而投影$\eqref{eq:project}$本身也是一个约束优化问题,虽然优化目标已经固定了,但仍属于未解决的问题,需要具体$\mathbb{X}$具体分析,因此还需要进一步探索下去。
离散分布 #
本文聚焦于概率空间中的优化,即搜索对象必须是一个概率分布,这一节我们先关注离散分布,搜索空间我们记为$\Delta^{n-1}$,它是全体$n$元离散型概率分布的集合,即
\begin{equation}\Delta^{n-1} = \left\{\boldsymbol{p}=(p_1,p_2,\cdots,p_n)\left|\, p_1,p_2,\cdots,p_n\geq 0,\sum_{i=1}^n p_i = 1\right.\right\}\end{equation}
我们的优化目标则是
\begin{equation}\boldsymbol{p}_* = \mathop{\text{argmin}}_{\boldsymbol{p}\in\Delta^{n-1}} F(\boldsymbol{p})\label{eq:p-loss}\end{equation}
拉氏乘子 #
对于等式或不等式约束下的优化问题,标准方法通常是“拉格朗日乘子法”,它将约束优化问题$\eqref{eq:p-loss}$转化为一个弱约束的$\text{min-max}$问题:
\begin{equation}\min_{\boldsymbol{p}\in\Delta^{n-1}} F(\boldsymbol{p}) = \min_{\boldsymbol{p}\in\mathbb{R}^n} \max_{\mu_i \geq 0,\lambda\in\mathbb{R}}F(\boldsymbol{p}) - \sum_{i=1}^n \mu_i p_i + \lambda\left(\sum_{i=1}^n p_i - 1\right)\label{eq:min-max}\end{equation}
注意在这个$\text{min-max}$优化中,我们去掉了$\boldsymbol{p}\in\Delta^{n-1}$这个约束,只是在$\max$这一步有一个比较简单的$\mu_i \geq 0$的约束。怎么证明右边的优化问题等价于左边呢?其实并不难,分三步来理解:
1、我们先要理解右端的$\text{min-max}$含义:$\min$在左,$\max$在右,这意味着我们最终是要求一个尽可能小的结果,但这个目标函数是先要对某些变量取$\max$;
2、当$p_i < 0$是,那么$\max$这一步必然有$\mu_i\to\infty$,此时结果目标函数值是$\infty$,而如果$p_i \geq 0$,那么$\max$这一步就必然有$\mu_i p_i =0$,此时目标函数值是有限的,显然后者更小一点,因此当右端取最优值时$p_i\geq 0$成立,同理我们也可以证明$\sum_{i=1}^n p_i = 1$成立;
3、由第2步的分析可知,当右端取最优值时,必然满足$\boldsymbol{p}\in\Delta^{n-1}$,且多出来的项为零,那么就等价于左边的优化问题。
接下来要用到一个“Minimax定理”:
如果$\mathbb{X},\mathbb{Y}$是两个凸集,$\boldsymbol{x}\in\mathbb{X},\boldsymbol{y}\in\mathbb{Y}$,并且$f(\boldsymbol{x},\boldsymbol{y})$关于$\boldsymbol{x}$是凸函数的(对于任意固定$\boldsymbol{y}$),关于$\boldsymbol{y}$是凹函数的(对于任意固定$\boldsymbol{x}$),那么成立 \begin{equation}\min_{\boldsymbol{x}\in\mathbb{X}}\max_{\boldsymbol{y}\in\mathbb{Y}} f(\boldsymbol{x},\boldsymbol{y}) = \max_{\boldsymbol{y}\in\mathbb{Y}}\min_{\boldsymbol{x}\in\mathbb{X}} f(\boldsymbol{x},\boldsymbol{y})\end{equation}
Minimax定理提供了$\min,\max$可交换的一个充分条件,这里边出现了一个新名词“凸集”,指的是集合内任意两点的加权平均,结果依然在集合内,即
\begin{equation}(1-\lambda)\boldsymbol{x}_1 + \lambda \boldsymbol{x}_2\in \mathbb{X},\qquad\forall \boldsymbol{x}_1,\boldsymbol{x}_2\in \mathbb{X},\quad\forall \lambda\in [0, 1]\end{equation}
由此可见凸集的条件并不是太苛刻,$\mathbb{R}^n,\Delta^{n-1}$都是凸集,还有全体非负数也是凸集,等等。
对于式$\eqref{eq:min-max}$右端的目标函数,它关于$\mu_i,\lambda$是一次函数,因此符合关于$\mu_i,\lambda$是凹函数的条件,并且除开$F(\boldsymbol{p})$外的项关于$\boldsymbol{p}$也是一次的,所以整个目标函数关于$\boldsymbol{p}$的凸性,等价于$F(\boldsymbol{p})$关于$\boldsymbol{p}$的凸性,即如果$F(\boldsymbol{p})$是关于$\boldsymbol{p}$的凸函数,那么式$\eqref{eq:min-max}$的$\min,\max$就可以交换:
\begin{equation}\small\min_{\boldsymbol{p}\in\mathbb{R}^n} \max_{\mu_i \geq 0,\lambda\in\mathbb{R}}F(\boldsymbol{p}) - \sum_{i=1}^n \mu_i p_i + \lambda\left(\sum_{i=1}^n p_i - 1\right) = \max_{\mu_i \geq 0,\lambda\in\mathbb{R}} \min_{\boldsymbol{p}\in\mathbb{R}^n}
F(\boldsymbol{p}) - \sum_{i=1}^n \mu_i p_i + \lambda\left(\sum_{i=1}^n p_i - 1\right)\end{equation}
这样我们就可以先对$\boldsymbol{p}$求$\min$了,这是一个无约束最小化问题,可以通过求解梯度等于零的方程组来完成,结果将带有参数$\lambda$和$\mu_i$,最后通过$p_i \geq 0$、$\mu_i p_i = 0$和$\sum_{i=1}^n p_i = 1$来确定这些参数。
凸集搜索 #
然而,尽管拉格朗日乘子法被视为求解约束优化问题的标准方法,但它并不算直观,而且它只能通过解方程来求得精确解,并不能导出类似梯度下降的迭代逼近算法,因此我们不能满足于拉格朗日乘子法。
从搜索的视角看,求解概率空间中的优化问题的关键,是保证在搜索过程中试探点都在集合$\Delta^{n-1}$内。换句话说,假设当前概率分布为$\boldsymbol{p}_t\in \Delta^{n-1}$,我们怎么构造下一个试探点$\boldsymbol{p}_{t+\eta}$呢?它有两个要求,一是$\boldsymbol{p}_{t+\eta}\in \Delta^{n-1}$,二是可以通过控制$\eta$的大小来控制它跟$\boldsymbol{p}_t$的接近程度。这时候$\Delta^{n-1}$的“凸集”性质就派上用场了,利用这一性质我们可以将$\boldsymbol{p}_{t+\eta}$定为
\begin{equation}\boldsymbol{p}_{t+\eta} = (1-\eta)\boldsymbol{p}_t + \eta \boldsymbol{q}_t,\quad \boldsymbol{q}_t\in \Delta^{n-1}\end{equation}
那么有
\begin{equation}F(\boldsymbol{p}_{t+\eta}) = F((1-\eta)\boldsymbol{p}_t + \eta \boldsymbol{q}_t) \approx F(\boldsymbol{p}_t) + \eta(\boldsymbol{q}_t - \boldsymbol{p}_t)\cdot\nabla_{\boldsymbol{p}_t} F(\boldsymbol{p}_t)\end{equation}
假设一阶近似的精度足够,那么要获得下降最快的方向,就相当于求解
\begin{equation}\mathop{\text{argmin}}_{\boldsymbol{q}_t\in\Delta^{n-1}}\,\boldsymbol{q}_t\cdot\nabla_{\boldsymbol{p}_t} F(\boldsymbol{p}_t) \end{equation}
这个目标函数倒是很简单,答案是
\begin{equation}\boldsymbol{q}_t = \text{onehot}(\text{argmin}(\nabla_{\boldsymbol{p}_t} F(\boldsymbol{p}_t)))\end{equation}
这里$\nabla_{\boldsymbol{p}_t} F(\boldsymbol{p}_t)$是一个向量,对一个向量做$\text{argmin}$指的是找出最小分量的位置。所以,上式也就是说$\boldsymbol{q}_t$是一个one hot分布,其中$1$所在的位置,就是梯度$\nabla_{\boldsymbol{p}_t} F(\boldsymbol{p}_t)$最小的分量所在的位置。
由此可见,概率空间的梯度下降形式是
\begin{equation}\boldsymbol{p}_{t+\eta} = (1 - \eta)\boldsymbol{p}_t + \eta\, \text{onehot}(\text{argmin}(\nabla_{\boldsymbol{p}_t} F(\boldsymbol{p}_t)))\end{equation}
以及$\boldsymbol{p}_t$是$F(\boldsymbol{p}_t)$极小值点的条件是:
\begin{equation}\boldsymbol{p}_t\cdot\nabla_{\boldsymbol{p}_t} F(\boldsymbol{p}_t) = (\nabla_{\boldsymbol{p}_t} F(\boldsymbol{p}_t))_{\min}\label{eq:p-min}\end{equation}
这里对向量的$\min$是指返回最小的那个分量。
一个例子 #
以《通向概率分布之路:盘点Softmax及其替代品》介绍的Sparsemax为例,它的原始定义为
\begin{equation}Sparsemax(\boldsymbol{x}) = \mathop{\text{argmin}}\limits_{\boldsymbol{p}\in\Delta^{n-1}}\Vert \boldsymbol{p} - \boldsymbol{x}\Vert^2\end{equation}
其中$\boldsymbol{x}\in\mathbb{R}^n$。不难发现,从前面投影梯度下降的角度来看,Sparsemax正好是从$\mathbb{R}^n$到$\Delta^{n-1}$的“投影”操作。
我们记$F(\boldsymbol{p})=\Vert \boldsymbol{p} - \boldsymbol{x}\Vert^2$,它对$\boldsymbol{p}$的梯度是$2(\boldsymbol{p} - \boldsymbol{x})$,所以根据式$\eqref{eq:p-min}$,极小值点满足的方程就是
\begin{equation}\boldsymbol{p}\cdot(\boldsymbol{p}-\boldsymbol{x}) = (\boldsymbol{p}-\boldsymbol{x})_{\min}\end{equation}
我们约定$x_i = x_j\Leftrightarrow p_i = p_j$,这里没有加粗的下标如$p_i$表示向量$\boldsymbol{p}$的第$i$个分量(即是一个标量),前一节加粗的下标如$\boldsymbol{p}_t$表示$\boldsymbol{p}$的第$t$次迭代结果(即还是一个向量),请读者细心区分。
在该约定下,由上式可以得到
\begin{equation}p_i > 0 \quad \Leftrightarrow \quad p_i-x_i = (\boldsymbol{p}-\boldsymbol{x})_{\min}\end{equation}
因为$\boldsymbol{p}$可以由$\boldsymbol{x}$确定,所以$(\boldsymbol{p}-\boldsymbol{x})_{\min}$是$\boldsymbol{x}$的函数,我们记为$-\lambda(\boldsymbol{x})$,那么$p_i = x_i - \lambda(\boldsymbol{x})$,但这只是对于$p_i > 0$成立,对于$p_i=0$,我们有$p_i-x_i > (\boldsymbol{p}-\boldsymbol{x})_{\min}$,即$x_i - \lambda(\boldsymbol{x}) < 0$。基于这两点,我们可以统一记
\begin{equation}p_i = \text{relu}(x_i - \lambda(\boldsymbol{x}))\end{equation}
其中$\lambda(\boldsymbol{x})$由$\boldsymbol{p}$的各分量之和为1来确定,其他细节内容请参考《通向概率分布之路:盘点Softmax及其替代品》。
连续分布 #
说完离散分布,接下来我们就转到连续分布了。看上去连续型分布只是离散型分布的极限版本,结果似乎不应该有太大差别,但事实上它们之间的特性有着本质不同,以至于我们需要为连续分布构建全新的方法论。
目标泛函 #
首先我们来说说目标函数。我们知道,描述连续型分布的方式是概率密度函数,所以此时目标函数的输入是一个概率密度函数,而此时的目标函数其实也不是普通的函数了,我们通常称之为“泛函”——从一整个函数到一个标量的映射。换言之,我们需要寻找一个概率密度函数,使得某个目标泛函最小化。
尽管很多人觉得“泛函分析心犯寒”,但实际上我们大部份人都接触过泛函,因为满足“输入函数,输出标量”的映射太多了,比如定积分
\begin{equation}\mathcal{I}[f]\triangleq \int_a^b f(x) dx\end{equation}
就是一个函数到标量的映射,所以它也是泛函。事实上,我们在实际应用中会遇到的泛函,基本上都是由定积分构建出来的,比如概率分布的KL散度:
\begin{equation}\mathcal{KL}[p\Vert q] = \int p(\boldsymbol{x})\log \frac{p(\boldsymbol{x})}{q(\boldsymbol{x})}d\boldsymbol{x}\end{equation}
其中积分默认在全空间(整个$\mathbb{R}^n$)进行。更一般的泛函的被积函数里边可能还包含导数项,如理论物理中的最小作用量:
\begin{equation}\mathcal{A}[x] = \int_{t_a}^{t_b} L(x(t),x'(t),t)dt\end{equation}
而接下来我们要最小化的目标泛函,则可以一般地写成
\begin{equation}\mathcal{F}[p] = \int F(p(\boldsymbol{x}))d\boldsymbol{x}\end{equation}
方便起见,我们还可以定义泛函导数
\begin{equation}\frac{\delta\mathcal{F}[p]}{\delta p}(\boldsymbol{x}) = \frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\end{equation}
紧支撑集 #
此外,我们还需要一个连续型概率空间的记号,它的基本定义是
\begin{equation}\mathbb{P} = \left\{p(\boldsymbol{x}) \,\Bigg|\, p(\boldsymbol{x})\geq 0(\forall\boldsymbol{x}\in\mathbb{R}^n),\int p(\boldsymbol{x})d\boldsymbol{x} = 1\right\}\end{equation}
不难证明,如果概率密度函数$p(\boldsymbol{x})$的极限$\lim_{\Vert\boldsymbol{x}\Vert\to\infty} p(\boldsymbol{x})$存在,那么必然有$\lim_{\Vert\boldsymbol{x}\Vert\to\infty} p(\boldsymbol{x}) = 0$,这也是后面的证明中要用到的一个性质。
然而,可以举例证明的是,并非所有概率密度函数在无穷远处都存在极限。为了避免理论上的困难,我们通常在理论证明时假设$p(\boldsymbol{x})$的支撑集是紧集。这里边又有两个概念:支撑集(Support)和紧集(Compact Set),支撑集指的是让$p(\boldsymbol{x}) > 0$的全体$\boldsymbol{x}$的集合,即
\begin{equation}\text{supp}(p) = \{\boldsymbol{x} | p(\boldsymbol{x}) > 0\}\end{equation}
紧集的一般定义比较复杂,不过在$\mathbb{R}^n$中,紧集等价于有界闭集。所以说白了,$p(\boldsymbol{x})$的支撑集是紧集的假设,直接作用是让$p(\boldsymbol{x})$具有“存在常数$C$,使得$\forall |\boldsymbol{x}| > C$都有$p(\boldsymbol{x}) = 0$”的性质,简化了$p(\boldsymbol{x})$在无穷远处的性态,从根本上避免了$\lim_{\Vert\boldsymbol{x}\Vert\to\infty} p(\boldsymbol{x}) = 0$的讨论。
从理论上来看,这个假设是非常强的,它甚至排除了像正态分布这样的简单分布(正态分布的支撑集是$\mathbb{R}^n$)。不过,从实践上来说,这个假设并不算离谱,因为我们说了如果极限$\lim_{\Vert\boldsymbol{x}\Vert\to\infty} p(\boldsymbol{x})$存在就必然为零,因此在超出一定范围后它就跟等于零没有太大区别了。极限不存在的例子确实有,但一般都需要比较刻意构造,对于我们实际能遇到的数据,基本上都满足极限存在的条件。
旧路不通 #
直觉上,连续分布的优化应该是照搬离散分布的思路,即设$\boldsymbol{p}_{t+\eta}(\boldsymbol{x}) = (1 - \eta)\boldsymbol{p}_t(\boldsymbol{x}) + \eta \boldsymbol{q}_t(\boldsymbol{x})$,因为跟离散分布一样,连续分布的概率密度函数集$\mathbb{P}$同样是一个凸集。现在我们将它代入目标泛函
\begin{equation}\begin{aligned}
\mathcal{F}[p_{t+\eta}] =&\, \int F(p_{t+\eta}(\boldsymbol{x}))d\boldsymbol{x} \\
=&\, \int F((1 - \eta)\boldsymbol{p}_t(\boldsymbol{x}) + \eta \boldsymbol{q}_t(\boldsymbol{x}))d\boldsymbol{x} \\
\approx&\,\int \left[F(p_t(\boldsymbol{x})) + \eta\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\Big(q_t(\boldsymbol{x}) - p_t(\boldsymbol{x})\Big)\right]d\boldsymbol{x}
\end{aligned}\end{equation}
假设一阶近似够用,那么问题转化为
\begin{equation}\mathop{\text{argmin}}_{q_t\in \mathbb{P}}\int\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}q_t(\boldsymbol{x})d\boldsymbol{x}\end{equation}
这个问题倒也是不难解,答案跟离散分布的one hot类似
\begin{equation}q_t(\boldsymbol{x}) = \delta\left(\boldsymbol{x} - \mathop{\text{argmin}}_{\boldsymbol{x}'} \frac{\partial F(p_t(\boldsymbol{x}'))}{\partial p_t(\boldsymbol{x}')}\right)\end{equation}
这里的$\delta(\cdot)$是狄拉克delta函数,表示单点分布的概率密度。
看上去很顺利,然而实际上此路并不通。首先,狄拉克delta函数并不是常规意义下的函数,它是广义函数(也是泛函的一种);其次,如果我们用普通函数的视角去看的话,狄拉克delta函数在某点处具有无穷大的值,而既然是无穷大的值,那么推导过程中的“一阶近似够用”的假设就不可能成立了。
变量代换 #
我们可以考虑继续对上一节的推导做一些修补,比如加上$q_t(\boldsymbol{x}) \leq C$的限制,以获得有意义的结果,然而这种缝缝补补的做法终究显得不够优雅。可是如果不利用凸集的性质,又该如何构建下一步的试探分布$\boldsymbol{p}_{t+\eta}(\boldsymbol{x})$呢?
这时候就要充分发挥概率密度函数的特性了——我们可以通过变量代换,来将一个概率密度函数变换为另一个概率密度函数,这是连续型分布的独有性质。具体来说,如果$p(\boldsymbol{x})$是一个概率密度函数,$\boldsymbol{y}=\boldsymbol{T}(\boldsymbol{x})$是一个可逆变换,那么$p(\boldsymbol{T}(\boldsymbol{x}))\left|\frac{\partial \boldsymbol{T}(\boldsymbol{x})}{\partial\boldsymbol{x}}\right|$同样是一个概率密度函数,其中$|\cdot|$表示矩阵的行列式绝对值。
基于这个特性,我们将下一步要试探的概率分布定义为
\begin{equation}\begin{aligned}
p_{t+\eta}(\boldsymbol{x}) =&\, p_t(\boldsymbol{x} + \eta\boldsymbol{\mu}_t(\boldsymbol{x}))\left|\boldsymbol{I} + \eta\frac{\partial \boldsymbol{\mu}_t(\boldsymbol{x})}{\partial\boldsymbol{x}}\right| \\
\approx &\, \Big[p_t(\boldsymbol{x}) + \eta\boldsymbol{\mu}_t(\boldsymbol{x})\cdot\nabla_{\boldsymbol{x}} p_t(\boldsymbol{x})\Big]\left[1 + \eta\,\text{Tr}\frac{\partial \boldsymbol{\mu}_t(\boldsymbol{x})}{\partial\boldsymbol{x}}\right] \\[3pt]
\approx &\, p_t(\boldsymbol{x}) + \eta\boldsymbol{\mu}_t(\boldsymbol{x})\cdot\nabla_{\boldsymbol{x}} p_t(\boldsymbol{x}) + \eta\, p_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\cdot\boldsymbol{\mu}_t(\boldsymbol{x}) \\[5pt]
= &\, p_t(\boldsymbol{x}) + \eta\nabla_{\boldsymbol{x}}\cdot\big[p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\big] \\
\end{aligned}\end{equation}
同样的结果我们在《生成扩散模型漫谈(十二):“硬刚”扩散ODE》已经推导过,其中行列式的近似展开,可以参考《行列式的导数》一文。
积分变换 #
利用这个新的$p_{t+\eta}(\boldsymbol{x})$,我们可以得到
\begin{equation}\begin{aligned}
\mathcal{F}[p_{t+\eta}] =&\, \int F(p_{t+\eta}(\boldsymbol{x}))d\boldsymbol{x} \\
\approx&\, \int F\Big(p_t(\boldsymbol{x}) + \eta\nabla_{\boldsymbol{x}}\cdot\big[p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\big]\Big)d\boldsymbol{x} \\
\approx&\, \int \left[F(p_t(\boldsymbol{x})) + \eta\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\nabla_{\boldsymbol{x}}\cdot\big[p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\big]\right]d\boldsymbol{x} \\
=&\, \mathcal{F}[p_t] + \eta\int \frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\nabla_{\boldsymbol{x}}\cdot\big[p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\big] d\boldsymbol{x} \\
\end{aligned}\label{eq:px-approx}\end{equation}
接下来需要像《测试函数法推导连续性方程和Fokker-Planck方程》一样,推导一个概率密度相关的积分恒等式。首先我们有
\begin{equation}\begin{aligned}
&\,\int \nabla_{\boldsymbol{x}}\cdot\left[\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})} p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\right] d\boldsymbol{x} \\[5pt]
=&\, \int \frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\nabla_{\boldsymbol{x}}\cdot\big[p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\big] d\boldsymbol{x} + \int \left(\nabla_{\boldsymbol{x}}\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\right)\cdot\big[p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\big] d\boldsymbol{x}
\end{aligned}\end{equation}
根据散度定理,我们有
\begin{equation}\int_{\Omega} \nabla_{\boldsymbol{x}}\cdot\left[\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})} p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\right] d\boldsymbol{x} = \int_{\partial\Omega} \left[\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})} p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\right]\cdot \hat{\boldsymbol{n}} dS\end{equation}
其中$\Omega$是积分区域,在这里是整个$\mathbb{R}^n$,$\partial\Omega$是区域边界,$\mathbb{R}^n$的边界自然是无穷远处,$\hat{\boldsymbol{n}}$是边界的外向单位法向量,$dS$是面积微元。在紧支撑集的假设下,无穷远处$p_t(\boldsymbol{x})=0$,所以上式右端实际上就是零的积分,结果是零。因此我们有
\begin{equation}\int \frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\nabla_{\boldsymbol{x}}\cdot\big[p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\big] d\boldsymbol{x} = - \int \left(\nabla_{\boldsymbol{x}}\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\right)\cdot\big[p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\big] d\boldsymbol{x}\end{equation}
代入式$\eqref{eq:px-approx}$得到
\begin{equation}\mathcal{F}[p_{t+\eta}] \approx \mathcal{F}[p_t] - \eta\int \left(p_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\right)\cdot \boldsymbol{\mu}_t(\boldsymbol{x}) d\boldsymbol{x} \label{eq:px-approx-2}\end{equation}
梯度之流 #
根据式$\eqref{eq:px-approx-2}$,让$\mathcal{F}[p_{t+\eta}] \leq \mathcal{F}[p_t]$的一个简单选择是
\begin{equation}\boldsymbol{\mu}_t(\boldsymbol{x}) = \nabla_{\boldsymbol{x}}\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\end{equation}
相应的迭代格式是
\begin{equation}p_{t+\eta}(\boldsymbol{x}) \approx p_t(\boldsymbol{x}) + \eta\nabla_{\boldsymbol{x}}\cdot\left[p_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\right] \end{equation}
我们还可以取$\eta\to 0$的极限得
\begin{equation}\frac{\partial}{\partial t}p_t(\boldsymbol{x}) = \nabla_{\boldsymbol{x}}\cdot\left[p_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\right] \end{equation}
或者简写成
\begin{equation}\frac{\partial p_t}{\partial t} = \nabla\cdot\left[p_t\nabla\frac{\delta \mathcal{F}[p_t]}{\delta p_t}\right] \end{equation}
这就是《梯度流:探索通向最小值之路》介绍过的Wasserstein梯度流,但这里我们没有引入Wasserstein距离的概念也得到了相同的结果。
由于$p_{t+\eta}(\boldsymbol{x})$是$p_t(\boldsymbol{x})$通过变换$\boldsymbol{x}\to \boldsymbol{x} + \eta \boldsymbol{\mu}_t(\boldsymbol{x})$得到的,所以我们还可以写出$\boldsymbol{x}$的运动轨迹ODE:
\begin{equation}\boldsymbol{x}_t = \boldsymbol{x}_{t+\eta} + \eta \boldsymbol{\mu}_t(\boldsymbol{x}_{t+\eta})\quad\Rightarrow\quad \frac{d\boldsymbol{x}_t}{dt} = -\boldsymbol{\mu}_t(\boldsymbol{x}_t) = -\nabla_{\boldsymbol{x}}\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\end{equation}
这个ODE的意义是,从分布$p_0(\boldsymbol{x})$的采样结果$\boldsymbol{x}_0$出发,按照此ODE运动到$\boldsymbol{x}_t$时,$\boldsymbol{x}_t$所服从的分布正是$p_t(\boldsymbol{x})$。
文章小结 #
本文系统整理了概率空间中目标函数的最小化方法,包括取到极小值的必要条件、类似梯度下降的迭代法等内容,相关结果在最优化、生成模型(尤其是扩散模型)等场景中时有用到。
转载到请包括本文地址:https://kexue.fm/archives/10289
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Aug. 06, 2024). 《通向最优分布之路:概率空间的最小化 》[Blog post]. Retrieved from https://kexue.fm/archives/10289
@online{kexuefm-10289,
title={通向最优分布之路:概率空间的最小化},
author={苏剑林},
year={2024},
month={Aug},
url={\url{https://kexue.fm/archives/10289}},
}
September 19th, 2024
苏老师,文中的这句话
“这时候就要充分发挥概率密度函数的特性了——我们可以通过变量代换,来将一个概率密度函数变换为另一个概率密度函数,这是连续型分布的独有性质。具体来说,如果$p(\boldsymbol{x})$是一个概率密度函数,$\boldsymbol{y}=\boldsymbol{T}(\boldsymbol{x})$是一个可逆变换,那么$p(\boldsymbol{T}(\boldsymbol{x}))\left|\frac{\partial \boldsymbol{T}(\boldsymbol{x})}{\partial\boldsymbol{x}}\right|$同样是一个概率密度函数”
这里是不是写错了,如果是概率密度函数的变换,应该是$p(\boldsymbol{x})\left|\frac{\partial(\boldsymbol{x})}{\partial\boldsymbol{y}}\right|$,这里应该要用到逆变换,
你写的这个应该是积分中的变量替换的公式,下述的公式36是不是要改一下,不知道我的理解是不是正确
苏老师,您是对的,我看错了
September 20th, 2024
苏老师,$p_{t+\eta}(\boldsymbol{x}) =\, p_t(\boldsymbol{x} + \eta\boldsymbol{\mu}(\boldsymbol{x}))\left|\boldsymbol{I} + \eta\frac{\partial \boldsymbol{\mu}(\boldsymbol{x})}{\partial\boldsymbol{x}}\right|$,这里是不是应该换成 $p_{t}(\boldsymbol{x}) =\, p_{t+\eta}(\boldsymbol{x} + \eta\boldsymbol{\mu}(\boldsymbol{x}))\left|\boldsymbol{I} + \eta\frac{\partial \boldsymbol{\mu}(\boldsymbol{x})}{\partial\boldsymbol{x}}\right|$,因为状态是从$\boldsymbol{x}\to \boldsymbol{x} + \eta \boldsymbol{\mu}(\boldsymbol{x})$,即是 $\boldsymbol{x}_{t+\eta} = \boldsymbol{x}_t + \eta \boldsymbol{\mu}(\boldsymbol{x}_t)$?
呃,我这里其实没把它当作变量代换用,而是将它当作一种构造新分布的方法。不过你指出的也对,但真正要改的是最后的运动方程,右端要多加个负号(目前已经加上)
October 22nd, 2024
苏神你好,拜读了你的这篇博客,对于你那个n 元离散型概率分布的定义(10)式感觉困惑,这里如果是n元的,那么每一个元的独立性不会让∑i=1npi=1成立的,这里是否应该是1元的离散型概率分布,只是离散成了n个值来表示。
你这样理解为,这里说的$n$元离散分布就是$(10)$式所定义的东西,而不是你头脑里想的另外一个东西。