从约束视角看深度学习优化若干新进展

在深度学习中最常用的优化方法是梯度下降方法及其变体。在过去很长一段时间中,Adam优化器是NLP社区的默认选择,在ViT出现之后,CV方面的工作也逐渐开始使用Adam和Adam的变体(在ViT之前,一种常见的观点是Adam不适用于Vision任务)。 最近Muon优化器在Kimi的新工作带动下又火了一把,相较于Adam优化器需要同时维护一阶、二阶矩估计量,Muon只需要维护一份梯度的动量估计,因此在大规模训练中有很大优势。最近笔者顺着Muon的reference看了Jeremy Bernstein在优化的一些文章,觉得很有意思,因此写这篇文章梳理一下这一系列工作的部分精要。本文的核心论点是:使用诱导范数来约束梯度更新,可以推导出最近的一些新出现的优化方法,这也可能是未来深度学习优化的一个有潜力的探索方向。 梯度下降(Recap) 当前深度学习优化算法的基石是梯度下降。之前笔者写过一篇拙文(自然梯度(二):黎曼距离下的最速下降)整理过梯度下降的推导,核心的结论是:当我们假设参数空间是一个欧几里得空间、参数的距离可以用欧几里得距离来衡量时,我们在某个点约束$\|\Delta\theta\|_2\le\epsilon(\epsilon>0)$时,$\Delta\theta$取梯度的反方向时可以让目标函数下降最多(具体的证明请参阅上述引文)。 使用梯度下降最大的问题是,它实际上忽略了模型的结构。换句话说,梯度下降相当于将模型所有参数展平为1维向量,并且用向量2范数来衡量每次更新的「步长」。这种抽象是实用的,但是也存在一定的问题。两组参数有可能在欧几里得空间中距离很近,但是诱导的模型输出空间距离很远。造成的结果就是更新的方向实际上不是目标函数下降最快的方向。 这个问题要如何解决呢?在自然梯度(二):黎曼距离下的最速下降中,我们介绍了自然梯度方法,即使用Fisher信息矩阵的逆作为梯度的pre-conditioner来矫正梯度下降的方向,从原理上是使用参数更新前后引导的概率分布的KL散度作为每次更新的步长约束。但是对于常见的深度神经网络来说,这样做仍然是不切实际的,因为FIM是一个$N\times N$的大矩阵(其中$N$是参数量),对于这么大的矩阵存储或求逆都是很难做到的。 诱导范数作为步长约束 是否有一种更「廉价」的方法,可以考虑模型的参数结构,同时将参数的变化对于输出的影响作为约束呢? 幸运的是,对于当下最流行的神经网络(e.g., Transformer)而言,模型往往可以拆解为很多小模块,其中最常见的是Linear模块(线性映射,这里忽略bias term) $$ f(\boldsymbol{x};\boldsymbol{W})=\boldsymbol{Wx},\ \boldsymbol{W}\in\mathbb{R}^{n\times m},\boldsymbol{x}\in \mathbb{R}^{m}\\ $$ 在标准的Transformer中,Attention、FFN、LM分类器都是由Linear模块组成的,Embedding从数学原理上也是输入为one-hot encoding的线性映射。假设现在对于某个Linear模块的参数$\boldsymbol{W}$做$\Delta\boldsymbol{W}$的更新($\boldsymbol{W}'\leftarrow \boldsymbol{W}+\Delta\boldsymbol{W}$),我们需要衡量这个更新对于最终输出的影响是多少(从而可以约束这个影响)。由于神经网络比较复杂,衡量$\Delta\boldsymbol{W}$对于最终目标函数的影响是相对繁琐的,但我们可以退而求其次,衡量$\Delta\boldsymbol{W}$对于这个Linear模块的输出$\boldsymbol{Wx}$的影响。 考虑线性模块的输入与输出空间的距离都使用欧几里得范数$\|\cdot\|_{\ell_2}$衡量,那么这个约束可以通过如下不等式实现 $$ \|\Delta\boldsymbol{W}\boldsymbol{x}\|_{\ell_2} \le {\color[rgb]{0, 0.5, 0.8}{\|\Delta\boldsymbol{W}\|_{\ell_2\to\ell_2}}}\|\boldsymbol{x}\|_{\ell_2} $$ 这里的${\color[rgb]{0, 0.5, 0.8}{\|\Delta\boldsymbol{W}\|_{\ell_2\to\ell_2}}}$是矩阵2范数。这个不等式告诉我们,如果约束了参数更新量的谱范数(不等式右侧),也就约束了更新前后这个线性模块输出的变化量。 假设现在需要优化的神经网络是由一系列的线性模块堆叠组成(e.g., MLP),我们可以参照梯度下降的推导构造如下的更新1 $$ \underset{\Delta\boldsymbol{W}_1,\ldots,\Delta\boldsymbol{W}_L}{\text{arg min}}\left[ \sum_{l=1}^L {\langle \boldsymbol{G}_l, \Delta\boldsymbol{W}_l \rangle}_F + \frac{\lambda}{2}\max_{l=1}^L{\color[rgb]{0, 0.5, 0.8}{\|\Delta\boldsymbol{W}_l\|^2_{\ell_2\to\ell_2}}} \right]\\ $$ 这里$\boldsymbol{G}_l$表示$\boldsymbol{W}_l$对应的梯度矩阵(布局与原参数矩阵相同),${\langle \cdot, \cdot \rangle}_F$表示Frobenius内积(对矩阵而言,逐元素相乘求和)。这里之所以使用$\max_{l=1}^L$(而不是直接求和),是因为我们引入这个约束时希望目标函数在$\Delta\boldsymbol{W}_l$变化下,能够保持平滑的性质2,因此需要bound所有参数矩阵更新量的谱范数的最大值。 我们来逐步推导这个最小值成立时的$\Delta\boldsymbol{W}_1,\ldots,\Delta\boldsymbol{W}_L$取值3。为了方便,把每个$\Delta\boldsymbol{W}_l$拆解成大小和方向两部分:$\Delta\boldsymbol{W}_l=c_l\boldsymbol{T}_l(c_l\triangleq\|\Delta\boldsymbol{W}_l\|_{\ell_2\to\ell_2})$ (为了可读性,下面的$\|\cdots\|$均表示谱范数$\|\cdot\|_{\ell_2\to\ell_2}$) $$ \begin{align} &\underset{\Delta\boldsymbol{W}_1,\ldots,\Delta\boldsymbol{W}_L}{\text{min}}\left[ \sum_{l=1}^L {\langle \boldsymbol{G}_l, \Delta\boldsymbol{W}_l \rangle}_F + \frac{\lambda}{2}\max_{l=1}^L\|\Delta\boldsymbol{W}_l\|^2 \right]\\ &=\underset{c_1,\ldots,c_L\ge 0}{\text{min}}\left[ \sum_{l=1}^L c_l\min_{\|\boldsymbol{T}_l\|=1}{\langle \boldsymbol{G}_l, \boldsymbol{T}_l \rangle}_F + \frac{\lambda}{2}\max_{l=1}^Lc_l^2 \right]\\ &=\underset{c_1,\ldots,c_L\ge 0}{\text{min}}\left[ -\sum_{l=1}^L c_l\max_{\|\boldsymbol{T}_l\|=1}{\langle \boldsymbol{G}_l, \boldsymbol{T}_l \rangle}_F + \frac{\lambda}{2}\max_{l=1}^Lc_l^2 \right]\\ &=\underset{c_1,\ldots,c_L\ge 0}{\text{min}}\left[ -\sum_{l=1}^L c_l \|\boldsymbol{G}_l\|_* + \frac{\lambda}{2}\max_{l=1}^Lc_l^2 \right]\quad\triangleright\|\cdot\|_*\text{表示核范数}\\ &=\underset{\eta\ge 0}{\text{min}}\left[ -\sum_{l=1}^L \eta\|\boldsymbol{G}_l\|_* + \frac{\lambda}{2}\max_{l=1}^L \eta^2 \right]\tag{1}\\ \end{align} $$ ...

2025-03-05 · Tianyang Lin

为什么LLM一般使用较大的权重衰减系数?

最近在阅读Muon is Scalable for LLM Training这篇文章的时候注意到他们使用无权重衰减(weight decay)版本的Muon优化LLM的时候,优化器的收敛优势会随着训练过程逐渐消失,又看到@小明同学在评论区提到的一个细节,很多开源的LLM在技术报告中都提到了使用0.1作为权重衰减的系数,觉得是个比较有意思的发现。结合Kimi的文章中关于bf16的简单陈述,笔者在本文中稍微展开讲下,权重衰减对于LLM的低精度训练中有什么作用。 首先把结论放在前面:除了一般认知中的正则化作用,权重衰减也可能降低精度损失的风险——对于计算机的浮点数而言,绝对值越大,精度越低。对于低精度/混合精度训练而言,使用权重衰减可以控制参数的绝对值范围,从而保证模型参数不落入低精度的数值区间。 浮点数的存储与精度 上述结论主要与浮点数在计算机内的存储形式有关。学过计算机的一些基本课程的读者可能有印象,浮点数的存储是二进制的形式,分为符号位、指数位和尾数位三段。深度学习中常见的浮点数协议(fp32、fp16、bf16、tf32)的区别在于指数位和尾数位的比特数量不同。由于浮点数是一个「指数」的形式,因此它在实数空间的分布是不均匀的。 这里我们考虑规范数的情形(指数位非全0),做一点分析。假设符号位、指数位、尾数位(mantissa)的二进制编码分别是$S$、$E$、$M$,那么对应的浮点数为: $$ \text{value}=(-1)^S\times 1.M\times 2^{E-\text{bias}} $$ 在单精度fp32标准中,$\text{bias}$取${01111111}_{2}=127_{10}$ . fp32浮点数的一个例子 例如在图中的例子中,$S=0$,$E=\underbrace{00\cdots0}_{7\ 0's}1$,$M=\underbrace{00\cdots0}_{22\ 0's}1$,相应的值为 $$ \begin{align} &{-1}^0\times 1.\underbrace{00\cdots0}_{22\ 0's}1_2\times 2^{00000001_2-{01111111}_{2}}\\ &\approx [1.175494490952134\times 10^{-38}]_{10} \end{align} $$ 现在我们来考虑不同的数值范围内的浮点数精度。对于区间范围$[2^{x}, 2^{x+1}],\forall -126\le x\le127$(这里的x已经是经过-bias之后得到的最终指数),我们希望在给定任意浮点数$y\in[2^{x}, 2^{x+1}]$的基础上增加一个最小量$\varepsilon$(即区间内两个浮点数的最小间隔),这个增加的过程是通过操纵二进制编码实现的,那么最小间隔只能是通过在$y$的尾数部分加上 $$ 0.\underbrace{00\cdots0}_{22\ 0's}1 $$ 来实现,对应的最小间隔是 $$ \begin{align} \varepsilon &= 0.\underbrace{00\cdots0}_{22\text{ 0's}}1_2\times 2^{x}\\ &=2^{-23}\times2^{x}=2^{x-23}\\ \end{align} $$ 如果你将上面的公式带入不同的指数位,可以验证与Wikipedia中给出的这张表的Gap是吻合的: 不同指数位下的最小精度,来源:Wikipedia 这个计算可以拓展到其他的精度格式: 对于半精度格式fp16而言,其尾数位有10位,对应的最小间隔是$2^{x-10}$; 对于半精度格式bf16而言,其尾数位有7位,对应的最小间隔是$2^{x-7}$; 从这里也可以看到,bf16相比fp16虽然拓宽了表示范围,但是减少了精度(同样数值范围内的最小间隔更宽了)。 从这里我们得到了一个结论:计算机存储的浮点数之间的最小间隔随着浮点数绝对值数值增加,指数级地增大,换言之,浮点数(绝对值)数值越大,精度越低。并且这个问题对于fp16或bf16格式的浮点数,问题要更加显著。 这个结论的另一个引申的问题是舍入误差,假如一个较大的浮点数和一个较小的浮点数相加,由于浮点数的加法(减法过程相当于取补码后相加,结论是类似的)过程需要先将两个数的指数位对齐,因此绝对值较小的数字的尾数的最后几位数字可能会在加法中丢失。这里我们举一个极端的例子来说明。 假设浮点数存储为fp32格式(8位指数、23位尾数)。 $$ \begin{align} x=(-1)^0\times 1.0_2\times 2^{-1}\\ y=(-1)^0\times 1.0_2\times 2^{-25}\\ \end{align} $$ ...

2025-02-26 · Tianyang Lin

Tensor Product Attention (TPA) 导读

最近Deepseek比较出圈,连带着里面用到的MLA也被讨论得更多了。MLA无疑是一个非常出色的attention改进,但是由于它的KV cache设计,不能很好地兼容RoPE,因此作者们使用了decoupled RoPE这样的「补丁」来引入位置关系,这无疑也增加了实现的复杂度。 最近,修改注意力KV Cache这一线工作又增添了TPA这个新成员,笔者觉得这篇文章比较有趣,因此希望写一篇简短的导读。之所以叫「导读」,是因为本文不打算太深入文章的formulation和实验,而是从笔者自己的视角出发介绍文章的一些重点贡献。。 MHA的拆解 最标准的多头注意力(Vaswani的版本),大致可以拆解成3个步骤(这里默认讨论自注意力)1: $$ \begin{aligned} \text{step 1 }& \begin{cases} \boldsymbol{q}_i^{(h)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(h)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(h)}\in\mathbb{R}^{d\times d_k} \\ \boldsymbol{k}_i^{(h)} = \boldsymbol{x}_i\boldsymbol{W}_k^{(h)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{(h)}\in\mathbb{R}^{d\times d_k}\\ \boldsymbol{v}_i^{(h)} = \boldsymbol{x}_i\boldsymbol{W}_v^{(h)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(h)}\in\mathbb{R}^{d\times d_v} \end{cases} \\ \text{step 2 }& \begin{cases} \boldsymbol{o}_t^{(h)} = \text{Attention}\left(\boldsymbol{q}_t^{(h)}, \boldsymbol{k}_{\leq t}^{(h)}, \boldsymbol{v}_{\leq t}^{(h)}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(h)} \boldsymbol{k}_i^{(h)}{}^{\top}\right)\boldsymbol{v}_i^{(h)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(h)} \boldsymbol{k}_i^{(h)}{}^{\top}\right)} \\ \end{cases} \\ \text{step 3 }& \begin{cases} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(H)}\right] \end{cases} \\ \end{aligned}\\ $$ 这里$\boldsymbol{x}_1,\boldsymbol{x}_2,\cdots,\boldsymbol{x}_l$是输入向量,上标$h\in \{1,\ldots,H\}$表示注意力头,$d,d_k,d_v$分别表示输入维度、key维度、value维度。在第一步中,我们将每个token的表征独立地线性投射到不同的自空间上,第二步是标准的点积注意力,第三步是拼接(后续再做一步线性变换)。 我们重点来审视一下第一个步骤,这里由于每个token的表征是独立操作的(类似FFN),因此我们可以将每个token的表征看做一个样本点,这里做的事情就是将一个$d$维度的向量,转化成3个矩阵$\tilde{\boldsymbol{Q}}\in\mathbb{R}^{H\times d_k},\tilde{\boldsymbol{K}}\in\mathbb{R}^{H\times d_k},\tilde{\boldsymbol{V}}\in\mathbb{R}^{H\times d_v}$,每一个头对应这个矩阵中的一行。接着我们逐行计算每行对应的三组向量的点积注意力,就构成了标准的多头注意力。在标准实现中,这个变换是通过参数矩阵的线性变换+ reshape实现的。 ...

2025-02-13 · Tianyang Lin

自然梯度(二):黎曼距离下的最速下降

上篇文章中,我们从Fisher信息矩阵(FIM)的定义出发,推导出Fisher矩阵与KL散度的关系,并建立如下结论:FIM可以作为概率模型的参数空间的一种黎曼度量。在本篇文章中,我们利用上篇得到的结论,推导自然梯度中为何引入FIM来修正梯度方向,并介绍自然梯度的一些性质。 梯度下降:欧氏距离下的最速下降 考虑一个最优化任务($f:\Theta\to\mathbb{R}$): $$ \underset{\theta}{\operatorname{min}} f(\theta) $$ 最常见的一阶优化方法是梯度下降/steepest descent: $$ \theta^+=\theta-\eta\nabla_\theta f $$ 其中$\eta$是学习率。这里的「steepest」指的是在约束欧氏距离定义下的步长在极小范围内时,选取梯度的负方向能最大化一步之内目标函数下降的程度。 $$ \lim_{\epsilon\to 0}\frac{1}{\epsilon}\left(\underset{\delta:\|\delta\|\le \epsilon}{\operatorname{argmin}} f(\theta+\delta)\right)=-\frac{\nabla_\theta f}{\|\nabla_\theta f\|}\tag{1} $$ proof ...

2025-02-06 · Tianyang Lin

自然梯度(一):Fisher信息矩阵作为黎曼度量

在一般的梯度下降中,我们认为目标函数梯度的负方向可以最小化一步更新后的目标函数值,这里隐含地假设了参数空间是欧氏空间,且参数构成了一组正交归一的坐标系统。在很多情况下,这一假设是不成立的,作为结果,优化过程的收敛效率可能受到影响。 作为解决这一问题的一种思路,自然梯度使用Fisher信息矩阵(的逆)作为梯度的pre-conditioner来矫正梯度的方向。本文将分为两篇,在第一篇中,我们从Fisher信息矩阵(FIM)的定义出发,推导出Fisher矩阵与KL散度的关系,并建立如下结论:FIM可以作为概率模型的参数空间的一种黎曼度量。在第二篇中,我们推导自然梯度中为何引入FIM来修正梯度方向,以及自然梯度的一些性质。 Score function与FIM 假设我们有一个由$\theta$参数化的概率模型,模型分布为$p(x|\theta)$,记对数似然函数为$\ell(\theta|x):=\log p(x|\theta)$。与对数似然函数相关的有两个定义,score function和fisher information。 定义1(score function):score function $s(\theta|x)$被定义为对数似然函数关于参数$\theta$的梯度 $$ s(\theta|x)=\nabla_\theta \ell(\theta|x) $$ 一些文章会提到score function是用来为参数的好坏打分(score),这是不严谨的。score function中的「score」其实不是为参数打分,而是在Fisher研究的遗传统计问题中给基因异常家庭的「打分」(参见:Interpretation of “score”)。因此,score function只是约定俗成的一种名称,其实质就是似然函数的梯度,描述的是似然函数对于参数变化的敏感程度。 性质1:Score function期望为0 $$ \mathbb{E}_{p(x|\theta)}[s(\theta|x)]=\boldsymbol{0} $$ proof ...

2025-02-05 · Tianyang Lin