上篇文章中,我们从Fisher信息矩阵(FIM)的定义出发,推导出Fisher矩阵与KL散度的关系,并建立如下结论:FIM可以作为概率模型的参数空间的一种黎曼度量。在本篇文章中,我们利用上篇得到的结论,推导自然梯度中为何引入FIM来修正梯度方向,并介绍自然梯度的一些性质。
梯度下降:欧氏距离下的最速下降
考虑一个最优化任务($f:\Theta\to\mathbb{R}$
):
$$ \underset{\theta}{\operatorname{min}} f(\theta) $$
最常见的一阶优化方法是梯度下降/steepest descent:
$$ \theta^+=\theta-\eta\nabla_\theta f $$
其中$\eta$
是学习率。这里的「steepest」指的是在约束欧氏距离定义下的步长在极小范围内时,选取梯度的负方向能最大化一步之内目标函数下降的程度。
$$ \lim_{\epsilon\to 0}\frac{1}{\epsilon}\left(\underset{\delta:\|\delta\|\le \epsilon}{\operatorname{argmin}} f(\theta+\delta)\right)=-\frac{\nabla_\theta f}{\|\nabla_\theta f\|}\tag{1} $$
proof
对极限内的目标函数做一阶泰勒展开:
$$ \begin{align} \underset{\delta:\|\delta\|\le \epsilon}{\operatorname{argmin}} f(\theta+\delta) &\approx \underset{\delta:\|\delta\|\le \epsilon}{\operatorname{argmin}} f(\theta) + \nabla_\theta f^\top\delta\\ &=\underset{\delta:\|\delta\|\le \epsilon}{\operatorname{argmin}} \nabla_\theta f^\top\delta \end{align} $$
我们将约束条件稍加改写:
$$ \underset{\delta}{\min} \nabla_\theta f^\top\delta\quad\text{s.t. }\|\delta\|^2\le \epsilon^2 $$
定义拉格朗日函数
$$ \mathcal{L}(\delta, \lambda):= \nabla_\theta f^\top\delta + \lambda (\|\delta\|^2-\epsilon^2) $$
根据KKT条件:
$$ \begin{align} &\nabla_\delta\mathcal{L}(\delta, \lambda) = 0 &\triangleright\text{Stationarity}\\ &\lambda(\|\delta\|^2-\epsilon^2)=0 &\triangleright\text{Complementary slackness}\\ &\|\delta\|^2-\epsilon^2 \le0&\triangleright\text{Primal feasibility}\\ &\lambda\ge 0&\triangleright\text{Dual feasibility}\\ \end{align} $$
根据驻点条件得到
$$ \begin{align} &\nabla_\theta f+2\lambda\delta=0\\ &\delta = -\frac{1}{2\lambda}\nabla_\theta f\\ \end{align} $$
代入互补松弛条件(这里$\lambda=0$
可以排除,因为会造成$\delta$
为unbounded):
$$ \begin{align} \lambda&\left(\left\|-\frac{1}{2\lambda}\nabla_\theta f\right\|^2-\epsilon^2\right)=0\\ &\lambda = \frac{1}{2\epsilon}\|\nabla_\theta f\|\\ \end{align} $$
带回驻点条件:
$$ \delta^* = -\epsilon \frac{\nabla_\theta f}{\|\nabla_\theta f\|} $$
将$\delta^*$
代入原极限表达式:
$$ \lim_{\epsilon\to 0}\frac{1}{\epsilon}\left(\underset{\delta:\|\delta\|\le \epsilon}{\operatorname{argmin}} f(\theta+\delta)\right)=-\frac{\nabla_\theta f}{\|\nabla_\theta f\|} $$
注意到在上述的「steepest」的求解中,对步长的约束条件$\|\delta\|\le\epsilon$
基于欧几里得距离,这里隐含了如下假设:(1)参数空间是标准欧几里得空间;(2)参数构成一组正交归一(orthonormal)的坐标系统。当这个假设不能很好满足的时候,梯度下降的最速性质可能会大打折扣。
为了说明这个问题,我们考虑一个二维的二次型函数$f(\boldsymbol{x})=\boldsymbol{x}^\top\boldsymbol{Ax}$
,我们令
$$ \boldsymbol{A} = \left[\begin{matrix}1 & 0.5\\0.5 & 2\end{matrix}\right] $$
这个函数的等高线图是一个典型的椭圆型,在这个例子中,参数空间不是标准欧氏空间,而是被$\boldsymbol{A}$
定义的椭球几何所支配。如下图所示,使用标准的梯度下降时,由于梯度方向并不指向最低点,因此优化路径是一条曲线,或者(当学习率过大时)呈Z字形。图片右侧是使用自然梯度的优化路径,后面我们会推导自然梯度的形式。
自然梯度:黎曼距离下的最速下降
在揭示了梯度下降的可能问题之后,我们将公式$(1)$
中的约束做如下的泛化(差异处我们用蓝色做了区分):
$$ \lim_{\epsilon\to 0}\frac{1}{\epsilon}\left(\underset{\delta: {\color[rgb]{0, 0.5, 0.8}\delta^\top G(\theta) \delta \le \epsilon^2}}{\operatorname{argmin}} f(\theta+\delta)\right)=?\tag{2} $$
这里我们引入了局部的度量张量$G(\theta)$
,上篇文章已经简要介绍过黎曼度量和其定义的线元(局部微小变化的长度)
$$ |\delta|^2 = \delta^\top G(\theta) \delta $$
对$(2)$
式做推导,得到的结果就是自然梯度的方向.
$$ \lim_{\epsilon\to 0}\frac{1}{\epsilon}\left(\underset{\delta: {\color[rgb]{0, 0.5, 0.8}\delta^\top G(\theta) \delta \le \epsilon^2}}{\operatorname{argmin}} f(\theta+\delta)\right)=-CG(\theta)^{-1}\nabla_\theta f \tag{3} $$
其中$C$
是某个常数,可以被吸收到学习率中。这里的证明框架与梯度下降是基本一致的,只不过对约束条件做了一定修改(高亮为蓝色):
proof
对极限内的目标函数做一阶泰勒展开:
$$ \begin{align} \underset{\delta: {\color[rgb]{0, 0.5, 0.8}\delta^\top G(\theta) \delta \le \epsilon^2}}{\operatorname{argmin}} f(\theta+\delta) &\approx \underset{\delta: {\color[rgb]{0, 0.5, 0.8}\delta^\top G(\theta) \delta \le \epsilon^2}}{\operatorname{argmin}} f(\theta) + \nabla_\theta f^\top\delta\\ &=\underset{\delta: {\color[rgb]{0, 0.5, 0.8}\delta^\top G(\theta) \delta \le \epsilon^2}}{\operatorname{argmin}} \nabla_\theta f^\top\delta \end{align} $$
定义拉格朗日函数
$$ \mathcal{L}(\delta, \lambda):= \nabla_\theta f^\top\delta + \lambda ({\color[rgb]{0, 0.5, 0.8}\delta^\top G(\theta) \delta} -\epsilon^2) $$
根据KKT条件:
$$ \begin{align} &\nabla_\delta\mathcal{L}(\delta, \lambda) = 0 &\triangleright\text{Stationarity}\\ &\lambda ({\color[rgb]{0, 0.5, 0.8}\delta^\top G(\theta) \delta} -\epsilon^2)=0 &\triangleright\text{Complementary slackness}\\ &{\color[rgb]{0, 0.5, 0.8}\delta^\top G(\theta) \delta} -\epsilon^2 \le0&\triangleright\text{Primal feasibility}\\ &\lambda\ge 0&\triangleright\text{Dual feasibility}\\ \end{align} $$
根据驻点条件得到
$$ \begin{align} &\nabla_\theta f+2\lambda {\color[rgb]{0, 0.5, 0.8}G(\theta)}\delta=0\\ &\delta = -\frac{1}{2\lambda} {\color[rgb]{0, 0.5, 0.8}G(\theta)^{-1}}\nabla_\theta f\\ \end{align} $$
代入互补松弛条件($\lambda=0$
可以排除,因为会造成$\delta$
为unbounded):
$$ \begin{align} {\color[rgb]{0, 0.5, 0.8}\left(-\frac{1}{2\lambda} G(\theta)^{-1}\nabla_\theta f\right)^\top }&{\color[rgb]{0, 0.5, 0.8}G(\theta)\left(-\frac{1}{2\lambda} G(\theta)^{-1}\nabla_\theta f\right)}-\epsilon^2=0\\ \lambda &= \frac{1}{2\epsilon}{\color[rgb]{0, 0.5, 0.8}\sqrt{ \nabla_\theta f^\top G(\theta)^{-1}\nabla_\theta f }}\\ \end{align} $$
带回驻点条件:
$$ \delta^* = -\epsilon\frac{{\color[rgb]{0, 0.5, 0.8}G(\theta)^{-1}}\nabla_\theta f}{{\color[rgb]{0, 0.5, 0.8}\sqrt{ \nabla_\theta f^\top G(\theta)^{-1}\nabla_\theta f }}} $$
将$\delta^*$
代入原极限表达式:
$$ \begin{align} \lim_{\epsilon\to 0}\frac{1}{\epsilon}\left(\underset{\delta: {\color[rgb]{0, 0.5, 0.8}\delta^\top G(\theta) \delta \le \epsilon^2}}{\operatorname{argmin}} f(\theta+\delta)\right)&=-\frac{{\color[rgb]{0, 0.5, 0.8}G(\theta)^{-1}}\nabla_\theta f}{{\color[rgb]{0, 0.5, 0.8}\sqrt{ \nabla_\theta f^\top G(\theta)^{-1}\nabla_\theta f }}}\\ &=-C{\color[rgb]{0, 0.5, 0.8}G(\theta)^{-1}}\nabla_\theta f\\ \end{align} $$
在上述结论中,我们实质上是对标准的梯度下降方向应用了度量张量的逆$G(\theta)^{-1}$
,从而修正了梯度的方向(可以将这个矩阵叫做pre-conditioner)。
在机器学习中,我们常关注的是概率模型的最大似然优化问题,在自然梯度(一):Fisher信息矩阵作为黎曼度量中,我们已经建立了Fisher信息矩阵是给定概率分布族的参数空间的黎曼度量张量这一结论。如果我们需要优化的函数是一个概率似然函数$\ell(\theta):=\log p(x|\theta)$
,则自然梯度可以直接由Fisher信息矩阵作为pre-conditioner
$$ \begin{align} \lim_{\epsilon\to 0}\frac{1}{\epsilon}\left(\underset{\delta: {\color[rgb]{0, 0.5, 0.8}D_{\text{KL}}\left(p(x|\theta)\|p(x|\theta+\delta)\right) \le \epsilon^2}}{\operatorname{argmin}} \ell(\theta+\delta)\right)&\approx\lim_{\epsilon\to 0}\frac{1}{\epsilon}\left(\underset{\delta: {\color[rgb]{0, 0.5, 0.8}\delta^\top F(\theta) \delta \le \epsilon^2}}{\operatorname{argmin}} \ell(\theta+\delta)\right)\\ &=-CF(\theta)^{-1}\nabla_\theta \ell \end{align} $$
对应的参数更新公式为
$$ \theta^+=\theta-\eta F(\theta)^{-1}\nabla_\theta \ell(\theta|x) $$
拓展到判别模型
在常见的监督学习设定下,我们学习的是一个判别模型,优化目标是一系列条件概率的联合对数似然函数,其中每个输入$\boldsymbol{x}^{(i)}$
对应一个条件概率分布
$$ \ell(\theta)=-\frac{1}{n} \sum_i^n \left[\log p_\theta(y^{(i)}|\boldsymbol{x}^{(i)})\right] $$
相应地,约束条件需要更改为在每个条件概率的KL散度的期望
$$ \mathbb{E}_{x\sim\tilde{q}(x)}\left[D_{\text{KL}}\left(p(y|x;\theta)\|p(y|x;\theta+\delta)\right)\right] \le \epsilon^2 $$
这里的$\tilde{q}(x)$
是输入数据的真实分布或替代分布(与真实分布接近)。这个约束条件对应的Fisher信息矩阵的形式为
$$ F(\theta)=\mathbb{E}_{x\sim\tilde{q}(x)}\left[ \mathbb{E}_{y\sim p(y|x;\theta)}\left[ \nabla_\theta \log p(y|x;\theta)\nabla_\theta \log p(y|x;\theta)^\top \right] \right]\tag{4} $$
自然梯度的特性
与二阶优化的联系与区别
自然梯度的一般形式中,使用$G(\theta)^{-1}$
作为梯度的pre-conditioner。如果把自然梯度看做一个一般的框架(而不仅仅考虑概率模型),那么当优化目标满足一定条件(e.g.,凸函数)时,二阶优化可以看做是选取Hessian作为自然梯度的度量张量。
对于常见的概率模型框架(优化对数似然损失),我们选取FIM作为度量张量,可以带来与二阶优化类似的性质,例如,在函数流形的局部曲率比较小的时候(plateau),自然梯度会将更新步长拉得比较大,从而可能有助于快速离开plateau。不过也需要注意,这里的曲率定义在模型的函数流形上,而不是最终的损失函数定义的函数流形上。
FIM相比Hessian具有一些比较好的特性。一方面,FIM是一个协方差矩阵,它总是半正定的,而Hessian则不然(非正定矩阵的逆是不稳定的)。另一方面,我们观察$(4)$
中定义的FIM,注意到内层的期望是定义在模型分布$p(y|x;\theta)$
上的,也就是说,估计一个FIM只需要输入分布和模型分布,而不需要知道标签的真实分布,这在mini-batch特别小(e.g., online learning)的时候非常方便——我们可以在一个无标注的数据集上估计FIM,然后将得到的统计量与一个有标注的batch数据计算得到的梯度结合更新模型参数。
另外,在特定条件下,自然梯度可以等价于广义-高斯牛顿方法,而后者一般被认为是一个二阶优化方法,可以参考Martens 2020.。
模型KL约束
使用FIM的自然梯度通过约束模型在一步更新前后的KL散度得到的「最优」方向,这种约束与模型的参数化方式无关——无论什么样的模型,一步更新的结果都是恒定的KL散度变化约束。在模型的分布距离约束下,优化过程中每一步更新后,模型的分布都不会有非常剧烈的变化,这构成了一种「平滑」的效应,Pascanu & Bengio 2014.认为这一定程度上可以防止过拟合。
应用限制:复杂度考虑
到目前为止,自然梯度仍然没有在深度学习中得到广泛应用。自然梯度需要计算FIM,对于包含$M$
个参数的模型而言,FIM的空间复杂度为$\mathcal{O}(M^2)$
,对于现在的神经网络而言,这是一个不小的负担——一阶优化方法只需要$\mathcal{O}(M)$
的优化器状态。另外,对一个大矩阵求逆也需要比较大的计算复杂度。将自然梯度推广到大模型中需要引入FIM的结构假设(e.g., 分块对角)。
总结
在自然梯度的两篇文章中,我们从Fisher信息矩阵(FIM)的定义出发,将FIM与概率模型的参数空间的黎曼度量建立联系。在此基础上,我们推导了自然梯度中为何引入FIM来修正梯度方向,并讨论了自然梯度的特性、与二阶优化的联系与区别、以及应用的限制。